import os import time import gc import threading from itertools import islice from datetime import datetime import gradio as gr import torch from transformers import pipeline, TextIteratorStreamer from transformers import AutoTokenizer from duckduckgo_search import DDGS import spaces # Import spaces early to enable ZeroGPU support # Optional: Disable GPU visibility if you wish to force CPU usage # os.environ["CUDA_VISIBLE_DEVICES"] = "" # ------------------------------ # Global Cancellation Event # ------------------------------ cancel_event = threading.Event() # ------------------------------ # Torch-Compatible Model Definitions with Adjusted Descriptions # ------------------------------ MODELS = { "Taiwan-ELM-1_1B-Instruct": {"repo_id": "liswei/Taiwan-ELM-1_1B-Instruct", "description": "Taiwan-ELM-1_1B-Instruct"}, "Taiwan-ELM-270M-Instruct": {"repo_id": "liswei/Taiwan-ELM-270M-Instruct", "description": "Taiwan-ELM-270M-Instruct"}, "Qwen3-8B": {"repo_id": "Qwen/Qwen3-8B", "description": "Qwen3-8B"}, "Qwen3-4B": {"repo_id": "Qwen/Qwen3-4B", "description": "Qwen3-4B"}, "Qwen3-1.7B": {"repo_id": "Qwen/Qwen3-1,7B", "description": "Qwen3-1.7B"}, "Qwen3-0.6B": {"repo_id": "Qwen/Qwen3-0.6B", "description": "Qwen3-0.6B"}, "Gemma-3-4B-IT": {"repo_id": "unsloth/gemma-3-4b-it", "description": "Gemma-3-4B-IT"}, "SmolLM2-135M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-135M-Instruct-TaiwanChat", "description": "SmolLM2‑135M Instruct fine-tuned on TaiwanChat"}, "SmolLM2-135M-Instruct": {"repo_id": "HuggingFaceTB/SmolLM2-135M-Instruct", "description": "Original SmolLM2‑135M Instruct"}, "SmolLM2-360M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-360M-Instruct-TaiwanChat", "description": "SmolLM2‑360M Instruct fine-tuned on TaiwanChat"}, "Llama-3.2-Taiwan-3B-Instruct": {"repo_id": "lianghsun/Llama-3.2-Taiwan-3B-Instruct", "description": "Llama-3.2-Taiwan-3B-Instruct"}, "MiniCPM3-4B": {"repo_id": "openbmb/MiniCPM3-4B", "description": "MiniCPM3-4B"}, "Qwen2.5-3B-Instruct": {"repo_id": "Qwen/Qwen2.5-3B-Instruct", "description": "Qwen2.5-3B-Instruct"}, "Qwen2.5-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-7B-Instruct", "description": "Qwen2.5-7B-Instruct"}, "Phi-4-mini-Instruct": {"repo_id": "unsloth/Phi-4-mini-instruct", "description": "Phi-4-mini-Instruct"}, "Meta-Llama-3.1-8B-Instruct": {"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct", "description": "Meta-Llama-3.1-8B-Instruct"}, "DeepSeek-R1-Distill-Llama-8B": {"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B", "description": "DeepSeek-R1-Distill-Llama-8B"}, "Mistral-7B-Instruct-v0.3": {"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3", "description": "Mistral-7B-Instruct-v0.3"}, "Qwen2.5-Coder-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct", "description": "Qwen2.5-Coder-7B-Instruct"}, } # Global cache for pipelines to avoid re-loading. PIPELINES = {} def load_pipeline(model_name): """ Load and cache a transformers pipeline for text generation. Tries bfloat16, falls back to float16 or float32 if unsupported. """ global PIPELINES if model_name in PIPELINES: return PIPELINES[model_name] repo = MODELS[model_name]["repo_id"] tokenizer = AutoTokenizer.from_pretrained(repo) for dtype in (torch.bfloat16, torch.float16, torch.float32): try: pipe = pipeline( task="text-generation", model=repo, tokenizer=tokenizer, trust_remote_code=True, torch_dtype=dtype, device_map="auto" ) PIPELINES[model_name] = pipe return pipe except Exception: continue # Final fallback pipe = pipeline( task="text-generation", model=repo, tokenizer=tokenizer, trust_remote_code=True, device_map="auto" ) PIPELINES[model_name] = pipe return pipe def retrieve_context(query, max_results=6, max_chars=600): """ Retrieve search snippets from DuckDuckGo (runs in background). Returns a list of result strings. """ try: with DDGS() as ddgs: return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}" for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))] except Exception: return [] def format_conversation(history, system_prompt): """ Flatten chat history and system prompt into a single string. """ prompt = system_prompt.strip() + "\n" for msg in history: if msg['role'] == 'user': prompt += "User: " + msg['content'].strip() + "\n" elif msg['role'] == 'assistant': prompt += "Assistant: " + msg['content'].strip() + "\n" else: prompt += msg['content'].strip() + "\n" if not prompt.strip().endswith("Assistant:"): prompt += "Assistant: " return prompt @spaces.GPU(duration=60) def chat_response(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty): """ Generates streaming chat responses, optionally with background web search. """ cancel_event.clear() history = list(chat_history or []) history.append({'role': 'user', 'content': user_msg}) # Launch web search if enabled debug = '' search_results = [] if enable_search: debug = 'Search task started.' thread_search = threading.Thread( target=lambda: search_results.extend( retrieve_context(user_msg, int(max_results), int(max_chars)) ) ) thread_search.daemon = True thread_search.start() else: debug = 'Web search disabled.' # Prepare assistant placeholder history.append({'role': 'assistant', 'content': ''}) try: # merge any fetched search results into the system prompt if search_results: enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) else: enriched = system_prompt # wait up to 1s for snippets, then replace debug with them if enable_search: thread_search.join(timeout=1.0) if search_results: debug = "### Search results merged into prompt\n\n" + "\n".join( f"- {r}" for r in search_results ) else: debug = "*No web search results found.*" # merge fetched snippets into the system prompt if search_results: enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) else: enriched = system_prompt prompt = format_conversation(history, enriched) pipe = load_pipeline(model_name) streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True) gen_thread = threading.Thread( target=pipe, args=(prompt,), kwargs={ 'max_new_tokens': max_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'repetition_penalty': repeat_penalty, 'streamer': streamer, 'return_full_text': False } ) gen_thread.start() assistant_text = '' for chunk in streamer: if cancel_event.is_set(): break assistant_text += chunk history[-1]['content'] = assistant_text # Show debug only once yield history, debug gen_thread.join() except Exception as e: history[-1]['content'] = f"Error: {e}" yield history, debug finally: gc.collect() def cancel_generation(): cancel_event.set() return 'Generation cancelled.' def update_default_prompt(enable_search): today = datetime.now().strftime('%Y-%m-%d') return f"You are a helpful assistant. Today is {today}." # ------------------------------ # Gradio UI # ------------------------------ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo: gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search") gr.Markdown("Interact with the model. Select parameters and chat below.") with gr.Row(): with gr.Column(scale=3): model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]) search_chk = gr.Checkbox(label="Enable Web Search", value=True) sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value)) gr.Markdown("### Generation Parameters") max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") k = gr.Slider(1, 100, value=40, step=1, label="Top-K") p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") gr.Markdown("### Web Search Settings") mr = gr.Number(value=6, precision=0, label="Max Results") mc = gr.Number(value=600, precision=0, label="Max Chars/Result") clr = gr.Button("Clear Chat") cnl = gr.Button("Cancel Generation") with gr.Column(scale=7): chat = gr.Chatbot(type="messages") txt = gr.Textbox(placeholder="Type your message and press Enter...") dbg = gr.Markdown() search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt) clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg]) cnl.click(fn=cancel_generation, outputs=dbg) txt.submit(fn=chat_response, inputs=[txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp], outputs=[chat, dbg]) demo.launch()