File size: 2,267 Bytes
b51c975
 
 
493e08a
b51c975
 
 
 
 
 
 
 
 
 
b9d1833
b51c975
b9d1833
b51c975
b9d1833
b51c975
b9d1833
 
b51c975
 
 
 
 
 
 
 
 
b9d1833
b51c975
 
 
 
b9d1833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b51c975
b9d1833
 
b51c975
b9d1833
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
import torch.optim as optim
from SimpleRNN import SimpleRNN
import os
import json
from tqdm import tqdm, trange
import time

training_text = open("train_data.txt", encoding="utf-8").read()
chars = sorted(list(set(training_text)))  # Unique characters
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

parameters = json.loads(open("parameter.json").read())
input_size = len(chars)
hidden_size = parameters["hidden_size"]
output_size = len(chars)
sequence_length = parameters["sequence_length"]
epochs = 1000
learning_rate = parameters["learning_rate"]
model_path = parameters["model_path"]

train_data = []
for i in range(len(training_text) - sequence_length):
    input_seq = training_text[i : i + sequence_length]
    target_char = training_text[i + sequence_length]
    train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char]))

if os.path.exists(model_path):
    model = torch.load(model_path, weights_only=False)
    print("Loaded pre-trained model. Continue training...")
else:
    print("Training new model...")
    model = SimpleRNN(input_size, hidden_size, output_size)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
    try:
        total_loss = 0
        hidden = torch.zeros(1, 1, hidden_size)
        
        pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A")
        count = 0
        for input_seq, target in pbar:
            count += 1
            optimizer.zero_grad()
            output, hidden = model(input_seq, hidden.detach())
            loss = criterion(output, torch.tensor([target]))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}"
        
        pbar.close()
        time.sleep(1)
    except KeyboardInterrupt:
        break

hidden = torch.zeros(1, 1, hidden_size)
output, hidden = model(input_seq, hidden.detach())

torch.save(model, model_path)
with open("vocab.json", "w") as f:
    f.write(json.dumps(chars))
print("Model saved.")