learn / app.py
yeonsoo
dif
bbf9c8b
raw
history blame
2.64 kB
import gradio as gr
import threading
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ
dataset = load_dataset("imdb")
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# ๋ฐ์ดํ„ฐ์…‹์„ ๋ชจ๋ธ์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌ
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_train_datasets = dataset["train"].map(tokenize_function, batched=True)
tokenized_test_datasets = dataset["test"].map(tokenize_function, batched=True)
# ํ›ˆ๋ จ ์„ค์ • (๋น ๋ฅด๊ฒŒ ํ›ˆ๋ จํ•˜๊ธฐ ์œ„ํ•ด ์—ํญ ์ˆ˜๋ฅผ ์ค„์ž„)
training_args = TrainingArguments(
output_dir="./results", # ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ
num_train_epochs=1, # ํ›ˆ๋ จ ์—ํญ ์ˆ˜ 1๋กœ ์„ค์ • (๋น ๋ฅด๊ฒŒ ํ…Œ์ŠคํŠธ)
per_device_train_batch_size=16, # ๋ฐฐ์น˜ ํฌ๊ธฐ ์ฆ๊ฐ€
per_device_eval_batch_size=16, # ๋ฐฐ์น˜ ํฌ๊ธฐ ์ฆ๊ฐ€
evaluation_strategy="epoch", # ์—ํญ๋งˆ๋‹ค ๊ฒ€์ฆ
logging_dir="./logs", # ๋กœ๊ทธ ์ €์žฅ ๊ฒฝ๋กœ
logging_steps=100, # 100 ์Šคํ…๋งˆ๋‹ค ๋กœ๊ทธ ์ถœ๋ ฅ
report_to="tensorboard", # ํ…์„œ๋ณด๋“œ๋กœ ๋กœ๊ทธ ๋ณด๊ณ 
load_best_model_at_end=True, # ์ตœ์ƒ์˜ ๋ชจ๋ธ๋กœ ์ข…๋ฃŒ
)
# ํ›ˆ๋ จ ํ•จ์ˆ˜
def train_model():
trainer = Trainer(
model=model, # ํ›ˆ๋ จํ•  ๋ชจ๋ธ
args=training_args, # ํ›ˆ๋ จ ์ธ์ž
train_dataset=tokenized_train_datasets, # ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹
eval_dataset=tokenized_test_datasets, # ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹
)
trainer.train()
# ํ›ˆ๋ จ์„ ๋ณ„๋„์˜ ์Šค๋ ˆ๋“œ์—์„œ ์‹คํ–‰
def start_training():
train_thread = threading.Thread(target=train_model)
train_thread.start()
# ๊ทธ๋ผ๋””์–ธํŠธ ๊ธฐ๋ฐ˜ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ UI์— ์—ฐ๊ฒฐ
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
return predicted_class
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
demo = gr.Interface(fn=classify_text, inputs="text", outputs="text")
# ํ›ˆ๋ จ ์‹œ์ž‘๊ณผ Gradio UI ์‹คํ–‰
def launch_app():
# ํ›ˆ๋ จ์„ ์‹œ์ž‘
start_training()
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
demo.launch()
# ํ—ˆ๊น…ํŽ˜์ด์Šค Spaces์— ์—…๋กœ๋“œ ํ•  ๋•Œ๋Š” ์ด ๋ถ€๋ถ„์„ ์‹คํ–‰ํ•˜๋„๋ก ์„ค์ •
if __name__ == "__main__":
launch_app()