Ruurd commited on
Commit
713dc22
·
1 Parent(s): 2ad2507

Change tokenizer selection

Browse files

Add system prompt
Add start message
Add incorporation of patient data

Files changed (1) hide show
  1. app.py +61 -68
app.py CHANGED
@@ -9,51 +9,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
9
  import threading
10
  import queue
11
 
12
- class RichTextStreamer(TextIteratorStreamer):
13
- def __init__(self, tokenizer, prompt_len=0, **kwargs):
14
- super().__init__(tokenizer, **kwargs)
15
- self.token_queue = queue.Queue()
16
- self.prompt_len = prompt_len
17
- self.count = 0
18
-
19
- def put(self, value):
20
- if isinstance(value, torch.Tensor):
21
- token_ids = value.view(-1).tolist()
22
- elif isinstance(value, list):
23
- token_ids = value
24
- else:
25
- token_ids = [value]
26
-
27
- for token_id in token_ids:
28
- self.count += 1
29
- if self.count <= self.prompt_len:
30
- continue # skip prompt tokens
31
- token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
32
- is_special = token_id in self.tokenizer.all_special_ids
33
- self.token_queue.put({
34
- "token_id": token_id,
35
- "token": token_str,
36
- "is_special": is_special
37
- })
38
-
39
- def __iter__(self):
40
- while True:
41
- try:
42
- token_info = self.token_queue.get(timeout=self.timeout)
43
- yield token_info
44
- except queue.Empty:
45
- if self.end_of_generation.is_set():
46
- break
47
-
48
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
49
- import threading
50
-
51
- from transformers import TextIteratorStreamer
52
- import threading
53
-
54
- from transformers import TextIteratorStreamer
55
- import queue
56
-
57
  class RichTextStreamer(TextIteratorStreamer):
58
  def __init__(self, tokenizer, prompt_len=0, **kwargs):
59
  super().__init__(tokenizer, **kwargs)
@@ -108,22 +63,66 @@ def chat_with_model(messages):
108
  max_new_tokens = 1024
109
  generated_tokens = 0
110
 
111
- prompt = format_prompt(messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  device = torch.device("cuda")
113
  current_model.to(device).half()
114
 
115
- # 1. Tokenize prompt
116
  inputs = current_tokenizer(prompt, return_tensors="pt").to(device)
117
  prompt_len = inputs["input_ids"].shape[-1]
118
 
119
- # 2. Init streamer with prompt_len
120
  streamer = RichTextStreamer(
121
  tokenizer=current_tokenizer,
122
  prompt_len=prompt_len,
123
  skip_special_tokens=False
124
  )
125
 
126
- # 3. Build generation kwargs
127
  generation_kwargs = dict(
128
  **inputs,
129
  max_new_tokens=max_new_tokens,
@@ -133,27 +132,20 @@ def chat_with_model(messages):
133
  pad_token_id=pad_id
134
  )
135
 
136
- # 4. Launch generation in a thread
137
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
138
  thread.start()
139
 
140
  messages = messages.copy()
141
  messages.append({"role": "assistant", "content": ""})
142
 
143
- print(f'Step 1: {messages}')
144
-
145
- prompt_text = current_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=False)
146
-
147
  for token_info in streamer:
148
  token_str = token_info["token"]
149
  token_id = token_info["token_id"]
150
  is_special = token_info["is_special"]
151
 
152
- # Stop immediately at EOS
153
  if token_id == eos_id:
154
  break
155
 
156
- # Detect reasoning block
157
  if "<think>" in token_str:
158
  in_think = True
159
  token_str = token_str.replace("<think>", "")
@@ -166,7 +158,6 @@ def chat_with_model(messages):
166
  else:
167
  output_text += token_str
168
 
169
- # Early stopping if user reappears
170
  if "\nUser" in output_text:
171
  output_text = output_text.split("\nUser")[0].rstrip()
172
  messages[-1]["content"] = output_text
@@ -178,34 +169,35 @@ def chat_with_model(messages):
178
 
179
  messages[-1]["content"] = output_text
180
 
181
- print(f'Step 2: {messages}')
182
-
183
  yield messages
184
 
185
  if in_think:
186
  output_text += "*"
187
  messages[-1]["content"] = output_text
188
-
189
- # Wait for thread to finish
190
- # current_model.to("cpu")
191
- torch.cuda.empty_cache()
192
 
 
193
  messages[-1]["content"] = output_text
194
- print(f'Step 3: {messages}')
195
-
196
  return messages
197
 
198
 
199
 
 
200
  # Globals
201
  current_model = None
202
  current_tokenizer = None
203
 
 
 
204
  def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
205
  global current_model, current_tokenizer
206
  token = os.getenv("HF_TOKEN")
207
 
208
- progress(0, desc="Loading tokenizer...")
 
 
 
 
 
209
  current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
210
 
211
  progress(0.5, desc="Loading model...")
@@ -219,6 +211,7 @@ def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
219
  progress(1, desc="Model ready.")
220
  return f"{model_name} loaded and ready!"
221
 
 
222
  # Format conversation as plain text
223
  def format_prompt(messages):
224
  prompt = ""
@@ -239,7 +232,7 @@ model_choices = [
239
  "meta-llama/Llama-3.2-3B-Instruct",
240
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
241
  "google/gemma-7b",
242
- "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
243
  ]
244
 
245
  # Example patient database
@@ -277,7 +270,7 @@ def autofill_patient(patient_key):
277
  return "", "", "", ""
278
 
279
  with gr.Blocks(css=".gradio-container {height: 100vh; overflow: hidden;}") as demo:
280
- gr.Markdown("## Radiologist's Companion")
281
 
282
  default_model = gr.State(model_choices[0])
283
 
 
9
  import threading
