File size: 9,512 Bytes
92acddd 63faa06 92acddd 63faa06 92acddd 63faa06 92acddd 346197d 92acddd 63faa06 92acddd 63faa06 92acddd 63faa06 92acddd 346197d 92acddd 63faa06 92acddd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
# app.py
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import os # Import os module to check if model directory exists
import time # To measure performance (optional)
print("Loading libraries...")
# --- Configuration ---
# Define the local directory where the downloaded model files are stored.
# This path MUST match where you downloaded the model files relative to this script.
model_dir = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"
# --- Model Loading ---
tokenizer = None
session = None
model_load_error = None
# Check if the model directory exists before attempting to load
if not os.path.isdir(model_dir):
model_load_error = (
f"Error: Model directory not found at '{os.path.abspath(model_dir)}'\n"
"Please ensure you have created the directory structure\n"
f"'./{model_dir}' relative to this script ({os.path.basename(__file__)})\n"
"and downloaded ALL the required model files into it from:\n"
"https://huggingface.co./microsoft/Phi-4-mini-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"
)
print(model_load_error)
else:
print(f"Found model directory: {os.path.abspath(model_dir)}")
print("Loading tokenizer...")
try:
# Load tokenizer associated with the Phi-4 model variant
tokenizer = AutoTokenizer.from_pretrained(model_dir)
print("Tokenizer loaded successfully.")
except Exception as e:
model_load_error = f"Error loading tokenizer from {model_dir}: {e}"
print(model_load_error)
# Only attempt to load session if tokenizer loaded successfully
if tokenizer:
print("Loading ONNX model session...")
model_path = os.path.join(model_dir, "model.onnx")
model_data_path = os.path.join(model_dir, "model.onnx.data")
if not os.path.exists(model_path):
model_load_error = f"Error: 'model.onnx' not found in {model_dir}"
print(model_load_error)
elif not os.path.exists(model_data_path):
model_load_error = f"Error: 'model.onnx.data' not found in {model_dir}. This large file contains the model weights and is required."
print(model_load_error)
else:
try:
# Load the ONNX model using ONNX Runtime for CPU execution
start_time = time.time()
# You can configure session options for performance if needed
# sess_options = ort.SessionOptions()
# sess_options.intra_op_num_threads = 4 # Example: Limit threads
session = ort.InferenceSession(
model_path,
providers=["CPUExecutionProvider"]
# sess_options=sess_options # Uncomment to use options
)
end_time = time.time()
print(f"ONNX model session loaded successfully using CPU provider in {end_time - start_time:.2f} seconds.")
except Exception as e:
model_load_error = f"Error loading ONNX session from {model_path}: {e}\n"
model_load_error += "Ensure 'onnxruntime' library is installed correctly and that both 'model.onnx' and 'model.onnx.data' are valid files."
print(model_load_error)
# --- Inference Function ---
def generate_response(prompt):
"""
Generates a response from the loaded ONNX model based on the user prompt.
"""
global tokenizer, session, model_load_error # Allow access to global vars
# Check if model loading failed earlier
if model_load_error:
return model_load_error
if not tokenizer or not session:
return "Error: Model or Tokenizer is not loaded correctly. Check console output."
print(f"\nReceived prompt: {prompt}")
start_time = time.time()
# Format the prompt with specific markers for instruction following
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
print("Tokenizing input...")
try:
# Tokenize the formatted prompt
inputs = tokenizer(full_prompt, return_tensors="np")
# Prepare inputs for the ONNX model
ort_inputs = {
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64)
}
print("Running model inference...")
inference_start_time = time.time()
# Run the ONNX model inference
outputs = session.run(None, ort_inputs)
generated_ids = outputs[0] # Assuming the first output contains the generated IDs
inference_end_time = time.time()
print(f"Inference complete in {inference_end_time - inference_start_time:.2f} seconds.")
# Decode the generated token IDs back into text
print("Decoding response...")
decoding_start_time = time.time()
# Ensure generated_ids is 1D if necessary, might be shape (1, sequence_length)
output_ids = generated_ids[0] if generated_ids.ndim == 2 else generated_ids
response = tokenizer.decode(output_ids, skip_special_tokens=True)
decoding_end_time = time.time()
print(f"Decoding complete in {decoding_end_time - decoding_start_time:.2f} seconds.")
# --- Response Cleaning ---
# 1. Find the start of the assistant's response
assistant_marker = "<|assistant|>"
assistant_pos = response.find(assistant_marker)
if assistant_pos != -1:
# If marker found, take text after it
cleaned_response = response[assistant_pos + len(assistant_marker):].strip()
else:
# Fallback: If marker isn't perfectly decoded, try removing the original input prompt
# This assumes the model might prepend the input sometimes.
# Remove the prompt part *without* the final <|assistant|> tag
prompt_part_to_remove = full_prompt.rsplit(assistant_marker, 1)[0]
if response.startswith(prompt_part_to_remove):
cleaned_response = response[len(prompt_part_to_remove):].strip()
else:
# If neither works well, return the raw response (might contain parts of the prompt)
cleaned_response = response.strip()
print("Warning: Could not reliably clean the prompt context from the response.")
total_time = time.time() - start_time
print(f"Generated response: {cleaned_response}")
print(f"Total processing time for this prompt: {total_time:.2f} seconds.")
return cleaned_response
except Exception as e:
print(f"Error during model inference or decoding: {e}")
import traceback
traceback.print_exc() # Print detailed traceback for debugging
return f"Error during generation: {e}"
# --- Gradio Interface Setup ---
print("Setting up Gradio interface...")
# Define CSS for better layout (optional)
css = """
#output_textbox textarea {
min-height: 300px; /* Make output box taller */
}
#input_textbox textarea {
min-height: 100px; /* Adjust input box height */
}
"""
demo = gr.Blocks(css=css, theme=gr.themes.Default()) # Use Blocks for more layout control
with demo:
gr.Markdown(
"""
# Phi-4-Mini ONNX Chatbot (Local CPU)
Interact with the `microsoft/Phi-4-mini-instruct-onnx` model variant
(`cpu-int4-rtn-block-32-acc-level-4`) running locally using ONNX Runtime on your CPU.
"""
)
with gr.Row():
with gr.Column(scale=2): # Input column
input_textbox = gr.Textbox(
label="Your Prompt",
placeholder="Type your question or instruction here...",
lines=4, # Initial lines, resizable
elem_id="input_textbox" # Assign ID for CSS
)
submit_button = gr.Button("Generate Response", variant="primary")
with gr.Column(scale=3): # Output column
output_textbox = gr.Textbox(
label="AI Response",
lines=10, # Initial lines, resizable
interactive=False, # User cannot type in the output box
elem_id="output_textbox" # Assign ID for CSS
)
# Display model loading status/errors
if model_load_error:
gr.Markdown(f"**<font color='red'>Model Loading Error:</font>**\n```\n{model_load_error}\n```")
elif session is None or tokenizer is None:
gr.Markdown("**<font color='orange'>Warning:</font>** Model or tokenizer did not load correctly. Check console logs.")
else:
gr.Markdown("**<font color='green'>Model and Tokenizer Loaded Successfully.</font>**")
# Connect button click to the function
submit_button.click(
fn=generate_response,
inputs=input_textbox,
outputs=output_textbox
)
# Allow submitting by pressing Enter in the input textbox
input_textbox.submit(
fn=generate_response,
inputs=input_textbox,
outputs=output_textbox
)
# --- Launch the Application ---
print("-" * 50)
print("Launching Gradio app...")
print("You can access it in your browser at the URL provided below (usually http://127.0.0.1:7860).")
print("Press CTRL+C in this terminal to stop the application.")
print("-" * 50)
# share=True creates a temporary public link (use with caution).
# Set debug=True for more detailed Gradio errors if needed.
demo.launch(share=False, debug=False)
print("Gradio app closed.") |