yeonsoo
commited on
Commit
ยท
bbf9c8b
1
Parent(s):
a4fc148
dif
Browse files- app.py +36 -16
- requirements.txt +4 -0
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
|
3 |
from datasets import load_dataset
|
4 |
|
@@ -14,29 +15,38 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
14 |
def tokenize_function(examples):
|
15 |
return tokenizer(examples["text"], padding="max_length", truncation=True)
|
16 |
|
17 |
-
|
|
|
18 |
|
19 |
-
# ํ๋ จ ์ค์
|
20 |
training_args = TrainingArguments(
|
21 |
output_dir="./results", # ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก
|
22 |
-
num_train_epochs=
|
23 |
-
per_device_train_batch_size=
|
24 |
-
per_device_eval_batch_size=
|
25 |
evaluation_strategy="epoch", # ์ํญ๋ง๋ค ๊ฒ์ฆ
|
26 |
logging_dir="./logs", # ๋ก๊ทธ ์ ์ฅ ๊ฒฝ๋ก
|
|
|
|
|
|
|
27 |
)
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
-
#
|
37 |
-
|
|
|
|
|
38 |
|
39 |
-
#
|
40 |
def classify_text(text):
|
41 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
42 |
outputs = model(**inputs)
|
@@ -44,7 +54,17 @@ def classify_text(text):
|
|
44 |
predicted_class = logits.argmax(-1).item()
|
45 |
return predicted_class
|
46 |
|
|
|
47 |
demo = gr.Interface(fn=classify_text, inputs="text", outputs="text")
|
48 |
|
49 |
-
# Gradio
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import threading
|
3 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
|
4 |
from datasets import load_dataset
|
5 |
|
|
|
15 |
def tokenize_function(examples):
|
16 |
return tokenizer(examples["text"], padding="max_length", truncation=True)
|
17 |
|
18 |
+
tokenized_train_datasets = dataset["train"].map(tokenize_function, batched=True)
|
19 |
+
tokenized_test_datasets = dataset["test"].map(tokenize_function, batched=True)
|
20 |
|
21 |
+
# ํ๋ จ ์ค์ (๋น ๋ฅด๊ฒ ํ๋ จํ๊ธฐ ์ํด ์ํญ ์๋ฅผ ์ค์)
|
22 |
training_args = TrainingArguments(
|
23 |
output_dir="./results", # ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก
|
24 |
+
num_train_epochs=1, # ํ๋ จ ์ํญ ์ 1๋ก ์ค์ (๋น ๋ฅด๊ฒ ํ
์คํธ)
|
25 |
+
per_device_train_batch_size=16, # ๋ฐฐ์น ํฌ๊ธฐ ์ฆ๊ฐ
|
26 |
+
per_device_eval_batch_size=16, # ๋ฐฐ์น ํฌ๊ธฐ ์ฆ๊ฐ
|
27 |
evaluation_strategy="epoch", # ์ํญ๋ง๋ค ๊ฒ์ฆ
|
28 |
logging_dir="./logs", # ๋ก๊ทธ ์ ์ฅ ๊ฒฝ๋ก
|
29 |
+
logging_steps=100, # 100 ์คํ
๋ง๋ค ๋ก๊ทธ ์ถ๋ ฅ
|
30 |
+
report_to="tensorboard", # ํ
์๋ณด๋๋ก ๋ก๊ทธ ๋ณด๊ณ
|
31 |
+
load_best_model_at_end=True, # ์ต์์ ๋ชจ๋ธ๋ก ์ข
๋ฃ
|
32 |
)
|
33 |
|
34 |
+
# ํ๋ จ ํจ์
|
35 |
+
def train_model():
|
36 |
+
trainer = Trainer(
|
37 |
+
model=model, # ํ๋ จํ ๋ชจ๋ธ
|
38 |
+
args=training_args, # ํ๋ จ ์ธ์
|
39 |
+
train_dataset=tokenized_train_datasets, # ํ๋ จ ๋ฐ์ดํฐ์
|
40 |
+
eval_dataset=tokenized_test_datasets, # ํ๊ฐ ๋ฐ์ดํฐ์
|
41 |
+
)
|
42 |
+
trainer.train()
|
43 |
|
44 |
+
# ํ๋ จ์ ๋ณ๋์ ์ค๋ ๋์์ ์คํ
|
45 |
+
def start_training():
|
46 |
+
train_thread = threading.Thread(target=train_model)
|
47 |
+
train_thread.start()
|
48 |
|
49 |
+
# ๊ทธ๋ผ๋์ธํธ ๊ธฐ๋ฐ ํ๋ จ๋ ๋ชจ๋ธ์ UI์ ์ฐ๊ฒฐ
|
50 |
def classify_text(text):
|
51 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
52 |
outputs = model(**inputs)
|
|
|
54 |
predicted_class = logits.argmax(-1).item()
|
55 |
return predicted_class
|
56 |
|
57 |
+
# Gradio ์ธํฐํ์ด์ค ์ค์
|
58 |
demo = gr.Interface(fn=classify_text, inputs="text", outputs="text")
|
59 |
|
60 |
+
# ํ๋ จ ์์๊ณผ Gradio UI ์คํ
|
61 |
+
def launch_app():
|
62 |
+
# ํ๋ จ์ ์์
|
63 |
+
start_training()
|
64 |
+
|
65 |
+
# Gradio ์ธํฐํ์ด์ค ์คํ
|
66 |
+
demo.launch()
|
67 |
+
|
68 |
+
# ํ๊น
ํ์ด์ค Spaces์ ์
๋ก๋ ํ ๋๋ ์ด ๋ถ๋ถ์ ์คํํ๋๋ก ์ค์
|
69 |
+
if __name__ == "__main__":
|
70 |
+
launch_app()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.0.12
|
2 |
+
transformers==4.28.1
|
3 |
+
datasets==2.13.1
|
4 |
+
torch==1.13.1
|