tiny_llm / train.py
xcx0902's picture
Upload folder using huggingface_hub
493e08a verified
import torch
import torch.nn as nn
import torch.optim as optim
from SimpleRNN import SimpleRNN
import os
import json
from tqdm import tqdm, trange
import time
training_text = open("train_data.txt", encoding="utf-8").read()
chars = sorted(list(set(training_text))) # Unique characters
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
parameters = json.loads(open("parameter.json").read())
input_size = len(chars)
hidden_size = parameters["hidden_size"]
output_size = len(chars)
sequence_length = parameters["sequence_length"]
epochs = 1000
learning_rate = parameters["learning_rate"]
model_path = parameters["model_path"]
train_data = []
for i in range(len(training_text) - sequence_length):
input_seq = training_text[i : i + sequence_length]
target_char = training_text[i + sequence_length]
train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char]))
if os.path.exists(model_path):
model = torch.load(model_path, weights_only=False)
print("Loaded pre-trained model. Continue training...")
else:
print("Training new model...")
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
try:
total_loss = 0
hidden = torch.zeros(1, 1, hidden_size)
pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A")
count = 0
for input_seq, target in pbar:
count += 1
optimizer.zero_grad()
output, hidden = model(input_seq, hidden.detach())
loss = criterion(output, torch.tensor([target]))
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}"
pbar.close()
time.sleep(1)
except KeyboardInterrupt:
break
hidden = torch.zeros(1, 1, hidden_size)
output, hidden = model(input_seq, hidden.detach())
torch.save(model, model_path)
with open("vocab.json", "w") as f:
f.write(json.dumps(chars))
print("Model saved.")