Spaces:
Running
Running
# ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํ๋ ๋ช ๋ น์ด์ ๋๋ค. | |
# ์ด ๋ถ๋ถ์ ์คํฌ๋ฆฝํธ ์คํ ์ด๋ฐ์ ํ ๋ฒ ์คํ๋ฉ๋๋ค. | |
import os | |
print("Installing required transformers branch...") | |
os.system("pip install git+https://github.com/shumingma/transformers.git") | |
print("Installation complete.") | |
# ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ import ํฉ๋๋ค. | |
import threading | |
import torch | |
import torch._dynamo | |
import gradio as gr | |
import spaces # Hugging Face Spaces ๊ด๋ จ ์ ํธ๋ฆฌํฐ | |
# torch._dynamo ์ค์ (์ ํ ์ฌํญ, ์ฑ๋ฅ ํฅ์ ์๋) | |
torch._dynamo.config.suppress_errors = True | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
) | |
# --- ๋ชจ๋ธ ๋ก๋ --- | |
# ๋ชจ๋ธ ๊ฒฝ๋ก ์ค์ (Hugging Face ๋ชจ๋ธ ID) | |
model_id = "microsoft/bitnet-b1.58-2B-4T" | |
# ๋ชจ๋ธ ๋ก๋ ์ ๊ฒฝ๊ณ ๋ฉ์์ง๋ฅผ ์ต์ํํ๊ธฐ ์ํด ๋ก๊น ๋ ๋ฒจ ์ค์ | |
os.environ["TRANSFORMERS_VERBOSITY"] = "error" | |
# AutoModelForCausalLM๊ณผ AutoTokenizer๋ฅผ ๋ก๋ํฉ๋๋ค. | |
# trust_remote_code=True๊ฐ ํ์ํ๋ฉฐ, device_map="auto"๋ฅผ ์ฌ์ฉํ์ฌ ์๋์ผ๋ก ๋๋ฐ์ด์ค ์ค์ | |
try: | |
print(f"๋ชจ๋ธ ๋ก๋ฉ ์ค: {model_id}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, # bf16 ์ฌ์ฉ (GPU ๊ถ์ฅ) | |
device_map="auto", # ์ฌ์ฉ ๊ฐ๋ฅํ ๋๋ฐ์ด์ค์ ์๋์ผ๋ก ๋ชจ๋ธ ๋ฐฐ์น | |
trust_remote_code=True | |
) | |
print(f"๋ชจ๋ธ ๋๋ฐ์ด์ค: {model.device}") | |
print("๋ชจ๋ธ ๋ก๋ ์๋ฃ.") | |
except Exception as e: | |
print(f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
tokenizer = None | |
model = None | |
print("๋ชจ๋ธ ๋ก๋์ ์คํจํ์ต๋๋ค. ์ ํ๋ฆฌ์ผ์ด์ ์ด ์ ๋๋ก ๋์ํ์ง ์์ ์ ์์ต๋๋ค.") | |
# --- ํ ์คํธ ์์ฑ ํจ์ (Gradio ChatInterface์ฉ) --- | |
# ์ด ํจ์๊ฐ GPU ์์์ ์ฌ์ฉํ๋๋ก ๋ช ์ (Hugging Face Spaces) | |
def respond( | |
message: str, | |
history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
): | |
if model is None or tokenizer is None: | |
yield "๋ชจ๋ธ ๋ก๋์ ์คํจํ์ฌ ํ ์คํธ ์์ฑ์ ํ ์ ์์ต๋๋ค." | |
return # ์์ฑ๊ธฐ ํจ์์ด๋ฏ๋ก return ๋์ ๋น yield ๋๋ ๊ทธ๋ฅ return | |
try: | |
# ๋ฉ์์ง ํ์์ ๋ชจ๋ธ์ chat template์ ๋ง๊ฒ ๊ตฌ์ฑ | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, bot_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if bot_msg: | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# ํ ์คํธ ์คํธ๋ฆฌ๋ฐ์ ์ํ streamer ์ค์ | |
streamer = TextIteratorStreamer( | |
tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id # ํจ๋ฉ ํ ํฐ ID ์ค์ | |
) | |
# ๋ชจ๋ธ ์์ฑ์ ๋ณ๋์ ์ค๋ ๋์์ ์คํ | |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
# ์คํธ๋ฆฌ๋จธ์์ ์์ฑ๋ ํ ์คํธ๋ฅผ ์ฝ์ด์ yield | |
response = "" | |
for new_text in streamer: | |
response += new_text | |
yield response # ์ค์๊ฐ์ผ๋ก ์๋ต์ Gradio ์ธํฐํ์ด์ค๋ก ์ ๋ฌ | |
except Exception as e: | |
yield f"ํ ์คํธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}" | |
# ์ค๋ฅ ๋ฐ์ ์ ์ค๋ ๋ ์ฒ๋ฆฌ ๋ก์ง ์ถ๊ฐ ๊ณ ๋ ค ํ์ (์ ํ ์ฌํญ) | |
# --- Gradio ์ธํฐํ์ด์ค ์ค์ --- | |
if model is not None and tokenizer is not None: | |
demo = gr.ChatInterface( | |
fn=respond, | |
title="Bitnet-b1.58-2B-4T Chatbot", | |
description="Microsoft Bitnet-b1.58-2B-4T ๋ชจ๋ธ์ ์ฌ์ฉํ ์ฑํ ๋ฐ๋ชจ์ ๋๋ค.", | |
examples=[ | |
[ | |
"์๋ ํ์ธ์! ์๊ธฐ์๊ฐ ํด์ฃผ์ธ์.", | |
"๋น์ ์ ์ ๋ฅํ AI ๋น์์ ๋๋ค.", # System message ์์ | |
512, # Max new tokens ์์ | |
0.7, # Temperature ์์ | |
0.95, # Top-p ์์ | |
], | |
[ | |
"ํ์ด์ฌ์ผ๋ก ๊ฐ๋จํ ์น ์๋ฒ ๋ง๋๋ ์ฝ๋ ์๋ ค์ค", | |
"๋น์ ์ ์ ๋ฅํ AI ๊ฐ๋ฐ์์ ๋๋ค.", # System message ์์ | |
1024, # Max new tokens ์์ | |
0.8, # Temperature ์์ | |
0.9, # Top-p ์์ | |
], | |
], | |
additional_inputs=[ | |
gr.Textbox( | |
value="๋น์ ์ ์ ๋ฅํ AI ๋น์์ ๋๋ค.", # ๊ธฐ๋ณธ ์์คํ ๋ฉ์์ง | |
label="System message", | |
lines=1 | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=4096, # ๋ชจ๋ธ ์ต๋ ์ปจํ ์คํธ ๊ธธ์ด ๊ณ ๋ ค (๋๋ ๋ ๊ธธ๊ฒ ์ค์ ) | |
value=512, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=2.0, # Temperature ๋ฒ์ ์กฐ์ (ํ์์) | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.0, # Top-p ๋ฒ์ ์กฐ์ (ํ์์) | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
) | |
# Gradio ์ฑ ์คํ | |
# Hugging Face Spaces์์๋ share=True๊ฐ ์๋์ผ๋ก ์ค์ ๋ฉ๋๋ค. | |
# debug=True๋ก ์ค์ ํ๋ฉด ์์ธ ๋ก๊ทธ๋ฅผ ๋ณผ ์ ์์ต๋๋ค. | |
demo.launch(debug=True) | |
else: | |
print("๋ชจ๋ธ ๋ก๋ ์คํจ๋ก ์ธํด Gradio ์ธํฐํ์ด์ค๋ฅผ ์คํํ ์ ์์ต๋๋ค.") |