File size: 6,177 Bytes
3b77cfa
 
 
ddd5b6c
 
 
 
3b77cfa
 
ddd5b6c
 
3b77cfa
1be6da6
ddd5b6c
 
 
 
 
3b77cfa
 
 
 
 
 
1be6da6
3b77cfa
 
 
 
 
 
 
 
 
 
1be6da6
 
 
 
3b77cfa
 
 
 
1be6da6
3b77cfa
 
1be6da6
3b77cfa
 
 
ddd5b6c
3b77cfa
 
 
 
1be6da6
ddd5b6c
1be6da6
 
 
 
 
 
 
 
3b77cfa
ddd5b6c
1be6da6
 
3b77cfa
 
1be6da6
 
3b77cfa
 
 
 
1be6da6
 
3b77cfa
 
ddd5b6c
3b77cfa
 
ddd5b6c
 
 
 
 
1be6da6
3b77cfa
ddd5b6c
3b77cfa
1be6da6
3b77cfa
 
1be6da6
 
 
3b77cfa
 
 
ddd5b6c
1be6da6
3b77cfa
 
 
ddd5b6c
1be6da6
 
 
 
 
 
 
ddd5b6c
1be6da6
3b77cfa
 
ddd5b6c
3b77cfa
 
 
 
1be6da6
 
 
3b77cfa
 
 
1be6da6
 
 
 
 
3b77cfa
ddd5b6c
3b77cfa
 
ddd5b6c
1be6da6
3b77cfa
1be6da6
 
3b77cfa
 
 
 
 
 
1be6da6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
import os
# optimum.onnxruntime μ—μ„œ __version__ import 제거
from transformers import AutoTokenizer, __version__ as transformers_version
from optimum.onnxruntime import ORTModelForCausalLM
# import optimum # optimum 자체의 버전 확인 μ‹œλ„ (선택적)

# --- Configuration ---
MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA"
ONNX_FILE_NAME = None

print(f"Using Transformers version: {transformers_version}")
# try:
#     print(f"Using Optimum version: {optimum.__version__}") # λ‹€λ₯Έ λ°©λ²•μœΌλ‘œ 버전 확인 μ‹œλ„
# except AttributeError:
#     print("Could not determine Optimum version automatically.")
print(f"Using Gradio version: {gr.__version__}")

# --- Device Selection ---
try:
    if torch.cuda.is_available():
        device = "cuda:0"
        provider = "CUDAExecutionProvider"
        print("Attempting to use GPU (CUDA).")
    else:
        device = "cpu"
        provider = "CPUExecutionProvider"
        print("Using CPU.")
except Exception as e:
    print(f"Device detection error: {e}. Defaulting to CPU.")
    device = "cpu"
    provider = "CPUExecutionProvider"

# --- Model and Tokenizer Loading ---
model = None
tokenizer = None
model_loaded_successfully = False

print(f"Attempting to load model: {MODEL_ID}")
print(f"Using device: {device}, Execution Provider: {provider}")

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    print("Tokenizer loaded successfully.")

    # ONNX λͺ¨λΈ λ‘œλ“œ μ‹œλ„
    model = ORTModelForCausalLM.from_pretrained(
        MODEL_ID,
        provider=provider,
        use_cache=True,
    )
    print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
    model_loaded_successfully = True

except ValueError as ve:
    # λͺ¨λΈ νƒ€μž… 미지원 였λ₯˜ 처리
    print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!")
    print(f"Model: {MODEL_ID}")
    print(f"Error message: {ve}")
    print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.")
    print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.")
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    model_loaded_successfully = False

except Exception as e:
    # λ‹€λ₯Έ μ˜ˆμ™Έ 처리
    print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!")
    print(f"Model: {MODEL_ID}")
    print(f"Error type: {type(e).__name__}")
    print(f"Error message: {e}")
    print("Check Space resources (memory limits), network connection, or other dependencies.")
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    model_loaded_successfully = False

# --- Chat Function ---
def chat_function(message: str, history: list):
    if not model_loaded_successfully or model is None or tokenizer is None:
        return "Error: The AI model is not loaded. Please check the application logs."

    try:
        # μ±„νŒ… 기둝 λ³€ν™˜
        chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
        for user_msg, model_msg in history:
            if user_msg: chat_messages.append({"role": "user", "content": user_msg})
            if model_msg: chat_messages.append({"role": "model", "content": model_msg})
        if message: chat_messages.append({"role": "user", "content": message})

        # ν”„λ‘¬ν”„νŠΈ 생성
        prompt = ""
        try:
             prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
        except Exception as template_error:
             print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.")
             prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
             for user_msg, model_msg in history:
                 if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
                 if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
             if message: prompt_parts.append(f"<start_of_turn>user\n{message}<end_of_turn>")
             prompt_parts.append("<start_of_turn>model")
             prompt = "\n".join(prompt_parts)

        # μž…λ ₯ 토큰화
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        # 응닡 생성
        print("Generating response...")
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.7,
                top_k=50,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )
        print("Generation complete.")

        # λ””μ½”λ”©
        input_token_len = inputs['input_ids'].shape[1]
        generated_tokens = outputs[0][input_token_len:]
        response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        response = response.replace("<end_of_turn>", "").strip()
        if not response:
            print("Warning: Generated empty response.")
            response = "Sorry, I couldn't generate a response for that."
        return response

    except Exception as e:
        print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {e}")
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        return f"Sorry, an error occurred during response generation. Please check logs."

# --- Gradio Interface ---
print("Creating Gradio Interface...")
iface = gr.ChatInterface(
    fn=chat_function,
    title="AI Assistant (Gemma 3 1B ONNX-GQA)",
    description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
    chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False),
    theme=gr.themes.Soft(),
    examples=[["Hello!"], ["Write a poem about the internet."]]
)

# --- Launch App ---
if __name__ == "__main__":
    print("Launching Gradio App...")
    iface.launch()