Spaces:
Sleeping
Sleeping
"""Added ConMamba and Mamba | |
Authors | |
* Xilin Jiang 2024 | |
""" | |
"""Transformer for ASR in the SpeechBrain style. | |
Authors | |
* Jianyuan Zhong 2020 | |
* Titouan Parcollet 2024 | |
* Luca Della Libera 2024 | |
""" | |
from dataclasses import dataclass | |
from typing import Any, Optional | |
import torch # noqa 42 | |
from torch import nn | |
from speechbrain.dataio.dataio import length_to_mask | |
from modules.Transformer import ( | |
NormalizedEmbedding, | |
TransformerInterface, | |
get_key_padding_mask, | |
get_lookahead_mask, | |
) | |
from speechbrain.nnet.activations import Swish | |
from speechbrain.nnet.containers import ModuleList | |
from speechbrain.nnet.linear import Linear | |
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig | |
class TransformerASRStreamingContext: | |
"""Streaming metadata and state for a `TransformerASR` instance.""" | |
dynchunktrain_config: DynChunkTrainConfig | |
"""Dynamic Chunk Training configuration holding chunk size and context size | |
information.""" | |
encoder_context: Any | |
"""Opaque encoder context information. It is constructed by the encoder's | |
`make_streaming_context` method and is passed to the encoder when using | |
`encode_streaming`. | |
""" | |
def make_transformer_src_mask( | |
src: torch.Tensor, | |
causal: bool = False, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
) -> Optional[torch.Tensor]: | |
"""Prepare the source transformer mask that restricts which frames can | |
attend to which frames depending on causal or other simple restricted | |
attention methods. | |
Arguments | |
--------- | |
src: torch.Tensor | |
The source tensor to build a mask from. The contents of the tensor are | |
not actually used currently; only its shape and other metadata (e.g. | |
device). | |
causal: bool | |
Whether strict causality shall be used. Frames will not be able to | |
attend to any future frame. | |
dynchunktrain_config: DynChunkTrainConfig, optional | |
Dynamic Chunk Training configuration. This implements a simple form of | |
chunkwise attention. Incompatible with `causal`. | |
Returns | |
------- | |
torch.Tensor | |
A boolean mask Tensor of shape (timesteps, timesteps). | |
""" | |
if causal: | |
assert dynchunktrain_config is None | |
return get_lookahead_mask(src) | |
if dynchunktrain_config is None: | |
return | |
# The following is not really the sole source used to implement this, | |
# but it helps introduce the concept. | |
# ref: Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition | |
# https://arxiv.org/pdf/2012.05481.pdf | |
timesteps = src.size(1) | |
# Mask the future at the right of each chunk | |
chunk_size = dynchunktrain_config.chunk_size | |
num_chunks = timesteps // chunk_size | |
timestep_idx = torch.arange(timesteps, device=src.device) | |
mask_idx = torch.arange( | |
chunk_size, chunk_size * (num_chunks + 2), chunk_size, device=src.device | |
).repeat_interleave(chunk_size)[:timesteps] | |
src_mask = timestep_idx[None] >= mask_idx[:, None] | |
# Mask the past at the left of each chunk (accounting for left context) | |
# only relevant if using left context | |
if not dynchunktrain_config.is_infinite_left_context(): | |
num_left_chunks = dynchunktrain_config.left_context_size | |
mask_idx -= chunk_size * (num_left_chunks + 1) | |
src_mask += timestep_idx[None] < mask_idx[:, None] | |
return src_mask | |
def make_transformer_src_tgt_masks( | |
src, | |
tgt=None, | |
wav_len=None, | |
pad_idx=0, | |
causal: bool = False, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
): | |
"""This function generates masks for training the transformer model, | |
opinionated for an ASR context with encoding masks and, optionally, decoding | |
masks (if specifying `tgt`). | |
Arguments | |
--------- | |
src : torch.Tensor | |
The sequence to the encoder (required). | |
tgt : torch.Tensor | |
The sequence to the decoder. | |
wav_len : torch.Tensor | |
The lengths of the inputs. | |
pad_idx : int | |
The index for <pad> token (default=0). | |
causal: bool | |
Whether strict causality shall be used. See `make_asr_src_mask` | |
dynchunktrain_config: DynChunkTrainConfig, optional | |
Dynamic Chunk Training configuration. See `make_asr_src_mask` | |
Returns | |
------- | |
src_key_padding_mask : torch.Tensor | |
Key padding mask for ignoring padding | |
tgt_key_padding_mask : torch.Tensor | |
Key padding mask for ignoring padding | |
src_mask : torch.Tensor | |
Mask for ignoring invalid (e.g. future) timesteps | |
tgt_mask : torch.Tensor | |
Mask for ignoring invalid (e.g. future) timesteps | |
""" | |
src_key_padding_mask = None | |
# mask out audio beyond the length of audio for each batch | |
if wav_len is not None: | |
abs_len = torch.round(wav_len * src.shape[1]) | |
src_key_padding_mask = ~length_to_mask(abs_len).bool() | |
# mask out the source | |
src_mask = make_transformer_src_mask( | |
src, causal=causal, dynchunktrain_config=dynchunktrain_config | |
) | |
# If no decoder in the transformer... | |
if tgt is not None: | |
tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) | |
tgt_mask = get_lookahead_mask(tgt) | |
else: | |
tgt_key_padding_mask = None | |
tgt_mask = None | |
return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask | |
class TransformerASR(TransformerInterface): | |
"""This is an implementation of transformer model for ASR. | |
The architecture is based on the paper "Attention Is All You Need": | |
https://arxiv.org/pdf/1706.03762.pdf | |
Arguments | |
--------- | |
tgt_vocab: int | |
Size of vocabulary. | |
input_size: int | |
Input feature size. | |
d_model : int, optional | |
Embedding dimension size. | |
(default=512). | |
nhead : int, optional | |
The number of heads in the multi-head attention models (default=8). | |
num_encoder_layers : int, optional | |
The number of sub-encoder-layers in the encoder (default=6). | |
num_decoder_layers : int, optional | |
The number of sub-decoder-layers in the decoder (default=6). | |
d_ffn : int, optional | |
The dimension of the feedforward network model (default=2048). | |
dropout : int, optional | |
The dropout value (default=0.1). | |
activation : torch.nn.Module, optional | |
The activation function of FFN layers. | |
Recommended: relu or gelu (default=relu). | |
positional_encoding: str, optional | |
Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings. | |
normalize_before: bool, optional | |
Whether normalization should be applied before or after MHA or FFN in Transformer layers. | |
Defaults to True as this was shown to lead to better performance and training stability. | |
kernel_size: int, optional | |
Kernel size in convolutional layers when Conformer is used. | |
bias: bool, optional | |
Whether to use bias in Conformer convolutional layers. | |
encoder_module: str, optional | |
Choose between Branchformer, Conformer, ConMamba, and Transformer for the encoder. | |
decoder_module: str, optional | |
Choose between Mamba and Transformer for the decoder. | |
decoder_module: str, optional | |
Choose between Transformer and Mamba for the decoder. | |
conformer_activation: torch.nn.Module, optional | |
Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module. | |
branchformer_activation: torch.nn.Module, optional | |
Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module. | |
attention_type: str, optional | |
Type of attention layer used in all Transformer or Conformer layers. | |
e.g. regularMHA or RelPosMHA. | |
max_length: int, optional | |
Max length for the target and source sequence in input. | |
Used for positional encodings. | |
causal: bool, optional | |
Whether the encoder should be causal or not (the decoder is always causal). | |
If causal the Conformer convolutional layer is causal. | |
csgu_linear_units: int, optional | |
Number of neurons in the hidden linear units of the CSGU Module. | |
-> Branchformer | |
gate_activation: torch.nn.Module, optional | |
Activation function used at the gate of the CSGU module. | |
-> Branchformer | |
use_linear_after_conv: bool, optional | |
If True, will apply a linear transformation of size input_size//2. | |
-> Branchformer | |
mamba_config: dict, optional | |
Mamba parameters if encoder_module or decoder_module is Mamba or ConMamba | |
Example | |
------- | |
>>> src = torch.rand([8, 120, 512]) | |
>>> tgt = torch.randint(0, 720, [8, 120]) | |
>>> net = TransformerASR( | |
... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU | |
... ) | |
>>> enc_out, dec_out = net.forward(src, tgt) | |
>>> enc_out.shape | |
torch.Size([8, 120, 512]) | |
>>> dec_out.shape | |
torch.Size([8, 120, 512]) | |
""" | |
def __init__( | |
self, | |
tgt_vocab, | |
input_size, | |
d_model=512, | |
nhead=8, | |
num_encoder_layers=6, | |
num_decoder_layers=6, | |
d_ffn=2048, | |
dropout=0.1, | |
activation=nn.ReLU, | |
positional_encoding="fixed_abs_sine", | |
normalize_before=False, | |
kernel_size: Optional[int] = 31, | |
bias: Optional[bool] = True, | |
encoder_module: Optional[str] = "transformer", | |
decoder_module: Optional[str] = "transformer", | |
conformer_activation: Optional[nn.Module] = Swish, | |
branchformer_activation: Optional[nn.Module] = nn.GELU, | |
attention_type: Optional[str] = "regularMHA", | |
max_length: Optional[int] = 2500, | |
causal: Optional[bool] = True, | |
csgu_linear_units: Optional[int] = 3072, | |
gate_activation: Optional[nn.Module] = nn.Identity, | |
use_linear_after_conv: Optional[bool] = False, | |
mamba_config=None | |
): | |
super().__init__( | |
d_model=d_model, | |
nhead=nhead, | |
num_encoder_layers=num_encoder_layers, | |
num_decoder_layers=num_decoder_layers, | |
d_ffn=d_ffn, | |
dropout=dropout, | |
activation=activation, | |
positional_encoding=positional_encoding, | |
normalize_before=normalize_before, | |
kernel_size=kernel_size, | |
bias=bias, | |
encoder_module=encoder_module, | |
decoder_module=decoder_module, | |
conformer_activation=conformer_activation, | |
branchformer_activation=branchformer_activation, | |
attention_type=attention_type, | |
max_length=max_length, | |
causal=causal, | |
csgu_linear_units=csgu_linear_units, | |
gate_activation=gate_activation, | |
use_linear_after_conv=use_linear_after_conv, | |
mamba_config=mamba_config | |
) | |
self.custom_src_module = ModuleList( | |
Linear( | |
input_size=input_size, | |
n_neurons=d_model, | |
bias=True, | |
combine_dims=False, | |
), | |
torch.nn.Dropout(dropout), | |
) | |
self.num_decoder_layers = num_decoder_layers | |
if num_decoder_layers > 0: | |
self.custom_tgt_module = ModuleList( | |
NormalizedEmbedding(d_model, tgt_vocab) | |
) | |
# reset parameters using xavier_normal_ | |
self._init_params() | |
def forward(self, src, tgt, wav_len=None, pad_idx=0): | |
""" | |
Arguments | |
---------- | |
src : torch.Tensor | |
The sequence to the encoder. | |
tgt : torch.Tensor | |
The sequence to the decoder. | |
wav_len: torch.Tensor, optional | |
Torch Tensor of shape (batch, ) containing the relative length to padded length for each example. | |
pad_idx : int, optional | |
The index for <pad> token (default=0). | |
""" | |
# reshape the src vector to [Batch, Time, Fea] is a 4d vector is given | |
if src.ndim == 4: | |
bz, t, ch1, ch2 = src.shape | |
src = src.reshape(bz, t, ch1 * ch2) | |
( | |
src_key_padding_mask, | |
tgt_key_padding_mask, | |
src_mask, | |
tgt_mask, | |
) = make_transformer_src_tgt_masks( | |
src, tgt, wav_len, causal=self.causal, pad_idx=pad_idx | |
) | |
src = self.custom_src_module(src) | |
# add pos encoding to queries if are sinusoidal ones else | |
if self.attention_type == "hypermixing": | |
pos_embs_encoder = None | |
elif self.attention_type == "RelPosMHAXL": | |
pos_embs_encoder = self.positional_encoding(src) | |
elif self.positional_encoding_type == "fixed_abs_sine": | |
src = src + self.positional_encoding(src) # add the encodings here | |
pos_embs_encoder = None | |
encoder_out, _ = self.encoder( | |
src=src, | |
src_mask=src_mask, | |
src_key_padding_mask=src_key_padding_mask, | |
pos_embs=pos_embs_encoder, | |
) | |
if self.num_decoder_layers > 0: | |
tgt = self.custom_tgt_module(tgt) | |
if self.attention_type == "RelPosMHAXL": | |
tgt = tgt + self.positional_encoding_decoder(tgt) | |
pos_embs_encoder = None # self.positional_encoding(src) | |
pos_embs_target = None | |
elif ( | |
self.positional_encoding_type == "fixed_abs_sine" | |
or self.attention_type == "hypermixing" | |
): | |
tgt = tgt + self.positional_encoding(tgt) | |
pos_embs_target = None | |
pos_embs_encoder = None | |
decoder_out, _, _ = self.decoder( | |
tgt=tgt, | |
memory=encoder_out, | |
memory_mask=None, | |
tgt_mask=tgt_mask, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
memory_key_padding_mask=src_key_padding_mask, | |
pos_embs_tgt=pos_embs_target, | |
pos_embs_src=pos_embs_encoder, | |
) | |
else: | |
decoder_out = None | |
return encoder_out, decoder_out | |
def decode(self, tgt, encoder_out, enc_len=None): | |
"""This method implements a decoding step for the transformer model. | |
Arguments | |
--------- | |
tgt : torch.Tensor | |
The sequence to the decoder. | |
encoder_out : torch.Tensor | |
Hidden output of the encoder. | |
enc_len : torch.LongTensor | |
The actual length of encoder states. | |
Returns | |
------- | |
prediction | |
""" | |
tgt_mask = get_lookahead_mask(tgt) | |
src_key_padding_mask = None | |
if enc_len is not None: | |
src_key_padding_mask = (1 - length_to_mask(enc_len)).bool() | |
if self.num_decoder_layers > 0: | |
tgt = self.custom_tgt_module(tgt) | |
if self.attention_type == "RelPosMHAXL": | |
tgt = tgt + self.positional_encoding_decoder(tgt) | |
pos_embs_encoder = None # self.positional_encoding(src) | |
pos_embs_target = None | |
elif ( | |
self.positional_encoding_type == "fixed_abs_sine" | |
or self.attention_type == "hypermixing" | |
): | |
tgt = tgt + self.positional_encoding(tgt) # add the encodings here | |
pos_embs_target = None | |
pos_embs_encoder = None | |
prediction, self_attns, multihead_attns = self.decoder( | |
tgt, | |
encoder_out, | |
tgt_mask=tgt_mask, | |
memory_key_padding_mask=src_key_padding_mask, | |
pos_embs_tgt=pos_embs_target, | |
pos_embs_src=pos_embs_encoder, | |
) | |
return prediction, multihead_attns[-1] | |
def encode( | |
self, | |
src, | |
wav_len=None, | |
pad_idx=0, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
): | |
""" | |
Encoder forward pass | |
Arguments | |
--------- | |
src : torch.Tensor | |
The sequence to the encoder. | |
wav_len : torch.Tensor, optional | |
Torch Tensor of shape (batch, ) containing the relative length to padded length for each example. | |
pad_idx : int | |
The index used for padding. | |
dynchunktrain_config : DynChunkTrainConfig | |
Dynamic chunking config. | |
Returns | |
------- | |
encoder_out : torch.Tensor | |
""" | |
# reshape the src vector to [Batch, Time, Fea] if a 4d vector is given | |
if src.dim() == 4: | |
bz, t, ch1, ch2 = src.shape | |
src = src.reshape(bz, t, ch1 * ch2) | |
( | |
src_key_padding_mask, | |
_, | |
src_mask, | |
_, | |
) = make_transformer_src_tgt_masks( | |
src, | |
None, | |
wav_len, | |
pad_idx=pad_idx, | |
causal=self.causal, | |
dynchunktrain_config=dynchunktrain_config, | |
) | |
src = self.custom_src_module(src) | |
if self.attention_type == "hypermixing": | |
pos_embs_source = None | |
elif self.attention_type == "RelPosMHAXL": | |
pos_embs_source = self.positional_encoding(src) | |
elif self.positional_encoding_type == "fixed_abs_sine": | |
src = src + self.positional_encoding(src) | |
pos_embs_source = None | |
encoder_out, _ = self.encoder( | |
src=src, | |
src_mask=src_mask, | |
src_key_padding_mask=src_key_padding_mask, | |
pos_embs=pos_embs_source, | |
dynchunktrain_config=dynchunktrain_config, | |
) | |
return encoder_out | |
def encode_streaming(self, src, context: TransformerASRStreamingContext): | |
""" | |
Streaming encoder forward pass | |
Arguments | |
--------- | |
src : torch.Tensor | |
The sequence (chunk) to the encoder. | |
context : TransformerASRStreamingContext | |
Mutable reference to the streaming context. This holds the state | |
needed to persist across chunk inferences and can be built using | |
`make_streaming_context`. This will get mutated by this function. | |
Returns | |
------- | |
Encoder output for this chunk. | |
Example | |
------- | |
>>> import torch | |
>>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR | |
>>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig | |
>>> net = TransformerASR( | |
... tgt_vocab=100, | |
... input_size=64, | |
... d_model=64, | |
... nhead=8, | |
... num_encoder_layers=1, | |
... num_decoder_layers=0, | |
... d_ffn=128, | |
... attention_type="RelPosMHAXL", | |
... positional_encoding=None, | |
... encoder_module="conformer", | |
... normalize_before=True, | |
... causal=False, | |
... ) | |
>>> ctx = net.make_streaming_context(DynChunkTrainConfig(16, 1)) | |
>>> src1 = torch.rand([8, 16, 64]) | |
>>> src2 = torch.rand([8, 16, 64]) | |
>>> out1 = net.encode_streaming(src1, ctx) | |
>>> out1.shape | |
torch.Size([8, 16, 64]) | |
>>> ctx.encoder_context.layers[0].mha_left_context.shape | |
torch.Size([8, 16, 64]) | |
>>> out2 = net.encode_streaming(src2, ctx) | |
>>> out2.shape | |
torch.Size([8, 16, 64]) | |
>>> ctx.encoder_context.layers[0].mha_left_context.shape | |
torch.Size([8, 16, 64]) | |
>>> combined_out = torch.concat((out1, out2), dim=1) | |
>>> combined_out.shape | |
torch.Size([8, 32, 64]) | |
""" | |
if src.dim() == 4: | |
bz, t, ch1, ch2 = src.shape | |
src = src.reshape(bz, t, ch1 * ch2) | |
# HACK: our problem here is that the positional_encoding is computed | |
# against the size of our source tensor, but we only know how many left | |
# context frames we're injecting to the encoder within the encoder | |
# context. | |
# so this workaround does just that. | |
# | |
# i'm not sure how this would be best refactored, but an option would be | |
# to let the encoder get the pos embedding itself and have a way to | |
# cache it. | |
# | |
# additionally, positional encoding functions take in a whole source | |
# tensor just to get its attributes (size, device, type) but this is | |
# sort of silly for the embeddings that don't need one. | |
# so we craft a dummy empty (uninitialized) tensor to help... | |
known_left_context = context.encoder_context.layers[0].mha_left_context | |
if known_left_context is None: | |
pos_encoding_dummy = src | |
else: | |
target_shape = list(src.shape) | |
target_shape[-2] += known_left_context.shape[-2] | |
pos_encoding_dummy = torch.empty(size=target_shape).to(src) | |
src = self.custom_src_module(src) | |
if self.attention_type == "RelPosMHAXL": | |
pos_embs_source = self.positional_encoding(pos_encoding_dummy) | |
elif self.positional_encoding_type == "fixed_abs_sine": | |
src = src + self.positional_encoding(pos_encoding_dummy) | |
pos_embs_source = None | |
encoder_out, _ = self.encoder.forward_streaming( | |
src=src, pos_embs=pos_embs_source, context=context.encoder_context | |
) | |
return encoder_out | |
def make_streaming_context( | |
self, dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={} | |
): | |
"""Creates a blank streaming context for this transformer and its | |
encoder. | |
Arguments | |
--------- | |
dynchunktrain_config : DynChunkTrainConfig | |
Runtime chunkwise attention configuration. | |
encoder_kwargs : dict | |
Parameters to be forward to the encoder's `make_streaming_context`. | |
Metadata required for the encoder could differ depending on the | |
encoder. | |
Returns | |
------- | |
TransformerASRStreamingContext | |
""" | |
return TransformerASRStreamingContext( | |
dynchunktrain_config=dynchunktrain_config, | |
encoder_context=self.encoder.make_streaming_context( | |
dynchunktrain_config, | |
**encoder_kwargs, | |
), | |
) | |
def _init_params(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
torch.nn.init.xavier_normal_(p) | |
class EncoderWrapper(nn.Module): | |
"""This is a wrapper of any ASR transformer encoder. By default, the | |
TransformerASR .forward() function encodes and decodes. With this wrapper | |
the .forward() function becomes .encode() only. | |
Important: The TransformerASR class must contain a .encode() function. | |
Arguments | |
--------- | |
transformer : sb.lobes.models.TransformerInterface | |
A Transformer instance that contains a .encode() function. | |
*args : tuple | |
**kwargs : dict | |
Arguments to forward to parent class. | |
Example | |
------- | |
>>> src = torch.rand([8, 120, 512]) | |
>>> tgt = torch.randint(0, 720, [8, 120]) | |
>>> net = TransformerASR( | |
... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU | |
... ) | |
>>> encoder = EncoderWrapper(net) | |
>>> enc_out = encoder(src) | |
>>> enc_out.shape | |
torch.Size([8, 120, 512]) | |
""" | |
def __init__(self, transformer, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.transformer = transformer | |
self.make_streaming_context = self.transformer.make_streaming_context | |
def forward(self, x, wav_lens=None, pad_idx=0, **kwargs): | |
"""Processes the input tensor x and returns an output tensor.""" | |
x = self.transformer.encode(x, wav_lens, pad_idx, **kwargs) | |
return x | |
def forward_streaming(self, x, context): | |
"""Processes the input audio chunk tensor `x`, using and updating the | |
mutable encoder `context`""" | |
x = self.transformer.encode_streaming(x, context) | |
return x | |
def make_streaming_context(self, *args, **kwargs): | |
"""Initializes a streaming context. Forwards all arguments to the | |
underlying transformer. See :meth:`speechbrain.lobes.models.transformer.TransformerASR.make_streaming_context`. | |
""" | |
return self.transformer.make_streaming_context(*args, **kwargs) | |