Ruurd commited on
Commit
4aa916f
·
1 Parent(s): 460c969
Files changed (1) hide show
  1. app.py +20 -33
app.py CHANGED
@@ -98,41 +98,26 @@ def chat_with_model(messages):
98
  return
99
 
100
  current_id = patient_id.value
101
- if current_id is None:
102
  yield messages
103
  return
104
 
105
- # 🛠 Missing variable initializations
106
  max_new_tokens = 1024
107
  output_text = ""
108
  in_think = False
109
  generated_tokens = 0
110
 
111
- pad_id = current_tokenizer.pad_token_id
112
  eos_id = current_tokenizer.eos_token_id
113
- if pad_id is None:
114
- pad_id = current_tokenizer.unk_token_id or 0
115
-
116
- # Remove the initial welcome if present
117
- filtered_messages = [msg for msg in messages if not (msg["role"] == "assistant" and "Welcome to the Radiologist's Companion" in msg["content"])]
118
 
119
- # Build system context
120
  system_messages = [
121
  {
122
  "role": "system",
123
  "content": (
124
  "You are a radiologist's companion, here to answer questions about the patient and assist in the diagnosis if asked to do so. "
125
  "You are able to call specialized tools. "
126
- "At the moment, you have one tool available: an organ segmentation algorithm for abdominal CTs.\n\n"
127
- "If the user requests an organ segmentation, output a JSON object in this structure:\n"
128
- "{\n"
129
- " \"function\": \"segment_organ\",\n"
130
- " \"arguments\": {\n"
131
- " \"scan_path\": \"<path_to_ct_scan>\",\n"
132
- " \"organ\": \"<organ_name>\"\n"
133
- " }\n"
134
- "}\n\n"
135
- "Once you call the function, the app will execute it and return the result."
136
  )
137
  },
138
  {
@@ -141,8 +126,13 @@ def chat_with_model(messages):
141
  }
142
  ]
143
 
 
 
 
 
144
  full_messages = system_messages + filtered_messages
145
 
 
146
  prompt = format_prompt(full_messages)
147
 
148
  device = torch.device("cuda")
@@ -169,15 +159,13 @@ def chat_with_model(messages):
169
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
170
  thread.start()
171
 
172
- messages = full_messages.copy()
173
- messages.append({"role": "assistant", "content": ""})
174
-
175
- print(messages)
176
 
177
  for token_info in streamer:
178
  token_str = token_info["token"]
179
  token_id = token_info["token_id"]
180
- is_special = token_info["is_special"]
181
 
182
  if token_id == eos_id:
183
  break
@@ -196,27 +184,26 @@ def chat_with_model(messages):
196
 
197
  if "\nUser" in output_text:
198
  output_text = output_text.split("\nUser")[0].rstrip()
199
- messages[-1]["content"] = output_text
200
  break
201
 
202
  generated_tokens += 1
203
  if generated_tokens >= max_new_tokens:
204
  break
205
 
206
- messages[-1]["content"] = output_text
207
 
208
- # Save conversation per patient
209
- patient_conversations[current_id] = messages
210
-
211
- yield messages
212
 
213
  if in_think:
214
  output_text += "*"
215
- messages[-1]["content"] = output_text
216
 
217
  torch.cuda.empty_cache()
218
- messages[-1]["content"] = output_text
219
- return messages
 
220
 
221
 
222
  def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
 
98
  return
99
 
100
  current_id = patient_id.value
101
+ if not current_id:
102
  yield messages
103
  return
104
 
 
105
  max_new_tokens = 1024
106
  output_text = ""
107
  in_think = False
108
  generated_tokens = 0
109
 
110
+ pad_id = current_tokenizer.pad_token_id or current_tokenizer.unk_token_id or 0
111
  eos_id = current_tokenizer.eos_token_id
 
 
 
 
 
112
 
113
+ # --- Build system context
114
  system_messages = [
115
  {
116
  "role": "system",
117
  "content": (
118
  "You are a radiologist's companion, here to answer questions about the patient and assist in the diagnosis if asked to do so. "
119
  "You are able to call specialized tools. "
120
+ "At the moment, you have one tool available: an organ segmentation algorithm for abdominal CTs."
 
 
 
 
 
 
 
 
 
121
  )
122
  },
123
  {
 
126
  }
127
  ]
128
 
129
+ # Remove welcome message (only once shown)
130
+ # filtered_messages = [msg for msg in messages if not (msg["role"] == "assistant" and "Welcome to the Radiologist's Companion" in msg["content"])]
131
+
132
+ # FULL conversation
133
  full_messages = system_messages + filtered_messages
134
 
135
+ # --- Generate from full context
136
  prompt = format_prompt(full_messages)
137
 
138
  device = torch.device("cuda")
 
159
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
160
  thread.start()
161
 
162
+ # Now extend previous messages
163
+ updated_messages = messages.copy()
164
+ updated_messages.append({"role": "assistant", "content": ""})
 
165
 
166
  for token_info in streamer:
167
  token_str = token_info["token"]
168
  token_id = token_info["token_id"]
 
169
 
170
  if token_id == eos_id:
171
  break
 
184
 
185
  if "\nUser" in output_text:
186
  output_text = output_text.split("\nUser")[0].rstrip()
187
+ updated_messages[-1]["content"] = output_text
188
  break
189
 
190
  generated_tokens += 1
191
  if generated_tokens >= max_new_tokens:
192
  break
193
 
194
+ updated_messages[-1]["content"] = output_text
195
 
196
+ patient_conversations[current_id] = updated_messages
197
+ yield updated_messages
 
 
198
 
199
  if in_think:
200
  output_text += "*"
201
+ updated_messages[-1]["content"] = output_text
202
 
203
  torch.cuda.empty_cache()
204
+ updated_messages[-1]["content"] = output_text
205
+ return updated_messages
206
+
207
 
208
 
209
  def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):