Spaces:
Running
on
Zero
Running
on
Zero
import copy, logging | |
import numbers | |
from functools import partial | |
from typing import Any, Callable, List, Optional, Tuple, Union | |
import torch | |
from torch import Tensor, nn | |
from torch.nn import functional as F | |
from .activation import MultiheadAttention | |
from .scaling import ActivationBalancer, BalancedDoubleSwish | |
from .scaling import BasicNorm as _BasicNorm | |
_shape_t = Union[int, List[int], torch.Size] | |
class LayerNorm(nn.Module): | |
__constants__ = ["normalized_shape", "eps", "elementwise_affine"] | |
normalized_shape: Tuple[int, ...] | |
eps: float | |
elementwise_affine: bool | |
def __init__( | |
self, | |
normalized_shape: _shape_t, | |
eps: float = 1e-5, | |
elementwise_affine: bool = True, | |
device=None, | |
dtype=None, | |
) -> None: | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super(LayerNorm, self).__init__() | |
if isinstance(normalized_shape, numbers.Integral): | |
# mypy error: incompatible types in assignment | |
normalized_shape = (normalized_shape,) # type: ignore[assignment] | |
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] | |
self.eps = eps | |
self.elementwise_affine = elementwise_affine | |
if self.elementwise_affine: | |
self.weight = nn.Parameter( | |
torch.empty(self.normalized_shape, **factory_kwargs) | |
) | |
self.bias = nn.Parameter( | |
torch.empty(self.normalized_shape, **factory_kwargs) | |
) | |
else: | |
self.register_parameter("weight", None) | |
self.register_parameter("bias", None) | |
self.reset_parameters() | |
def reset_parameters(self) -> None: | |
if self.elementwise_affine: | |
nn.init.ones_(self.weight) | |
nn.init.zeros_(self.bias) | |
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
if isinstance(input, tuple): | |
input, embedding = input | |
return ( | |
F.layer_norm( | |
input, | |
self.normalized_shape, | |
self.weight, | |
self.bias, | |
self.eps, | |
), | |
embedding, | |
) | |
assert embedding is None | |
return F.layer_norm( | |
input, self.normalized_shape, self.weight, self.bias, self.eps | |
) | |
def extra_repr(self) -> str: | |
return ( | |
"{normalized_shape}, eps={eps}, " | |
"elementwise_affine={elementwise_affine}".format(**self.__dict__) | |
) | |
class AdaptiveLayerNorm(nn.Module): | |
r"""Adaptive Layer Normalization""" | |
def __init__(self, d_model, norm) -> None: | |
super(AdaptiveLayerNorm, self).__init__() | |
self.project_layer = nn.Linear(d_model, 2 * d_model) | |
self.norm = norm | |
self.d_model = d_model | |
self.eps = self.norm.eps | |
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: | |
if isinstance(input, tuple): | |
input, embedding = input | |
weight, bias = torch.split( | |
self.project_layer(embedding), | |
split_size_or_sections=self.d_model, | |
dim=-1, | |
) | |
return (weight * self.norm(input) + bias, embedding) | |
weight, bias = torch.split( | |
self.project_layer(embedding), | |
split_size_or_sections=self.d_model, | |
dim=-1, | |
) | |
return weight * self.norm(input) + bias | |
class BasicNorm(_BasicNorm): | |
def __init__( | |
self, | |
d_model: int, | |
eps: float = 1e-5, | |
device=None, | |
dtype=None, | |
): | |
super(BasicNorm, self).__init__(d_model, eps=eps) | |
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
if isinstance(input, tuple): | |
input, embedding = input | |
return ( | |
super(BasicNorm, self).forward(input), | |
embedding, | |
) | |
assert embedding is None | |
return super(BasicNorm, self).forward(input) | |
class BalancedBasicNorm(nn.Module): | |
def __init__( | |
self, | |
d_model: int, | |
eps: float = 1e-5, | |
device=None, | |
dtype=None, | |
): | |
super(BalancedBasicNorm, self).__init__() | |
self.balancer = ActivationBalancer( | |
d_model, | |
channel_dim=-1, | |
min_positive=0.45, | |
max_positive=0.55, | |
max_abs=6.0, | |
) | |
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) | |
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
if isinstance(input, tuple): | |
input, embedding = input | |
return self.norm((self.balancer(input), embedding)) | |
assert embedding is None | |
return self.norm(self.balancer(input)) | |
class IdentityNorm(nn.Module): | |
def __init__( | |
self, | |
d_model: int, | |
eps: float = 1e-5, | |
device=None, | |
dtype=None, | |
) -> None: | |
super(IdentityNorm, self).__init__() | |
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
if isinstance(input, tuple): | |
return input | |
assert embedding is None | |
return input | |
class TransformerEncoderLayer(nn.Module): | |
__constants__ = ["batch_first", "norm_first"] | |
def __init__( | |
self, | |
d_model: int, | |
nhead: int, | |
dim_feedforward: int = 2048, | |
dropout: float = 0.1, | |
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | |
batch_first: bool = False, | |
norm_first: bool = False, | |
device=None, | |
dtype=None, | |
linear1_self_attention_cls: nn.Module = nn.Linear, | |
linear2_self_attention_cls: nn.Module = nn.Linear, | |
linear1_feedforward_cls: nn.Module = nn.Linear, | |
linear2_feedforward_cls: nn.Module = nn.Linear, | |
layer_norm_cls: nn.Module = LayerNorm, | |
layer_norm_eps: float = 1e-5, | |
adaptive_layer_norm=False, | |
) -> None: | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super(TransformerEncoderLayer, self).__init__() | |
self.self_attn = MultiheadAttention( | |
d_model, | |
nhead, | |
dropout=dropout, | |
batch_first=batch_first, | |
linear1_cls=linear1_self_attention_cls, | |
linear2_cls=linear2_self_attention_cls, | |
**factory_kwargs, | |
) | |
# Implementation of Feedforward model | |
self.linear1 = linear1_feedforward_cls( | |
d_model, dim_feedforward, **factory_kwargs | |
) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = linear2_feedforward_cls( | |
dim_feedforward, d_model, **factory_kwargs | |
) | |
self.norm_first = norm_first | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
# Legacy string support for activation function. | |
if isinstance(activation, str): | |
activation = _get_activation_fn(activation) | |
elif isinstance(activation, partial): | |
activation = activation(d_model) | |
elif activation == BalancedDoubleSwish: | |
activation = BalancedDoubleSwish(d_model) | |
# # We can't test self.activation in forward() in TorchScript, | |
# # so stash some information about it instead. | |
# if activation is F.relu or isinstance(activation, torch.nn.ReLU): | |
# self.activation_relu_or_gelu = 1 | |
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU): | |
# self.activation_relu_or_gelu = 2 | |
# else: | |
# self.activation_relu_or_gelu = 0 | |
self.activation = activation | |
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) | |
if layer_norm_cls == IdentityNorm: | |
norm2 = BalancedBasicNorm( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
else: | |
norm2 = layer_norm_cls( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
if adaptive_layer_norm: | |
self.norm1 = AdaptiveLayerNorm(d_model, norm1) | |
self.norm2 = AdaptiveLayerNorm(d_model, norm2) | |
else: | |
self.norm1 = norm1 | |
self.norm2 = norm2 | |
def __setstate__(self, state): | |
super(TransformerEncoderLayer, self).__setstate__(state) | |
if not hasattr(self, "activation"): | |
self.activation = F.relu | |
def forward( | |
self, | |
src, | |
src_mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
need_weights: Optional[bool] = False, | |
past: Optional[Tensor] = None, | |
) -> Tensor: | |
r"""Pass the input through the encoder layer. | |
Args: | |
src: the sequence to the encoder layer (required). | |
src_mask: the mask for the src sequence (optional). | |
src_key_padding_mask: the mask for the src keys per batch (optional). | |
Shape: | |
see the docs in Transformer class. | |
""" | |
if isinstance(src, dict): | |
sinu = src["sinu"] | |
pm_sinu = src["pm_sinu"] | |
src = src["input"] | |
else: | |
sinu = None | |
pm_sinu = None | |
x, stage_embedding = src, None | |
is_src_tuple = False | |
if isinstance(src, tuple): | |
x, stage_embedding = src | |
is_src_tuple = True | |
if src_key_padding_mask is not None: | |
_skpm_dtype = src_key_padding_mask.dtype | |
if _skpm_dtype != torch.bool and not torch.is_floating_point( | |
src_key_padding_mask | |
): | |
raise AssertionError( | |
"only bool and floating types of key_padding_mask are supported" | |
) | |
if need_weights: | |
raise NotImplementedError | |
if self.norm_first: | |
out, attn = self._sa_block_attn( | |
self.norm1(x, stage_embedding), | |
src_mask, | |
src_key_padding_mask, | |
past, sinu = sinu | |
) | |
out, present = out # present is the kvcache of the present timestep | |
x = x + out | |
x = x + self._ff_block(self.norm2(x, stage_embedding)) | |
else: | |
out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past, sinu = sinu) | |
out, present = out # present is the kvcache of the present timestep | |
x = self.norm1( | |
x + out, | |
stage_embedding, | |
) | |
x = self.norm2(x + self._ff_block(x), stage_embedding) | |
assert not is_src_tuple | |
# return (x, stage_embedding) | |
return (x, attn) | |
else: | |
if self.norm_first: | |
out = self._sa_block( | |
self.norm1(x, stage_embedding), | |
src_mask, | |
src_key_padding_mask, past, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'] | |
) | |
out, present = out # present is the kvcache of the present timestep | |
x = x + out | |
x = x + self._ff_block(self.norm2(x, stage_embedding)) | |
else: | |
out = self._sa_block(x, src_mask, src_key_padding_mask, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q']) | |
out, present = out # present is the kvcache of the present timestep | |
x = self.norm1( | |
x + out, | |
stage_embedding, past | |
) | |
x = self.norm2(x + self._ff_block(x), stage_embedding) | |
if is_src_tuple: | |
x = (x, stage_embedding) | |
if present != None: | |
x = [x, present] | |
return x | |
# self-attention block | |
def _sa_block( | |
self, | |
x: Tensor, | |
attn_mask: Optional[Tensor], | |
key_padding_mask: Optional[Tensor], | |
past: Optional[Tensor] = None, | |
sinu = None, | |
q_sinu = None, | |
k_sinu = None | |
) -> Tensor: | |
x = self.self_attn( | |
x, | |
x, | |
x, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask, | |
need_weights=False, | |
past=past, | |
sinu = sinu, | |
q_sinu = q_sinu, | |
k_sinu = k_sinu | |
) | |
x, present = x | |
return self.dropout1(x), present | |
# self-attention block, also return attention weights | |
def _sa_block_attn( | |
self, | |
x: Tensor, | |
attn_mask: Optional[Tensor], | |
key_padding_mask: Optional[Tensor], | |
past: Optional[Tensor] = None, | |
) -> Tensor: | |
x, attn = self.self_attn( | |
x, | |
x, | |
x, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask, | |
need_weights=True, | |
past=past | |
) | |
x, present = x | |
return (self.dropout1(x), present), attn | |
# feed forward block | |
def _ff_block(self, x: Tensor) -> Tensor: | |
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
return self.dropout2(x) | |
def pre_compute_sinusoidal(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz | |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) | |
position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1] | |
inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2] | |
freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2] | |
freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d] | |
return {"sin": freqs.sin(), "cos": freqs.cos()} | |
def pre_compute_freqs(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz | |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) | |
position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1] | |
inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2] | |
freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2] | |
freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d] | |
return freqs | |
class TransformerEncoder(nn.Module): | |
r"""TransformerEncoder is a stack of N encoder layers. Users can build the | |
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. | |
Args: | |
encoder_layer: an instance of the TransformerEncoderLayer() class (required). | |
num_layers: the number of sub-encoder-layers in the encoder (required). | |
norm: the layer normalization component (optional). | |
enable_nested_tensor: if True, input will automatically convert to nested tensor | |
(and convert back on output). This will improve the overall performance of | |
TransformerEncoder when padding rate is high. Default: ``True`` (enabled). | |
Examples:: | |
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) | |
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) | |
>>> src = torch.rand(10, 32, 512) | |
>>> out = transformer_encoder(src) | |
""" | |
__constants__ = ["norm"] | |
def __init__(self, encoder_layer, num_layers, norm=None, rope_base=None, d_model=None, nhead=None, args=None): | |
super(TransformerEncoder, self).__init__() | |
self.layers = _get_clones(encoder_layer, num_layers) | |
self.num_layers = num_layers | |
self.norm = norm | |
if args != None: | |
self.progress_no_multiple = args.progress_no_multiple | |
self.progress_scale = args.progress_scale | |
else: | |
self.progress_no_multiple = False | |
self.progress_scale = 1 | |
if rope_base is not None: | |
if self.progress_no_multiple: | |
self.pm_freqs = pre_compute_freqs(d_model//nhead, rope_base) | |
self.sinu = None | |
else: | |
self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base) | |
self.pm_freqs = None | |
# logging.info(f"get precomputed sinusoidal for {rope_base=}: {self.sinu=}") | |
else: | |
self.sinu = None | |
self.pm_freqs = None | |
def forward( | |
self, | |
src: Tensor, | |
mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
return_layer_states: bool = False, | |
need_weights:Optional[bool] = False, | |
past: Optional[Tensor] = None, | |
) -> Tensor: | |
r"""Pass the input through the encoder layers in turn. | |
Args: | |
src: the sequence to the encoder (required). | |
mask: the mask for the src sequence (optional). | |
src_key_padding_mask: the mask for the src keys per batch (optional). | |
return_layer_states: return layers' state (optional). | |
Shape: | |
see the docs in Transformer class. | |
""" | |
if return_layer_states: | |
raise NotImplementedError | |
assert not need_weights | |
layer_states = [] # layers' output | |
output = src | |
for mod in self.layers: | |
output = mod( | |
output, | |
src_mask=mask, | |
src_key_padding_mask=src_key_padding_mask, | |
past=past | |
) | |
layer_states.append(output[0]) | |
if self.norm is not None: | |
output = self.norm(output) | |
return layer_states, output | |
if need_weights: | |
raise NotImplementedError | |
assert not return_layer_states | |
layer_attn = [] # layers' output | |
output = src | |
for mod in self.layers: | |
output = mod( | |
output, | |
src_mask=mask, | |
src_key_padding_mask=src_key_padding_mask, | |
need_weights=True, | |
past=past | |
) | |
layer_attn.append(output[1]) | |
if self.norm is not None: | |
output = self.norm(output) | |
return layer_attn, output | |
output = src | |
all_present = [] | |
if self.sinu is not None: | |
# use rope | |
assert self.pm_freqs is None | |
for k, v in self.sinu.items(): | |
self.sinu[k] = v.to(output.device) | |
if self.pm_freqs is not None: | |
assert self.sinu is None | |
self.pm_freqs = self.pm_freqs.to(output.device) | |
if src_key_padding_mask != None: | |
query_lens = (~src_key_padding_mask).int().sum(-1).to(output.device) | |
else: | |
query_lens = torch.tensor([output.shape[1]]*output.shape[0]).to(output.device) | |
assert query_lens.ndim==1, query_lens | |
q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] | |
query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1) | |
q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d] | |
q_emb = q_emb / q_lens_expanded * self.progress_scale | |
q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead | |
q_sin = q_emb.sin().unsqueeze(1) | |
self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}} | |
else: | |
self.pm_sinu = {"q": None} | |
output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu} | |
for n_layer, mod in enumerate(self.layers): | |
output = mod( | |
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer] | |
) | |
if isinstance(output, list): | |
output, present = output | |
all_present.append(present) | |
if self.sinu is not None or self.pm_sinu is not None: | |
output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu} | |
if self.sinu is not None or self.pm_sinu is not None: | |
output = output["input"] | |
if self.norm is not None: | |
output = self.norm(output) | |
if all_present != []: | |
all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim) | |
output = [output, all_present] | |
return output | |
class TransformerDecoderLayer(nn.Module): | |
__constants__ = ["batch_first", "norm_first"] | |
def __init__( | |
self, | |
d_model: int, | |
nhead: int, | |
dim_feedforward: int = 2048, | |
dropout: float = 0.1, | |
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | |
linear1_self_attention_cls: nn.Module = nn.Linear, | |
linear2_self_attention_cls: nn.Module = nn.Linear, | |
linear1_feedforward_cls: nn.Module = nn.Linear, | |
linear2_feedforward_cls: nn.Module = nn.Linear, | |
batch_first: bool = False, | |
norm_first: bool = False, | |
device=None, | |
dtype=None, | |
layer_norm_cls: nn.Module = LayerNorm, | |
layer_norm_eps: float = 1e-5, | |
adaptive_layer_norm=False, | |
) -> None: | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super(TransformerDecoderLayer, self).__init__() | |
self.self_attn = MultiheadAttention( | |
d_model, | |
nhead, | |
dropout=dropout, | |
batch_first=batch_first, | |
linear1_cls=linear1_self_attention_cls, | |
linear2_cls=linear2_self_attention_cls, | |
**factory_kwargs, | |
) | |
self.multihead_attn = MultiheadAttention( | |
d_model, | |
nhead, | |
dropout=dropout, | |
batch_first=batch_first, | |
linear1_cls=linear1_self_attention_cls, | |
linear2_cls=linear2_self_attention_cls, | |
**factory_kwargs, | |
) | |
# Implementation of Feedforward model | |
self.linear1 = linear1_feedforward_cls( | |
d_model, dim_feedforward, **factory_kwargs | |
) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = linear2_feedforward_cls( | |
dim_feedforward, d_model, **factory_kwargs | |
) | |
self.norm_first = norm_first | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
# Legacy string support for activation function. | |
if isinstance(activation, str): | |
self.activation = _get_activation_fn(activation) | |
elif isinstance(activation, partial): | |
self.activation = activation(d_model) | |
elif activation == BalancedDoubleSwish: | |
self.activation = BalancedDoubleSwish(d_model) | |
else: | |
self.activation = activation | |
if adaptive_layer_norm: | |
norm1 = layer_norm_cls( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
norm2 = layer_norm_cls( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
norm3 = layer_norm_cls( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
self.norm1 = AdaptiveLayerNorm(d_model, norm1) | |
self.norm2 = AdaptiveLayerNorm(d_model, norm2) | |
self.norm3 = AdaptiveLayerNorm(d_model, norm3) | |
else: | |
self.norm1 = layer_norm_cls( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
self.norm2 = layer_norm_cls( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
if layer_norm_cls == IdentityNorm: | |
self.norm3 = BalancedBasicNorm( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
else: | |
self.norm3 = layer_norm_cls( | |
d_model, eps=layer_norm_eps, **factory_kwargs | |
) | |
def forward( | |
self, | |
tgt: Tensor, | |
memory: Tensor, | |
tgt_mask: Optional[Tensor] = None, | |
memory_mask: Optional[Tensor] = None, | |
tgt_key_padding_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, | |
tgt_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used | |
memory_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used | |
past: Optional[Tensor] = None, | |
) -> Tensor: | |
r"""Pass the inputs (and mask) through the decoder layer. | |
Args: | |
tgt: the sequence to the decoder layer (required). | |
memory: the sequence from the last layer of the encoder (required). | |
tgt_mask: the mask for the tgt sequence (optional). | |
memory_mask: the mask for the memory sequence (optional). | |
tgt_key_padding_mask: the mask for the tgt keys per batch (optional). | |
memory_key_padding_mask: the mask for the memory keys per batch (optional). | |
past: the previous kvcache of the decoder (optional). shape: (2, batch_size, num_heads, seq_len, head_dim) | |
Shape: | |
see the docs in Transformer class. | |
""" | |
if isinstance(tgt, dict): | |
pm_sinu = tgt["pm_sinu"] | |
sinu = tgt["sinu"] | |
args = tgt["args"] | |
tgt = tgt["input"] | |
else: | |
pm_sinu = None | |
sinu = None | |
args = None | |
tgt_is_tuple = False | |
if isinstance(tgt, tuple): | |
x, stage_embedding = tgt | |
tgt_is_tuple = True | |
else: | |
x, stage_embedding = tgt, None | |
# logging.info(f"{tgt_key_padding_mask=}, {memory_key_padding_mask=}") | |
# logging.info(f"{tgt_key_padding_mask.shape=}, {memory_key_padding_mask.shape=}") | |
# logging.info(f"{query_lens=}, {key_lens=}") | |
# past stores the kvcache for self-attention, and it can also be used to infer q_offset | |
if past is not None and past.ndim > 2: | |
q_offset = past[0].shape[-2] # past is (2, batch_size, num_heads, seq_len, head_dim), 2 contains [k, v], these are for self-attn, therefore also reflect the length of q | |
else: | |
q_offset = 0 | |
if self.norm_first: | |
temp = self._sa_block( | |
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset | |
) | |
present = temp[1] | |
x = x + temp[0] | |
cross_out = self._mha_block( | |
self.norm2(x, stage_embedding), | |
memory, | |
memory_mask, | |
memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args = args, q_offset=q_offset | |
) | |
if isinstance(cross_out, dict): | |
attention_weights = cross_out["attention_weights"] | |
cross_out = cross_out["x"] | |
else: | |
attention_weights = None | |
x = x + cross_out | |
x = x + self._ff_block(self.norm3(x, stage_embedding)) | |
else: | |
temp = self._sa_block(x, tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset) | |
present = temp[1] | |
x = self.norm1( | |
x + temp[0], | |
stage_embedding, | |
) | |
cross_out = self._mha_block( | |
x, memory, memory_mask, memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args=args, q_offset=q_offset | |
) | |
if isinstance(cross_out, dict): | |
attention_weights = cross_out["attention_weights"] | |
cross_out = cross_out["x"] | |
else: | |
attention_weights = None | |
x = self.norm2( | |
x | |
+ cross_out, | |
stage_embedding, | |
) | |
x = self.norm3(x + self._ff_block(x), stage_embedding) | |
if attention_weights is not None: | |
x = {"x": x, "attention_weights": attention_weights} | |
if tgt_is_tuple: | |
x = (x, stage_embedding) | |
if present != None: | |
x = [x, present] | |
return x | |
# self-attention block | |
def _sa_block( | |
self, | |
x: Tensor, | |
attn_mask: Optional[Tensor], | |
key_padding_mask: Optional[Tensor], | |
q_sinu=None, | |
k_sinu=None, | |
sinu = None, | |
args = None, | |
past = None, | |
q_offset = 0 | |
) -> Tensor: | |
# if past is not None and past.ndim > 2: | |
# print(f"self-attn, k len: {past[0].shape[-2] + x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}") | |
# else: | |
# print(f"self-attn, k len: {x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}") | |
x = self.self_attn( | |
x, | |
x, | |
x, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask, | |
need_weights=False, | |
q_sinu = q_sinu, | |
k_sinu = k_sinu, | |
sinu = sinu, | |
past = past, | |
q_offset = q_offset | |
) | |
x, present = x | |
return self.dropout1(x), present | |
# multihead attention block | |
def _mha_block( | |
self, | |
x: Tensor, | |
mem: Tensor, | |
attn_mask: Optional[Tensor], | |
key_padding_mask: Optional[Tensor], | |
q_sinu = None, | |
k_sinu = None, | |
sinu = None, | |
args = None, | |
q_offset = 0 | |
) -> Tensor: | |
# print(f"cross-attn, k len: {mem.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}") | |
x = self.multihead_attn( | |
x, | |
mem, | |
mem, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask, | |
need_weights=False, | |
q_sinu = q_sinu, | |
k_sinu = k_sinu, | |
sinu = sinu, | |
args = args, | |
q_offset = q_offset | |
) | |
if len(x) == 2 and isinstance(x[0], dict) and "attention_weights" in x[0]: | |
x, present = x | |
attention_weights = x['attention_weights'] | |
x = x['attn_output'] | |
return {"x": self.dropout2(x), "attention_weights": attention_weights} | |
elif len(x) == 2: | |
x = x[0] | |
return self.dropout2(x) | |
# feed forward block | |
def _ff_block(self, x: Tensor) -> Tensor: | |
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
return self.dropout3(x) | |
def _get_clones(module, N): | |
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: | |
if activation == "relu": | |
return F.relu | |
elif activation == "gelu": | |
return F.gelu | |
raise RuntimeError( | |
"activation should be relu/gelu, not {}".format(activation) | |
) | |
def _generate_square_subsequent_mask( | |
sz: int, | |
device: Optional[torch.device] = None, | |
dtype: Optional[torch.dtype] = None, | |
) -> Tensor: | |
r"""Generate a square causal mask for the sequence. | |
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). | |
""" | |
if device is None: | |
device = torch.device('cpu') | |
if dtype is None: | |
dtype = torch.float32 | |
return torch.triu( | |
torch.full((sz, sz), float('-inf'), dtype=dtype, device=device), | |
diagonal=1, | |
) | |
def _get_seq_len( | |
src: Tensor, | |
batch_first: bool | |
) -> Optional[int]: | |
if src.is_nested: | |
return None | |
else: | |
src_size = src.size() | |
if len(src_size) == 2: | |
# unbatched: S, E | |
return src_size[0] | |
else: | |
# batched: B, S, E if batch_first else S, B, E | |
seq_len_pos = 1 if batch_first else 0 | |
return src_size[seq_len_pos] | |
def _detect_is_causal_mask( | |
mask: Optional[Tensor], | |
is_causal: Optional[bool] = None, | |
size: Optional[int] = None, | |
) -> bool: | |
"""Return whether the given attention mask is causal. | |
Warning: | |
If ``is_causal`` is not ``None``, its value will be returned as is. If a | |
user supplies an incorrect ``is_causal`` hint, | |
``is_causal=False`` when the mask is in fact a causal attention.mask | |
may lead to reduced performance relative to what would be achievable | |
with ``is_causal=True``; | |
``is_causal=True`` when the mask is in fact not a causal attention.mask | |
may lead to incorrect and unpredictable execution - in some scenarios, | |
a causal mask may be applied based on the hint, in other execution | |
scenarios the specified mask may be used. The choice may not appear | |
to be deterministic, in that a number of factors like alignment, | |
hardware SKU, etc influence the decision whether to use a mask or | |
rely on the hint. | |
``size`` if not None, check whether the mask is a causal mask of the provided size | |
Otherwise, checks for any causal mask. | |
""" | |
# Prevent type refinement | |
make_causal = (is_causal is True) | |
if is_causal is None and mask is not None: | |
sz = size if size is not None else mask.size(-2) | |
causal_comparison = _generate_square_subsequent_mask( | |
sz, device=mask.device, dtype=mask.dtype) | |
# Do not use `torch.equal` so we handle batched masks by | |
# broadcasting the comparison. | |
if mask.size() == causal_comparison.size(): | |
make_causal = bool((mask == causal_comparison).all()) | |
else: | |
make_causal = False | |
return make_causal | |
class TransformerDecoder(nn.Module): | |
r"""TransformerDecoder is a stack of N decoder layers. | |
Args: | |
decoder_layer: an instance of the TransformerDecoderLayer() class (required). | |
num_layers: the number of sub-decoder-layers in the decoder (required). | |
norm: the layer normalization component (optional). | |
Examples:: | |
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) | |
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) | |
>>> memory = torch.rand(10, 32, 512) | |
>>> tgt = torch.rand(20, 32, 512) | |
>>> out = transformer_decoder(tgt, memory) | |
""" | |
__constants__ = ['norm'] | |
def __init__( | |
self, | |
decoder_layer: "TransformerDecoderLayer", | |
num_layers: int, | |
norm: Optional[nn.Module] = None, | |
rope_base=None, | |
d_model=None, | |
nhead=None, | |
args=None | |
) -> None: | |
super().__init__() | |
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") | |
self.layers = _get_clones(decoder_layer, num_layers) | |
self.num_layers = num_layers | |
self.norm = norm | |
self.args = args | |
if getattr(self.args, 'decoder_regular_rope', False): | |
self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base) | |
self.pm_freqs = None | |
else: | |
self.sinu = None | |
if rope_base is not None: | |
self.pm_freqs = pre_compute_freqs(d_model/nhead, rope_base) | |
# logging.info(f"get precomputed freqs for {rope_base=}: {self.freqs=}") | |
else: | |
self.pm_freqs = None | |
self.progress_scale = getattr(self.args, "progress_scale", 1.0) | |
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, | |
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, | |
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None, | |
memory_is_causal: bool = False, query_lens: Optional[Tensor] = None, key_lens: Optional[Tensor] = None, past: Optional[Tensor] = None) -> Tensor: | |
r"""Pass the inputs (and mask) through the decoder layer in turn. | |
Args: | |
tgt: the sequence to the decoder (required). | |
memory: the sequence from the last layer of the encoder (required). | |
tgt_mask: the mask for the tgt sequence (optional). | |
memory_mask: the mask for the memory sequence (optional). | |
tgt_key_padding_mask: the mask for the tgt keys per batch (optional). | |
memory_key_padding_mask: the mask for the memory keys per batch (optional). | |
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. | |
Default: ``None``; try to detect a causal mask. | |
Warning: | |
``tgt_is_causal`` provides a hint that ``tgt_mask`` is | |
the causal mask. Providing incorrect hints can result in | |
incorrect execution, including forward and backward | |
compatibility. | |
memory_is_causal: If specified, applies a causal mask as | |
``memory mask``. | |
Default: ``False``. | |
Warning: | |
``memory_is_causal`` provides a hint that | |
``memory_mask`` is the causal mask. Providing incorrect | |
hints can result in incorrect execution, including | |
forward and backward compatibility. | |
Shape: | |
see the docs in :class:`~torch.nn.Transformer`. | |
""" | |
output = tgt | |
# seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) | |
# tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) | |
if self.sinu is not None: | |
assert self.pm_freqs is None | |
for key in self.sinu: | |
self.sinu[key] = self.sinu[key].to(output.device) | |
if self.pm_freqs is not None: | |
assert self.sinu is None | |
if not self.training and hasattr(self, "pm_sinu") and past is not None and past[0].ndim > 2: # inference mode, will use cached sinu for the same example | |
assert self.pm_sinu['q'] is not None and self.pm_sinu['k'] is not None | |
# check batch size, need to modify the batch size if we use multi_trial during inference | |
if self.pm_sinu['q']['cos'].shape[0] != tgt.shape[0]: | |
if self.pm_sinu['q']['cos'].shape[0] > tgt.shape[0]: | |
self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'][:tgt.shape[0]] | |
self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'][:tgt.shape[0]] | |
self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'][:tgt.shape[0]] | |
self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'][:tgt.shape[0]] | |
else: | |
assert self.pm_sinu['q']['cos'].shape[0] == 1 | |
self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'].repeat(tgt.shape[0], 1, 1, 1) | |
self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'].repeat(tgt.shape[0], 1, 1, 1) | |
self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'].repeat(tgt.shape[0], 1, 1, 1) | |
self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'].repeat(tgt.shape[0], 1, 1, 1) | |
pass | |
else: | |
self.pm_freqs = self.pm_freqs.to(output.device) | |
if query_lens is None: | |
query_lens = (~tgt_key_padding_mask).int().sum(-1).to(tgt.device) | |
if key_lens is None: | |
key_lens = (~memory_key_padding_mask).int().sum(-1).to(tgt.device) | |
assert key_lens.ndim==1, key_lens | |
assert query_lens.ndim==1, query_lens | |
q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] | |
k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] | |
query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1) | |
key_ids_multiple = k_lens_expanded / (k_lens_expanded - 1) | |
q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d] | |
k_emb = self.pm_freqs * key_ids_multiple # [B, k_len_max, d] | |
q_emb = q_emb / q_lens_expanded * self.progress_scale | |
k_emb = k_emb / k_lens_expanded * self.progress_scale | |
q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead | |
q_sin = q_emb.sin().unsqueeze(1) | |
k_cos = k_emb.cos().unsqueeze(1) | |
k_sin = k_emb.sin().unsqueeze(1) | |
self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}, "k": {"cos": k_cos, "sin": k_sin}} | |
else: | |
self.pm_sinu = {"q": None, "k": None} | |
output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args} | |
if past != None: | |
all_present = [] | |
if self.training and getattr(self.args, "attention_alignment_loss", 0): | |
all_attn_weights = [] | |
for i, mod in enumerate(self.layers): | |
output = mod(output, memory, tgt_mask=tgt_mask, | |
memory_mask=memory_mask, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
memory_key_padding_mask=memory_key_padding_mask, | |
past=past[i] if past != None else None | |
# tgt_is_causal=tgt_is_causal, | |
# memory_is_causal=memory_is_causal | |
) | |
if past != None: | |
output, cur_present = output | |
all_present.append(cur_present) | |
if isinstance(output, dict): | |
current_attn_weights = output["attention_weights"] | |
all_attn_weights.append(current_attn_weights) | |
output = output["x"] | |
if self.sinu is not None or self.pm_sinu is not None: | |
output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args} | |
if self.pm_sinu is not None or self.sinu is not None: | |
output = output["input"] | |
if self.norm is not None: | |
output = self.norm(output) | |
if self.training and getattr(self.args, "attention_alignment_loss", 0): | |
assert len(all_attn_weights) == self.num_layers, f"{len(all_attn_weights)=}, {self.num_layers=}" | |
output = {"output": output, "attention_weights": all_attn_weights} | |
if past != None: | |
all_present = torch.stack(all_present, dim=0) | |
output = [output, all_present] | |
else: | |
output = [output, None] | |
return output |