File size: 1,200 Bytes
b9d1833
493e08a
b9d1833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from SimpleRNN import SimpleRNN
import json
from tqdm import tqdm, trange

parameters = json.loads(open("parameter.json").read())
model_path = parameters["model_path"]

model = torch.load(model_path, weights_only=False)
with open("vocab.json", "r") as f:
    chars = json.loads(f.read())
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
print("Loaded pre-trained model.")

input_size = len(chars)
hidden_size = parameters["hidden_size"]
output_size = len(chars)

def generate_text(start_text, length):
    model.eval()
    hidden = torch.zeros(1, 1, hidden_size)
    input_seq = torch.tensor([char_to_idx[ch] for ch in start_text])

    generated_text = start_text
    for _ in trange(length):
        output, hidden = model(input_seq, hidden)
        predicted_idx = output.argmax().item()
        generated_text += idx_to_char[predicted_idx]
        input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx])))

    return generated_text

while True:
    prompt = input("Ask LLM: ")
    length = int(input("Length of text: "))
    print("LLM Output: ", generate_text(prompt, length))