VoiceStar / models /modules /transformer.py
mrfakename's picture
Upload 51 files
82bc972 verified
import copy, logging
import numbers
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from .activation import MultiheadAttention
from .scaling import ActivationBalancer, BalancedDoubleSwish
from .scaling import BasicNorm as _BasicNorm
_shape_t = Union[int, List[int], torch.Size]
class LayerNorm(nn.Module):
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class AdaptiveLayerNorm(nn.Module):
r"""Adaptive Layer Normalization"""
def __init__(self, d_model, norm) -> None:
super(AdaptiveLayerNorm, self).__init__()
self.project_layer = nn.Linear(d_model, 2 * d_model)
self.norm = norm
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1,
)
return weight * self.norm(input) + bias
class BasicNorm(_BasicNorm):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
):
super(BasicNorm, self).__init__(d_model, eps=eps)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (
super(BasicNorm, self).forward(input),
embedding,
)
assert embedding is None
return super(BasicNorm, self).forward(input)
class BalancedBasicNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
):
super(BalancedBasicNorm, self).__init__()
self.balancer = ActivationBalancer(
d_model,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0,
)
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return self.norm((self.balancer(input), embedding))
assert embedding is None
return self.norm(self.balancer(input))
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
assert embedding is None
return input
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
activation = activation(d_model)
elif activation == BalancedDoubleSwish:
activation = BalancedDoubleSwish(d_model)
# # We can't test self.activation in forward() in TorchScript,
# # so stash some information about it instead.
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
# self.activation_relu_or_gelu = 1
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
# self.activation_relu_or_gelu = 2
# else:
# self.activation_relu_or_gelu = 0
self.activation = activation
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
else:
self.norm1 = norm1
self.norm2 = norm2
def __setstate__(self, state):
super(TransformerEncoderLayer, self).__setstate__(state)
if not hasattr(self, "activation"):
self.activation = F.relu
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
need_weights: Optional[bool] = False,
past: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
if isinstance(src, dict):
sinu = src["sinu"]
pm_sinu = src["pm_sinu"]
src = src["input"]
else:
sinu = None
pm_sinu = None
x, stage_embedding = src, None
is_src_tuple = False
if isinstance(src, tuple):
x, stage_embedding = src
is_src_tuple = True
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if need_weights:
raise NotImplementedError
if self.norm_first:
out, attn = self._sa_block_attn(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,
past, sinu = sinu
)
out, present = out # present is the kvcache of the present timestep
x = x + out
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past, sinu = sinu)
out, present = out # present is the kvcache of the present timestep
x = self.norm1(
x + out,
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
assert not is_src_tuple
# return (x, stage_embedding)
return (x, attn)
else:
if self.norm_first:
out = self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask, past, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q']
)
out, present = out # present is the kvcache of the present timestep
x = x + out
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
out = self._sa_block(x, src_mask, src_key_padding_mask, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'])
out, present = out # present is the kvcache of the present timestep
x = self.norm1(
x + out,
stage_embedding, past
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
x = (x, stage_embedding)
if present != None:
x = [x, present]
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
past: Optional[Tensor] = None,
sinu = None,
q_sinu = None,
k_sinu = None
) -> Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
past=past,
sinu = sinu,
q_sinu = q_sinu,
k_sinu = k_sinu
)
x, present = x
return self.dropout1(x), present
# self-attention block, also return attention weights
def _sa_block_attn(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
past: Optional[Tensor] = None,
) -> Tensor:
x, attn = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=True,
past=past
)
x, present = x
return (self.dropout1(x), present), attn
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
def pre_compute_sinusoidal(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1]
inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2]
freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2]
freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d]
return {"sin": freqs.sin(), "cos": freqs.cos()}
def pre_compute_freqs(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1]
inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2]
freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2]
freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d]
return freqs
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
enable_nested_tensor: if True, input will automatically convert to nested tensor
(and convert back on output). This will improve the overall performance of
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None, rope_base=None, d_model=None, nhead=None, args=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
if args != None:
self.progress_no_multiple = args.progress_no_multiple
self.progress_scale = args.progress_scale
else:
self.progress_no_multiple = False
self.progress_scale = 1
if rope_base is not None:
if self.progress_no_multiple:
self.pm_freqs = pre_compute_freqs(d_model//nhead, rope_base)
self.sinu = None
else:
self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base)
self.pm_freqs = None
# logging.info(f"get precomputed sinusoidal for {rope_base=}: {self.sinu=}")
else:
self.sinu = None
self.pm_freqs = None
def forward(
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
need_weights:Optional[bool] = False,
past: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
return_layer_states: return layers' state (optional).
Shape:
see the docs in Transformer class.
"""
if return_layer_states:
raise NotImplementedError
assert not need_weights
layer_states = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
past=past
)
layer_states.append(output[0])
if self.norm is not None:
output = self.norm(output)
return layer_states, output
if need_weights:
raise NotImplementedError
assert not return_layer_states
layer_attn = [] # layers' output
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
need_weights=True,
past=past
)
layer_attn.append(output[1])
if self.norm is not None:
output = self.norm(output)
return layer_attn, output
output = src
all_present = []
if self.sinu is not None:
# use rope
assert self.pm_freqs is None
for k, v in self.sinu.items():
self.sinu[k] = v.to(output.device)
if self.pm_freqs is not None:
assert self.sinu is None
self.pm_freqs = self.pm_freqs.to(output.device)
if src_key_padding_mask != None:
query_lens = (~src_key_padding_mask).int().sum(-1).to(output.device)
else:
query_lens = torch.tensor([output.shape[1]]*output.shape[0]).to(output.device)
assert query_lens.ndim==1, query_lens
q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1)
q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d]
q_emb = q_emb / q_lens_expanded * self.progress_scale
q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead
q_sin = q_emb.sin().unsqueeze(1)
self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}}
else:
self.pm_sinu = {"q": None}
output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu}
for n_layer, mod in enumerate(self.layers):
output = mod(
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
)
if isinstance(output, list):
output, present = output
all_present.append(present)
if self.sinu is not None or self.pm_sinu is not None:
output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu}
if self.sinu is not None or self.pm_sinu is not None:
output = output["input"]
if self.norm is not None:
output = self.norm(output)
if all_present != []:
all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
output = [output, all_present]
return output
class TransformerDecoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
self.multihead_attn = MultiheadAttention(
d_model,
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
elif isinstance(activation, partial):
self.activation = activation(d_model)
elif activation == BalancedDoubleSwish:
self.activation = BalancedDoubleSwish(d_model)
else:
self.activation = activation
if adaptive_layer_norm:
norm1 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
else:
self.norm1 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
self.norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
if layer_norm_cls == IdentityNorm:
self.norm3 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs
)
else:
self.norm3 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs
)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
tgt_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used
memory_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used
past: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
past: the previous kvcache of the decoder (optional). shape: (2, batch_size, num_heads, seq_len, head_dim)
Shape:
see the docs in Transformer class.
"""
if isinstance(tgt, dict):
pm_sinu = tgt["pm_sinu"]
sinu = tgt["sinu"]
args = tgt["args"]
tgt = tgt["input"]
else:
pm_sinu = None
sinu = None
args = None
tgt_is_tuple = False
if isinstance(tgt, tuple):
x, stage_embedding = tgt
tgt_is_tuple = True
else:
x, stage_embedding = tgt, None
# logging.info(f"{tgt_key_padding_mask=}, {memory_key_padding_mask=}")
# logging.info(f"{tgt_key_padding_mask.shape=}, {memory_key_padding_mask.shape=}")
# logging.info(f"{query_lens=}, {key_lens=}")
# past stores the kvcache for self-attention, and it can also be used to infer q_offset
if past is not None and past.ndim > 2:
q_offset = past[0].shape[-2] # past is (2, batch_size, num_heads, seq_len, head_dim), 2 contains [k, v], these are for self-attn, therefore also reflect the length of q
else:
q_offset = 0
if self.norm_first:
temp = self._sa_block(
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset
)
present = temp[1]
x = x + temp[0]
cross_out = self._mha_block(
self.norm2(x, stage_embedding),
memory,
memory_mask,
memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args = args, q_offset=q_offset
)
if isinstance(cross_out, dict):
attention_weights = cross_out["attention_weights"]
cross_out = cross_out["x"]
else:
attention_weights = None
x = x + cross_out
x = x + self._ff_block(self.norm3(x, stage_embedding))
else:
temp = self._sa_block(x, tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset)
present = temp[1]
x = self.norm1(
x + temp[0],
stage_embedding,
)
cross_out = self._mha_block(
x, memory, memory_mask, memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args=args, q_offset=q_offset
)
if isinstance(cross_out, dict):
attention_weights = cross_out["attention_weights"]
cross_out = cross_out["x"]
else:
attention_weights = None
x = self.norm2(
x
+ cross_out,
stage_embedding,
)
x = self.norm3(x + self._ff_block(x), stage_embedding)
if attention_weights is not None:
x = {"x": x, "attention_weights": attention_weights}
if tgt_is_tuple:
x = (x, stage_embedding)
if present != None:
x = [x, present]
return x
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
q_sinu=None,
k_sinu=None,
sinu = None,
args = None,
past = None,
q_offset = 0
) -> Tensor:
# if past is not None and past.ndim > 2:
# print(f"self-attn, k len: {past[0].shape[-2] + x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
# else:
# print(f"self-attn, k len: {x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
q_sinu = q_sinu,
k_sinu = k_sinu,
sinu = sinu,
past = past,
q_offset = q_offset
)
x, present = x
return self.dropout1(x), present
# multihead attention block
def _mha_block(
self,
x: Tensor,
mem: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
q_sinu = None,
k_sinu = None,
sinu = None,
args = None,
q_offset = 0
) -> Tensor:
# print(f"cross-attn, k len: {mem.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}")
x = self.multihead_attn(
x,
mem,
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
q_sinu = q_sinu,
k_sinu = k_sinu,
sinu = sinu,
args = args,
q_offset = q_offset
)
if len(x) == 2 and isinstance(x[0], dict) and "attention_weights" in x[0]:
x, present = x
attention_weights = x['attention_weights']
x = x['attn_output']
return {"x": self.dropout2(x), "attention_weights": attention_weights}
elif len(x) == 2:
x = x[0]
return self.dropout2(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
def _generate_square_subsequent_mask(
sz: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
r"""Generate a square causal mask for the sequence.
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
"""
if device is None:
device = torch.device('cpu')
if dtype is None:
dtype = torch.float32
return torch.triu(
torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
diagonal=1,
)
def _get_seq_len(
src: Tensor,
batch_first: bool
) -> Optional[int]:
if src.is_nested:
return None
else:
src_size = src.size()
if len(src_size) == 2:
# unbatched: S, E
return src_size[0]
else:
# batched: B, S, E if batch_first else S, B, E
seq_len_pos = 1 if batch_first else 0
return src_size[seq_len_pos]
def _detect_is_causal_mask(
mask: Optional[Tensor],
is_causal: Optional[bool] = None,
size: Optional[int] = None,
) -> bool:
"""Return whether the given attention mask is causal.
Warning:
If ``is_causal`` is not ``None``, its value will be returned as is. If a
user supplies an incorrect ``is_causal`` hint,
``is_causal=False`` when the mask is in fact a causal attention.mask
may lead to reduced performance relative to what would be achievable
with ``is_causal=True``;
``is_causal=True`` when the mask is in fact not a causal attention.mask
may lead to incorrect and unpredictable execution - in some scenarios,
a causal mask may be applied based on the hint, in other execution
scenarios the specified mask may be used. The choice may not appear
to be deterministic, in that a number of factors like alignment,
hardware SKU, etc influence the decision whether to use a mask or
rely on the hint.
``size`` if not None, check whether the mask is a causal mask of the provided size
Otherwise, checks for any causal mask.
"""
# Prevent type refinement
make_causal = (is_causal is True)
if is_causal is None and mask is not None:
sz = size if size is not None else mask.size(-2)
causal_comparison = _generate_square_subsequent_mask(
sz, device=mask.device, dtype=mask.dtype)
# Do not use `torch.equal` so we handle batched masks by
# broadcasting the comparison.
if mask.size() == causal_comparison.size():
make_causal = bool((mask == causal_comparison).all())
else:
make_causal = False
return make_causal
class TransformerDecoder(nn.Module):
r"""TransformerDecoder is a stack of N decoder layers.
Args:
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
"""
__constants__ = ['norm']
def __init__(
self,
decoder_layer: "TransformerDecoderLayer",
num_layers: int,
norm: Optional[nn.Module] = None,
rope_base=None,
d_model=None,
nhead=None,
args=None
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.args = args
if getattr(self.args, 'decoder_regular_rope', False):
self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base)
self.pm_freqs = None
else:
self.sinu = None
if rope_base is not None:
self.pm_freqs = pre_compute_freqs(d_model/nhead, rope_base)
# logging.info(f"get precomputed freqs for {rope_base=}: {self.freqs=}")
else:
self.pm_freqs = None
self.progress_scale = getattr(self.args, "progress_scale", 1.0)
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
memory_is_causal: bool = False, query_lens: Optional[Tensor] = None, key_lens: Optional[Tensor] = None, past: Optional[Tensor] = None) -> Tensor:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
Default: ``None``; try to detect a causal mask.
Warning:
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
memory_is_causal: If specified, applies a causal mask as
``memory mask``.
Default: ``False``.
Warning:
``memory_is_causal`` provides a hint that
``memory_mask`` is the causal mask. Providing incorrect
hints can result in incorrect execution, including
forward and backward compatibility.
Shape:
see the docs in :class:`~torch.nn.Transformer`.
"""
output = tgt
# seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
# tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
if self.sinu is not None:
assert self.pm_freqs is None
for key in self.sinu:
self.sinu[key] = self.sinu[key].to(output.device)
if self.pm_freqs is not None:
assert self.sinu is None
if not self.training and hasattr(self, "pm_sinu") and past is not None and past[0].ndim > 2: # inference mode, will use cached sinu for the same example
assert self.pm_sinu['q'] is not None and self.pm_sinu['k'] is not None
# check batch size, need to modify the batch size if we use multi_trial during inference
if self.pm_sinu['q']['cos'].shape[0] != tgt.shape[0]:
if self.pm_sinu['q']['cos'].shape[0] > tgt.shape[0]:
self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'][:tgt.shape[0]]
self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'][:tgt.shape[0]]
self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'][:tgt.shape[0]]
self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'][:tgt.shape[0]]
else:
assert self.pm_sinu['q']['cos'].shape[0] == 1
self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'].repeat(tgt.shape[0], 1, 1, 1)
self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'].repeat(tgt.shape[0], 1, 1, 1)
self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'].repeat(tgt.shape[0], 1, 1, 1)
self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'].repeat(tgt.shape[0], 1, 1, 1)
pass
else:
self.pm_freqs = self.pm_freqs.to(output.device)
if query_lens is None:
query_lens = (~tgt_key_padding_mask).int().sum(-1).to(tgt.device)
if key_lens is None:
key_lens = (~memory_key_padding_mask).int().sum(-1).to(tgt.device)
assert key_lens.ndim==1, key_lens
assert query_lens.ndim==1, query_lens
q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1]
query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1)
key_ids_multiple = k_lens_expanded / (k_lens_expanded - 1)
q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d]
k_emb = self.pm_freqs * key_ids_multiple # [B, k_len_max, d]
q_emb = q_emb / q_lens_expanded * self.progress_scale
k_emb = k_emb / k_lens_expanded * self.progress_scale
q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead
q_sin = q_emb.sin().unsqueeze(1)
k_cos = k_emb.cos().unsqueeze(1)
k_sin = k_emb.sin().unsqueeze(1)
self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}, "k": {"cos": k_cos, "sin": k_sin}}
else:
self.pm_sinu = {"q": None, "k": None}
output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args}
if past != None:
all_present = []
if self.training and getattr(self.args, "attention_alignment_loss", 0):
all_attn_weights = []
for i, mod in enumerate(self.layers):
output = mod(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
past=past[i] if past != None else None
# tgt_is_causal=tgt_is_causal,
# memory_is_causal=memory_is_causal
)
if past != None:
output, cur_present = output
all_present.append(cur_present)
if isinstance(output, dict):
current_attn_weights = output["attention_weights"]
all_attn_weights.append(current_attn_weights)
output = output["x"]
if self.sinu is not None or self.pm_sinu is not None:
output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args}
if self.pm_sinu is not None or self.sinu is not None:
output = output["input"]
if self.norm is not None:
output = self.norm(output)
if self.training and getattr(self.args, "attention_alignment_loss", 0):
assert len(all_attn_weights) == self.num_layers, f"{len(all_attn_weights)=}, {self.num_layers=}"
output = {"output": output, "attention_weights": all_attn_weights}
if past != None:
all_present = torch.stack(all_present, dim=0)
output = [output, all_present]
else:
output = [output, None]
return output