Kevin Fink commited on
Commit
f7961d6
·
1 Parent(s): 871c25a

gradio fix

Browse files
Files changed (1) hide show
  1. app.py +50 -45
app.py CHANGED
@@ -1,52 +1,57 @@
1
  import gradio as gr
2
  from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
 
4
 
5
  def fine_tune_model(model_name, dataset_name, hub_id, num_epochs, batch_size, lr, grad):
6
- # Load the dataset
7
- dataset = load_dataset(dataset_name)
8
-
9
- # Load the model and tokenizer
10
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, num_labels=2)
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
-
13
- # Tokenize the dataset
14
- def tokenize_function(examples):
15
- return tokenizer(examples['text'], padding="max_length", truncation=True)
16
-
17
- tokenized_datasets = dataset.map(tokenize_function, batched=True)
18
-
19
- # Set training arguments
20
- training_args = TrainingArguments(
21
- output_dir='./results',
22
- evaluation_strategy="epoch",
23
- learning_rate=lr,
24
- per_device_train_batch_size=batch_size,
25
- per_device_eval_batch_size=batch_size,
26
- num_train_epochs=num_epochs,
27
- weight_decay=0.01,
28
- evaluation_strategy='epoch',
29
- gradient_accumulation_steps=grad,
30
- load_best_model_at_end=True,
31
- metric_for_best_model="accuracy",
32
- greater_is_better=True,
33
- logging_dir='./logs',
34
- logging_steps=10,
35
- push_to_hub=True,
36
- hub_model_id=hub_id,
37
- )
38
-
39
- # Create Trainer
40
- trainer = Trainer(
41
- model=model,
42
- args=training_args,
43
- train_dataset=tokenized_datasets['train'],
44
- eval_dataset=tokenized_datasets['validation'],
45
- )
46
-
47
- # Fine-tune the model
48
- trainer.train()
49
- trainer.push_to_hub(commit_message="Training complete!")
 
 
 
 
50
  return 'DONE!'#model
51
  '''
52
  # Define Gradio interface
@@ -58,7 +63,7 @@ def predict(text):
58
  '''
59
  # Create Gradio interface
60
  iface = gr.Interface(
61
- fn=fine_tune_model,
62
  inputs=[
63
  gr.inputs.Textbox(label="Model Name (e.g., 'google/t5-efficient-tiny-nh8')"),
64
  gr.inputs.Textbox(label="Dataset Name (e.g., 'imdb')"),
 
1
  import gradio as gr
2
  from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
+ import traceback
5
 
6
  def fine_tune_model(model_name, dataset_name, hub_id, num_epochs, batch_size, lr, grad):
7
+ try:
8
+
9
+ # Load the dataset
10
+ dataset = load_dataset(dataset_name)
11
+
12
+ # Load the model and tokenizer
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, num_labels=2)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ # Tokenize the dataset
17
+ def tokenize_function(examples):
18
+ return tokenizer(examples['text'], padding="max_length", truncation=True)
19
+
20
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
21
+
22
+ # Set training arguments
23
+ training_args = TrainingArguments(
24
+ output_dir='./results',
25
+ evaluation_strategy="epoch",
26
+ learning_rate=lr,
27
+ per_device_train_batch_size=batch_size,
28
+ per_device_eval_batch_size=batch_size,
29
+ num_train_epochs=num_epochs,
30
+ weight_decay=0.01,
31
+ evaluation_strategy='epoch',
32
+ gradient_accumulation_steps=grad,
33
+ load_best_model_at_end=True,
34
+ metric_for_best_model="accuracy",
35
+ greater_is_better=True,
36
+ logging_dir='./logs',
37
+ logging_steps=10,
38
+ push_to_hub=True,
39
+ hub_model_id=hub_id,
40
+ )
41
+
42
+ # Create Trainer
43
+ trainer = Trainer(
44
+ model=model,
45
+ args=training_args,
46
+ train_dataset=tokenized_datasets['train'],
47
+ eval_dataset=tokenized_datasets['validation'],
48
+ )
49
+
50
+ # Fine-tune the model
51
+ trainer.train()
52
+ trainer.push_to_hub(commit_message="Training complete!")
53
+ except Exception as e:
54
+ return f"An error occurred: {str(e)}, TB: {traceback.format_exc()}"
55
  return 'DONE!'#model
56
  '''
57
  # Define Gradio interface
 
63
  '''
64
  # Create Gradio interface
65
  iface = gr.Interface(
66
+ fine_tune_model,
67
  inputs=[
68
  gr.inputs.Textbox(label="Model Name (e.g., 'google/t5-efficient-tiny-nh8')"),
69
  gr.inputs.Textbox(label="Dataset Name (e.g., 'imdb')"),