Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import os
|
|
5 |
from huggingface_hub import snapshot_download
|
6 |
import argparse
|
7 |
import logging
|
|
|
8 |
|
9 |
# --- Logging Setup ---
|
10 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
@@ -67,7 +68,6 @@ def initialize_model():
|
|
67 |
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
|
68 |
logging.info(model_status)
|
69 |
try:
|
70 |
-
# FIX: Removed explicit DeviceType. Let the library infer or use string if needed by constructor.
|
71 |
# The simple constructor often works by detecting the installed ORT package.
|
72 |
logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
|
73 |
model = og.Model(model_path) # Simplified model loading
|
@@ -107,10 +107,13 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
|
|
107 |
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
|
108 |
|
109 |
try:
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
# FIX: Removed eos_token_id and pad_token_id as they are not attributes
|
113 |
-
# of onnxruntime_genai.Tokenizer and likely handled internally by the generator.
|
114 |
search_options = {
|
115 |
"max_length": max_length,
|
116 |
"temperature": temperature,
|
@@ -121,8 +124,13 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
|
|
121 |
|
122 |
params = og.GeneratorParams(model)
|
123 |
params.set_search_options(**search_options)
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
start_time = time.time()
|
128 |
# Create generator AFTER setting parameters including inputs
|
@@ -134,22 +142,27 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
|
|
134 |
token_count = 0
|
135 |
# Rely primarily on generator.is_done()
|
136 |
while not generator.is_done():
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
141 |
|
142 |
-
|
143 |
|
144 |
-
|
145 |
-
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
|
152 |
-
|
|
|
|
|
|
|
|
|
153 |
|
154 |
end_time = time.time()
|
155 |
ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
|
@@ -159,6 +172,12 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
|
|
159 |
logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
|
160 |
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
except AttributeError as ae:
|
163 |
# Catch potential future API changes or issues during generation setup
|
164 |
logging.error(f"AttributeError during generation setup: {ae}", exc_info=True)
|
@@ -176,11 +195,7 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
|
|
176 |
def add_user_message(user_message, history):
|
177 |
"""Adds the user's message to the chat history for display."""
|
178 |
if not user_message:
|
179 |
-
# Returning original history prevents adding empty message
|
180 |
-
# Use gr.Warning or gr.Info for user feedback? Or raise gr.Error?
|
181 |
-
# gr.Warning("Please enter a message.") # Shows warning toast
|
182 |
return "", history # Clear input, return unchanged history
|
183 |
-
# raise gr.Error("Please enter a message.") # Stops execution, shows error
|
184 |
history = history + [[user_message, None]] # Append user message, leave bot response None
|
185 |
return "", history # Clear input textbox, return updated history
|
186 |
|
@@ -188,20 +203,15 @@ def add_user_message(user_message, history):
|
|
188 |
def generate_bot_response(history, max_length, temperature, top_p, top_k):
|
189 |
"""Generates the bot's response based on the history and streams it."""
|
190 |
if not history or history[-1][1] is not None:
|
191 |
-
# This case means user submitted empty message or something went wrong
|
192 |
-
# No need to generate if the last turn isn't user's pending turn
|
193 |
return history
|
194 |
|
195 |
user_prompt = history[-1][0] # Get the latest user prompt
|
196 |
-
# Prepare history for the model
|
197 |
-
model_history = history[:-1]
|
198 |
|
199 |
-
# Get the generator stream
|
200 |
response_stream = generate_response_stream(
|
201 |
user_prompt, model_history, max_length, temperature, top_p, top_k
|
202 |
)
|
203 |
|
204 |
-
# Stream the response chunks back to Gradio
|
205 |
history[-1][1] = "" # Initialize the bot response string in the history
|
206 |
for chunk in response_stream:
|
207 |
history[-1][1] += chunk # Append the chunk to the bot's message in history
|
@@ -210,12 +220,9 @@ def generate_bot_response(history, max_length, temperature, top_p, top_k):
|
|
210 |
# 3. Function to clear chat
|
211 |
def clear_chat():
|
212 |
"""Clears the chat history and input."""
|
213 |
-
global model_status
|
214 |
-
# Reset status only if it was showing an error from generation maybe?
|
215 |
-
# Or just always reset to Ready if model is loaded.
|
216 |
if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"):
|
217 |
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
|
218 |
-
# Keep the original error if init failed, otherwise show ready status
|
219 |
return None, [], model_status # Clear Textbox, Chatbot history, and update status display
|
220 |
|
221 |
|
@@ -224,13 +231,11 @@ try:
|
|
224 |
initialize_model()
|
225 |
except Exception as e:
|
226 |
print(f"FATAL: Model initialization failed: {e}")
|
227 |
-
# model_status is already set inside initialize_model on error
|
228 |
|
229 |
|
230 |
# --- Gradio Interface ---
|
231 |
logging.info("Creating Gradio Interface...")
|
232 |
|
233 |
-
# Select a theme
|
234 |
theme = gr.themes.Soft(
|
235 |
primary_hue="blue",
|
236 |
secondary_hue="sky",
|
@@ -249,11 +254,9 @@ with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
|
|
249 |
""")
|
250 |
with gr.Column(scale=1, min_width=150):
|
251 |
gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
|
252 |
-
# Use the global model_status variable for the initial value
|
253 |
model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
|
254 |
|
255 |
-
|
256 |
-
# Main Layout (Chat on Left, Settings on Right)
|
257 |
with gr.Row():
|
258 |
# Chat Column
|
259 |
with gr.Column(scale=3):
|
@@ -262,57 +265,47 @@ with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
|
|
262 |
height=600,
|
263 |
layout="bubble",
|
264 |
bubble_full_width=False,
|
265 |
-
avatar_images=(None, PHI_LOGO_URL)
|
266 |
)
|
267 |
with gr.Row():
|
268 |
prompt_input = gr.Textbox(
|
269 |
label="Your Message",
|
270 |
placeholder="<|user|>\nType your message here...\n<|end|>",
|
271 |
lines=4,
|
272 |
-
scale=9
|
273 |
)
|
274 |
-
# Combine Send and Clear Buttons Vertically? Or keep side-by-side? Side-by-side looks better
|
275 |
with gr.Column(scale=1, min_width=120):
|
276 |
submit_button = gr.Button("Send", variant="primary", size="lg")
|
277 |
clear_button = gr.Button("🗑️ Clear Chat", variant="secondary")
|
278 |
|
279 |
-
|
280 |
# Settings Column
|
281 |
with gr.Column(scale=1, min_width=250):
|
282 |
gr.Markdown("### ⚙️ Generation Settings")
|
283 |
-
with gr.Group():
|
284 |
max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
|
285 |
temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
|
286 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
|
287 |
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
|
288 |
-
|
289 |
-
gr.Markdown("---") # Separator
|
290 |
gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
|
291 |
gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.")
|
292 |
|
293 |
-
|
294 |
-
# Event Listeners (Connecting UI components to functions)
|
295 |
-
|
296 |
-
# Define inputs for the bot response generator
|
297 |
bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k]
|
298 |
|
299 |
-
# Chain actions:
|
300 |
-
# 1. User presses Enter or clicks Send
|
301 |
-
# 2. `add_user_message` updates history, clears input
|
302 |
-
# 3. `generate_bot_response` streams bot reply into history
|
303 |
submit_event = prompt_input.submit(
|
304 |
fn=add_user_message,
|
305 |
inputs=[prompt_input, chatbot],
|
306 |
-
outputs=[prompt_input, chatbot],
|
307 |
-
queue=False,
|
308 |
).then(
|
309 |
-
fn=generate_bot_response,
|
310 |
-
inputs=bot_response_inputs,
|
311 |
-
outputs=[chatbot],
|
312 |
-
api_name="chat"
|
313 |
)
|
314 |
|
315 |
-
submit_button.click(
|
316 |
fn=add_user_message,
|
317 |
inputs=[prompt_input, chatbot],
|
318 |
outputs=[prompt_input, chatbot],
|
@@ -321,18 +314,17 @@ with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
|
|
321 |
fn=generate_bot_response,
|
322 |
inputs=bot_response_inputs,
|
323 |
outputs=[chatbot],
|
324 |
-
api_name=False
|
325 |
)
|
326 |
|
327 |
-
# Clear button action
|
328 |
clear_button.click(
|
329 |
fn=clear_chat,
|
330 |
inputs=None,
|
331 |
-
outputs=[prompt_input, chatbot, model_status_text],
|
332 |
-
queue=False
|
333 |
)
|
334 |
|
335 |
# Launch the Gradio app
|
336 |
logging.info("Launching Gradio App...")
|
337 |
-
demo.queue(max_size=20)
|
338 |
demo.launch(show_error=True, max_threads=40)
|
|
|
5 |
from huggingface_hub import snapshot_download
|
6 |
import argparse
|
7 |
import logging
|
8 |
+
import numpy as np # Import numpy
|
9 |
|
10 |
# --- Logging Setup ---
|
11 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
68 |
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
|
69 |
logging.info(model_status)
|
70 |
try:
|
|
|
71 |
# The simple constructor often works by detecting the installed ORT package.
|
72 |
logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
|
73 |
model = og.Model(model_path) # Simplified model loading
|
|
|
107 |
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
|
108 |
|
109 |
try:
|
110 |
+
input_tokens_list = tokenizer.encode(full_prompt) # Encode returns a list/array
|
111 |
+
# Ensure input_tokens is a numpy array of the correct type (int32 is common)
|
112 |
+
input_tokens = np.array(input_tokens_list, dtype=np.int32)
|
113 |
+
# Reshape to (batch_size, sequence_length), which is (1, N) for single prompt
|
114 |
+
input_tokens = input_tokens.reshape((1, -1))
|
115 |
+
|
116 |
|
|
|
|
|
117 |
search_options = {
|
118 |
"max_length": max_length,
|
119 |
"temperature": temperature,
|
|
|
124 |
|
125 |
params = og.GeneratorParams(model)
|
126 |
params.set_search_options(**search_options)
|
127 |
+
|
128 |
+
# FIX: Create a dictionary mapping input names to tensors (numpy arrays)
|
129 |
+
# and pass this dictionary to set_inputs.
|
130 |
+
# Assuming the standard input name "input_ids".
|
131 |
+
inputs = {"input_ids": input_tokens}
|
132 |
+
logging.info(f"Setting inputs with keys: {inputs.keys()} and shape for 'input_ids': {inputs['input_ids'].shape}")
|
133 |
+
params.set_inputs(inputs)
|
134 |
|
135 |
start_time = time.time()
|
136 |
# Create generator AFTER setting parameters including inputs
|
|
|
142 |
token_count = 0
|
143 |
# Rely primarily on generator.is_done()
|
144 |
while not generator.is_done():
|
145 |
+
try:
|
146 |
+
generator.compute_logits()
|
147 |
+
generator.generate_next_token()
|
148 |
+
if first_token_time is None:
|
149 |
+
first_token_time = time.time() # Record time to first token
|
150 |
|
151 |
+
next_token = generator.get_next_tokens()[0]
|
152 |
|
153 |
+
decoded_chunk = tokenizer.decode([next_token])
|
154 |
+
token_count += 1
|
155 |
|
156 |
+
# Secondary check: Stop if the model explicitly generates the <|end|> string literal.
|
157 |
+
if decoded_chunk == "<|end|>":
|
158 |
+
logging.info("Assistant explicitly generated <|end|> token string.")
|
159 |
+
break
|
160 |
|
161 |
+
yield decoded_chunk # Yield just the text chunk
|
162 |
+
except Exception as loop_error:
|
163 |
+
logging.error(f"Error inside generation loop: {loop_error}", exc_info=True)
|
164 |
+
yield f"\n\nError during token generation: {loop_error}"
|
165 |
+
break # Exit loop on error
|
166 |
|
167 |
end_time = time.time()
|
168 |
ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
|
|
|
172 |
logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
|
173 |
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
|
174 |
|
175 |
+
except TypeError as te:
|
176 |
+
# Catch type errors specifically during setup if the input format is still wrong
|
177 |
+
logging.error(f"TypeError during generation setup: {te}", exc_info=True)
|
178 |
+
logging.error("Check if the input format {'input_ids': token_array} is correct.")
|
179 |
+
model_status = f"Generation Setup TypeError: {te}"
|
180 |
+
yield f"\n\nSorry, a TypeError occurred setting up generation: {te}"
|
181 |
except AttributeError as ae:
|
182 |
# Catch potential future API changes or issues during generation setup
|
183 |
logging.error(f"AttributeError during generation setup: {ae}", exc_info=True)
|
|
|
195 |
def add_user_message(user_message, history):
|
196 |
"""Adds the user's message to the chat history for display."""
|
197 |
if not user_message:
|
|
|
|
|
|
|
198 |
return "", history # Clear input, return unchanged history
|
|
|
199 |
history = history + [[user_message, None]] # Append user message, leave bot response None
|
200 |
return "", history # Clear input textbox, return updated history
|
201 |
|
|
|
203 |
def generate_bot_response(history, max_length, temperature, top_p, top_k):
|
204 |
"""Generates the bot's response based on the history and streams it."""
|
205 |
if not history or history[-1][1] is not None:
|
|
|
|
|
206 |
return history
|
207 |
|
208 |
user_prompt = history[-1][0] # Get the latest user prompt
|
209 |
+
model_history = history[:-1] # Prepare history for the model
|
|
|
210 |
|
|
|
211 |
response_stream = generate_response_stream(
|
212 |
user_prompt, model_history, max_length, temperature, top_p, top_k
|
213 |
)
|
214 |
|
|
|
215 |
history[-1][1] = "" # Initialize the bot response string in the history
|
216 |
for chunk in response_stream:
|
217 |
history[-1][1] += chunk # Append the chunk to the bot's message in history
|
|
|
220 |
# 3. Function to clear chat
|
221 |
def clear_chat():
|
222 |
"""Clears the chat history and input."""
|
223 |
+
global model_status
|
|
|
|
|
224 |
if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"):
|
225 |
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
|
|
|
226 |
return None, [], model_status # Clear Textbox, Chatbot history, and update status display
|
227 |
|
228 |
|
|
|
231 |
initialize_model()
|
232 |
except Exception as e:
|
233 |
print(f"FATAL: Model initialization failed: {e}")
|
|
|
234 |
|
235 |
|
236 |
# --- Gradio Interface ---
|
237 |
logging.info("Creating Gradio Interface...")
|
238 |
|
|
|
239 |
theme = gr.themes.Soft(
|
240 |
primary_hue="blue",
|
241 |
secondary_hue="sky",
|
|
|
254 |
""")
|
255 |
with gr.Column(scale=1, min_width=150):
|
256 |
gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
|
|
|
257 |
model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
|
258 |
|
259 |
+
# Main Layout
|
|
|
260 |
with gr.Row():
|
261 |
# Chat Column
|
262 |
with gr.Column(scale=3):
|
|
|
265 |
height=600,
|
266 |
layout="bubble",
|
267 |
bubble_full_width=False,
|
268 |
+
avatar_images=(None, PHI_LOGO_URL)
|
269 |
)
|
270 |
with gr.Row():
|
271 |
prompt_input = gr.Textbox(
|
272 |
label="Your Message",
|
273 |
placeholder="<|user|>\nType your message here...\n<|end|>",
|
274 |
lines=4,
|
275 |
+
scale=9
|
276 |
)
|
|
|
277 |
with gr.Column(scale=1, min_width=120):
|
278 |
submit_button = gr.Button("Send", variant="primary", size="lg")
|
279 |
clear_button = gr.Button("🗑️ Clear Chat", variant="secondary")
|
280 |
|
|
|
281 |
# Settings Column
|
282 |
with gr.Column(scale=1, min_width=250):
|
283 |
gr.Markdown("### ⚙️ Generation Settings")
|
284 |
+
with gr.Group():
|
285 |
max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
|
286 |
temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
|
287 |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
|
288 |
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
|
289 |
+
gr.Markdown("---")
|
|
|
290 |
gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
|
291 |
gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.")
|
292 |
|
293 |
+
# Event Listeners
|
|
|
|
|
|
|
294 |
bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k]
|
295 |
|
|
|
|
|
|
|
|
|
296 |
submit_event = prompt_input.submit(
|
297 |
fn=add_user_message,
|
298 |
inputs=[prompt_input, chatbot],
|
299 |
+
outputs=[prompt_input, chatbot],
|
300 |
+
queue=False,
|
301 |
).then(
|
302 |
+
fn=generate_bot_response,
|
303 |
+
inputs=bot_response_inputs,
|
304 |
+
outputs=[chatbot],
|
305 |
+
api_name="chat"
|
306 |
)
|
307 |
|
308 |
+
submit_button.click(
|
309 |
fn=add_user_message,
|
310 |
inputs=[prompt_input, chatbot],
|
311 |
outputs=[prompt_input, chatbot],
|
|
|
314 |
fn=generate_bot_response,
|
315 |
inputs=bot_response_inputs,
|
316 |
outputs=[chatbot],
|
317 |
+
api_name=False
|
318 |
)
|
319 |
|
|
|
320 |
clear_button.click(
|
321 |
fn=clear_chat,
|
322 |
inputs=None,
|
323 |
+
outputs=[prompt_input, chatbot, model_status_text],
|
324 |
+
queue=False
|
325 |
)
|
326 |
|
327 |
# Launch the Gradio app
|
328 |
logging.info("Launching Gradio App...")
|
329 |
+
demo.queue(max_size=20)
|
330 |
demo.launch(show_error=True, max_threads=40)
|