File size: 3,578 Bytes
2c0549b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763e81e
2c0549b
763e81e
2c0549b
763e81e
 
2c0549b
763e81e
 
 
 
2c0549b
763e81e
2c0549b
 
763e81e
2c0549b
763e81e
2c0549b
 
 
 
763e81e
 
 
2c0549b
 
763e81e
 
2c0549b
763e81e
 
 
 
2c0549b
 
763e81e
 
 
 
2c0549b
763e81e
 
 
 
 
 
2c0549b
763e81e
 
 
 
 
 
 
 
 
2c0549b
 
763e81e
 
2c0549b
 
763e81e
 
 
 
2c0549b
763e81e
2c0549b
 
 
763e81e
2c0549b
 
 
763e81e
2c0549b
 
 
 
 
 
763e81e
2c0549b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Function to load the model and tokenizer (only needs to run once)
def load_model():
    model_id = "microsoft/bitnet-b1.58-2B-4T"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto"  # This will use available GPU if present
    )
    return model, tokenizer

# Load the model and tokenizer
print("Loading model, please wait...")
model, tokenizer = load_model()
print("Model loaded successfully!")

def generate_response(message, chat_history, max_length=4096):
    """
    Generates a response from the BitNet model based on the user's message
    """
    if not message.strip():
        return "", chat_history
    
    # Create a chat prompt based on the history and new message
    full_prompt = ""
    for user_msg, bot_msg in chat_history:
        full_prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n\n"
    
    full_prompt += f"User: {message}\nAssistant:"

    # Create inputs for the model
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            do_sample=True,
            temperature=0.7,  # Slightly higher temperature for more creative responses
            top_p=0.95,
        )
    
    # Extract only the generated part (the response)
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    # Update chat history
    chat_history.append((message, response.strip()))
    
    return "", chat_history

# Define the Gradio interface
def create_chat_interface():
    with gr.Blocks(title="BitNet Chat Assistant") as demo:
        gr.Markdown("# 💬 BitNet Chat Assistant")
        gr.Markdown("A lightweight chat application powered by Microsoft's BitNet b1.58 2B4T model.")
        
        chatbot = gr.Chatbot(height=400)
        msg = gr.Textbox(
            show_label=False,
            placeholder="Type your message here...",
            container=False
        )
        
        clear = gr.Button("Clear Conversation")
        
        def clear_convo():
            return "", []
        
        msg.submit(
            fn=generate_response,
            inputs=[msg, chatbot],
            outputs=[msg, chatbot]
        )
        
        clear.click(fn=clear_convo, inputs=[], outputs=[msg, chatbot])
        
        # Add some example inputs
        examples = [
            ["Hello, how are you today?"],
            ["Can you tell me about artificial intelligence?"],
            ["What's your favorite book?"],
            ["Write a short poem about technology."],
        ]
        gr.Examples(examples=examples, inputs=[msg])
        
        gr.Markdown("""
        ## About
        This application uses Microsoft's BitNet b1.58 2B4T, a 1-bit Large Language Model, for conversational AI.
        The model runs efficiently on consumer hardware due to its 1-bit architecture, offering significant
        advantages in memory usage, energy consumption, and latency.
        
        Note: This is a demonstration of the lightweight model's capabilities.
        """)
        
    return demo

# Create and launch the Gradio interface
if __name__ == "__main__":
    demo = create_chat_interface()
    demo.launch(share=True)  # Set share=False if you don't want a public link