File size: 3,578 Bytes
2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b 763e81e 2c0549b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# Function to load the model and tokenizer (only needs to run once)
def load_model():
model_id = "microsoft/bitnet-b1.58-2B-4T"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto" # This will use available GPU if present
)
return model, tokenizer
# Load the model and tokenizer
print("Loading model, please wait...")
model, tokenizer = load_model()
print("Model loaded successfully!")
def generate_response(message, chat_history, max_length=4096):
"""
Generates a response from the BitNet model based on the user's message
"""
if not message.strip():
return "", chat_history
# Create a chat prompt based on the history and new message
full_prompt = ""
for user_msg, bot_msg in chat_history:
full_prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n\n"
full_prompt += f"User: {message}\nAssistant:"
# Create inputs for the model
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=0.7, # Slightly higher temperature for more creative responses
top_p=0.95,
)
# Extract only the generated part (the response)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# Update chat history
chat_history.append((message, response.strip()))
return "", chat_history
# Define the Gradio interface
def create_chat_interface():
with gr.Blocks(title="BitNet Chat Assistant") as demo:
gr.Markdown("# 💬 BitNet Chat Assistant")
gr.Markdown("A lightweight chat application powered by Microsoft's BitNet b1.58 2B4T model.")
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
container=False
)
clear = gr.Button("Clear Conversation")
def clear_convo():
return "", []
msg.submit(
fn=generate_response,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
clear.click(fn=clear_convo, inputs=[], outputs=[msg, chatbot])
# Add some example inputs
examples = [
["Hello, how are you today?"],
["Can you tell me about artificial intelligence?"],
["What's your favorite book?"],
["Write a short poem about technology."],
]
gr.Examples(examples=examples, inputs=[msg])
gr.Markdown("""
## About
This application uses Microsoft's BitNet b1.58 2B4T, a 1-bit Large Language Model, for conversational AI.
The model runs efficiently on consumer hardware due to its 1-bit architecture, offering significant
advantages in memory usage, energy consumption, and latency.
Note: This is a demonstration of the lightweight model's capabilities.
""")
return demo
# Create and launch the Gradio interface
if __name__ == "__main__":
demo = create_chat_interface()
demo.launch(share=True) # Set share=False if you don't want a public link |