File size: 2,638 Bytes
58f09d1
bbf9c8b
a4fc148
 
58f09d1
a4fc148
 
58f09d1
a4fc148
 
 
 
 
 
 
 
 
bbf9c8b
 
a4fc148
bbf9c8b
a4fc148
 
bbf9c8b
 
 
a4fc148
 
bbf9c8b
 
 
a4fc148
 
bbf9c8b
 
 
 
 
 
 
 
 
a4fc148
bbf9c8b
 
 
 
a4fc148
bbf9c8b
a4fc148
 
 
 
 
 
 
bbf9c8b
a4fc148
 
bbf9c8b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import threading
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ
dataset = load_dataset("imdb")

# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# ๋ฐ์ดํ„ฐ์…‹์„ ๋ชจ๋ธ์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌ
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_train_datasets = dataset["train"].map(tokenize_function, batched=True)
tokenized_test_datasets = dataset["test"].map(tokenize_function, batched=True)

# ํ›ˆ๋ จ ์„ค์ • (๋น ๋ฅด๊ฒŒ ํ›ˆ๋ จํ•˜๊ธฐ ์œ„ํ•ด ์—ํญ ์ˆ˜๋ฅผ ์ค„์ž„)
training_args = TrainingArguments(
    output_dir="./results",           # ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ
    num_train_epochs=1,               # ํ›ˆ๋ จ ์—ํญ ์ˆ˜ 1๋กœ ์„ค์ • (๋น ๋ฅด๊ฒŒ ํ…Œ์ŠคํŠธ)
    per_device_train_batch_size=16,   # ๋ฐฐ์น˜ ํฌ๊ธฐ ์ฆ๊ฐ€
    per_device_eval_batch_size=16,    # ๋ฐฐ์น˜ ํฌ๊ธฐ ์ฆ๊ฐ€
    evaluation_strategy="epoch",      # ์—ํญ๋งˆ๋‹ค ๊ฒ€์ฆ
    logging_dir="./logs",             # ๋กœ๊ทธ ์ €์žฅ ๊ฒฝ๋กœ
    logging_steps=100,                # 100 ์Šคํ…๋งˆ๋‹ค ๋กœ๊ทธ ์ถœ๋ ฅ
    report_to="tensorboard",          # ํ…์„œ๋ณด๋“œ๋กœ ๋กœ๊ทธ ๋ณด๊ณ 
    load_best_model_at_end=True,      # ์ตœ์ƒ์˜ ๋ชจ๋ธ๋กœ ์ข…๋ฃŒ
)

# ํ›ˆ๋ จ ํ•จ์ˆ˜
def train_model():
    trainer = Trainer(
        model=model,                       # ํ›ˆ๋ จํ•  ๋ชจ๋ธ
        args=training_args,                # ํ›ˆ๋ จ ์ธ์ž
        train_dataset=tokenized_train_datasets,  # ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹
        eval_dataset=tokenized_test_datasets,    # ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹
    )
    trainer.train()

# ํ›ˆ๋ จ์„ ๋ณ„๋„์˜ ์Šค๋ ˆ๋“œ์—์„œ ์‹คํ–‰
def start_training():
    train_thread = threading.Thread(target=train_model)
    train_thread.start()

# ๊ทธ๋ผ๋””์–ธํŠธ ๊ธฐ๋ฐ˜ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ UI์— ์—ฐ๊ฒฐ
def classify_text(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = logits.argmax(-1).item()
    return predicted_class

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
demo = gr.Interface(fn=classify_text, inputs="text", outputs="text")

# ํ›ˆ๋ จ ์‹œ์ž‘๊ณผ Gradio UI ์‹คํ–‰
def launch_app():
    # ํ›ˆ๋ จ์„ ์‹œ์ž‘
    start_training()

    # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
    demo.launch()

# ํ—ˆ๊น…ํŽ˜์ด์Šค Spaces์— ์—…๋กœ๋“œ ํ•  ๋•Œ๋Š” ์ด ๋ถ€๋ถ„์„ ์‹คํ–‰ํ•˜๋„๋ก ์„ค์ •
if __name__ == "__main__":
    launch_app()