# app.py import gradio as gr from transformers import AutoTokenizer import onnxruntime as ort import numpy as np import os # Import os module to check if model directory exists import time # To measure performance (optional) print("Loading libraries...") # --- Configuration --- # Define the local directory where the downloaded model files are stored. # This path MUST match where you downloaded the model files relative to this script. model_dir = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4" # --- Model Loading --- tokenizer = None session = None model_load_error = None # Check if the model directory exists before attempting to load if not os.path.isdir(model_dir): model_load_error = ( f"Error: Model directory not found at '{os.path.abspath(model_dir)}'\n" "Please ensure you have created the directory structure\n" f"'./{model_dir}' relative to this script ({os.path.basename(__file__)})\n" "and downloaded ALL the required model files into it from:\n" "https://huggingface.co./microsoft/Phi-4-mini-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4" ) print(model_load_error) else: print(f"Found model directory: {os.path.abspath(model_dir)}") print("Loading tokenizer...") try: # Load tokenizer associated with the Phi-4 model variant tokenizer = AutoTokenizer.from_pretrained(model_dir) print("Tokenizer loaded successfully.") except Exception as e: model_load_error = f"Error loading tokenizer from {model_dir}: {e}" print(model_load_error) # Only attempt to load session if tokenizer loaded successfully if tokenizer: print("Loading ONNX model session...") model_path = os.path.join(model_dir, "model.onnx") model_data_path = os.path.join(model_dir, "model.onnx.data") if not os.path.exists(model_path): model_load_error = f"Error: 'model.onnx' not found in {model_dir}" print(model_load_error) elif not os.path.exists(model_data_path): model_load_error = f"Error: 'model.onnx.data' not found in {model_dir}. This large file contains the model weights and is required." print(model_load_error) else: try: # Load the ONNX model using ONNX Runtime for CPU execution start_time = time.time() # You can configure session options for performance if needed # sess_options = ort.SessionOptions() # sess_options.intra_op_num_threads = 4 # Example: Limit threads session = ort.InferenceSession( model_path, providers=["CPUExecutionProvider"] # sess_options=sess_options # Uncomment to use options ) end_time = time.time() print(f"ONNX model session loaded successfully using CPU provider in {end_time - start_time:.2f} seconds.") except Exception as e: model_load_error = f"Error loading ONNX session from {model_path}: {e}\n" model_load_error += "Ensure 'onnxruntime' library is installed correctly and that both 'model.onnx' and 'model.onnx.data' are valid files." print(model_load_error) # --- Inference Function --- def generate_response(prompt): """ Generates a response from the loaded ONNX model based on the user prompt. """ global tokenizer, session, model_load_error # Allow access to global vars # Check if model loading failed earlier if model_load_error: return model_load_error if not tokenizer or not session: return "Error: Model or Tokenizer is not loaded correctly. Check console output." print(f"\nReceived prompt: {prompt}") start_time = time.time() # Format the prompt with specific markers for instruction following full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" print("Tokenizing input...") try: # Tokenize the formatted prompt inputs = tokenizer(full_prompt, return_tensors="np") # Prepare inputs for the ONNX model ort_inputs = { "input_ids": inputs["input_ids"].astype(np.int64), "attention_mask": inputs["attention_mask"].astype(np.int64) } print("Running model inference...") inference_start_time = time.time() # Run the ONNX model inference outputs = session.run(None, ort_inputs) generated_ids = outputs[0] # Assuming the first output contains the generated IDs inference_end_time = time.time() print(f"Inference complete in {inference_end_time - inference_start_time:.2f} seconds.") # Decode the generated token IDs back into text print("Decoding response...") decoding_start_time = time.time() # Ensure generated_ids is 1D if necessary, might be shape (1, sequence_length) output_ids = generated_ids[0] if generated_ids.ndim == 2 else generated_ids response = tokenizer.decode(output_ids, skip_special_tokens=True) decoding_end_time = time.time() print(f"Decoding complete in {decoding_end_time - decoding_start_time:.2f} seconds.") # --- Response Cleaning --- # 1. Find the start of the assistant's response assistant_marker = "<|assistant|>" assistant_pos = response.find(assistant_marker) if assistant_pos != -1: # If marker found, take text after it cleaned_response = response[assistant_pos + len(assistant_marker):].strip() else: # Fallback: If marker isn't perfectly decoded, try removing the original input prompt # This assumes the model might prepend the input sometimes. # Remove the prompt part *without* the final <|assistant|> tag prompt_part_to_remove = full_prompt.rsplit(assistant_marker, 1)[0] if response.startswith(prompt_part_to_remove): cleaned_response = response[len(prompt_part_to_remove):].strip() else: # If neither works well, return the raw response (might contain parts of the prompt) cleaned_response = response.strip() print("Warning: Could not reliably clean the prompt context from the response.") total_time = time.time() - start_time print(f"Generated response: {cleaned_response}") print(f"Total processing time for this prompt: {total_time:.2f} seconds.") return cleaned_response except Exception as e: print(f"Error during model inference or decoding: {e}") import traceback traceback.print_exc() # Print detailed traceback for debugging return f"Error during generation: {e}" # --- Gradio Interface Setup --- print("Setting up Gradio interface...") # Define CSS for better layout (optional) css = """ #output_textbox textarea { min-height: 300px; /* Make output box taller */ } #input_textbox textarea { min-height: 100px; /* Adjust input box height */ } """ demo = gr.Blocks(css=css, theme=gr.themes.Default()) # Use Blocks for more layout control with demo: gr.Markdown( """ # Phi-4-Mini ONNX Chatbot (Local CPU) Interact with the `microsoft/Phi-4-mini-instruct-onnx` model variant (`cpu-int4-rtn-block-32-acc-level-4`) running locally using ONNX Runtime on your CPU. """ ) with gr.Row(): with gr.Column(scale=2): # Input column input_textbox = gr.Textbox( label="Your Prompt", placeholder="Type your question or instruction here...", lines=4, # Initial lines, resizable elem_id="input_textbox" # Assign ID for CSS ) submit_button = gr.Button("Generate Response", variant="primary") with gr.Column(scale=3): # Output column output_textbox = gr.Textbox( label="AI Response", lines=10, # Initial lines, resizable interactive=False, # User cannot type in the output box elem_id="output_textbox" # Assign ID for CSS ) # Display model loading status/errors if model_load_error: gr.Markdown(f"**Model Loading Error:**\n```\n{model_load_error}\n```") elif session is None or tokenizer is None: gr.Markdown("**Warning:** Model or tokenizer did not load correctly. Check console logs.") else: gr.Markdown("**Model and Tokenizer Loaded Successfully.**") # Connect button click to the function submit_button.click( fn=generate_response, inputs=input_textbox, outputs=output_textbox ) # Allow submitting by pressing Enter in the input textbox input_textbox.submit( fn=generate_response, inputs=input_textbox, outputs=output_textbox ) # --- Launch the Application --- print("-" * 50) print("Launching Gradio app...") print("You can access it in your browser at the URL provided below (usually http://127.0.0.1:7860).") print("Press CTRL+C in this terminal to stop the application.") print("-" * 50) # share=True creates a temporary public link (use with caution). # Set debug=True for more detailed Gradio errors if needed. demo.launch(share=False, debug=False) print("Gradio app closed.")