HCK-GPT / model /train.py
hackermoon1's picture
Update model/train.py
4e20470 verified
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import os
def fine_tune_gpt2(data_path, output_dir="model/fine_tuned"):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Carregar dados
dataset = load_dataset("text", data_files={"train": data_path})
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Configurar treinamento
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=500,
save_total_limit=2,
logging_dir="logs",
logging_steps=100,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
)
trainer.train()
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Modelo salvo em {output_dir}")
if __name__ == "__main__":
fine_tune_gpt2("data/processed/train.txt")