Ruurd commited on
Commit
0040338
·
1 Parent(s): 2372fe6

add submit button

Browse files
Files changed (1) hide show
  1. app.py +232 -36
app.py CHANGED
@@ -1,53 +1,249 @@
1
  import os
2
- import subprocess
 
 
 
 
 
3
 
4
- def install(package):
5
- subprocess.check_call([os.sys.executable, "-m", "pip", "install", package])
6
 
7
- install("transformers")
 
8
 
9
- import gradio as gr
10
- from transformers import AutoModelForCausalLM, AutoTokenizer
11
- import torch
12
- import spaces
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Dictionary to store loaded models and tokenizers
16
- loaded_models = {}
 
 
 
 
 
 
17
 
18
- def load_model(model_name):
19
- """Load the model and tokenizer if not already loaded."""
20
- if model_name not in loaded_models:
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- model = AutoModelForCausalLM.from_pretrained(
23
- model_name, torch_dtype=torch.float16, device_map="auto"
24
- )
25
- loaded_models[model_name] = (tokenizer, model)
26
- return loaded_models[model_name]
27
 
28
  @spaces.GPU
29
- def generate_text(model_name, prompt):
30
- """Generate text using the selected model."""
31
- tokenizer, model = load_model(model_name)
32
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
- outputs = model.generate(**inputs, max_new_tokens=256)
34
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
35
-
36
- # List of models to choose from
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  model_choices = [
38
- "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
39
  "meta-llama/Llama-3.2-3B-Instruct",
40
- "google/gemma-7b"
 
 
41
  ]
42
 
43
- # Gradio interface setup
44
  with gr.Blocks() as demo:
45
- gr.Markdown("## Clinical Text Analysis with Multiple Models")
46
- model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
47
- input_text = gr.Textbox(label="Input Clinical Text")
48
- output_text = gr.Textbox(label="Generated Output")
49
- analyze_button = gr.Button("Analyze")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)
52
 
53
  demo.launch()
 
1
  import os
2
+ import torch
3
+ import time
4
+ import gradio as gr
5
+ import spaces
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
+ import threading
8
 
9
+ from transformers import TextIteratorStreamer
10
+ import threading
11
 
12
+ from transformers import TextIteratorStreamer
13
+ import queue
14
 
15
+ class RichTextStreamer(TextIteratorStreamer):
16
+ def __init__(self, tokenizer, prompt_len=0, **kwargs):
17
+ super().__init__(tokenizer, **kwargs)
18
+ self.token_queue = queue.Queue()
19
+ self.prompt_len = prompt_len
20
+ self.count = 0
21
+
22
+ def put(self, value):
23
+ if isinstance(value, torch.Tensor):
24
+ token_ids = value.view(-1).tolist()
25
+ elif isinstance(value, list):
26
+ token_ids = value
27
+ else:
28
+ token_ids = [value]
29
 
30
+ for token_id in token_ids:
31
+ self.count += 1
32
+ if self.count <= self.prompt_len:
33
+ continue # skip prompt tokens
34
+ token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
35
+ is_special = token_id in self.tokenizer.all_special_ids
36
+ self.token_queue.put({
37
+ "token_id": token_id,
38
+ "token": token_str,
39
+ "is_special": is_special
40
+ })
41
 
42
+ def __iter__(self):
43
+ while True:
44
+ try:
45
+ token_info = self.token_queue.get(timeout=self.timeout)
46
+ yield token_info
47
+ except queue.Empty:
48
+ if self.end_of_generation.is_set():
49
+ break
50
 
 
 
 
 
 
 
 
 
 
51
 
52
  @spaces.GPU
