bitnet / app.py
kimhyunwoo's picture
Update app.py
4bf6d97 verified
raw
history blame
6.24 kB
# ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ•˜๋Š” ๋ช…๋ น์–ด์ž…๋‹ˆ๋‹ค.
# ์ด ๋ถ€๋ถ„์€ ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰ ์ดˆ๋ฐ˜์— ํ•œ ๋ฒˆ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
import os
print("Installing required transformers branch...")
os.system("pip install git+https://github.com/shumingma/transformers.git")
print("Installation complete.")
# ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์„ import ํ•ฉ๋‹ˆ๋‹ค.
import threading
import torch
import torch._dynamo
import gradio as gr
import spaces # Hugging Face Spaces ๊ด€๋ จ ์œ ํ‹ธ๋ฆฌํ‹ฐ
# torch._dynamo ์„ค์ • (์„ ํƒ ์‚ฌํ•ญ, ์„ฑ๋Šฅ ํ–ฅ์ƒ ์‹œ๋„)
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
# --- ๋ชจ๋ธ ๋กœ๋“œ ---
# ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (Hugging Face ๋ชจ๋ธ ID)
model_id = "microsoft/bitnet-b1.58-2B-4T"
# ๋ชจ๋ธ ๋กœ๋“œ ์‹œ ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋กœ๊น… ๋ ˆ๋ฒจ ์„ค์ •
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
# AutoModelForCausalLM๊ณผ AutoTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# trust_remote_code=True๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, device_map="auto"๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž๋™์œผ๋กœ ๋””๋ฐ”์ด์Šค ์„ค์ •
try:
print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # bf16 ์‚ฌ์šฉ (GPU ๊ถŒ์žฅ)
device_map="auto", # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋””๋ฐ”์ด์Šค์— ์ž๋™์œผ๋กœ ๋ชจ๋ธ ๋ฐฐ์น˜
trust_remote_code=True
)
print(f"๋ชจ๋ธ ๋””๋ฐ”์ด์Šค: {model.device}")
print("๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ.")
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
tokenizer = None
model = None
print("๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์ด ์ œ๋Œ€๋กœ ๋™์ž‘ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
# --- ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜ (Gradio ChatInterface์šฉ) ---
@spaces.GPU # ์ด ํ•จ์ˆ˜๊ฐ€ GPU ์ž์›์„ ์‚ฌ์šฉํ•˜๋„๋ก ๋ช…์‹œ (Hugging Face Spaces)
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
if model is None or tokenizer is None:
yield "๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ์„ ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
return # ์ƒ์„ฑ๊ธฐ ํ•จ์ˆ˜์ด๋ฏ€๋กœ return ๋Œ€์‹  ๋นˆ yield ๋˜๋Š” ๊ทธ๋ƒฅ return
try:
# ๋ฉ”์‹œ์ง€ ํ˜•์‹์„ ๋ชจ๋ธ์˜ chat template์— ๋งž๊ฒŒ ๊ตฌ์„ฑ
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# ํ…์ŠคํŠธ ์ŠคํŠธ๋ฆฌ๋ฐ์„ ์œ„ํ•œ streamer ์„ค์ •
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id # ํŒจ๋”ฉ ํ† ํฐ ID ์„ค์ •
)
# ๋ชจ๋ธ ์ƒ์„ฑ์„ ๋ณ„๋„์˜ ์Šค๋ ˆ๋“œ์—์„œ ์‹คํ–‰
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# ์ŠคํŠธ๋ฆฌ๋จธ์—์„œ ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๋ฅผ ์ฝ์–ด์™€ yield
response = ""
for new_text in streamer:
response += new_text
yield response # ์‹ค์‹œ๊ฐ„์œผ๋กœ ์‘๋‹ต์„ Gradio ์ธํ„ฐํŽ˜์ด์Šค๋กœ ์ „๋‹ฌ
except Exception as e:
yield f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
# ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ์Šค๋ ˆ๋“œ ์ฒ˜๋ฆฌ ๋กœ์ง ์ถ”๊ฐ€ ๊ณ ๋ ค ํ•„์š” (์„ ํƒ ์‚ฌํ•ญ)
# --- Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ---
if model is not None and tokenizer is not None:
demo = gr.ChatInterface(
fn=respond,
title="Bitnet-b1.58-2B-4T Chatbot",
description="Microsoft Bitnet-b1.58-2B-4T ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์ฑ„ํŒ… ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.",
examples=[
[
"์•ˆ๋…•ํ•˜์„ธ์š”! ์ž๊ธฐ์†Œ๊ฐœ ํ•ด์ฃผ์„ธ์š”.",
"๋‹น์‹ ์€ ์œ ๋Šฅํ•œ AI ๋น„์„œ์ž…๋‹ˆ๋‹ค.", # System message ์˜ˆ์‹œ
512, # Max new tokens ์˜ˆ์‹œ
0.7, # Temperature ์˜ˆ์‹œ
0.95, # Top-p ์˜ˆ์‹œ
],
[
"ํŒŒ์ด์ฌ์œผ๋กœ ๊ฐ„๋‹จํ•œ ์›น ์„œ๋ฒ„ ๋งŒ๋“œ๋Š” ์ฝ”๋“œ ์•Œ๋ ค์ค˜",
"๋‹น์‹ ์€ ์œ ๋Šฅํ•œ AI ๊ฐœ๋ฐœ์ž์ž…๋‹ˆ๋‹ค.", # System message ์˜ˆ์‹œ
1024, # Max new tokens ์˜ˆ์‹œ
0.8, # Temperature ์˜ˆ์‹œ
0.9, # Top-p ์˜ˆ์‹œ
],
],
additional_inputs=[
gr.Textbox(
value="๋‹น์‹ ์€ ์œ ๋Šฅํ•œ AI ๋น„์„œ์ž…๋‹ˆ๋‹ค.", # ๊ธฐ๋ณธ ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€
label="System message",
lines=1
),
gr.Slider(
minimum=1,
maximum=4096, # ๋ชจ๋ธ ์ตœ๋Œ€ ์ปจํ…์ŠคํŠธ ๊ธธ์ด ๊ณ ๋ ค (๋˜๋Š” ๋” ๊ธธ๊ฒŒ ์„ค์ •)
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=2.0, # Temperature ๋ฒ”์œ„ ์กฐ์ • (ํ•„์š”์‹œ)
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.0, # Top-p ๋ฒ”์œ„ ์กฐ์ • (ํ•„์š”์‹œ)
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
)
# Gradio ์•ฑ ์‹คํ–‰
# Hugging Face Spaces์—์„œ๋Š” share=True๊ฐ€ ์ž๋™์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.
# debug=True๋กœ ์„ค์ •ํ•˜๋ฉด ์ƒ์„ธ ๋กœ๊ทธ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
demo.launch(debug=True)
else:
print("๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ๋กœ ์ธํ•ด Gradio ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")