hackermoon1 commited on
Commit
4e20470
·
verified ·
1 Parent(s): 869969f

Update model/train.py

Browse files
Files changed (1) hide show
  1. model/train.py +40 -0
model/train.py CHANGED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
+ import os
4
+
5
+ def fine_tune_gpt2(data_path, output_dir="model/fine_tuned"):
6
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
8
+
9
+ # Carregar dados
10
+ dataset = load_dataset("text", data_files={"train": data_path})
11
+
12
+ def tokenize_function(examples):
13
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
14
+
15
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
16
+
17
+ # Configurar treinamento
18
+ training_args = TrainingArguments(
19
+ output_dir=output_dir,
20
+ num_train_epochs=3,
21
+ per_device_train_batch_size=4,
22
+ save_steps=500,
23
+ save_total_limit=2,
24
+ logging_dir="logs",
25
+ logging_steps=100,
26
+ )
27
+
28
+ trainer = Trainer(
29
+ model=model,
30
+ args=training_args,
31
+ train_dataset=tokenized_dataset["train"],
32
+ )
33
+
34
+ trainer.train()
35
+ model.save_pretrained(output_dir)
36
+ tokenizer.save_pretrained(output_dir)
37
+ print(f"Modelo salvo em {output_dir}")
38
+
39
+ if __name__ == "__main__":
40
+ fine_tune_gpt2("data/processed/train.txt")