53
+ def chat_with_model(messages):
54
+ global current_model, current_tokenizer
55
+ if current_model is None or current_tokenizer is None:
56
+ yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
57
+ return
58
+
59
+ pad_id = current_tokenizer.pad_token_id
60
+ eos_id = current_tokenizer.eos_token_id
61
+ if pad_id is None:
62
+ pad_id = current_tokenizer.unk_token_id or 0
63
+
64
+ output_text = ""
65
+ in_think = False
66
+ max_new_tokens = 1024
67
+ generated_tokens = 0
68
+
69
+ prompt = format_prompt(messages)
70
+ device = torch.device("cuda")
71
+ current_model.to(device).half()
72
+
73
+ # 1. Tokenize prompt
74
+ inputs = current_tokenizer(prompt, return_tensors="pt").to(device)
75
+ prompt_len = inputs["input_ids"].shape[-1]
76
+
77
+ # 2. Init streamer with prompt_len
78
+ streamer = RichTextStreamer(
79
+ tokenizer=current_tokenizer,
80
+ prompt_len=prompt_len,
81
+ skip_special_tokens=False
82
+ )
83
+
84
+ # 3. Build generation kwargs
85
+ generation_kwargs = dict(
86
+ **inputs,
87
+ max_new_tokens=max_new_tokens,
88
+ do_sample=True,
89
+ streamer=streamer,
90
+ eos_token_id=eos_id,
91
+ pad_token_id=pad_id
92
+ )
93
+
94
+ # 4. Launch generation in a thread
95
+ thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
96
+ thread.start()
97
+
98
+ messages = messages.copy()
99
+ messages.append({"role": "assistant", "content": ""})
100
+
101
+ print(f'Step 1: {messages}')
102
+
103
+ prompt_text = current_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=False)
104
+
105
+ for token_info in streamer:
106
+ token_str = token_info["token"]
107
+ token_id = token_info["token_id"]
108
+ is_special = token_info["is_special"]
109
+
110
+ # Stop immediately at EOS
111
+ if token_id == eos_id:
112
+ break
113
+
114
+ # Detect reasoning block
115
+ if "<think>" in token_str:
116
+ in_think = True
117
+ token_str = token_str.replace("<think>", "")
118
+ output_text += "*"
119
+
120
+ if "</think>" in token_str:
121
+ in_think = False
122
+ token_str = token_str.replace("</think>", "")
123
+ output_text += token_str + "*"
124
+ else:
125
+ output_text += token_str
126
+
127
+ # Early stopping if user reappears
128
+ if "\nUser" in output_text:
129
+ output_text = output_text.split("\nUser")[0].rstrip()
130
+ messages[-1]["content"] = output_text
131
+ break
132
+
133
+ generated_tokens += 1
134
+ if generated_tokens >= max_new_tokens:
135
+ break
136
+
137
+ messages[-1]["content"] = output_text
138
+
139
+ print(f'Step 2: {messages}')
140
+
141
+ yield messages
142
+
143
+ if in_think:
144
+ output_text += "*"
145
+ messages[-1]["content"] = output_text
146
+
147
+ # Wait for thread to finish
148
+ # current_model.to("cpu")
149
+ torch.cuda.empty_cache()
150
+
151
+ messages[-1]["content"] = output_text
152
+ print(f'Step 3: {messages}')
153
+
154
+ return messages
155
+
156
+
157
+
158
+ # Globals
159
+ current_model = None
160
+ current_tokenizer = None
161
+
162
+ def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
163
+ global current_model, current_tokenizer
164
+ token = os.getenv("HF_TOKEN")
165
+
166
+ progress(0, desc="Loading tokenizer...")
167
+ current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
168
+
169
+ progress(0.5, desc="Loading model...")
170
+ current_model = AutoModelForCausalLM.from_pretrained(
171
+ model_name,
172
+ torch_dtype=torch.float16,
173
+ device_map="cpu", # loaded to CPU initially
174
+ use_auth_token=token
175
+ )
176
+
177
+ progress(1, desc="Model ready.")
178
+ return f"{model_name} loaded and ready!"
179
+
180
+ # Format conversation as plain text
181
+ def format_prompt(messages):
182
+ prompt = ""
183
+ for msg in messages:
184
+ role = msg["role"]
185
+ if role == "user":
186
+ prompt += f"User: {msg['content'].strip()}\n"
187
+ elif role == "assistant":
188
+ prompt += f"Assistant: {msg['content'].strip()}\n"
189
+ prompt += "Assistant:"
190
+ return prompt
191
+
192
+ def add_user_message(user_input, history):
193
+ return "", history + [{"role": "user", "content": user_input}]
194
+
195
+ # Curated models
196
  model_choices = [
 
197
  "meta-llama/Llama-3.2-3B-Instruct",
198
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
199
+ "google/gemma-7b",
200
+ "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
201
  ]
202
 
 
203
  with gr.Blocks() as demo:
204
+ gr.Markdown("## Clinical Chatbot (Streaming)")
205
+
206
+ default_model = gr.State(model_choices[0])
207
+
208
+ with gr.Row():
209
+ mode = gr.Radio(["Choose from list", "Enter custom model"], value="Choose from list", label="Model Input Mode")
210
+ model_selector = gr.Dropdown(choices=model_choices, label="Select Predefined Model")
211
+ model_textbox = gr.Textbox(label="Or Enter HF Model Name")
212
+
213
+ model_status = gr.Textbox(label="Model Status", interactive=False)
214
+ chatbot = gr.Chatbot(label="Chat", type="messages")
215
+ msg = gr.Textbox(label="Your message", placeholder="Enter clinical input...", show_label=False)
216
+ with gr.Row():
217
+ submit_btn = gr.Button("Submit")
218
+ clear = gr.Button("Clear")
219
+
220
+ def resolve_model_choice(mode, dropdown_value, textbox_value):
221
+ return textbox_value.strip() if mode == "Enter custom model" else dropdown_value
222
+
223
+ # Load on launch
224
+ demo.load(fn=load_model_on_selection, inputs=default_model, outputs=model_status)
225
+
226
+ # Model selection logic
227
+ mode.select(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then(
228
+ load_model_on_selection, inputs=default_model, outputs=model_status
229
+ )
230
+ model_selector.change(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then(
231
+ load_model_on_selection, inputs=default_model, outputs=model_status
232
+ )
233
+ model_textbox.submit(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then(
234
+ load_model_on_selection, inputs=default_model, outputs=model_status
235
+ )
236
+
237
+ # Submit via enter key or button
238
+ msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
239
+ chat_with_model, chatbot, chatbot
240
+ )
241
+ submit_btn.click(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
242
+ chat_with_model, chatbot, chatbot
243
+ )
244
+
245
+ clear.click(lambda: [], None, chatbot, queue=False)
246
+
247
 
 
248
 
249
  demo.launch()