|
|
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
import onnxruntime as ort |
|
import numpy as np |
|
import os |
|
import time |
|
|
|
print("Loading libraries...") |
|
|
|
|
|
|
|
|
|
model_dir = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4" |
|
|
|
|
|
tokenizer = None |
|
session = None |
|
model_load_error = None |
|
|
|
|
|
if not os.path.isdir(model_dir): |
|
model_load_error = ( |
|
f"Error: Model directory not found at '{os.path.abspath(model_dir)}'\n" |
|
"Please ensure you have created the directory structure\n" |
|
f"'./{model_dir}' relative to this script ({os.path.basename(__file__)})\n" |
|
"and downloaded ALL the required model files into it from:\n" |
|
"https://huggingface.co./microsoft/Phi-4-mini-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4" |
|
) |
|
print(model_load_error) |
|
else: |
|
print(f"Found model directory: {os.path.abspath(model_dir)}") |
|
print("Loading tokenizer...") |
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
print("Tokenizer loaded successfully.") |
|
except Exception as e: |
|
model_load_error = f"Error loading tokenizer from {model_dir}: {e}" |
|
print(model_load_error) |
|
|
|
|
|
if tokenizer: |
|
print("Loading ONNX model session...") |
|
model_path = os.path.join(model_dir, "model.onnx") |
|
model_data_path = os.path.join(model_dir, "model.onnx.data") |
|
|
|
if not os.path.exists(model_path): |
|
model_load_error = f"Error: 'model.onnx' not found in {model_dir}" |
|
print(model_load_error) |
|
elif not os.path.exists(model_data_path): |
|
model_load_error = f"Error: 'model.onnx.data' not found in {model_dir}. This large file contains the model weights and is required." |
|
print(model_load_error) |
|
else: |
|
try: |
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
session = ort.InferenceSession( |
|
model_path, |
|
providers=["CPUExecutionProvider"] |
|
|
|
) |
|
end_time = time.time() |
|
print(f"ONNX model session loaded successfully using CPU provider in {end_time - start_time:.2f} seconds.") |
|
except Exception as e: |
|
model_load_error = f"Error loading ONNX session from {model_path}: {e}\n" |
|
model_load_error += "Ensure 'onnxruntime' library is installed correctly and that both 'model.onnx' and 'model.onnx.data' are valid files." |
|
print(model_load_error) |
|
|
|
|
|
def generate_response(prompt): |
|
""" |
|
Generates a response from the loaded ONNX model based on the user prompt. |
|
""" |
|
global tokenizer, session, model_load_error |
|
|
|
|
|
if model_load_error: |
|
return model_load_error |
|
if not tokenizer or not session: |
|
return "Error: Model or Tokenizer is not loaded correctly. Check console output." |
|
|
|
print(f"\nReceived prompt: {prompt}") |
|
start_time = time.time() |
|
|
|
|
|
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" |
|
print("Tokenizing input...") |
|
|
|
try: |
|
|
|
inputs = tokenizer(full_prompt, return_tensors="np") |
|
|
|
|
|
ort_inputs = { |
|
"input_ids": inputs["input_ids"].astype(np.int64), |
|
"attention_mask": inputs["attention_mask"].astype(np.int64) |
|
} |
|
print("Running model inference...") |
|
inference_start_time = time.time() |
|
|
|
|
|
outputs = session.run(None, ort_inputs) |
|
generated_ids = outputs[0] |
|
|
|
inference_end_time = time.time() |
|
print(f"Inference complete in {inference_end_time - inference_start_time:.2f} seconds.") |
|
|
|
|
|
print("Decoding response...") |
|
decoding_start_time = time.time() |
|
|
|
output_ids = generated_ids[0] if generated_ids.ndim == 2 else generated_ids |
|
response = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
decoding_end_time = time.time() |
|
print(f"Decoding complete in {decoding_end_time - decoding_start_time:.2f} seconds.") |
|
|
|
|
|
|
|
assistant_marker = "<|assistant|>" |
|
assistant_pos = response.find(assistant_marker) |
|
|
|
if assistant_pos != -1: |
|
|
|
cleaned_response = response[assistant_pos + len(assistant_marker):].strip() |
|
else: |
|
|
|
|
|
|
|
prompt_part_to_remove = full_prompt.rsplit(assistant_marker, 1)[0] |
|
if response.startswith(prompt_part_to_remove): |
|
cleaned_response = response[len(prompt_part_to_remove):].strip() |
|
else: |
|
|
|
cleaned_response = response.strip() |
|
print("Warning: Could not reliably clean the prompt context from the response.") |
|
|
|
|
|
total_time = time.time() - start_time |
|
print(f"Generated response: {cleaned_response}") |
|
print(f"Total processing time for this prompt: {total_time:.2f} seconds.") |
|
return cleaned_response |
|
|
|
except Exception as e: |
|
print(f"Error during model inference or decoding: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return f"Error during generation: {e}" |
|
|
|
|
|
print("Setting up Gradio interface...") |
|
|
|
|
|
css = """ |
|
#output_textbox textarea { |
|
min-height: 300px; /* Make output box taller */ |
|
} |
|
#input_textbox textarea { |
|
min-height: 100px; /* Adjust input box height */ |
|
} |
|
""" |
|
|
|
demo = gr.Blocks(css=css, theme=gr.themes.Default()) |
|
|
|
with demo: |
|
gr.Markdown( |
|
""" |
|
# Phi-4-Mini ONNX Chatbot (Local CPU) |
|
Interact with the `microsoft/Phi-4-mini-instruct-onnx` model variant |
|
(`cpu-int4-rtn-block-32-acc-level-4`) running locally using ONNX Runtime on your CPU. |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
input_textbox = gr.Textbox( |
|
label="Your Prompt", |
|
placeholder="Type your question or instruction here...", |
|
lines=4, |
|
elem_id="input_textbox" |
|
) |
|
submit_button = gr.Button("Generate Response", variant="primary") |
|
with gr.Column(scale=3): |
|
output_textbox = gr.Textbox( |
|
label="AI Response", |
|
lines=10, |
|
interactive=False, |
|
elem_id="output_textbox" |
|
) |
|
|
|
|
|
if model_load_error: |
|
gr.Markdown(f"**<font color='red'>Model Loading Error:</font>**\n```\n{model_load_error}\n```") |
|
elif session is None or tokenizer is None: |
|
gr.Markdown("**<font color='orange'>Warning:</font>** Model or tokenizer did not load correctly. Check console logs.") |
|
else: |
|
gr.Markdown("**<font color='green'>Model and Tokenizer Loaded Successfully.</font>**") |
|
|
|
|
|
submit_button.click( |
|
fn=generate_response, |
|
inputs=input_textbox, |
|
outputs=output_textbox |
|
) |
|
|
|
|
|
input_textbox.submit( |
|
fn=generate_response, |
|
inputs=input_textbox, |
|
outputs=output_textbox |
|
) |
|
|
|
|
|
print("-" * 50) |
|
print("Launching Gradio app...") |
|
print("You can access it in your browser at the URL provided below (usually http://127.0.0.1:7860).") |
|
print("Press CTRL+C in this terminal to stop the application.") |
|
print("-" * 50) |
|
|
|
|
|
|
|
demo.launch(share=False, debug=False) |
|
|
|
print("Gradio app closed.") |