Yyyy / app.py
Athspi's picture
Update app.py
21cab83 verified
import torch
import gradio as gr
from unsloth import FastLanguageModel
from transformers import TextStreamer, GenerationConfig
import warnings
# Suppress specific warnings if needed (optional)
warnings.filterwarnings("ignore", category=UserWarning, message=".*padding_mask.*")
# --- Configuration ---
MODEL_NAME = "unsloth/gemma-3-1b-it"
MAX_SEQ_LENGTH = 4096 # Choose based on model's capabilities and your VRAM
DTYPE = None # None for auto detection, or torch.float16, torch.bfloat16
LOAD_IN_4BIT = True # Use 4-bit quantization for lower memory usage
# --- Load Model and Tokenizer ---
print(f"Loading model: {MODEL_NAME}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=DTYPE,
load_in_4bit=LOAD_IN_4BIT,
# token = "hf_...", # Add your Hugging Face token if needed (for gated models)
)
print("Model and tokenizer loaded successfully.")
# Optimize for inference
FastLanguageModel.for_inference(model)
print("Model optimized for inference.")
# --- Generation Function ---
def generate_response(
prompt,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
system_prompt=None # Optional system prompt
):
"""
Generates a response from the model given a prompt and parameters.
"""
messages = []
if system_prompt and system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# Apply chat template
try:
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True, # Ensures the '<start_of_turn>model' token is added
return_tensors="pt",
).to(model.device) # Ensure inputs are on the same device as the model
except Exception as e:
print(f"Error applying chat template: {e}")
# Fallback or simple concatenation if template fails (less ideal)
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# --- Use a TextStreamer for Gradio ---
# While streaming works well in terminals, Gradio updates per yield.
# For a simpler Gradio experience, we'll generate the full response at once.
# If you want streaming in Gradio, it requires more complex handling
# with gr.Textbox(interactive=False) and yielding chunks.
# --- Generate Full Response ---
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id, # Use EOS token for padding
eos_token_id=tokenizer.eos_token_id,
)
print("\nGenerating response...")
print(f" Prompt length: {inputs.shape[1]} tokens")
print(f" Max new tokens: {max_new_tokens}")
print(f" Temperature: {temperature}")
print(f" Top-P: {top_p}")
print(f" Do Sample: {do_sample}")
with torch.inference_mode(): # Ensure no gradients are computed
outputs = model.generate(
input_ids=inputs,
attention_mask=torch.ones_like(inputs), # Provide attention mask
generation_config=generation_config,
)
# Decode the generated tokens, skipping the prompt part
# outputs[0] contains the full sequence (prompt + response)
input_length = inputs.shape[1]
response_tokens = outputs[0][input_length:]
response_text = tokenizer.decode(response_tokens, skip_special_tokens=True)
print(f" Response length: {len(response_tokens)} tokens")
print("Generation complete.")
return response_text.strip()
# --- Gradio Interface ---
print("Creating Gradio interface...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# चैट Unsloth Gemma 3.1B-IT Interface
Interact with the `{MODEL_NAME}` model optimized with Unsloth.
Enter your prompt below and adjust the generation parameters.
*Note: Running on {'GPU' if torch.cuda.is_available() else 'CPU'}. 4-bit quantization is {'enabled' if LOAD_IN_4BIT else 'disabled'}*.
"""
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Your Prompt",
placeholder="Ask me anything...",
lines=4,
show_copy_button=True,
)
system_prompt_input = gr.Textbox(
label="System Prompt (Optional)",
placeholder="Example: You are a helpful assistant.",
lines=2
)
submit_button = gr.Button("Generate Response", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### Generation Parameters")
max_new_tokens_slider = gr.Slider(
minimum=32,
maximum=2048, # Adjust max based on VRAM and needs
value=512,
step=32,
label="Max New Tokens",
info="Maximum number of tokens to generate."
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.6, # Adjusted default temperature slightly lower
step=0.05,
label="Temperature",
info="Controls randomness. Lower values are more deterministic."
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P (Nucleus Sampling)",
info="Considers only the most probable tokens with cumulative probability P."
)
do_sample_checkbox = gr.Checkbox(
value=True,
label="Use Sampling",
info="If unchecked, uses greedy decoding (picks the most likely token)."
)
output_textbox = gr.Markdown(label="Model Response", value="*Response will appear here...*")
# --- Connect Components ---
submit_button.click(
fn=generate_response,
inputs=[
prompt_input,
max_new_tokens_slider,
temperature_slider,
top_p_slider,
do_sample_checkbox,
system_prompt_input,
],
outputs=output_textbox,
api_name="generate" # Allows API access if needed
)
# --- Examples ---
gr.Examples(
examples=[
["Explain the concept of Large Language Models (LLMs) in simple terms.", 512, 0.7, 0.9, True, ""],
["Write a short story about a robot exploring a futuristic city.", 768, 0.8, 0.95, True, ""],
["Provide 5 ideas for a healthy breakfast.", 256, 0.6, 0.9, True, ""],
["Translate 'Hello, how are you?' to French.", 64, 0.5, 0.9, False, ""],
["What is the capital of Australia?", 64, 0.3, 0.9, False, "You are a factual answering bot."]
],
inputs=[
prompt_input,
max_new_tokens_slider,
temperature_slider,
top_p_slider,
do_sample_checkbox,
system_prompt_input,
],
outputs=output_textbox, # Output examples to the main output area
fn=generate_response, # Make examples clickable
cache_examples=False, # Recalculate examples on click if needed, or True to cache
)
# --- Launch the Interface ---
if __name__ == "__main__":
print("Launching Gradio interface...")
# share=True creates a public link (use with caution)
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
# Use server_name="0.0.0.0" to make it accessible on your local network
# Use server_port=7860 (or another) to specify the port
print("Gradio interface launched. Access it at http://<your-ip-address>:7860 (or http://127.0.0.1:7860 locally)")