Ruurd commited on
Commit
0c58ef6
·
1 Parent(s): 713dc22

Initialize with patient information

Browse files
Files changed (1) hide show
  1. app.py +93 -41
app.py CHANGED
@@ -5,10 +5,54 @@ import torch
5
  import time
6
  import gradio as gr
7
  import spaces
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
9
  import threading
10
  import queue
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class RichTextStreamer(TextIteratorStreamer):
13
  def __init__(self, tokenizer, prompt_len=0, **kwargs):
14
  super().__init__(tokenizer, **kwargs)
@@ -169,6 +213,10 @@ def chat_with_model(messages):
169
 
170
  messages[-1]["content"] = output_text
171
 
 
 
 
 
172
  yield messages
173
 
174
  if in_think:
@@ -182,11 +230,7 @@ def chat_with_model(messages):
182
 
183
 
184
 
185
- # Globals
186
- current_model = None
187
- current_tokenizer = None
188
 
189
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, LlamaTokenizer
190
 
191
  def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
192
  global current_model, current_tokenizer
@@ -198,7 +242,7 @@ def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
198
  progress(0.2, desc="Loading tokenizer...")
199
 
200
  # Default
201
- current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
202
 
203
  progress(0.5, desc="Loading model...")
204
  current_model = AutoModelForCausalLM.from_pretrained(
@@ -225,50 +269,39 @@ def format_prompt(messages):
225
  return prompt
226
 
227
  def add_user_message(user_input, history):
228
- return "", history + [{"role": "user", "content": user_input}]
 
 
 
 
 
229
 
230
- # Curated models
231
- model_choices = [
232
- "meta-llama/Llama-3.2-3B-Instruct",
233
- "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
234
- "google/gemma-7b",
235
- "mistralai/Mistral-Nemo-Instruct-FP8-2407"
236
- ]
237
 
238
- # Example patient database
239
- patient_db = {
240
- "001 - John Doe": {
241
- "name": "John Doe",
242
- "age": "45",
243
- "id": "001",
244
- "notes": "History of chest pain and hypertension. No prior surgeries."
245
- },
246
- "002 - Maria Sanchez": {
247
- "name": "Maria Sanchez",
248
- "age": "62",
249
- "id": "002",
250
- "notes": "Suspected pulmonary embolism. Shortness of breath, tachycardia."
251
- },
252
- "003 - Ahmed Al-Farsi": {
253
- "name": "Ahmed Al-Farsi",
254
- "age": "29",
255
- "id": "003",
256
- "notes": "Persistent migraines. MRI scheduled for brain imaging."
257
- },
258
- "004 - Lin Wei": {
259
- "name": "Lin Wei",
260
- "age": "51",
261
- "id": "004",
262
- "notes": "Annual screening. Family history of breast cancer."
263
- }
264
- }
265
 
266
  def autofill_patient(patient_key):
267
  if patient_key in patient_db:
268
  info = patient_db[patient_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  return info["name"], info["age"], info["id"], info["notes"]
270
  return "", "", "", ""
271
 
 
272
  with gr.Blocks(css=".gradio-container {height: 100vh; overflow: hidden;}") as demo:
273
  gr.Markdown("<h2 style='text-align: center;'>Radiologist's Companion</h2>")
274
 
@@ -311,6 +344,25 @@ with gr.Blocks(css=".gradio-container {height: 100vh; overflow: hidden;}") as de
311
  outputs=[patient_name, patient_age, patient_id, patient_notes]
312
  )
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  # Load on launch
315
  demo.load(fn=load_model_on_selection, inputs=default_model, outputs=model_status)
316
 
 
5
  import time
6
  import gradio as gr
7
  import spaces
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, LlamaTokenizer, TextIteratorStreamer
9
  import threading
10
  import queue
11
 
12
+ # Globals
13
+ current_model = None
14
+ current_tokenizer = None
15
+
16
+ # Curated models
17
+ model_choices = [
18
+ "meta-llama/Llama-3.2-3B-Instruct",
19
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
20
+ "google/gemma-7b-it",
21
+ "mistralai/Mistral-Nemo-Instruct-FP8-2407"
22
+ ]
23
+
24
+ # Example patient database
25
+ patient_db = {
26
+ "001 - John Doe": {
27
+ "name": "John Doe",
28
+ "age": "45",
29
+ "id": "001",
30
+ "notes": "History of chest pain and hypertension. No prior surgeries."
31
+ },
32
+ "002 - Maria Sanchez": {
33
+ "name": "Maria Sanchez",
34
+ "age": "62",
35
+ "id": "002",
36
+ "notes": "Suspected pulmonary embolism. Shortness of breath, tachycardia."
37
+ },
38
+ "003 - Ahmed Al-Farsi": {
39
+ "name": "Ahmed Al-Farsi",
40
+ "age": "29",
41
+ "id": "003",
42
+ "notes": "Persistent migraines. MRI scheduled for brain imaging."
43
+ },
44
+ "004 - Lin Wei": {
45
+ "name": "Lin Wei",
46
+ "age": "51",
47
+ "id": "004",
48
+ "notes": "Annual screening. Family history of breast cancer."
49
+ }
50
+ }
51
+
52
+ # Store conversations per patient
53
+ patient_conversations = {}
54
+
55
+
56
  class RichTextStreamer(TextIteratorStreamer):
57
  def __init__(self, tokenizer, prompt_len=0, **kwargs):
58
  super().__init__(tokenizer, **kwargs)
 
213
 
214
  messages[-1]["content"] = output_text
215
 
216
+ current_id = patient_id.value
217
+ if current_id:
218
+ patient_conversations[current_id] = messages
219
+
220
  yield messages
221
 
222
  if in_think:
 
230
 
231
 
232
 
 
 
 
233
 
 
234
 
235
  def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
236
  global current_model, current_tokenizer
 
242
  progress(0.2, desc="Loading tokenizer...")
243
 
244
  # Default
245
+ current_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code= True, use_auth_token=token)
246
 
