import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer # Function to load the model and tokenizer (only needs to run once) def load_model(): model_id = "microsoft/bitnet-b1.58-2B-4T" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto" # This will use available GPU if present ) return model, tokenizer # Load the model and tokenizer print("Loading model, please wait...") model, tokenizer = load_model() print("Model loaded successfully!") def generate_response(message, chat_history, max_length=4096): """ Generates a response from the BitNet model based on the user's message """ if not message.strip(): return "", chat_history # Create a chat prompt based on the history and new message full_prompt = "" for user_msg, bot_msg in chat_history: full_prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n\n" full_prompt += f"User: {message}\nAssistant:" # Create inputs for the model inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_length, do_sample=True, temperature=0.7, # Slightly higher temperature for more creative responses top_p=0.95, ) # Extract only the generated part (the response) response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) # Update chat history chat_history.append((message, response.strip())) return "", chat_history # Define the Gradio interface def create_chat_interface(): with gr.Blocks(title="BitNet Chat Assistant") as demo: gr.Markdown("# 💬 BitNet Chat Assistant") gr.Markdown("A lightweight chat application powered by Microsoft's BitNet b1.58 2B4T model.") chatbot = gr.Chatbot(height=400) msg = gr.Textbox( show_label=False, placeholder="Type your message here...", container=False ) clear = gr.Button("Clear Conversation") def clear_convo(): return "", [] msg.submit( fn=generate_response, inputs=[msg, chatbot], outputs=[msg, chatbot] ) clear.click(fn=clear_convo, inputs=[], outputs=[msg, chatbot]) # Add some example inputs examples = [ ["Hello, how are you today?"], ["Can you tell me about artificial intelligence?"], ["What's your favorite book?"], ["Write a short poem about technology."], ] gr.Examples(examples=examples, inputs=[msg]) gr.Markdown(""" ## About This application uses Microsoft's BitNet b1.58 2B4T, a 1-bit Large Language Model, for conversational AI. The model runs efficiently on consumer hardware due to its 1-bit architecture, offering significant advantages in memory usage, energy consumption, and latency. Note: This is a demonstration of the lightweight model's capabilities. """) return demo # Create and launch the Gradio interface if __name__ == "__main__": demo = create_chat_interface() demo.launch(share=True) # Set share=False if you don't want a public link