Spaces:
Sleeping
Sleeping
File size: 6,177 Bytes
3b77cfa ddd5b6c 3b77cfa ddd5b6c 3b77cfa 1be6da6 ddd5b6c 3b77cfa 1be6da6 3b77cfa 1be6da6 3b77cfa 1be6da6 3b77cfa 1be6da6 3b77cfa ddd5b6c 3b77cfa 1be6da6 ddd5b6c 1be6da6 3b77cfa ddd5b6c 1be6da6 3b77cfa 1be6da6 3b77cfa 1be6da6 3b77cfa ddd5b6c 3b77cfa ddd5b6c 1be6da6 3b77cfa ddd5b6c 3b77cfa 1be6da6 3b77cfa 1be6da6 3b77cfa ddd5b6c 1be6da6 3b77cfa ddd5b6c 1be6da6 ddd5b6c 1be6da6 3b77cfa ddd5b6c 3b77cfa 1be6da6 3b77cfa 1be6da6 3b77cfa ddd5b6c 3b77cfa ddd5b6c 1be6da6 3b77cfa 1be6da6 3b77cfa 1be6da6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import torch
import os
# optimum.onnxruntime μμ __version__ import μ κ±°
from transformers import AutoTokenizer, __version__ as transformers_version
from optimum.onnxruntime import ORTModelForCausalLM
# import optimum # optimum μ체μ λ²μ νμΈ μλ (μ νμ )
# --- Configuration ---
MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA"
ONNX_FILE_NAME = None
print(f"Using Transformers version: {transformers_version}")
# try:
# print(f"Using Optimum version: {optimum.__version__}") # λ€λ₯Έ λ°©λ²μΌλ‘ λ²μ νμΈ μλ
# except AttributeError:
# print("Could not determine Optimum version automatically.")
print(f"Using Gradio version: {gr.__version__}")
# --- Device Selection ---
try:
if torch.cuda.is_available():
device = "cuda:0"
provider = "CUDAExecutionProvider"
print("Attempting to use GPU (CUDA).")
else:
device = "cpu"
provider = "CPUExecutionProvider"
print("Using CPU.")
except Exception as e:
print(f"Device detection error: {e}. Defaulting to CPU.")
device = "cpu"
provider = "CPUExecutionProvider"
# --- Model and Tokenizer Loading ---
model = None
tokenizer = None
model_loaded_successfully = False
print(f"Attempting to load model: {MODEL_ID}")
print(f"Using device: {device}, Execution Provider: {provider}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Tokenizer loaded successfully.")
# ONNX λͺ¨λΈ λ‘λ μλ
model = ORTModelForCausalLM.from_pretrained(
MODEL_ID,
provider=provider,
use_cache=True,
)
print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
model_loaded_successfully = True
except ValueError as ve:
# λͺ¨λΈ νμ
λ―Έμ§μ μ€λ₯ μ²λ¦¬
print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!")
print(f"Model: {MODEL_ID}")
print(f"Error message: {ve}")
print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.")
print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
model_loaded_successfully = False
except Exception as e:
# λ€λ₯Έ μμΈ μ²λ¦¬
print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!")
print(f"Model: {MODEL_ID}")
print(f"Error type: {type(e).__name__}")
print(f"Error message: {e}")
print("Check Space resources (memory limits), network connection, or other dependencies.")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
model_loaded_successfully = False
# --- Chat Function ---
def chat_function(message: str, history: list):
if not model_loaded_successfully or model is None or tokenizer is None:
return "Error: The AI model is not loaded. Please check the application logs."
try:
# μ±ν
κΈ°λ‘ λ³ν
chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
for user_msg, model_msg in history:
if user_msg: chat_messages.append({"role": "user", "content": user_msg})
if model_msg: chat_messages.append({"role": "model", "content": model_msg})
if message: chat_messages.append({"role": "user", "content": message})
# ν둬ννΈ μμ±
prompt = ""
try:
prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
except Exception as template_error:
print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.")
prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
for user_msg, model_msg in history:
if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
if message: prompt_parts.append(f"<start_of_turn>user\n{message}<end_of_turn>")
prompt_parts.append("<start_of_turn>model")
prompt = "\n".join(prompt_parts)
# μ
λ ₯ ν ν°ν
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# μλ΅ μμ±
print("Generating response...")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
print("Generation complete.")
# λμ½λ©
input_token_len = inputs['input_ids'].shape[1]
generated_tokens = outputs[0][input_token_len:]
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
response = response.replace("<end_of_turn>", "").strip()
if not response:
print("Warning: Generated empty response.")
response = "Sorry, I couldn't generate a response for that."
return response
except Exception as e:
print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!")
print(f"Error type: {type(e).__name__}")
print(f"Error message: {e}")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
return f"Sorry, an error occurred during response generation. Please check logs."
# --- Gradio Interface ---
print("Creating Gradio Interface...")
iface = gr.ChatInterface(
fn=chat_function,
title="AI Assistant (Gemma 3 1B ONNX-GQA)",
description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False),
theme=gr.themes.Soft(),
examples=[["Hello!"], ["Write a poem about the internet."]]
)
# --- Launch App ---
if __name__ == "__main__":
print("Launching Gradio App...")
iface.launch() |