10
  import queue
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class RichTextStreamer(TextIteratorStreamer):
13
  def __init__(self, tokenizer, prompt_len=0, **kwargs):
14
  super().__init__(tokenizer, **kwargs)
 
63
  max_new_tokens = 1024
64
  generated_tokens = 0
65
 
66
+ # PREPARE SYSTEM + INITIAL MESSAGES
67
+ system_messages = [
68
+ {
69
+ "role": "system",
70
+ "content": (
71
+ "You are a radiologist's companion, here to answer questions about the patient and assist in the diagnosis if asked to do so. "
72
+ "You are able to call specialized tools. "
73
+ "At the moment, you have one tool available: an organ segmentation algorithm for abdominal CTs.\n\n"
74
+ "If the user requests an organ segmentation, output a JSON object in this structure:\n"
75
+ "{\n"
76
+ " \"function\": \"segment_organ\",\n"
77
+ " \"arguments\": {\n"
78
+ " \"scan_path\": \"<path_to_ct_scan>\",\n"
79
+ " \"organ\": \"<organ_name>\"\n"
80
+ " }\n"
81
+ "}\n\n"
82
+ "Once you call the function, the app will execute it and return the result."
83
+ )
84
+ },
85
+ {
86
+ "role": "system",
87
+ "content": f"Patient Information:\nName: {patient_name.value}\nAge: {patient_age.value}\nID: {patient_id.value}\nNotes: {patient_notes.value}"
88
+ }
89
+ ]
90
+
91
+ # Optional: if you later add available_images, you could append another system message.
92
+
93
+ welcome_message = (
94
+ "**Welcome to the Radiologist's Companion!**\n\n"
95
+ "You can ask me about the patient's medical history or available imaging data.\n"
96
+ "- I can summarize key details from the EHR.\n"
97
+ "- I can tell you which medical images are available.\n"
98
+ "- If you'd like an organ segmentation (e.g. spleen, liver, kidney_left, colon, femur_right) on an abdominal CT scan, just ask!\n\n"
99
+ "**Example Requests:**\n"
100
+ "- \"What do we know about this patient?\"\n"
101
+ "- \"Which images are available for this patient?\"\n"
102
+ "- \"Can you segment the spleen from the CT scan?\"\n"
103
+ )
104
+
105
+ # If it's the first user message (i.e., no assistant yet), prepend welcome
106
+ if len(messages) == 1 and messages[0]['role'] == 'user':
107
+ messages = [{"role": "assistant", "content": welcome_message}] + messages
108
+
109
+ # Merge full conversation
110
+ full_messages = system_messages + messages
111
+
112
+ prompt = format_prompt(full_messages)
113
+
114
  device = torch.device("cuda")
115
  current_model.to(device).half()
116
 
 
117
  inputs = current_tokenizer(prompt, return_tensors="pt").to(device)
118
  prompt_len = inputs["input_ids"].shape[-1]
119
 
 
120
  streamer = RichTextStreamer(
121
  tokenizer=current_tokenizer,
122
  prompt_len=prompt_len,
123
  skip_special_tokens=False
124
  )
125
 
 
126
  generation_kwargs = dict(
127
  **inputs,
128
  max_new_tokens=max_new_tokens,
 
132
  pad_token_id=pad_id
133
  )
134
 
 
135
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
136
  thread.start()
137
 
138
  messages = messages.copy()
139
  messages.append({"role": "assistant", "content": ""})
140
 
 
 
 
 
141
  for token_info in streamer:
142
  token_str = token_info["token"]
143
  token_id = token_info["token_id"]
144
  is_special = token_info["is_special"]
145
 
 
146
  if token_id == eos_id:
147
  break
148
 
 
149
  if "<think>" in token_str:
150
  in_think = True
151
  token_str = token_str.replace("<think>", "")
 
158
  else:
159
  output_text += token_str
160
 
 
161
  if "\nUser" in output_text:
162
  output_text = output_text.split("\nUser")[0].rstrip()
163
  messages[-1]["content"] = output_text
 
169
 
170
  messages[-1]["content"] = output_text
171
 
 
 
172
  yield messages
173
 
174
  if in_think:
175
  output_text += "*"
176
  messages[-1]["content"] = output_text
 
 
 
 
177
 
178
+ torch.cuda.empty_cache()
179
  messages[-1]["content"] = output_text
 
 
180
  return messages
181
 
182
 
183
 
184
+
185
  # Globals
186
  current_model = None
187
  current_tokenizer = None
188
 
189
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, LlamaTokenizer
190
+
191
  def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
192
  global current_model, current_tokenizer
193
  token = os.getenv("HF_TOKEN")
194
 
195
+ progress(0, desc="Loading config...")
196
+ config = AutoConfig.from_pretrained(model_name, use_auth_token=token)
197
+
198
+ progress(0.2, desc="Loading tokenizer...")
199
+
200
+ # Default
201
  current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
202
 
203
  progress(0.5, desc="Loading model...")
 
211
  progress(1, desc="Model ready.")
212
  return f"{model_name} loaded and ready!"
213
 
214
+
215
  # Format conversation as plain text
216
  def format_prompt(messages):
217
  prompt = ""
 
232
  "meta-llama/Llama-3.2-3B-Instruct",
233
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
234
  "google/gemma-7b",
235
+ "mistralai/Mistral-Nemo-Instruct-FP8-2407"
236
  ]
237
 
238
  # Example patient database
 
270
  return "", "", "", ""
271
 
272
  with gr.Blocks(css=".gradio-container {height: 100vh; overflow: hidden;}") as demo:
273
+ gr.Markdown("<h2 style='text-align: center;'>Radiologist's Companion</h2>")
274
 
275
  default_model = gr.State(model_choices[0])
276