247
  progress(0.5, desc="Loading model...")
248
  current_model = AutoModelForCausalLM.from_pretrained(
 
269
  return prompt
270
 
271
  def add_user_message(user_input, history):
272
+ current_id = patient_id.value
273
+ if current_id:
274
+ conversation = patient_conversations.get(current_id, [])
275
+ conversation.append({"role": "user", "content": user_input})
276
+ patient_conversations[current_id] = conversation
277
+ return "", patient_conversations[current_id]
278
 
 
 
 
 
 
 
 
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  def autofill_patient(patient_key):
282
  if patient_key in patient_db:
283
  info = patient_db[patient_key]
284
+
285
+ # Init conversation if not existing
286
+ if info["id"] not in patient_conversations:
287
+ welcome_message = (
288
+ "**Welcome to the Radiologist's Companion!**\n\n"
289
+ "You can ask me about the patient's medical history or available imaging data.\n"
290
+ "- I can summarize key details from the EHR.\n"
291
+ "- I can tell you which medical images are available.\n"
292
+ "- If you'd like an organ segmentation (e.g. spleen, liver, kidney_left, colon, femur_right) on an abdominal CT scan, just ask!\n\n"
293
+ "**Example Requests:**\n"
294
+ "- \"What do we know about this patient?\"\n"
295
+ "- \"Which images are available for this patient?\"\n"
296
+ "- \"Can you segment the spleen from the CT scan?\"\n"
297
+ )
298
+
299
+ patient_conversations[info["id"]] = [{"role": "assistant", "content": welcome_message}]
300
+
301
  return info["name"], info["age"], info["id"], info["notes"]
302
  return "", "", "", ""
303
 
304
+
305
  with gr.Blocks(css=".gradio-container {height: 100vh; overflow: hidden;}") as demo:
306
  gr.Markdown("<h2 style='text-align: center;'>Radiologist's Companion</h2>")
307
 
 
344
  outputs=[patient_name, patient_age, patient_id, patient_notes]
345
  )
346
 
347
+ # After patient selected, load their conversation into chatbot
348
+ def load_patient_conversation(patient_key):
349
+ if patient_key in patient_db:
350
+ patient_id = patient_db[patient_key]["id"]
351
+ history = patient_conversations.get(patient_id, [])
352
+ return history
353
+ return []
354
+
355
+ patient_selector.change(
356
+ autofill_patient,
357
+ inputs=[patient_selector],
358
+ outputs=[patient_name, patient_age, patient_id, patient_notes]
359
+ ).then(
360
+ load_patient_conversation,
361
+ inputs=[patient_selector],
362
+ outputs=[chatbot]
363
+ )
364
+
365
+
366
  # Load on launch
367
  demo.load(fn=load_model_on_selection, inputs=default_model, outputs=model_status)
368