Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import os | |
# optimum.onnxruntime μμ __version__ import μ κ±° | |
from transformers import AutoTokenizer, __version__ as transformers_version | |
from optimum.onnxruntime import ORTModelForCausalLM | |
# import optimum # optimum μ체μ λ²μ νμΈ μλ (μ νμ ) | |
# --- Configuration --- | |
MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA" | |
ONNX_FILE_NAME = None | |
print(f"Using Transformers version: {transformers_version}") | |
# try: | |
# print(f"Using Optimum version: {optimum.__version__}") # λ€λ₯Έ λ°©λ²μΌλ‘ λ²μ νμΈ μλ | |
# except AttributeError: | |
# print("Could not determine Optimum version automatically.") | |
print(f"Using Gradio version: {gr.__version__}") | |
# --- Device Selection --- | |
try: | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
provider = "CUDAExecutionProvider" | |
print("Attempting to use GPU (CUDA).") | |
else: | |
device = "cpu" | |
provider = "CPUExecutionProvider" | |
print("Using CPU.") | |
except Exception as e: | |
print(f"Device detection error: {e}. Defaulting to CPU.") | |
device = "cpu" | |
provider = "CPUExecutionProvider" | |
# --- Model and Tokenizer Loading --- | |
model = None | |
tokenizer = None | |
model_loaded_successfully = False | |
print(f"Attempting to load model: {MODEL_ID}") | |
print(f"Using device: {device}, Execution Provider: {provider}") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print("Tokenizer loaded successfully.") | |
# ONNX λͺ¨λΈ λ‘λ μλ | |
model = ORTModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
provider=provider, | |
use_cache=True, | |
) | |
print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.") | |
model_loaded_successfully = True | |
except ValueError as ve: | |
# λͺ¨λΈ νμ λ―Έμ§μ μ€λ₯ μ²λ¦¬ | |
print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!") | |
print(f"Model: {MODEL_ID}") | |
print(f"Error message: {ve}") | |
print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.") | |
print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.") | |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | |
model_loaded_successfully = False | |
except Exception as e: | |
# λ€λ₯Έ μμΈ μ²λ¦¬ | |
print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!") | |
print(f"Model: {MODEL_ID}") | |
print(f"Error type: {type(e).__name__}") | |
print(f"Error message: {e}") | |
print("Check Space resources (memory limits), network connection, or other dependencies.") | |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | |
model_loaded_successfully = False | |
# --- Chat Function --- | |
def chat_function(message: str, history: list): | |
if not model_loaded_successfully or model is None or tokenizer is None: | |
return "Error: The AI model is not loaded. Please check the application logs." | |
try: | |
# μ±ν κΈ°λ‘ λ³ν | |
chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}] | |
for user_msg, model_msg in history: | |
if user_msg: chat_messages.append({"role": "user", "content": user_msg}) | |
if model_msg: chat_messages.append({"role": "model", "content": model_msg}) | |
if message: chat_messages.append({"role": "user", "content": message}) | |
# ν둬ννΈ μμ± | |
prompt = "" | |
try: | |
prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True) | |
except Exception as template_error: | |
print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.") | |
prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"] | |
for user_msg, model_msg in history: | |
if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>") | |
if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>") | |
if message: prompt_parts.append(f"<start_of_turn>user\n{message}<end_of_turn>") | |
prompt_parts.append("<start_of_turn>model") | |
prompt = "\n".join(prompt_parts) | |
# μ λ ₯ ν ν°ν | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# μλ΅ μμ± | |
print("Generating response...") | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
print("Generation complete.") | |
# λμ½λ© | |
input_token_len = inputs['input_ids'].shape[1] | |
generated_tokens = outputs[0][input_token_len:] | |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
response = response.replace("<end_of_turn>", "").strip() | |
if not response: | |
print("Warning: Generated empty response.") | |
response = "Sorry, I couldn't generate a response for that." | |
return response | |
except Exception as e: | |
print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!") | |
print(f"Error type: {type(e).__name__}") | |
print(f"Error message: {e}") | |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") | |
return f"Sorry, an error occurred during response generation. Please check logs." | |
# --- Gradio Interface --- | |
print("Creating Gradio Interface...") | |
iface = gr.ChatInterface( | |
fn=chat_function, | |
title="AI Assistant (Gemma 3 1B ONNX-GQA)", | |
description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}", | |
chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False), | |
theme=gr.themes.Soft(), | |
examples=[["Hello!"], ["Write a poem about the internet."]] | |
) | |
# --- Launch App --- | |
if __name__ == "__main__": | |
print("Launching Gradio App...") | |
iface.launch() |