import gradio as gr import torch import os # optimum.onnxruntime 에서 __version__ import 제거 from transformers import AutoTokenizer, __version__ as transformers_version from optimum.onnxruntime import ORTModelForCausalLM # import optimum # optimum 자체의 버전 확인 시도 (선택적) # --- Configuration --- MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA" ONNX_FILE_NAME = None print(f"Using Transformers version: {transformers_version}") # try: # print(f"Using Optimum version: {optimum.__version__}") # 다른 방법으로 버전 확인 시도 # except AttributeError: # print("Could not determine Optimum version automatically.") print(f"Using Gradio version: {gr.__version__}") # --- Device Selection --- try: if torch.cuda.is_available(): device = "cuda:0" provider = "CUDAExecutionProvider" print("Attempting to use GPU (CUDA).") else: device = "cpu" provider = "CPUExecutionProvider" print("Using CPU.") except Exception as e: print(f"Device detection error: {e}. Defaulting to CPU.") device = "cpu" provider = "CPUExecutionProvider" # --- Model and Tokenizer Loading --- model = None tokenizer = None model_loaded_successfully = False print(f"Attempting to load model: {MODEL_ID}") print(f"Using device: {device}, Execution Provider: {provider}") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print("Tokenizer loaded successfully.") # ONNX 모델 로드 시도 model = ORTModelForCausalLM.from_pretrained( MODEL_ID, provider=provider, use_cache=True, ) print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.") model_loaded_successfully = True except ValueError as ve: # 모델 타입 미지원 오류 처리 print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!") print(f"Model: {MODEL_ID}") print(f"Error message: {ve}") print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.") print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.") print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") model_loaded_successfully = False except Exception as e: # 다른 예외 처리 print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!") print(f"Model: {MODEL_ID}") print(f"Error type: {type(e).__name__}") print(f"Error message: {e}") print("Check Space resources (memory limits), network connection, or other dependencies.") print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") model_loaded_successfully = False # --- Chat Function --- def chat_function(message: str, history: list): if not model_loaded_successfully or model is None or tokenizer is None: return "Error: The AI model is not loaded. Please check the application logs." try: # 채팅 기록 변환 chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}] for user_msg, model_msg in history: if user_msg: chat_messages.append({"role": "user", "content": user_msg}) if model_msg: chat_messages.append({"role": "model", "content": model_msg}) if message: chat_messages.append({"role": "user", "content": message}) # 프롬프트 생성 prompt = "" try: prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True) except Exception as template_error: print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.") prompt_parts = ["system\nYou are a helpful AI assistant."] for user_msg, model_msg in history: if user_msg: prompt_parts.append(f"user\n{user_msg}") if model_msg: prompt_parts.append(f"model\n{model_msg}") if message: prompt_parts.append(f"user\n{message}") prompt_parts.append("model") prompt = "\n".join(prompt_parts) # 입력 토큰화 inputs = tokenizer(prompt, return_tensors="pt").to(device) # 응답 생성 print("Generating response...") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_k=50, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) print("Generation complete.") # 디코딩 input_token_len = inputs['input_ids'].shape[1] generated_tokens = outputs[0][input_token_len:] response = tokenizer.decode(generated_tokens, skip_special_tokens=True) response = response.replace("", "").strip() if not response: print("Warning: Generated empty response.") response = "Sorry, I couldn't generate a response for that." return response except Exception as e: print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!") print(f"Error type: {type(e).__name__}") print(f"Error message: {e}") print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") return f"Sorry, an error occurred during response generation. Please check logs." # --- Gradio Interface --- print("Creating Gradio Interface...") iface = gr.ChatInterface( fn=chat_function, title="AI Assistant (Gemma 3 1B ONNX-GQA)", description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}", chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False), theme=gr.themes.Soft(), examples=[["Hello!"], ["Write a poem about the internet."]] ) # --- Launch App --- if __name__ == "__main__": print("Launching Gradio App...") iface.launch()