|
import torch
|
|
from SimpleRNN import SimpleRNN
|
|
import json
|
|
from tqdm import tqdm, trange
|
|
|
|
parameters = json.loads(open("parameter.json").read())
|
|
model_path = parameters["model_path"]
|
|
|
|
model = torch.load(model_path, weights_only=False)
|
|
with open("vocab.json", "r") as f:
|
|
chars = json.loads(f.read())
|
|
char_to_idx = {ch: i for i, ch in enumerate(chars)}
|
|
idx_to_char = {i: ch for i, ch in enumerate(chars)}
|
|
print("Loaded pre-trained model.")
|
|
|
|
input_size = len(chars)
|
|
hidden_size = parameters["hidden_size"]
|
|
output_size = len(chars)
|
|
|
|
def generate_text(start_text, length):
|
|
model.eval()
|
|
hidden = torch.zeros(1, 1, hidden_size)
|
|
input_seq = torch.tensor([char_to_idx[ch] for ch in start_text])
|
|
|
|
generated_text = start_text
|
|
for _ in trange(length):
|
|
output, hidden = model(input_seq, hidden)
|
|
predicted_idx = output.argmax().item()
|
|
generated_text += idx_to_char[predicted_idx]
|
|
input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx])))
|
|
|
|
return generated_text
|
|
|
|
while True:
|
|
prompt = input("Ask LLM: ")
|
|
length = int(input("Length of text: "))
|
|
print("LLM Output: ", generate_text(prompt, length))
|
|
|