Spaces:
Running
Running
File size: 3,337 Bytes
9d9cc80 b9185e7 4bf6d97 7ca8994 f498762 9d9cc80 f498762 b9185e7 f498762 b9185e7 4bf6d97 7ca8994 b9185e7 0c5d476 b9185e7 7ca8994 b9185e7 7ca8994 b9185e7 7ca8994 b9185e7 7ca8994 4bf6d97 7ca8994 b9185e7 7ca8994 9d9cc80 7ca8994 b9185e7 7ca8994 9d9cc80 7ca8994 b9185e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
os.system("pip install git+https://github.com/shumingma/transformers.git")
import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
import gradio as gr
import spaces
model_id = "microsoft/bitnet-b1.58-2B-4T"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print(model.device)
@spaces.GPU
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""
Generate a chat response using streaming with TextIteratorStreamer.
Args:
message: User's current message.
history: List of (user, assistant) tuples from previous turns.
system_message: Initial system prompt guiding the assistant.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability.
Yields:
The growing response text as new tokens are generated.
"""
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
demo = gr.ChatInterface(
fn=respond,
title="Bitnet-b1.58-2B-4T",
description="Bitnet-b1.58-2B-4T",
examples=[
[
"Hello!",
"You are a helpful AI.",
512,
0.7,
0.95,
],
[
"Can you code a snake game?",
"You are a helpful AI.",
2048,
0.7,
0.95,
],
],
additional_inputs=[
gr.Textbox(
value="You are a helpful AI assistant.",
label="System message"
),
gr.Slider(
minimum=1,
maximum=8192,
value=2048,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
)
if __name__ == "__main__":
demo.launch() |