File size: 617 Bytes
493e08a
 
 
 
 
 
 
 
 
 
 
 
 
 
ab33b86
493e08a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        x = torch.nn.functional.one_hot(x, num_classes=self.input_size).float()
        out, hidden = self.rnn(x.unsqueeze(0), hidden)
        out = self.fc(out[:, -1, :])
        return out, hidden