Update app.py
Browse files
app.py
CHANGED
@@ -1,330 +1,210 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
import
|
4 |
-
import
|
5 |
-
from huggingface_hub import snapshot_download
|
6 |
-
import argparse
|
7 |
-
import logging
|
8 |
-
import numpy as np # Import numpy
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
|
13 |
# --- Configuration ---
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
# ---
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
|
53 |
-
try:
|
54 |
-
snapshot_download(
|
55 |
-
MODEL_REPO,
|
56 |
-
allow_patterns=[MODEL_VARIANT_GLOB],
|
57 |
-
local_dir=LOCAL_MODEL_DIR,
|
58 |
-
local_dir_use_symlinks=False
|
59 |
-
)
|
60 |
-
model_path = model_variant_dir
|
61 |
-
logging.info(f"Model downloaded to: {model_path}")
|
62 |
-
except Exception as e:
|
63 |
-
logging.error(f"Error downloading model: {e}", exc_info=True)
|
64 |
-
model_status = f"Error downloading model: {e}"
|
65 |
-
raise RuntimeError(f"Failed to download model: {e}")
|
66 |
-
|
67 |
-
# --- Load ---
|
68 |
-
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
|
69 |
-
logging.info(model_status)
|
70 |
try:
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
except AttributeError as ae:
|
78 |
-
logging.error(f"AttributeError during model/tokenizer init: {ae}", exc_info=True)
|
79 |
-
logging.error("This might indicate an installation issue or version incompatibility with onnxruntime_genai.")
|
80 |
-
model_status = f"Init Error: {ae}"
|
81 |
-
raise RuntimeError(f"Failed to initialize model/tokenizer: {ae}")
|
82 |
except Exception as e:
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
# Add the current user prompt and the trigger for the assistant's response
|
105 |
-
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
106 |
-
|
107 |
-
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
|
108 |
-
|
109 |
-
try:
|
110 |
-
input_tokens_list = tokenizer.encode(full_prompt) # Encode returns a list/array
|
111 |
-
# Ensure input_tokens is a numpy array of the correct type (int32 is common)
|
112 |
-
input_tokens_np = np.array(input_tokens_list, dtype=np.int32)
|
113 |
-
# Reshape to (batch_size, sequence_length), which is (1, N) for single prompt
|
114 |
-
input_tokens_np = input_tokens_np.reshape((1, -1))
|
115 |
-
logging.info(f"Prepared input_tokens shape: {input_tokens_np.shape}, dtype: {input_tokens_np.dtype}")
|
116 |
-
|
117 |
-
|
118 |
-
search_options = {
|
119 |
-
"max_length": max_length,
|
120 |
-
"temperature": temperature,
|
121 |
-
"top_p": top_p,
|
122 |
-
"top_k": top_k,
|
123 |
-
"do_sample": True,
|
124 |
-
}
|
125 |
-
|
126 |
-
params = og.GeneratorParams(model)
|
127 |
-
params.set_search_options(**search_options)
|
128 |
-
|
129 |
-
# FIX: Reverting to direct assignment based on official examples,
|
130 |
-
# ensuring the numpy array is correctly shaped *before* assignment.
|
131 |
-
logging.info("Attempting direct assignment: params.input_ids = input_tokens_np")
|
132 |
-
params.input_ids = input_tokens_np # Use the reshaped numpy array
|
133 |
-
|
134 |
-
start_time = time.time()
|
135 |
-
# Create generator AFTER setting parameters including input_ids
|
136 |
-
generator = og.Generator(model, params)
|
137 |
-
model_status = "Generating..." # Update status indicator
|
138 |
-
logging.info("Streaming response...")
|
139 |
-
|
140 |
-
first_token_time = None
|
141 |
-
token_count = 0
|
142 |
-
# Rely primarily on generator.is_done()
|
143 |
-
while not generator.is_done():
|
144 |
-
try:
|
145 |
-
generator.compute_logits()
|
146 |
-
generator.generate_next_token()
|
147 |
-
if first_token_time is None:
|
148 |
-
first_token_time = time.time() # Record time to first token
|
149 |
-
|
150 |
-
next_token = generator.get_next_tokens()[0]
|
151 |
-
|
152 |
-
decoded_chunk = tokenizer.decode([next_token])
|
153 |
-
token_count += 1
|
154 |
-
|
155 |
-
# Secondary check: Stop if the model explicitly generates the <|end|> string literal.
|
156 |
-
if decoded_chunk == "<|end|>":
|
157 |
-
logging.info("Assistant explicitly generated <|end|> token string.")
|
158 |
-
break
|
159 |
-
|
160 |
-
yield decoded_chunk # Yield just the text chunk
|
161 |
-
except Exception as loop_error:
|
162 |
-
logging.error(f"Error inside generation loop: {loop_error}", exc_info=True)
|
163 |
-
yield f"\n\nError during token generation: {loop_error}"
|
164 |
-
break # Exit loop on error
|
165 |
-
|
166 |
-
end_time = time.time()
|
167 |
-
ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
|
168 |
-
total_time = end_time - start_time
|
169 |
-
tps = (token_count / total_time) if total_time > 0 else 0
|
170 |
-
|
171 |
-
logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
|
172 |
-
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
|
173 |
-
|
174 |
-
except AttributeError as ae:
|
175 |
-
# Catching this specifically again after trying direct assignment
|
176 |
-
logging.error(f"AttributeError during generation setup (using params.input_ids): {ae}", exc_info=True)
|
177 |
-
logging.error("This suggests the 'input_ids' attribute is not available in this version, despite examples.")
|
178 |
-
model_status = f"Generation Setup AttributeError: {ae}"
|
179 |
-
yield f"\n\nSorry, an AttributeError occurred setting up generation: {ae}"
|
180 |
-
except TypeError as te:
|
181 |
-
# Catch type errors specifically during setup if the input format is still wrong
|
182 |
-
logging.error(f"TypeError during generation setup: {te}", exc_info=True)
|
183 |
-
logging.error("Check input data types and shapes if this occurs.")
|
184 |
-
model_status = f"Generation Setup TypeError: {te}"
|
185 |
-
yield f"\n\nSorry, a TypeError occurred setting up generation: {te}"
|
186 |
-
except Exception as e:
|
187 |
-
logging.error(f"Error during generation: {e}", exc_info=True)
|
188 |
-
model_status = f"Error during generation: {e}"
|
189 |
-
yield f"\n\nSorry, an error occurred during generation: {e}" # Yield error message
|
190 |
-
|
191 |
-
|
192 |
-
# --- Gradio Interface Functions ---
|
193 |
-
|
194 |
-
# 1. Function to add user message to chat history
|
195 |
-
def add_user_message(user_message, history):
|
196 |
-
"""Adds the user's message to the chat history for display."""
|
197 |
-
if not user_message:
|
198 |
-
return "", history # Clear input, return unchanged history
|
199 |
-
history = history + [[user_message, None]] # Append user message, leave bot response None
|
200 |
-
return "", history # Clear input textbox, return updated history
|
201 |
-
|
202 |
-
# 2. Function to handle bot response generation and streaming
|
203 |
-
def generate_bot_response(history, max_length, temperature, top_p, top_k):
|
204 |
-
"""Generates the bot's response based on the history and streams it."""
|
205 |
-
if not history or history[-1][1] is not None:
|
206 |
-
return history
|
207 |
-
|
208 |
-
user_prompt = history[-1][0] # Get the latest user prompt
|
209 |
-
model_history = history[:-1] # Prepare history for the model
|
210 |
-
|
211 |
-
response_stream = generate_response_stream(
|
212 |
-
user_prompt, model_history, max_length, temperature, top_p, top_k
|
213 |
)
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
219 |
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
return None, [], model_status # Clear Textbox, Chatbot history, and update status display
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
-
|
230 |
-
try:
|
231 |
-
initialize_model()
|
232 |
-
except Exception as e:
|
233 |
-
print(f"FATAL: Model initialization failed: {e}")
|
234 |
|
235 |
|
236 |
# --- Gradio Interface ---
|
237 |
-
|
238 |
-
|
239 |
-
theme
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
with gr.Row(equal_height=False):
|
249 |
-
with gr.Column(scale=3):
|
250 |
-
gr.Markdown(f"""
|
251 |
-
# Phi-4 Mini Instruct ONNX Chat 🤖
|
252 |
-
Interact with the quantized `{model_variant_name}` version of [`{MODEL_REPO}`]({HF_MODEL_URL})
|
253 |
-
running efficiently via [`onnxruntime-genai`]({ORT_GENAI_URL}) ({EXECUTION_PROVIDER.upper()}).
|
254 |
-
""")
|
255 |
-
with gr.Column(scale=1, min_width=150):
|
256 |
-
gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
|
257 |
-
model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
|
258 |
|
259 |
-
# Main Layout
|
260 |
with gr.Row():
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
)
|
270 |
-
with gr.Row():
|
271 |
-
prompt_input = gr.Textbox(
|
272 |
-
label="Your Message",
|
273 |
-
placeholder="<|user|>\nType your message here...\n<|end|>",
|
274 |
-
lines=4,
|
275 |
-
scale=9
|
276 |
-
)
|
277 |
-
with gr.Column(scale=1, min_width=120):
|
278 |
-
submit_button = gr.Button("Send", variant="primary", size="lg")
|
279 |
-
clear_button = gr.Button("🗑️ Clear Chat", variant="secondary")
|
280 |
-
|
281 |
-
# Settings Column
|
282 |
-
with gr.Column(scale=1, min_width=250):
|
283 |
-
gr.Markdown("### ⚙️ Generation Settings")
|
284 |
-
with gr.Group():
|
285 |
-
max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
|
286 |
-
temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
|
287 |
-
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
|
288 |
-
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
|
289 |
-
gr.Markdown("---")
|
290 |
-
gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
|
291 |
-
gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.")
|
292 |
-
|
293 |
-
# Event Listeners
|
294 |
-
bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k]
|
295 |
|
296 |
-
|
297 |
-
fn=add_user_message,
|
298 |
-
inputs=[prompt_input, chatbot],
|
299 |
-
outputs=[prompt_input, chatbot],
|
300 |
-
queue=False,
|
301 |
-
).then(
|
302 |
-
fn=generate_bot_response,
|
303 |
-
inputs=bot_response_inputs,
|
304 |
-
outputs=[chatbot],
|
305 |
-
api_name="chat"
|
306 |
-
)
|
307 |
|
|
|
308 |
submit_button.click(
|
309 |
-
fn=
|
310 |
-
inputs=[
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
|
|
318 |
)
|
319 |
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
)
|
326 |
|
327 |
-
# Launch the
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
import gradio as gr
|
3 |
+
from unsloth import FastLanguageModel
|
4 |
+
from transformers import TextStreamer, GenerationConfig
|
5 |
+
import warnings
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
# Suppress specific warnings if needed (optional)
|
8 |
+
warnings.filterwarnings("ignore", category=UserWarning, message=".*padding_mask.*")
|
9 |
|
10 |
# --- Configuration ---
|
11 |
+
MODEL_NAME = "unsloth/gemma-3-1b-it"
|
12 |
+
MAX_SEQ_LENGTH = 4096 # Choose based on model's capabilities and your VRAM
|
13 |
+
DTYPE = None # None for auto detection, or torch.float16, torch.bfloat16
|
14 |
+
LOAD_IN_4BIT = True # Use 4-bit quantization for lower memory usage
|
15 |
+
|
16 |
+
# --- Load Model and Tokenizer ---
|
17 |
+
print(f"Loading model: {MODEL_NAME}")
|
18 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
19 |
+
model_name=MODEL_NAME,
|
20 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
21 |
+
dtype=DTYPE,
|
22 |
+
load_in_4bit=LOAD_IN_4BIT,
|
23 |
+
# token = "hf_...", # Add your Hugging Face token if needed (for gated models)
|
24 |
+
)
|
25 |
+
print("Model and tokenizer loaded successfully.")
|
26 |
+
|
27 |
+
# Optimize for inference
|
28 |
+
FastLanguageModel.for_inference(model)
|
29 |
+
print("Model optimized for inference.")
|
30 |
+
|
31 |
+
# --- Generation Function ---
|
32 |
+
def generate_response(
|
33 |
+
prompt,
|
34 |
+
max_new_tokens=512,
|
35 |
+
temperature=0.7,
|
36 |
+
top_p=0.9,
|
37 |
+
do_sample=True,
|
38 |
+
system_prompt=None # Optional system prompt
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Generates a response from the model given a prompt and parameters.
|
42 |
+
"""
|
43 |
+
messages = []
|
44 |
+
if system_prompt and system_prompt.strip():
|
45 |
+
messages.append({"role": "system", "content": system_prompt})
|
46 |
+
messages.append({"role": "user", "content": prompt})
|
47 |
+
|
48 |
+
# Apply chat template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
try:
|
50 |
+
inputs = tokenizer.apply_chat_template(
|
51 |
+
messages,
|
52 |
+
tokenize=True,
|
53 |
+
add_generation_prompt=True, # Ensures the '<start_of_turn>model' token is added
|
54 |
+
return_tensors="pt",
|
55 |
+
).to(model.device) # Ensure inputs are on the same device as the model
|
|
|
|
|
|
|
|
|
|
|
56 |
except Exception as e:
|
57 |
+
print(f"Error applying chat template: {e}")
|
58 |
+
# Fallback or simple concatenation if template fails (less ideal)
|
59 |
+
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
60 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
61 |
+
|
62 |
+
|
63 |
+
# --- Use a TextStreamer for Gradio ---
|
64 |
+
# While streaming works well in terminals, Gradio updates per yield.
|
65 |
+
# For a simpler Gradio experience, we'll generate the full response at once.
|
66 |
+
# If you want streaming in Gradio, it requires more complex handling
|
67 |
+
# with gr.Textbox(interactive=False) and yielding chunks.
|
68 |
+
|
69 |
+
# --- Generate Full Response ---
|
70 |
+
generation_config = GenerationConfig(
|
71 |
+
max_new_tokens=max_new_tokens,
|
72 |
+
temperature=temperature,
|
73 |
+
top_p=top_p,
|
74 |
+
do_sample=do_sample,
|
75 |
+
pad_token_id=tokenizer.eos_token_id, # Use EOS token for padding
|
76 |
+
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
)
|
78 |
|
79 |
+
print("\nGenerating response...")
|
80 |
+
print(f" Prompt length: {inputs.shape[1]} tokens")
|
81 |
+
print(f" Max new tokens: {max_new_tokens}")
|
82 |
+
print(f" Temperature: {temperature}")
|
83 |
+
print(f" Top-P: {top_p}")
|
84 |
+
print(f" Do Sample: {do_sample}")
|
85 |
|
86 |
+
with torch.inference_mode(): # Ensure no gradients are computed
|
87 |
+
outputs = model.generate(
|
88 |
+
input_ids=inputs,
|
89 |
+
attention_mask=torch.ones_like(inputs), # Provide attention mask
|
90 |
+
generation_config=generation_config,
|
91 |
+
)
|
|
|
92 |
|
93 |
+
# Decode the generated tokens, skipping the prompt part
|
94 |
+
# outputs[0] contains the full sequence (prompt + response)
|
95 |
+
input_length = inputs.shape[1]
|
96 |
+
response_tokens = outputs[0][input_length:]
|
97 |
+
response_text = tokenizer.decode(response_tokens, skip_special_tokens=True)
|
98 |
+
print(f" Response length: {len(response_tokens)} tokens")
|
99 |
+
print("Generation complete.")
|
100 |
|
101 |
+
return response_text.strip()
|
|
|
|
|
|
|
|
|
102 |
|
103 |
|
104 |
# --- Gradio Interface ---
|
105 |
+
print("Creating Gradio interface...")
|
106 |
+
|
107 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
108 |
+
gr.Markdown(
|
109 |
+
f"""
|
110 |
+
# चैट Unsloth Gemma 3.1B-IT Interface
|
111 |
+
Interact with the `{MODEL_NAME}` model optimized with Unsloth.
|
112 |
+
Enter your prompt below and adjust the generation parameters.
|
113 |
+
*Note: Running on {'GPU' if torch.cuda.is_available() else 'CPU'}. 4-bit quantization is {'enabled' if LOAD_IN_4BIT else 'disabled'}*.
|
114 |
+
"""
|
115 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
|
|
117 |
with gr.Row():
|
118 |
+
with gr.Column(scale=2):
|
119 |
+
prompt_input = gr.Textbox(
|
120 |
+
label="Your Prompt",
|
121 |
+
placeholder="Ask me anything...",
|
122 |
+
lines=4,
|
123 |
+
show_copy_button=True,
|
124 |
+
)
|
125 |
+
system_prompt_input = gr.Textbox(
|
126 |
+
label="System Prompt (Optional)",
|
127 |
+
placeholder="Example: You are a helpful assistant.",
|
128 |
+
lines=2
|
129 |
+
)
|
130 |
+
submit_button = gr.Button("Generate Response", variant="primary")
|
131 |
+
|
132 |
+
with gr.Column(scale=1):
|
133 |
+
gr.Markdown("### Generation Parameters")
|
134 |
+
max_new_tokens_slider = gr.Slider(
|
135 |
+
minimum=32,
|
136 |
+
maximum=2048, # Adjust max based on VRAM and needs
|
137 |
+
value=512,
|
138 |
+
step=32,
|
139 |
+
label="Max New Tokens",
|
140 |
+
info="Maximum number of tokens to generate."
|
141 |
+
)
|
142 |
+
temperature_slider = gr.Slider(
|
143 |
+
minimum=0.1,
|
144 |
+
maximum=1.5,
|
145 |
+
value=0.6, # Adjusted default temperature slightly lower
|
146 |
+
step=0.05,
|
147 |
+
label="Temperature",
|
148 |
+
info="Controls randomness. Lower values are more deterministic."
|
149 |
+
)
|
150 |
+
top_p_slider = gr.Slider(
|
151 |
+
minimum=0.1,
|
152 |
+
maximum=1.0,
|
153 |
+
value=0.9,
|
154 |
+
step=0.05,
|
155 |
+
label="Top-P (Nucleus Sampling)",
|
156 |
+
info="Considers only the most probable tokens with cumulative probability P."
|
157 |
+
)
|
158 |
+
do_sample_checkbox = gr.Checkbox(
|
159 |
+
value=True,
|
160 |
+
label="Use Sampling",
|
161 |
+
info="If unchecked, uses greedy decoding (picks the most likely token)."
|
162 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
output_textbox = gr.Markdown(label="Model Response", value="*Response will appear here...*")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
+
# --- Connect Components ---
|
167 |
submit_button.click(
|
168 |
+
fn=generate_response,
|
169 |
+
inputs=[
|
170 |
+
prompt_input,
|
171 |
+
max_new_tokens_slider,
|
172 |
+
temperature_slider,
|
173 |
+
top_p_slider,
|
174 |
+
do_sample_checkbox,
|
175 |
+
system_prompt_input,
|
176 |
+
],
|
177 |
+
outputs=output_textbox,
|
178 |
+
api_name="generate" # Allows API access if needed
|
179 |
)
|
180 |
|
181 |
+
# --- Examples ---
|
182 |
+
gr.Examples(
|
183 |
+
examples=[
|
184 |
+
["Explain the concept of Large Language Models (LLMs) in simple terms.", 512, 0.7, 0.9, True, ""],
|
185 |
+
["Write a short story about a robot exploring a futuristic city.", 768, 0.8, 0.95, True, ""],
|
186 |
+
["Provide 5 ideas for a healthy breakfast.", 256, 0.6, 0.9, True, ""],
|
187 |
+
["Translate 'Hello, how are you?' to French.", 64, 0.5, 0.9, False, ""],
|
188 |
+
["What is the capital of Australia?", 64, 0.3, 0.9, False, "You are a factual answering bot."]
|
189 |
+
],
|
190 |
+
inputs=[
|
191 |
+
prompt_input,
|
192 |
+
max_new_tokens_slider,
|
193 |
+
temperature_slider,
|
194 |
+
top_p_slider,
|
195 |
+
do_sample_checkbox,
|
196 |
+
system_prompt_input,
|
197 |
+
],
|
198 |
+
outputs=output_textbox, # Output examples to the main output area
|
199 |
+
fn=generate_response, # Make examples clickable
|
200 |
+
cache_examples=False, # Recalculate examples on click if needed, or True to cache
|
201 |
)
|
202 |
|
203 |
+
# --- Launch the Interface ---
|
204 |
+
if __name__ == "__main__":
|
205 |
+
print("Launching Gradio interface...")
|
206 |
+
# share=True creates a public link (use with caution)
|
207 |
+
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
|
208 |
+
# Use server_name="0.0.0.0" to make it accessible on your local network
|
209 |
+
# Use server_port=7860 (or another) to specify the port
|
210 |
+
print("Gradio interface launched. Access it at http://<your-ip-address>:7860 (or http://127.0.0.1:7860 locally)")
|