|
import torch |
|
import gradio as gr |
|
from unsloth import FastLanguageModel |
|
from transformers import TextStreamer, GenerationConfig |
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, message=".*padding_mask.*") |
|
|
|
|
|
MODEL_NAME = "unsloth/gemma-3-1b-it" |
|
MAX_SEQ_LENGTH = 4096 |
|
DTYPE = None |
|
LOAD_IN_4BIT = True |
|
|
|
|
|
print(f"Loading model: {MODEL_NAME}") |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=MODEL_NAME, |
|
max_seq_length=MAX_SEQ_LENGTH, |
|
dtype=DTYPE, |
|
load_in_4bit=LOAD_IN_4BIT, |
|
|
|
) |
|
print("Model and tokenizer loaded successfully.") |
|
|
|
|
|
FastLanguageModel.for_inference(model) |
|
print("Model optimized for inference.") |
|
|
|
|
|
def generate_response( |
|
prompt, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
system_prompt=None |
|
): |
|
""" |
|
Generates a response from the model given a prompt and parameters. |
|
""" |
|
messages = [] |
|
if system_prompt and system_prompt.strip(): |
|
messages.append({"role": "system", "content": system_prompt}) |
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
try: |
|
inputs = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt", |
|
).to(model.device) |
|
except Exception as e: |
|
print(f"Error applying chat template: {e}") |
|
|
|
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" |
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=do_sample, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
print("\nGenerating response...") |
|
print(f" Prompt length: {inputs.shape[1]} tokens") |
|
print(f" Max new tokens: {max_new_tokens}") |
|
print(f" Temperature: {temperature}") |
|
print(f" Top-P: {top_p}") |
|
print(f" Do Sample: {do_sample}") |
|
|
|
with torch.inference_mode(): |
|
outputs = model.generate( |
|
input_ids=inputs, |
|
attention_mask=torch.ones_like(inputs), |
|
generation_config=generation_config, |
|
) |
|
|
|
|
|
|
|
input_length = inputs.shape[1] |
|
response_tokens = outputs[0][input_length:] |
|
response_text = tokenizer.decode(response_tokens, skip_special_tokens=True) |
|
print(f" Response length: {len(response_tokens)} tokens") |
|
print("Generation complete.") |
|
|
|
return response_text.strip() |
|
|
|
|
|
|
|
print("Creating Gradio interface...") |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
f""" |
|
# चैट Unsloth Gemma 3.1B-IT Interface |
|
Interact with the `{MODEL_NAME}` model optimized with Unsloth. |
|
Enter your prompt below and adjust the generation parameters. |
|
*Note: Running on {'GPU' if torch.cuda.is_available() else 'CPU'}. 4-bit quantization is {'enabled' if LOAD_IN_4BIT else 'disabled'}*. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
prompt_input = gr.Textbox( |
|
label="Your Prompt", |
|
placeholder="Ask me anything...", |
|
lines=4, |
|
show_copy_button=True, |
|
) |
|
system_prompt_input = gr.Textbox( |
|
label="System Prompt (Optional)", |
|
placeholder="Example: You are a helpful assistant.", |
|
lines=2 |
|
) |
|
submit_button = gr.Button("Generate Response", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### Generation Parameters") |
|
max_new_tokens_slider = gr.Slider( |
|
minimum=32, |
|
maximum=2048, |
|
value=512, |
|
step=32, |
|
label="Max New Tokens", |
|
info="Maximum number of tokens to generate." |
|
) |
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.5, |
|
value=0.6, |
|
step=0.05, |
|
label="Temperature", |
|
info="Controls randomness. Lower values are more deterministic." |
|
) |
|
top_p_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
label="Top-P (Nucleus Sampling)", |
|
info="Considers only the most probable tokens with cumulative probability P." |
|
) |
|
do_sample_checkbox = gr.Checkbox( |
|
value=True, |
|
label="Use Sampling", |
|
info="If unchecked, uses greedy decoding (picks the most likely token)." |
|
) |
|
|
|
output_textbox = gr.Markdown(label="Model Response", value="*Response will appear here...*") |
|
|
|
|
|
submit_button.click( |
|
fn=generate_response, |
|
inputs=[ |
|
prompt_input, |
|
max_new_tokens_slider, |
|
temperature_slider, |
|
top_p_slider, |
|
do_sample_checkbox, |
|
system_prompt_input, |
|
], |
|
outputs=output_textbox, |
|
api_name="generate" |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["Explain the concept of Large Language Models (LLMs) in simple terms.", 512, 0.7, 0.9, True, ""], |
|
["Write a short story about a robot exploring a futuristic city.", 768, 0.8, 0.95, True, ""], |
|
["Provide 5 ideas for a healthy breakfast.", 256, 0.6, 0.9, True, ""], |
|
["Translate 'Hello, how are you?' to French.", 64, 0.5, 0.9, False, ""], |
|
["What is the capital of Australia?", 64, 0.3, 0.9, False, "You are a factual answering bot."] |
|
], |
|
inputs=[ |
|
prompt_input, |
|
max_new_tokens_slider, |
|
temperature_slider, |
|
top_p_slider, |
|
do_sample_checkbox, |
|
system_prompt_input, |
|
], |
|
outputs=output_textbox, |
|
fn=generate_response, |
|
cache_examples=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Launching Gradio interface...") |
|
|
|
demo.launch(share=False, server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
print("Gradio interface launched. Access it at http://<your-ip-address>:7860 (or http://127.0.0.1:7860 locally)") |