openfree's picture
Update app.py
763e81e verified
raw
history blame
3.58 kB
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# Function to load the model and tokenizer (only needs to run once)
def load_model():
model_id = "microsoft/bitnet-b1.58-2B-4T"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto" # This will use available GPU if present
)
return model, tokenizer
# Load the model and tokenizer
print("Loading model, please wait...")
model, tokenizer = load_model()
print("Model loaded successfully!")
def generate_response(message, chat_history, max_length=4096):
"""
Generates a response from the BitNet model based on the user's message
"""
if not message.strip():
return "", chat_history
# Create a chat prompt based on the history and new message
full_prompt = ""
for user_msg, bot_msg in chat_history:
full_prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n\n"
full_prompt += f"User: {message}\nAssistant:"
# Create inputs for the model
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=0.7, # Slightly higher temperature for more creative responses
top_p=0.95,
)
# Extract only the generated part (the response)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# Update chat history
chat_history.append((message, response.strip()))
return "", chat_history
# Define the Gradio interface
def create_chat_interface():
with gr.Blocks(title="BitNet Chat Assistant") as demo:
gr.Markdown("# 💬 BitNet Chat Assistant")
gr.Markdown("A lightweight chat application powered by Microsoft's BitNet b1.58 2B4T model.")
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
container=False
)
clear = gr.Button("Clear Conversation")
def clear_convo():
return "", []
msg.submit(
fn=generate_response,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
clear.click(fn=clear_convo, inputs=[], outputs=[msg, chatbot])
# Add some example inputs
examples = [
["Hello, how are you today?"],
["Can you tell me about artificial intelligence?"],
["What's your favorite book?"],
["Write a short poem about technology."],
]
gr.Examples(examples=examples, inputs=[msg])
gr.Markdown("""
## About
This application uses Microsoft's BitNet b1.58 2B4T, a 1-bit Large Language Model, for conversational AI.
The model runs efficiently on consumer hardware due to its 1-bit architecture, offering significant
advantages in memory usage, energy consumption, and latency.
Note: This is a demonstration of the lightweight model's capabilities.
""")
return demo
# Create and launch the Gradio interface
if __name__ == "__main__":
demo = create_chat_interface()
demo.launch(share=True) # Set share=False if you don't want a public link