Athspi commited on
Commit
e3d67e9
·
verified ·
1 Parent(s): 751f392

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -64
app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  from huggingface_hub import snapshot_download
6
  import argparse
7
  import logging
 
8
 
9
  # --- Logging Setup ---
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -67,7 +68,6 @@ def initialize_model():
67
  model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
68
  logging.info(model_status)
69
  try:
70
- # FIX: Removed explicit DeviceType. Let the library infer or use string if needed by constructor.
71
  # The simple constructor often works by detecting the installed ORT package.
72
  logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
73
  model = og.Model(model_path) # Simplified model loading
@@ -107,10 +107,13 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
107
  logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
108
 
109
  try:
110
- input_tokens = tokenizer.encode(full_prompt)
 
 
 
 
 
111
 
112
- # FIX: Removed eos_token_id and pad_token_id as they are not attributes
113
- # of onnxruntime_genai.Tokenizer and likely handled internally by the generator.
114
  search_options = {
115
  "max_length": max_length,
116
  "temperature": temperature,
@@ -121,8 +124,13 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
121
 
122
  params = og.GeneratorParams(model)
123
  params.set_search_options(**search_options)
124
- # FIX: Use the set_inputs method as suggested by the error message
125
- params.set_inputs(input_tokens)
 
 
 
 
 
126
 
127
  start_time = time.time()
128
  # Create generator AFTER setting parameters including inputs
@@ -134,22 +142,27 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
134
  token_count = 0
135
  # Rely primarily on generator.is_done()
136
  while not generator.is_done():
137
- generator.compute_logits()
138
- generator.generate_next_token()
139
- if first_token_time is None:
140
- first_token_time = time.time() # Record time to first token
 
141
 
142
- next_token = generator.get_next_tokens()[0]
143
 
144
- decoded_chunk = tokenizer.decode([next_token])
145
- token_count += 1
146
 
147
- # Secondary check: Stop if the model explicitly generates the <|end|> string literal.
148
- if decoded_chunk == "<|end|>":
149
- logging.info("Assistant explicitly generated <|end|> token string.")
150
- break
151
 
152
- yield decoded_chunk # Yield just the text chunk
 
 
 
 
153
 
154
  end_time = time.time()
155
  ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
@@ -159,6 +172,12 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
159
  logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
160
  model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
161
 
 
 
 
 
 
 
162
  except AttributeError as ae:
163
  # Catch potential future API changes or issues during generation setup
164
  logging.error(f"AttributeError during generation setup: {ae}", exc_info=True)
@@ -176,11 +195,7 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
176
  def add_user_message(user_message, history):
177
  """Adds the user's message to the chat history for display."""
178
  if not user_message:
179
- # Returning original history prevents adding empty message
180
- # Use gr.Warning or gr.Info for user feedback? Or raise gr.Error?
181
- # gr.Warning("Please enter a message.") # Shows warning toast
182
  return "", history # Clear input, return unchanged history
183
- # raise gr.Error("Please enter a message.") # Stops execution, shows error
184
  history = history + [[user_message, None]] # Append user message, leave bot response None
185
  return "", history # Clear input textbox, return updated history
186
 
@@ -188,20 +203,15 @@ def add_user_message(user_message, history):
188
  def generate_bot_response(history, max_length, temperature, top_p, top_k):
189
  """Generates the bot's response based on the history and streams it."""
190
  if not history or history[-1][1] is not None:
191
- # This case means user submitted empty message or something went wrong
192
- # No need to generate if the last turn isn't user's pending turn
193
  return history
194
 
195
  user_prompt = history[-1][0] # Get the latest user prompt
196
- # Prepare history for the model (all turns *before* the current one)
197
- model_history = history[:-1]
198
 
199
- # Get the generator stream
200
  response_stream = generate_response_stream(
201
  user_prompt, model_history, max_length, temperature, top_p, top_k
202
  )
203
 
204
- # Stream the response chunks back to Gradio
205
  history[-1][1] = "" # Initialize the bot response string in the history
206
  for chunk in response_stream:
207
  history[-1][1] += chunk # Append the chunk to the bot's message in history
@@ -210,12 +220,9 @@ def generate_bot_response(history, max_length, temperature, top_p, top_k):
210
  # 3. Function to clear chat
211
  def clear_chat():
212
  """Clears the chat history and input."""
213
- global model_status # Keep model status indicator updated
214
- # Reset status only if it was showing an error from generation maybe?
215
- # Or just always reset to Ready if model is loaded.
216
  if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"):
217
  model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
218
- # Keep the original error if init failed, otherwise show ready status
219
  return None, [], model_status # Clear Textbox, Chatbot history, and update status display
220
 
221
 
@@ -224,13 +231,11 @@ try:
224
  initialize_model()
225
  except Exception as e:
226
  print(f"FATAL: Model initialization failed: {e}")
227
- # model_status is already set inside initialize_model on error
228
 
229
 
230
  # --- Gradio Interface ---
231
  logging.info("Creating Gradio Interface...")
232
 
233
- # Select a theme
234
  theme = gr.themes.Soft(
235
  primary_hue="blue",
236
  secondary_hue="sky",
@@ -249,11 +254,9 @@ with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
249
  """)
250
  with gr.Column(scale=1, min_width=150):
251
  gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
252
- # Use the global model_status variable for the initial value
253
  model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
254
 
255
-
256
- # Main Layout (Chat on Left, Settings on Right)
257
  with gr.Row():
258
  # Chat Column
259
  with gr.Column(scale=3):
@@ -262,57 +265,47 @@ with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
262
  height=600,
263
  layout="bubble",
264
  bubble_full_width=False,
265
- avatar_images=(None, PHI_LOGO_URL) # (user, bot)
266
  )
267
  with gr.Row():
268
  prompt_input = gr.Textbox(
269
  label="Your Message",
270
  placeholder="<|user|>\nType your message here...\n<|end|>",
271
  lines=4,
272
- scale=9 # Make textbox wider
273
  )
274
- # Combine Send and Clear Buttons Vertically? Or keep side-by-side? Side-by-side looks better
275
  with gr.Column(scale=1, min_width=120):
276
  submit_button = gr.Button("Send", variant="primary", size="lg")
277
  clear_button = gr.Button("🗑️ Clear Chat", variant="secondary")
278
 
279
-
280
  # Settings Column
281
  with gr.Column(scale=1, min_width=250):
282
  gr.Markdown("### ⚙️ Generation Settings")
283
- with gr.Group(): # Group settings visually
284
  max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
285
  temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
286
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
287
  top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
288
-
289
- gr.Markdown("---") # Separator
290
  gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
291
  gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.")
292
 
293
-
294
- # Event Listeners (Connecting UI components to functions)
295
-
296
- # Define inputs for the bot response generator
297
  bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k]
298
 
299
- # Chain actions:
300
- # 1. User presses Enter or clicks Send
301
- # 2. `add_user_message` updates history, clears input
302
- # 3. `generate_bot_response` streams bot reply into history
303
  submit_event = prompt_input.submit(
304
  fn=add_user_message,
305
  inputs=[prompt_input, chatbot],
306
- outputs=[prompt_input, chatbot], # Update textbox and history
307
- queue=False, # Submit is fast
308
  ).then(
309
- fn=generate_bot_response, # Call the generator function
310
- inputs=bot_response_inputs, # Pass history and params
311
- outputs=[chatbot], # Stream output directly to chatbot
312
- api_name="chat" # Optional: name for API usage
313
  )
314
 
315
- submit_button.click( # Mirror actions for button click
316
  fn=add_user_message,
317
  inputs=[prompt_input, chatbot],
318
  outputs=[prompt_input, chatbot],
@@ -321,18 +314,17 @@ with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
321
  fn=generate_bot_response,
322
  inputs=bot_response_inputs,
323
  outputs=[chatbot],
324
- api_name=False # Don't expose button click as separate API endpoint
325
  )
326
 
327
- # Clear button action
328
  clear_button.click(
329
  fn=clear_chat,
330
  inputs=None,
331
- outputs=[prompt_input, chatbot, model_status_text], # Clear input, chat, and update status text
332
- queue=False # Clearing is fast
333
  )
334
 
335
  # Launch the Gradio app
336
  logging.info("Launching Gradio App...")
337
- demo.queue(max_size=20) # Enable queuing with a limit
338
  demo.launch(show_error=True, max_threads=40)
 
5
  from huggingface_hub import snapshot_download
6
  import argparse
7
  import logging
8
+ import numpy as np # Import numpy
9
 
10
  # --- Logging Setup ---
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
68
  model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
69
  logging.info(model_status)
70
  try:
 
71
  # The simple constructor often works by detecting the installed ORT package.
72
  logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
73
  model = og.Model(model_path) # Simplified model loading
 
107
  logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
108
 
109
  try:
110
+ input_tokens_list = tokenizer.encode(full_prompt) # Encode returns a list/array
111
+ # Ensure input_tokens is a numpy array of the correct type (int32 is common)
112
+ input_tokens = np.array(input_tokens_list, dtype=np.int32)
113
+ # Reshape to (batch_size, sequence_length), which is (1, N) for single prompt
114
+ input_tokens = input_tokens.reshape((1, -1))
115
+
116
 
 
 
117
  search_options = {
118
  "max_length": max_length,
119
  "temperature": temperature,
 
124
 
125
  params = og.GeneratorParams(model)
126
  params.set_search_options(**search_options)
127
+
128
+ # FIX: Create a dictionary mapping input names to tensors (numpy arrays)
129
+ # and pass this dictionary to set_inputs.
130
+ # Assuming the standard input name "input_ids".
131
+ inputs = {"input_ids": input_tokens}
132
+ logging.info(f"Setting inputs with keys: {inputs.keys()} and shape for 'input_ids': {inputs['input_ids'].shape}")
133
+ params.set_inputs(inputs)
134
 
135
  start_time = time.time()
136
  # Create generator AFTER setting parameters including inputs
 
142
  token_count = 0
143
  # Rely primarily on generator.is_done()
144
  while not generator.is_done():
145
+ try:
146
+ generator.compute_logits()
147
+ generator.generate_next_token()
148
+ if first_token_time is None:
149
+ first_token_time = time.time() # Record time to first token
150
 
151
+ next_token = generator.get_next_tokens()[0]
152
 
153
+ decoded_chunk = tokenizer.decode([next_token])
154
+ token_count += 1
155
 
156
+ # Secondary check: Stop if the model explicitly generates the <|end|> string literal.
157
+ if decoded_chunk == "<|end|>":
158
+ logging.info("Assistant explicitly generated <|end|> token string.")
159
+ break
160
 
161
+ yield decoded_chunk # Yield just the text chunk
162
+ except Exception as loop_error:
163
+ logging.error(f"Error inside generation loop: {loop_error}", exc_info=True)
164
+ yield f"\n\nError during token generation: {loop_error}"
165
+ break # Exit loop on error
166
 
167
  end_time = time.time()
168
  ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
 
172
  logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
173
  model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
174
 
175
+ except TypeError as te:
176
+ # Catch type errors specifically during setup if the input format is still wrong
177
+ logging.error(f"TypeError during generation setup: {te}", exc_info=True)
178
+ logging.error("Check if the input format {'input_ids': token_array} is correct.")
179
+ model_status = f"Generation Setup TypeError: {te}"
180
+ yield f"\n\nSorry, a TypeError occurred setting up generation: {te}"
181
  except AttributeError as ae:
182
  # Catch potential future API changes or issues during generation setup
183
  logging.error(f"AttributeError during generation setup: {ae}", exc_info=True)
 
195
  def add_user_message(user_message, history):
196
  """Adds the user's message to the chat history for display."""
197
  if not user_message:
 
 
 
198
  return "", history # Clear input, return unchanged history
 
199
  history = history + [[user_message, None]] # Append user message, leave bot response None
200
  return "", history # Clear input textbox, return updated history
201
 
 
203
  def generate_bot_response(history, max_length, temperature, top_p, top_k):
204
  """Generates the bot's response based on the history and streams it."""
205
  if not history or history[-1][1] is not None:
 
 
206
  return history
207
 
208
  user_prompt = history[-1][0] # Get the latest user prompt
209
+ model_history = history[:-1] # Prepare history for the model
 
210
 
 
211
  response_stream = generate_response_stream(
212
  user_prompt, model_history, max_length, temperature, top_p, top_k
213
  )
214
 
 
215
  history[-1][1] = "" # Initialize the bot response string in the history
216
  for chunk in response_stream:
217
  history[-1][1] += chunk # Append the chunk to the bot's message in history
 
220
  # 3. Function to clear chat
221
  def clear_chat():
222
  """Clears the chat history and input."""
223
+ global model_status
 
 
224
  if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"):
225
  model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
 
226
  return None, [], model_status # Clear Textbox, Chatbot history, and update status display
227
 
228
 
 
231
  initialize_model()
232
  except Exception as e:
233
  print(f"FATAL: Model initialization failed: {e}")
 
234
 
235
 
236
  # --- Gradio Interface ---
237
  logging.info("Creating Gradio Interface...")
238
 
 
239
  theme = gr.themes.Soft(
240
  primary_hue="blue",
241
  secondary_hue="sky",
 
254
  """)
255
  with gr.Column(scale=1, min_width=150):
256
  gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
 
257
  model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
258
 
259
+ # Main Layout
 
260
  with gr.Row():
261
  # Chat Column
262
  with gr.Column(scale=3):
 
265
  height=600,
266
  layout="bubble",
267
  bubble_full_width=False,
268
+ avatar_images=(None, PHI_LOGO_URL)
269
  )
270
  with gr.Row():
271
  prompt_input = gr.Textbox(
272
  label="Your Message",
273
  placeholder="<|user|>\nType your message here...\n<|end|>",
274
  lines=4,
275
+ scale=9
276
  )
 
277
  with gr.Column(scale=1, min_width=120):
278
  submit_button = gr.Button("Send", variant="primary", size="lg")
279
  clear_button = gr.Button("🗑️ Clear Chat", variant="secondary")
280
 
 
281
  # Settings Column
282
  with gr.Column(scale=1, min_width=250):
283
  gr.Markdown("### ⚙️ Generation Settings")
284
+ with gr.Group():
285
  max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
286
  temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
287
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
288
  top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
289
+ gr.Markdown("---")
 
290
  gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
291
  gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.")
292
 
293
+ # Event Listeners
 
 
 
294
  bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k]
295
 
 
 
 
 
296
  submit_event = prompt_input.submit(
297
  fn=add_user_message,
298
  inputs=[prompt_input, chatbot],
299
+ outputs=[prompt_input, chatbot],
300
+ queue=False,
301
  ).then(
302
+ fn=generate_bot_response,
303
+ inputs=bot_response_inputs,
304
+ outputs=[chatbot],
305
+ api_name="chat"
306
  )
307
 
308
+ submit_button.click(
309
  fn=add_user_message,
310
  inputs=[prompt_input, chatbot],
311
  outputs=[prompt_input, chatbot],
 
314
  fn=generate_bot_response,
315
  inputs=bot_response_inputs,
316
  outputs=[chatbot],
317
+ api_name=False
318
  )
319
 
 
320
  clear_button.click(
321
  fn=clear_chat,
322
  inputs=None,
323
+ outputs=[prompt_input, chatbot, model_status_text],
324
+ queue=False
325
  )
326
 
327
  # Launch the Gradio app
328
  logging.info("Launching Gradio App...")
329
+ demo.queue(max_size=20)
330
  demo.launch(show_error=True, max_threads=40)