File size: 4,438 Bytes
c84e5ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import uuid

def load_model():
    model_id = "microsoft/bitnet-b1.58-2B-4T"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16
    )
    return model, tokenizer

def generate_response(user_input, system_prompt, max_new_tokens, temperature, top_p, top_k, history):
    model, tokenizer = load_model()
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input},
    ]
    
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    chat_input = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate response
    chat_outputs = model.generate(
        **chat_input,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        do_sample=True
    )
    
    # Decode response
    response = tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True)
    
    # Update history
    history.append({"role": "user", "content": user_input})
    history.append({"role": "assistant", "content": response})
    return history, history

# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# BitNet b1.58 2B4T Demo")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("""
            ## About BitNet b1.58 2B4T
            BitNet b1.58 2B4T is the first open-source, native 1-bit Large Language Model with 2 billion parameters, developed by Microsoft Research. Trained on 4 trillion tokens, it matches the performance of full-precision models while offering significant efficiency gains in memory, energy, and latency. Features include:
            - Transformer-based architecture with BitLinear layers
            - Native 1.58-bit weights and 8-bit activations
            - Maximum context length of 4096 tokens
            - Optimized for efficient inference with bitnet.cpp
            """)
        
        with gr.Column():
            gr.Markdown("""
            ## About Tonic AI
            Tonic AI is a vibrant community of AI enthusiasts and developers always building cool demos and pushing the boundaries of what's possible with AI. We're passionate about creating innovative, accessible, and engaging AI experiences for everyone. Join us in exploring the future of AI!
            """)
    
    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
            system_prompt = gr.Textbox(
                label="System Prompt",
                value="You are a helpful AI assistant.",
                placeholder="Enter system prompt..."
            )
            
            with gr.Accordion("Advanced Options", open=False):
                max_new_tokens = gr.Slider(
                    minimum=10,
                    maximum=500,
                    value=50,
                    step=10,
                    label="Max New Tokens"
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature"
                )
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.05,
                    label="Top P"
                )
                top_k = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=50,
                    step=1,
                    label="Top K"
                )
            
            submit_btn = gr.Button("Send")
        
        with gr.Column():
            chatbot = gr.Chatbot(label="Conversation", type="messages")
    
    chat_history = gr.State([])
    
    submit_btn.click(
        fn=generate_response,
        inputs=[
            user_input,
            system_prompt,
            max_new_tokens,
            temperature,
            top_p,
            top_k,
            chat_history
        ],
        outputs=[chatbot, chat_history]
    )

if __name__ == "__main__":
    demo.launch(ssr_mode=False)