yeonsoo commited on
Commit
bbf9c8b
ยท
1 Parent(s): a4fc148
Files changed (2) hide show
  1. app.py +36 -16
  2. requirements.txt +4 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
3
  from datasets import load_dataset
4
 
@@ -14,29 +15,38 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
14
  def tokenize_function(examples):
15
  return tokenizer(examples["text"], padding="max_length", truncation=True)
16
 
17
- tokenized_datasets = dataset.map(tokenize_function, batched=True)
 
18
 
19
- # ํ›ˆ๋ จ ์„ค์ •
20
  training_args = TrainingArguments(
21
  output_dir="./results", # ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ
22
- num_train_epochs=3, # ํ›ˆ๋ จ ์—ํญ ์ˆ˜
23
- per_device_train_batch_size=8, # ๋ฐฐ์น˜ ํฌ๊ธฐ
24
- per_device_eval_batch_size=8, # ๊ฒ€์ฆ ๋ฐฐ์น˜ ํฌ๊ธฐ
25
  evaluation_strategy="epoch", # ์—ํญ๋งˆ๋‹ค ๊ฒ€์ฆ
26
  logging_dir="./logs", # ๋กœ๊ทธ ์ €์žฅ ๊ฒฝ๋กœ
 
 
 
27
  )
28
 
29
- trainer = Trainer(
30
- model=model, # ํ›ˆ๋ จํ•  ๋ชจ๋ธ
31
- args=training_args, # ํ›ˆ๋ จ ์ธ์ž
32
- train_dataset=tokenized_datasets["train"], # ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹
33
- eval_dataset=tokenized_datasets["test"], # ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹
34
- )
 
 
 
35
 
36
- # ํ›ˆ๋ จ ์‹œ์ž‘
37
- trainer.train()
 
 
38
 
39
- # ๊ทธ๋ผ๋””์˜ค ์ธํ„ฐํŽ˜์ด์Šค๋กœ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ UI์— ์—ฐ๊ฒฐ
40
  def classify_text(text):
41
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
42
  outputs = model(**inputs)
@@ -44,7 +54,17 @@ def classify_text(text):
44
  predicted_class = logits.argmax(-1).item()
45
  return predicted_class
46
 
 
47
  demo = gr.Interface(fn=classify_text, inputs="text", outputs="text")
48
 
49
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰ (ํ›ˆ๋ จ ํ›„)
50
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import threading
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
  from datasets import load_dataset
5
 
 
15
  def tokenize_function(examples):
16
  return tokenizer(examples["text"], padding="max_length", truncation=True)
17
 
18
+ tokenized_train_datasets = dataset["train"].map(tokenize_function, batched=True)
19
+ tokenized_test_datasets = dataset["test"].map(tokenize_function, batched=True)
20
 
21
+ # ํ›ˆ๋ จ ์„ค์ • (๋น ๋ฅด๊ฒŒ ํ›ˆ๋ จํ•˜๊ธฐ ์œ„ํ•ด ์—ํญ ์ˆ˜๋ฅผ ์ค„์ž„)
22
  training_args = TrainingArguments(
23
  output_dir="./results", # ๊ฒฐ๊ณผ ์ €์žฅ ๊ฒฝ๋กœ
24
+ num_train_epochs=1, # ํ›ˆ๋ จ ์—ํญ ์ˆ˜ 1๋กœ ์„ค์ • (๋น ๋ฅด๊ฒŒ ํ…Œ์ŠคํŠธ)
25
+ per_device_train_batch_size=16, # ๋ฐฐ์น˜ ํฌ๊ธฐ ์ฆ๊ฐ€
26
+ per_device_eval_batch_size=16, # ๋ฐฐ์น˜ ํฌ๊ธฐ ์ฆ๊ฐ€
27
  evaluation_strategy="epoch", # ์—ํญ๋งˆ๋‹ค ๊ฒ€์ฆ
28
  logging_dir="./logs", # ๋กœ๊ทธ ์ €์žฅ ๊ฒฝ๋กœ
29
+ logging_steps=100, # 100 ์Šคํ…๋งˆ๋‹ค ๋กœ๊ทธ ์ถœ๋ ฅ
30
+ report_to="tensorboard", # ํ…์„œ๋ณด๋“œ๋กœ ๋กœ๊ทธ ๋ณด๊ณ 
31
+ load_best_model_at_end=True, # ์ตœ์ƒ์˜ ๋ชจ๋ธ๋กœ ์ข…๋ฃŒ
32
  )
33
 
34
+ # ํ›ˆ๋ จ ํ•จ์ˆ˜
35
+ def train_model():
36
+ trainer = Trainer(
37
+ model=model, # ํ›ˆ๋ จํ•  ๋ชจ๋ธ
38
+ args=training_args, # ํ›ˆ๋ จ ์ธ์ž
39
+ train_dataset=tokenized_train_datasets, # ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹
40
+ eval_dataset=tokenized_test_datasets, # ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹
41
+ )
42
+ trainer.train()
43
 
44
+ # ํ›ˆ๋ จ์„ ๋ณ„๋„์˜ ์Šค๋ ˆ๋“œ์—์„œ ์‹คํ–‰
45
+ def start_training():
46
+ train_thread = threading.Thread(target=train_model)
47
+ train_thread.start()
48
 
49
+ # ๊ทธ๋ผ๋””์–ธํŠธ ๊ธฐ๋ฐ˜ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ UI์— ์—ฐ๊ฒฐ
50
  def classify_text(text):
51
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
52
  outputs = model(**inputs)
 
54
  predicted_class = logits.argmax(-1).item()
55
  return predicted_class
56
 
57
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
58
  demo = gr.Interface(fn=classify_text, inputs="text", outputs="text")
59
 
60
+ # ํ›ˆ๋ จ ์‹œ์ž‘๊ณผ Gradio UI ์‹คํ–‰
61
+ def launch_app():
62
+ # ํ›ˆ๋ จ์„ ์‹œ์ž‘
63
+ start_training()
64
+
65
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
66
+ demo.launch()
67
+
68
+ # ํ—ˆ๊น…ํŽ˜์ด์Šค Spaces์— ์—…๋กœ๋“œ ํ•  ๋•Œ๋Š” ์ด ๋ถ€๋ถ„์„ ์‹คํ–‰ํ•˜๋„๋ก ์„ค์ •
69
+ if __name__ == "__main__":
70
+ launch_app()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==3.0.12
2
+ transformers==4.28.1
3
+ datasets==2.13.1
4
+ torch==1.13.1