|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout): |
|
super().__init__() |
|
|
|
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True) |
|
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, src): |
|
""" |
|
src: src_len x batch_size x img_channel |
|
outputs: src_len x batch_size x hid_dim |
|
hidden: batch_size x hid_dim |
|
""" |
|
|
|
embedded = self.dropout(src) |
|
|
|
outputs, hidden = self.rnn(embedded) |
|
|
|
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))) |
|
|
|
return outputs, hidden |
|
|
|
class Attention(nn.Module): |
|
def __init__(self, enc_hid_dim, dec_hid_dim): |
|
super().__init__() |
|
|
|
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) |
|
self.v = nn.Linear(dec_hid_dim, 1, bias = False) |
|
|
|
def forward(self, hidden, encoder_outputs): |
|
""" |
|
hidden: batch_size x hid_dim |
|
encoder_outputs: src_len x batch_size x hid_dim, |
|
outputs: batch_size x src_len |
|
""" |
|
|
|
batch_size = encoder_outputs.shape[1] |
|
src_len = encoder_outputs.shape[0] |
|
|
|
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) |
|
|
|
encoder_outputs = encoder_outputs.permute(1, 0, 2) |
|
|
|
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) |
|
|
|
attention = self.v(energy).squeeze(2) |
|
|
|
return F.softmax(attention, dim = 1) |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention): |
|
super().__init__() |
|
|
|
self.output_dim = output_dim |
|
self.attention = attention |
|
|
|
self.embedding = nn.Embedding(output_dim, emb_dim) |
|
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim) |
|
self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, input, hidden, encoder_outputs): |
|
""" |
|
inputs: batch_size |
|
hidden: batch_size x hid_dim |
|
encoder_outputs: src_len x batch_size x hid_dim |
|
""" |
|
|
|
input = input.unsqueeze(0) |
|
|
|
embedded = self.dropout(self.embedding(input)) |
|
|
|
a = self.attention(hidden, encoder_outputs) |
|
|
|
a = a.unsqueeze(1) |
|
|
|
encoder_outputs = encoder_outputs.permute(1, 0, 2) |
|
|
|
weighted = torch.bmm(a, encoder_outputs) |
|
|
|
weighted = weighted.permute(1, 0, 2) |
|
|
|
rnn_input = torch.cat((embedded, weighted), dim = 2) |
|
|
|
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) |
|
|
|
assert (output == hidden).all() |
|
|
|
embedded = embedded.squeeze(0) |
|
output = output.squeeze(0) |
|
weighted = weighted.squeeze(0) |
|
|
|
prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1)) |
|
|
|
return prediction, hidden.squeeze(0), a.squeeze(1) |
|
|
|
class Seq2Seq(nn.Module): |
|
def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1): |
|
super().__init__() |
|
|
|
attn = Attention(encoder_hidden, decoder_hidden) |
|
|
|
self.encoder = Encoder(img_channel, encoder_hidden, decoder_hidden, dropout) |
|
self.decoder = Decoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout, attn) |
|
|
|
def forward_encoder(self, src): |
|
""" |
|
src: timestep x batch_size x channel |
|
hidden: batch_size x hid_dim |
|
encoder_outputs: src_len x batch_size x hid_dim |
|
""" |
|
|
|
encoder_outputs, hidden = self.encoder(src) |
|
|
|
return (hidden, encoder_outputs) |
|
|
|
def forward_decoder(self, tgt, memory): |
|
""" |
|
tgt: timestep x batch_size |
|
hidden: batch_size x hid_dim |
|
encouder: src_len x batch_size x hid_dim |
|
output: batch_size x 1 x vocab_size |
|
""" |
|
|
|
tgt = tgt[-1] |
|
hidden, encoder_outputs = memory |
|
output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs) |
|
output = output.unsqueeze(1) |
|
|
|
return output, (hidden, encoder_outputs) |
|
|
|
def forward(self, src, trg): |
|
""" |
|
src: time_step x batch_size |
|
trg: time_step x batch_size |
|
outputs: batch_size x time_step x vocab_size |
|
""" |
|
|
|
batch_size = src.shape[1] |
|
trg_len = trg.shape[0] |
|
trg_vocab_size = self.decoder.output_dim |
|
device = src.device |
|
|
|
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device) |
|
encoder_outputs, hidden = self.encoder(src) |
|
|
|
for t in range(trg_len): |
|
input = trg[t] |
|
output, hidden, _ = self.decoder(input, hidden, encoder_outputs) |
|
|
|
outputs[t] = output |
|
|
|
outputs = outputs.transpose(0, 1).contiguous() |
|
|
|
return outputs |
|
|
|
def expand_memory(self, memory, beam_size): |
|
hidden, encoder_outputs = memory |
|
hidden = hidden.repeat(beam_size, 1) |
|
encoder_outputs = encoder_outputs.repeat(1, beam_size, 1) |
|
|
|
return (hidden, encoder_outputs) |
|
|
|
def get_memory(self, memory, i): |
|
hidden, encoder_outputs = memory |
|
hidden = hidden[[i]] |
|
encoder_outputs = encoder_outputs[:, [i],:] |
|
|
|
return (hidden, encoder_outputs) |
|
|