tiny_llm / run.py
xcx0902's picture
Upload folder using huggingface_hub
493e08a verified
import torch
from SimpleRNN import SimpleRNN
import json
from tqdm import tqdm, trange
parameters = json.loads(open("parameter.json").read())
model_path = parameters["model_path"]
model = torch.load(model_path, weights_only=False)
with open("vocab.json", "r") as f:
chars = json.loads(f.read())
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
print("Loaded pre-trained model.")
input_size = len(chars)
hidden_size = parameters["hidden_size"]
output_size = len(chars)
def generate_text(start_text, length):
model.eval()
hidden = torch.zeros(1, 1, hidden_size)
input_seq = torch.tensor([char_to_idx[ch] for ch in start_text])
generated_text = start_text
for _ in trange(length):
output, hidden = model(input_seq, hidden)
predicted_idx = output.argmax().item()
generated_text += idx_to_char[predicted_idx]
input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx])))
return generated_text
while True:
prompt = input("Ask LLM: ")
length = int(input("Length of text: "))
print("LLM Output: ", generate_text(prompt, length))