import torch import gradio as gr from unsloth import FastLanguageModel from transformers import TextStreamer, GenerationConfig import warnings # Suppress specific warnings if needed (optional) warnings.filterwarnings("ignore", category=UserWarning, message=".*padding_mask.*") # --- Configuration --- MODEL_NAME = "unsloth/gemma-3-1b-it" MAX_SEQ_LENGTH = 4096 # Choose based on model's capabilities and your VRAM DTYPE = None # None for auto detection, or torch.float16, torch.bfloat16 LOAD_IN_4BIT = True # Use 4-bit quantization for lower memory usage # --- Load Model and Tokenizer --- print(f"Loading model: {MODEL_NAME}") model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH, dtype=DTYPE, load_in_4bit=LOAD_IN_4BIT, # token = "hf_...", # Add your Hugging Face token if needed (for gated models) ) print("Model and tokenizer loaded successfully.") # Optimize for inference FastLanguageModel.for_inference(model) print("Model optimized for inference.") # --- Generation Function --- def generate_response( prompt, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True, system_prompt=None # Optional system prompt ): """ Generates a response from the model given a prompt and parameters. """ messages = [] if system_prompt and system_prompt.strip(): messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) # Apply chat template try: inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, # Ensures the 'model' token is added return_tensors="pt", ).to(model.device) # Ensure inputs are on the same device as the model except Exception as e: print(f"Error applying chat template: {e}") # Fallback or simple concatenation if template fails (less ideal) formatted_prompt = f"user\n{prompt}\nmodel\n" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) # --- Use a TextStreamer for Gradio --- # While streaming works well in terminals, Gradio updates per yield. # For a simpler Gradio experience, we'll generate the full response at once. # If you want streaming in Gradio, it requires more complex handling # with gr.Textbox(interactive=False) and yielding chunks. # --- Generate Full Response --- generation_config = GenerationConfig( max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id, # Use EOS token for padding eos_token_id=tokenizer.eos_token_id, ) print("\nGenerating response...") print(f" Prompt length: {inputs.shape[1]} tokens") print(f" Max new tokens: {max_new_tokens}") print(f" Temperature: {temperature}") print(f" Top-P: {top_p}") print(f" Do Sample: {do_sample}") with torch.inference_mode(): # Ensure no gradients are computed outputs = model.generate( input_ids=inputs, attention_mask=torch.ones_like(inputs), # Provide attention mask generation_config=generation_config, ) # Decode the generated tokens, skipping the prompt part # outputs[0] contains the full sequence (prompt + response) input_length = inputs.shape[1] response_tokens = outputs[0][input_length:] response_text = tokenizer.decode(response_tokens, skip_special_tokens=True) print(f" Response length: {len(response_tokens)} tokens") print("Generation complete.") return response_text.strip() # --- Gradio Interface --- print("Creating Gradio interface...") with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f""" # चैट Unsloth Gemma 3.1B-IT Interface Interact with the `{MODEL_NAME}` model optimized with Unsloth. Enter your prompt below and adjust the generation parameters. *Note: Running on {'GPU' if torch.cuda.is_available() else 'CPU'}. 4-bit quantization is {'enabled' if LOAD_IN_4BIT else 'disabled'}*. """ ) with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="Your Prompt", placeholder="Ask me anything...", lines=4, show_copy_button=True, ) system_prompt_input = gr.Textbox( label="System Prompt (Optional)", placeholder="Example: You are a helpful assistant.", lines=2 ) submit_button = gr.Button("Generate Response", variant="primary") with gr.Column(scale=1): gr.Markdown("### Generation Parameters") max_new_tokens_slider = gr.Slider( minimum=32, maximum=2048, # Adjust max based on VRAM and needs value=512, step=32, label="Max New Tokens", info="Maximum number of tokens to generate." ) temperature_slider = gr.Slider( minimum=0.1, maximum=1.5, value=0.6, # Adjusted default temperature slightly lower step=0.05, label="Temperature", info="Controls randomness. Lower values are more deterministic." ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus Sampling)", info="Considers only the most probable tokens with cumulative probability P." ) do_sample_checkbox = gr.Checkbox( value=True, label="Use Sampling", info="If unchecked, uses greedy decoding (picks the most likely token)." ) output_textbox = gr.Markdown(label="Model Response", value="*Response will appear here...*") # --- Connect Components --- submit_button.click( fn=generate_response, inputs=[ prompt_input, max_new_tokens_slider, temperature_slider, top_p_slider, do_sample_checkbox, system_prompt_input, ], outputs=output_textbox, api_name="generate" # Allows API access if needed ) # --- Examples --- gr.Examples( examples=[ ["Explain the concept of Large Language Models (LLMs) in simple terms.", 512, 0.7, 0.9, True, ""], ["Write a short story about a robot exploring a futuristic city.", 768, 0.8, 0.95, True, ""], ["Provide 5 ideas for a healthy breakfast.", 256, 0.6, 0.9, True, ""], ["Translate 'Hello, how are you?' to French.", 64, 0.5, 0.9, False, ""], ["What is the capital of Australia?", 64, 0.3, 0.9, False, "You are a factual answering bot."] ], inputs=[ prompt_input, max_new_tokens_slider, temperature_slider, top_p_slider, do_sample_checkbox, system_prompt_input, ], outputs=output_textbox, # Output examples to the main output area fn=generate_response, # Make examples clickable cache_examples=False, # Recalculate examples on click if needed, or True to cache ) # --- Launch the Interface --- if __name__ == "__main__": print("Launching Gradio interface...") # share=True creates a public link (use with caution) demo.launch(share=False, server_name="0.0.0.0", server_port=7860) # Use server_name="0.0.0.0" to make it accessible on your local network # Use server_port=7860 (or another) to specify the port print("Gradio interface launched. Access it at http://:7860 (or http://127.0.0.1:7860 locally)")