tiny_llm / SimpleRNN.py
xcx0902's picture
Upload folder using huggingface_hub
ab33b86 verified
raw
history blame contribute delete
617 Bytes
import torch
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
x = torch.nn.functional.one_hot(x, num_classes=self.input_size).float()
out, hidden = self.rnn(x.unsqueeze(0), hidden)
out = self.fc(out[:, -1, :])
return out, hidden