Spaces:
Sleeping
Sleeping
""" | |
Overview: | |
This file implements the core modules of GTrXL Transformer as described in | |
"Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764). | |
""" | |
from typing import Optional, Dict, List | |
import warnings | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from ding.torch_utils.network.nn_module import fc_block, build_normalization, F | |
class PositionalEmbedding(nn.Module): | |
""" | |
Overview: | |
The PositionalEmbedding module implements the positional embedding used in the vanilla Transformer model. | |
Interfaces: | |
``__init__``, ``forward`` | |
.. note:: | |
This implementation is adapted from https://github.com/kimiyoung/transformer-xl/blob/ \ | |
master/pytorch/mem_transformer.py | |
""" | |
def __init__(self, embedding_dim: int): | |
""" | |
Overview: | |
Initialize the PositionalEmbedding module. | |
Arguments: | |
- embedding_dim: (:obj:`int`): The dimensionality of the embeddings. | |
""" | |
super(PositionalEmbedding, self).__init__() | |
self.embedding_dim = embedding_dim | |
inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_dim, 2.0) / embedding_dim)) # (embedding_dim / 2) | |
self.register_buffer('inv_freq', inv_freq) | |
def forward(self, pos_seq: torch.Tensor) -> torch.Tensor: | |
""" | |
Overview: | |
Compute positional embedding given a sequence of positions. | |
Arguments: | |
- pos_seq (:obj:`torch.Tensor`): The positional sequence, \ | |
typically a 1D tensor of integers in the form of [seq_len-1, seq_len-2, ..., 1, 0], | |
Returns: | |
- pos_embedding (:obj:`torch.Tensor`): The computed positional embeddings. \ | |
The shape of the tensor is (seq_len, 1, embedding_dim). | |
""" | |
sinusoid_inp = torch.outer(pos_seq, self.inv_freq) | |
# For position embedding, the order of sin/cos is negligible. | |
# This is because tokens are consumed by the matrix multiplication which is permutation-invariant. | |
pos_embedding = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) | |
return pos_embedding.unsqueeze(1) | |
class GRUGatingUnit(torch.nn.Module): | |
""" | |
Overview: | |
The GRUGatingUnit module implements the GRU gating mechanism used in the GTrXL model. | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, input_dim: int, bg: float = 2.): | |
""" | |
Overview: | |
Initialize the GRUGatingUnit module. | |
Arguments: | |
- input_dim (:obj:`int`): The dimensionality of the input. | |
- bg (:obj:`bg`): The gate bias. By setting bg > 0 we can explicitly initialize the gating mechanism to \ | |
be close to the identity map. This can greatly improve the learning speed and stability since it \ | |
initializes the agent close to a Markovian policy (ignore attention at the beginning). | |
""" | |
super(GRUGatingUnit, self).__init__() | |
self.Wr = torch.nn.Linear(input_dim, input_dim, bias=False) | |
self.Ur = torch.nn.Linear(input_dim, input_dim, bias=False) | |
self.Wz = torch.nn.Linear(input_dim, input_dim, bias=False) | |
self.Uz = torch.nn.Linear(input_dim, input_dim, bias=False) | |
self.Wg = torch.nn.Linear(input_dim, input_dim, bias=False) | |
self.Ug = torch.nn.Linear(input_dim, input_dim, bias=False) | |
self.bg = nn.Parameter(torch.full([input_dim], bg)) # bias | |
self.sigmoid = torch.nn.Sigmoid() | |
self.tanh = torch.nn.Tanh() | |
def forward(self, x: torch.Tensor, y: torch.Tensor): | |
""" | |
Overview: | |
Compute the output value using the GRU gating mechanism. | |
Arguments: | |
- x: (:obj:`torch.Tensor`): The first input tensor. | |
- y: (:obj:`torch.Tensor`): The second input tensor. \ | |
x and y should have the same shape and their last dimension should match the input_dim. | |
Returns: | |
- g: (:obj:`torch.Tensor`): The output of the GRU gating mechanism. \ | |
The shape of g matches the shapes of x and y. | |
""" | |
r = self.sigmoid(self.Wr(y) + self.Ur(x)) | |
z = self.sigmoid(self.Wz(y) + self.Uz(x) - self.bg) | |
h = self.tanh(self.Wg(y) + self.Ug(torch.mul(r, x))) # element wise multiplication | |
g = torch.mul(1 - z, x) + torch.mul(z, h) | |
return g # x.shape == y.shape == g.shape | |
class Memory: | |
""" | |
Overview: | |
A class that stores the context used to add memory to Transformer. | |
Interfaces: | |
``__init__``, ``init``, ``update``, ``get``, ``to`` | |
.. note:: | |
For details, refer to Transformer-XL: https://arxiv.org/abs/1901.02860 | |
""" | |
def __init__( | |
self, | |
memory_len: int = 20, | |
batch_size: int = 64, | |
embedding_dim: int = 256, | |
layer_num: int = 3, | |
memory: Optional[torch.Tensor] = None | |
) -> None: | |
""" | |
Overview: | |
Initialize the Memory module. | |
Arguments: | |
- memory_len (:obj:`int`): The dimension of memory, i.e., how many past observations to use as memory. | |
- batch_size (:obj:`int`): The dimension of each batch. | |
- embedding_dim (:obj:`int`): The dimension of embedding, which is the dimension of a single observation \ | |
after embedding. | |
- layer_num (:obj:`int`): The number of transformer layers. | |
- memory (:obj:`Optional[torch.Tensor]`): The initial memory. Default is None. | |
""" | |
super(Memory, self).__init__() | |
self.embedding_dim = embedding_dim | |
self.bs = batch_size | |
self.layer_num = layer_num | |
self.memory_len = memory_len | |
self.memory = None | |
self.init(memory) | |
def init(self, memory: Optional[torch.Tensor] = None): | |
""" | |
Overview: | |
Initialize memory with an input list of tensors or create it automatically given its dimensions. | |
Arguments: | |
- memory (:obj:`Optional[torch.Tensor]`): Input memory tensor with shape \ | |
(layer_num, memory_len, bs, embedding_dim). Its shape is (layer_num, memory_len, bs, embedding_dim), \ | |
where memory_len is length of memory, bs is batch size and embedding_dim is the dimension of embedding. | |
""" | |
if memory is not None: | |
self.memory = memory | |
layer_num_plus1, self.memory_len, self.bs, self.embedding_dim = memory.shape | |
self.layer_num = layer_num_plus1 - 1 | |
else: | |
self.memory = torch.zeros( | |
self.layer_num + 1, self.memory_len, self.bs, self.embedding_dim, dtype=torch.float | |
) | |
def update(self, hidden_state: List[torch.Tensor]): | |
""" | |
Overview: | |
Update the memory given a sequence of hidden states. | |
Example for single layer: | |
memory_len=3, hidden_size_len=2, bs=3 | |
m00 m01 m02 h00 h01 h02 m20 m21 m22 | |
m = m10 m11 m12 h = h10 h11 h12 => new_m = h00 h01 h02 | |
m20 m21 m22 h10 h11 h12 | |
Arguments: | |
- hidden_state: (:obj:`List[torch.Tensor]`): The hidden states to update the memory. \ | |
Each tensor in the list has shape (cur_seq, bs, embedding_dim), where cur_seq \ | |
is the length of the sequence. | |
Returns: | |
- memory: (:obj:`Optional[torch.Tensor]`): The updated memory, with shape \ | |
(layer_num, memory_len, bs, embedding_dim). | |
""" | |
if self.memory is None or hidden_state is None: | |
raise ValueError('Failed to update memory! Memory would be None') # TODO add support of no memory | |
sequence_len = hidden_state[0].shape[0] | |
with torch.no_grad(): | |
new_memory = [] | |
end = self.memory_len + sequence_len | |
beg = max(0, end - self.memory_len) | |
for i in range(self.layer_num + 1): | |
m = self.memory[i] | |
h = hidden_state[i] | |
cat = torch.cat([m, h], dim=0) | |
new_memory.append(cat[beg:end].detach()) | |
new_memory = torch.stack(new_memory, dim=0) | |
self.memory = new_memory | |
return new_memory | |
def get(self): | |
""" | |
Overview: | |
Get the current memory. | |
Returns: | |
- memory: (:obj:`Optional[torch.Tensor]`): The current memory, \ | |
with shape (layer_num, memory_len, bs, embedding_dim). | |
""" | |
return self.memory | |
def to(self, device: str = 'cpu'): | |
""" | |
Overview: | |
Move the current memory to the specified device. | |
Arguments: | |
device (:obj:`str`): The device to move the memory to. Default is 'cpu'. | |
""" | |
self.memory = self.memory.to(device) | |
class AttentionXL(torch.nn.Module): | |
""" | |
Overview: | |
An implementation of the Attention mechanism used in the TransformerXL model. | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, input_dim: int, head_dim: int, head_num: int, dropout: nn.Module) -> None: | |
""" | |
Overview: | |
Initialize the AttentionXL module. | |
Arguments: | |
- input_dim (:obj:`int`): The dimensionality of the input features. | |
- head_dim (:obj:`int`): The dimensionality of each attention head. | |
- head_num (:obj:`int`): The number of attention heads. | |
- dropout (:obj:`nn.Module`): The dropout layer to use | |
""" | |
super(AttentionXL, self).__init__() | |
self.head_num = head_num | |
self.head_dim = head_dim | |
self.dropout = dropout | |
self.attention_kv = fc_block(input_dim, head_dim * head_num * 2) # key, value | |
self.attention_q = fc_block(input_dim, head_dim * head_num) # query (not computed with past hidden states) | |
self.project = fc_block(head_dim * head_num, input_dim) # project attention output back to input_dim | |
self.project_pos = fc_block(input_dim, head_dim * head_num) # project the positional embedding | |
self.scale = 1 / (head_dim ** 0.5) # for scaled dot product attention | |
def _rel_shift(self, x: torch.Tensor, zero_upper: bool = False) -> torch.Tensor: | |
""" | |
Overview: | |
Perform a relative shift operation on the attention score matrix. | |
Example: | |
a00 a01 a02 0 a00 a01 a02 0 a00 a01 a02 0 a10 a02 0 0 | |
a10 a11 a12 => 0 a10 a11 a12 => a02 0 a10 => a11 a12 0 => a11 a12 0 | |
a20 a21 a22 0 a20 a21 a22 a11 a12 0 a20 a21 a22 a20 a21 a22 | |
a20 a21 a22 | |
1) Append one "column" of zeros to the left | |
2) Reshape the matrix from [3 x 4] into [4 x 3] | |
3) Remove the first "row" | |
4) Mask out the upper triangle (optional) | |
.. note:: | |
See the following material for better understanding: | |
https://github.com/kimiyoung/transformer-xl/issues/8 | |
https://arxiv.org/pdf/1901.02860.pdf (Appendix B) | |
Arguments: | |
- x (:obj:`torch.Tensor`): The input tensor with shape (cur_seq, full_seq, bs, head_num). | |
- zero_upper (:obj:`bool`): If True, the upper-right triangle of the matrix is set to zero. | |
Returns: | |
- x (:obj:`torch.Tensor`): The input tensor after the relative shift operation, \ | |
with shape (cur_seq, full_seq, bs, head_num). | |
""" | |
x_padded = F.pad(x, [1, 0]) # step 1 | |
x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) # step 2 | |
x = x_padded[:, :, 1:].view_as(x) # step 3 | |
if zero_upper: | |
ones = torch.ones((x.size(2), x.size(3))).unsqueeze(0).unsqueeze(0) | |
x = x * torch.tril(ones.to(x.device), x.size(3) - x.size(2)) # step 4 | |
return x | |
def forward( | |
self, | |
inputs: torch.Tensor, | |
pos_embedding: torch.Tensor, | |
full_input: torch.Tensor, | |
u: torch.nn.Parameter, | |
v: torch.nn.Parameter, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
""" | |
Overview: | |
Compute the forward pass for the AttentionXL module. | |
Arguments: | |
- inputs (:obj:`torch.Tensor`): The attention input with shape (cur_seq, bs, input_dim). | |
- pos_embedding (:obj:`torch.Tensor`): The positional embedding with shape (full_seq, 1, full_seq). | |
- full_input (:obj:`torch.Tensor`): The concatenated memory and input tensor with shape \ | |
(full_seq, bs, input_dim). | |
- u (:obj:`torch.nn.Parameter`): The content parameter with shape (head_num, head_dim). | |
- v (:obj:`torch.nn.Parameter`): The position parameter with shape (head_num, head_dim). | |
- mask (:obj:`Optional[torch.Tensor]`): The attention mask with shape (cur_seq, full_seq, 1). \ | |
If None, no masking is applied. | |
Returns: | |
- output (:obj:`torch.Tensor`): The output of the attention mechanism with shape (cur_seq, bs, input_dim). | |
""" | |
bs, cur_seq, full_seq = inputs.shape[1], inputs.shape[0], full_input.shape[0] | |
prev_seq = full_seq - cur_seq | |
kv = self.attention_kv(full_input) | |
key, value = torch.chunk(kv, 2, dim=-1) # full_seq x bs x num_head*dim_head | |
query = self.attention_q(inputs) # cur_seq x bs x num_head*dim_head | |
r = self.project_pos(pos_embedding) # full_seq x 1 x num_head*dim_head | |
key = key.view(full_seq, bs, self.head_num, self.head_dim) | |
query = query.view(cur_seq, bs, self.head_num, self.head_dim) | |
value = value.view(cur_seq + prev_seq, bs, self.head_num, self.head_dim) | |
r = r.view(full_seq, self.head_num, self.head_dim) | |
# (query + u) * key^T | |
q_u = query + u | |
content_attn = q_u.permute(1, 2, 0, 3) @ key.permute(1, 2, 3, 0) # bs x head_num x cur_seq x full_seq | |
# (query + v) * R^T | |
q_v = query + v | |
position_attn = q_v.permute(1, 2, 0, 3) @ r.permute(1, 2, 0) # bs x head_num x cur_seq x full_seq | |
position_attn = self._rel_shift(position_attn) | |
attn = content_attn + position_attn # bs x head_num x cur_seq x full_seq | |
attn.mul_(self.scale) | |
# fills float('-inf') where mask is True to let softmax ignore those positions. | |
if mask is not None and mask.any().item(): | |
mask = mask.permute(2, 0, 1).unsqueeze(1) # 1 x 1 x cur_seq x full_seq | |
assert mask.shape[2:] == attn.shape[2:] # check shape of mask | |
attn = attn.masked_fill(mask, -float("inf")).type_as(attn) | |
attn = F.softmax(attn, dim=-1) | |
attn = self.dropout(attn) | |
# multiply softmax output by value | |
attn_vec = attn @ value.permute(1, 2, 0, 3) | |
attn_vec = attn_vec.permute(2, 0, 1, 3) | |
attn_vec = attn_vec.contiguous().view(cur_seq, bs, self.head_num * self.head_dim) | |
# cur_seq x bs x head_num * head_dim | |
output = self.dropout(self.project(attn_vec)) # cur_seq x bs x input_dim | |
return output | |
class GatedTransformerXLLayer(torch.nn.Module): | |
""" | |
Overview: | |
This class implements the attention layer of GTrXL (Gated Transformer-XL). | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, | |
input_dim: int, | |
head_dim: int, | |
hidden_dim: int, | |
head_num: int, | |
mlp_num: int, | |
dropout: nn.Module, | |
activation: nn.Module, | |
gru_gating: bool = True, | |
gru_bias: float = 2. | |
) -> None: | |
""" | |
Overview: | |
Initialize GatedTransformerXLLayer. | |
Arguments: | |
- input_dim (:obj:`int`): The dimension of the input tensor. | |
- head_dim (:obj:`int`): The dimension of each head in the multi-head attention. | |
- hidden_dim (:obj:`int`): The dimension of the hidden layer in the MLP. | |
- head_num (:obj:`int`): The number of heads for the multi-head attention. | |
- mlp_num (:obj:`int`): The number of MLP layers in the attention layer. | |
- dropout (:obj:`nn.Module`): The dropout module used in the MLP and attention layers. | |
- activation (:obj:`nn.Module`): The activation function to be used in the MLP layers. | |
- gru_gating (:obj:`bool`, optional): Whether to use GRU gates. If False, replace GRU gates with \ | |
residual connections. Default is True. | |
- gru_bias (:obj:`float`, optional): The bias of the GRU gate. Default is 2. | |
""" | |
super(GatedTransformerXLLayer, self).__init__() | |
self.dropout = dropout | |
self.gating = gru_gating | |
if self.gating is True: | |
self.gate1 = GRUGatingUnit(input_dim, gru_bias) | |
self.gate2 = GRUGatingUnit(input_dim, gru_bias) | |
self.attention = AttentionXL( | |
input_dim, | |
head_dim, | |
head_num, | |
dropout, | |
) | |
layers = [] | |
dims = [input_dim] + [hidden_dim] * (mlp_num - 1) + [input_dim] | |
for i in range(mlp_num): | |
layers.append(fc_block(dims[i], dims[i + 1], activation=activation)) | |
if i != mlp_num - 1: | |
layers.append(self.dropout) | |
layers.append(self.dropout) | |
self.mlp = nn.Sequential(*layers) | |
self.layernorm1 = build_normalization('LN')(input_dim) | |
self.layernorm2 = build_normalization('LN')(input_dim) | |
self.activation = activation | |
def forward( | |
self, | |
inputs: torch.Tensor, | |
pos_embedding: torch.Tensor, | |
u: torch.nn.Parameter, | |
v: torch.nn.Parameter, | |
memory: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
""" | |
Overview: | |
Compute forward pass of GTrXL layer. | |
Arguments: | |
- inputs (:obj:`torch.Tensor`): The attention input tensor of shape (cur_seq, bs, input_dim). | |
- pos_embedding (:obj:`torch.Tensor`): The positional embedding tensor of shape (full_seq, 1, full_seq). | |
- u (:obj:`torch.nn.Parameter`): The content parameter tensor of shape (head_num, head_dim). | |
- v (:obj:`torch.nn.Parameter`): The position parameter tensor of shape (head_num, head_dim). | |
- memory (:obj:`torch.Tensor`): The memory tensor of shape (prev_seq, bs, input_dim). | |
- mask (:obj:`Optional[torch.Tensor]`): The attention mask tensor of shape (cur_seq, full_seq, 1). | |
Default is None. | |
Returns: | |
- output (:obj:`torch.Tensor`): layer output of shape (cur_seq, bs, input_dim) | |
""" | |
# concat memory with input across sequence dimension | |
full_input = torch.cat([memory, inputs], dim=0) # full_seq x bs x input_dim | |
x1 = self.layernorm1(full_input) | |
a1 = self.dropout(self.attention(inputs, pos_embedding, x1, u, v, mask=mask)) | |
a1 = self.activation(a1) # RELU after attention | |
o1 = self.gate1(inputs, a1) if self.gating else inputs + a1 | |
x2 = self.layernorm2(o1) | |
m2 = self.dropout(self.mlp(x2)) | |
o2 = self.gate2(o1, m2) if self.gating else o1 + m2 | |
return o2 | |
class GTrXL(nn.Module): | |
""" | |
Overview: | |
GTrXL Transformer implementation as described in "Stabilizing Transformer for Reinforcement Learning" | |
(https://arxiv.org/abs/1910.06764). | |
Interfaces: | |
``__init__``, ``forward``, ``reset_memory``, ``get_memory`` | |
""" | |
def __init__( | |
self, | |
input_dim: int, | |
head_dim: int = 128, | |
embedding_dim: int = 256, | |
head_num: int = 2, | |
mlp_num: int = 2, | |
layer_num: int = 3, | |
memory_len: int = 64, | |
dropout_ratio: float = 0., | |
activation: nn.Module = nn.ReLU(), | |
gru_gating: bool = True, | |
gru_bias: float = 2., | |
use_embedding_layer: bool = True, | |
) -> None: | |
"""Overview: | |
Init GTrXL Model. | |
Arguments: | |
- input_dim (:obj:`int`): The dimension of the input observation. | |
- head_dim (:obj:`int`, optional): The dimension of each head. Default is 128. | |
- embedding_dim (:obj:`int`, optional): The dimension of the embedding. Default is 256. | |
- head_num (:obj:`int`, optional): The number of heads for multi-head attention. Default is 2. | |
- mlp_num (:obj:`int`, optional): The number of MLP layers in the attention layer. Default is 2. | |
- layer_num (:obj:`int`, optional): The number of transformer layers. Default is 3. | |
- memory_len (:obj:`int`, optional): The length of memory. Default is 64. | |
- dropout_ratio (:obj:`float`, optional): The dropout ratio. Default is 0. | |
- activation (:obj:`nn.Module`, optional): The activation function. Default is nn.ReLU(). | |
- gru_gating (:obj:`bool`, optional): If False, replace GRU gates with residual connections. \ | |
Default is True. | |
- gru_bias (:obj:`float`, optional): The GRU gate bias. Default is 2.0. | |
- use_embedding_layer (:obj:`bool`, optional): If False, don't use input embedding layer. Default is True. | |
Raises: | |
- AssertionError: If `embedding_dim` is not an even number. | |
""" | |
super(GTrXL, self).__init__() | |
assert embedding_dim % 2 == 0, 'embedding_dim={} should be even'.format(input_dim) | |
self.head_num = head_num | |
self.head_dim = head_dim | |
self.layer_num = layer_num | |
if isinstance(input_dim, list): | |
input_dim = np.prod(input_dim) | |
self.use_embedding_layer = use_embedding_layer | |
if use_embedding_layer: | |
self.embedding = fc_block(input_dim, embedding_dim, activation=activation) | |
self.activation = activation | |
self.pos_embedding = PositionalEmbedding(embedding_dim) | |
# memory to save hidden states of past segments | |
# it will be initialized in the forward method to get its size dynamically | |
self.memory = None | |
self.memory_len = memory_len | |
layers = [] | |
dims = [embedding_dim] + [embedding_dim] * layer_num | |
self.dropout = nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Identity() | |
for i in range(layer_num): | |
layers.append( | |
GatedTransformerXLLayer( | |
dims[i], head_dim, embedding_dim, head_num, mlp_num, self.dropout, self.activation, gru_gating, | |
gru_bias | |
) | |
) | |
self.layers = nn.Sequential(*layers) | |
self.embedding_dim = embedding_dim | |
# u and v are the parameters to compute global content bias and global positional bias | |
self.u, self.v = ( | |
torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)), | |
torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)), | |
) | |
self.att_mask = {} # create an attention mask for each different seq_len, in this way we don't need to create a | |
# new one each time we call the forward method | |
self.pos_embedding_dict = {} # create a pos embedding for each different seq_len | |
def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None): | |
""" | |
Overview: | |
Clear or set the memory of GTrXL. | |
Arguments: | |
- batch_size (:obj:`Optional[int]`): The batch size. Default is None. | |
- state (:obj:`Optional[torch.Tensor]`): The input memory with shape \ | |
(layer_num, memory_len, bs, embedding_dim). Default is None. | |
""" | |
self.memory = Memory(memory_len=self.memory_len, layer_num=self.layer_num, embedding_dim=self.embedding_dim) | |
if batch_size is not None: | |
self.memory = Memory(self.memory_len, batch_size, self.embedding_dim, self.layer_num) | |
elif state is not None: | |
self.memory.init(state) | |
def get_memory(self): | |
""" | |
Overview: | |
Returns the memory of GTrXL. | |
Returns: | |
- memory (:obj:`Optional[torch.Tensor]`): The output memory or None if memory has not been initialized. \ | |
The shape is (layer_num, memory_len, bs, embedding_dim). | |
""" | |
if self.memory is None: | |
return None | |
else: | |
return self.memory.get() | |
def forward(self, x: torch.Tensor, batch_first: bool = False, return_mem: bool = True) -> Dict[str, torch.Tensor]: | |
""" | |
Overview: | |
Performs a forward pass on the GTrXL. | |
Arguments: | |
- x (:obj:`torch.Tensor`): The input tensor with shape (seq_len, bs, input_size). | |
- batch_first (:obj:`bool`, optional): If the input data has shape (bs, seq_len, input_size), \ | |
set this parameter to True to transpose along the first and second dimension and obtain shape \ | |
(seq_len, bs, input_size). This does not affect the output memory. Default is False. \ | |
- return_mem (:obj:`bool`, optional): If False, return only the output tensor without dict. Default is True. | |
Returns: | |
- x (:obj:`Dict[str, torch.Tensor]`): A dictionary containing the transformer output of shape \ | |
(seq_len, bs, embedding_size) and memory of shape (layer_num, seq_len, bs, embedding_size). | |
""" | |
if batch_first: | |
x = torch.transpose(x, 1, 0) # bs x cur_seq x input_dim -> cur_seq x bs x input_dim | |
cur_seq, bs = x.shape[:2] | |
memory = None if self.memory is None else self.memory.get() | |
if memory is None: | |
self.reset_memory(bs) # (layer_num+1) x memory_len x batch_size x embedding_dim | |
elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim: | |
warnings.warn( | |
"Memory {} and Input {} dimensions don't match," | |
" this will cause the memory to be initialized to fit your input!".format( | |
list(memory.shape[-2:]), [x.shape[-2]] + [self.embedding_dim] | |
) | |
) | |
self.reset_memory(bs) | |
self.memory.to(x.device) | |
memory = self.memory.get() | |
if self.use_embedding_layer: | |
x = self.dropout(self.embedding(x)) | |
prev_seq = self.memory_len | |
full_seq = cur_seq + prev_seq | |
if cur_seq in self.att_mask.keys(): | |
attn_mask = self.att_mask[cur_seq] | |
else: | |
attn_mask = ( | |
torch.triu( | |
torch.ones((cur_seq, full_seq)), | |
diagonal=1 + prev_seq, # fixed in train, eval, collect | |
).bool().unsqueeze(-1).to(x.device) | |
) # cur_seq x full_seq x 1 | |
self.att_mask[cur_seq] = attn_mask | |
if cur_seq in self.pos_embedding_dict.keys(): | |
pos_embedding = self.pos_embedding_dict[cur_seq] | |
else: | |
pos_ips = torch.arange(full_seq - 1, -1, -1.0, dtype=torch.float) # full_seq | |
pos_embedding = self.pos_embedding(pos_ips.to(x.device)) | |
self.pos_embedding_dict[cur_seq] = pos_embedding | |
pos_embedding = self.dropout(pos_embedding) # full_seq x 1 x embedding_dim | |
hidden_state = [x] | |
out = x | |
for i in range(self.layer_num): | |
layer = self.layers[i] | |
out = layer( | |
out, | |
pos_embedding, | |
self.u, | |
self.v, | |
mask=attn_mask, | |
memory=memory[i], # (layer_num+1) x memory_len x batch_size x embedding_dim | |
) # cur_seq x bs x embedding_dim | |
hidden_state.append(out.clone()) | |
out = self.dropout(out) | |
self.memory.update(hidden_state) # (layer_num+1) x memory_len x batch_size x embedding_dim | |
if batch_first: | |
out = torch.transpose(out, 1, 0) # cur_seq x bs x embedding_dim -> bs x cur_seq x embedding_dim | |
if return_mem: | |
output = {"logit": out, "memory": memory} # return the content of the memory before the last update | |
else: | |
output = {"logit": out} | |
return output | |