Spaces:
Sleeping
Sleeping
"""ConMamba encoder and Mamba decoder implementation. | |
Authors | |
------- | |
* Xilin Jiang 2024 | |
""" | |
import warnings | |
from dataclasses import dataclass | |
from typing import List, Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import speechbrain as sb | |
from speechbrain.nnet.activations import Swish | |
from speechbrain.nnet.attention import ( | |
MultiheadAttention, | |
PositionalwiseFeedForward, | |
RelPosMHAXL, | |
) | |
from speechbrain.nnet.hypermixing import HyperMixing | |
from speechbrain.nnet.normalization import LayerNorm | |
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig | |
# Mamba | |
from mamba_ssm import Mamba | |
from .mamba.bimamba import Mamba as BiMamba | |
class ConvolutionModule(nn.Module): | |
"""This is an implementation of convolution module in Conmamba. | |
""" | |
def __init__( | |
self, | |
input_size, | |
kernel_size=31, | |
bias=True, | |
activation=Swish, | |
dropout=0.0, | |
causal=False, | |
dilation=1, | |
): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.causal = causal | |
self.dilation = dilation | |
if self.causal: | |
self.padding = (kernel_size - 1) * 2 ** (dilation - 1) | |
else: | |
self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2 | |
self.layer_norm = nn.LayerNorm(input_size) | |
self.bottleneck = nn.Sequential( | |
# pointwise | |
nn.Conv1d( | |
input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias | |
), | |
nn.GLU(dim=1), | |
) | |
# depthwise | |
self.conv = nn.Conv1d( | |
input_size, | |
input_size, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=self.padding, | |
dilation=dilation, | |
groups=input_size, | |
bias=bias, | |
) | |
# BatchNorm in the original Conformer replaced with a LayerNorm due to | |
# https://github.com/speechbrain/speechbrain/pull/1329 | |
# see discussion | |
# https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884 | |
self.after_conv = nn.Sequential( | |
nn.LayerNorm(input_size), | |
activation(), | |
# pointwise | |
nn.Linear(input_size, input_size, bias=bias), | |
nn.Dropout(dropout), | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
): | |
"""Applies the convolution to an input tensor `x`. | |
""" | |
if dynchunktrain_config is not None: | |
# chances are chunking+causal is unintended; i don't know where it | |
# may make sense, but if it does to you, feel free to implement it. | |
assert ( | |
not self.causal | |
), "Chunked convolution not supported with causal padding" | |
assert ( | |
self.dilation == 1 | |
), "Current DynChunkTrain logic does not support dilation != 1" | |
# in a causal convolution, which is not the case here, an output | |
# frame would never be able to depend on a input frame from any | |
# point in the future. | |
# but with the dynamic chunk convolution, we instead use a "normal" | |
# convolution but where, for any output frame, the future beyond the | |
# "current" chunk gets masked. | |
# see the paper linked in the documentation for details. | |
chunk_size = dynchunktrain_config.chunk_size | |
batch_size = x.shape[0] | |
# determine the amount of padding we need to insert at the right of | |
# the last chunk so that all chunks end up with the same size. | |
if x.shape[1] % chunk_size != 0: | |
final_right_padding = chunk_size - (x.shape[1] % chunk_size) | |
else: | |
final_right_padding = 0 | |
# -> [batch_size, t, in_channels] | |
out = self.layer_norm(x) | |
# -> [batch_size, in_channels, t] for the CNN | |
out = out.transpose(1, 2) | |
# -> [batch_size, in_channels, t] (pointwise) | |
out = self.bottleneck(out) | |
# -> [batch_size, in_channels, lc+t+final_right_padding] | |
out = F.pad(out, (self.padding, final_right_padding), value=0) | |
# now, make chunks with left context. | |
# as a recap to what the above padding and this unfold do, consider | |
# each a/b/c letter represents a frame as part of chunks a, b, c. | |
# consider a chunk size of 4 and a kernel size of 5 (padding=2): | |
# | |
# input seq: 00aaaabbbbcc00 | |
# chunk #1: 00aaaa | |
# chunk #2: aabbbb | |
# chunk #3: bbcc00 | |
# | |
# a few remarks here: | |
# - the left padding gets inserted early so that the unfold logic | |
# works trivially | |
# - the right 0-padding got inserted as the number of time steps | |
# could not be evenly split in `chunk_size` chunks | |
# -> [batch_size, in_channels, num_chunks, lc+chunk_size] | |
out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size) | |
# as we manually disable padding in the convolution below, we insert | |
# right 0-padding to the chunks, e.g. reusing the above example: | |
# | |
# chunk #1: 00aaaa00 | |
# chunk #2: aabbbb00 | |
# chunk #3: bbcc0000 | |
# -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad] | |
out = F.pad(out, (0, self.padding), value=0) | |
# the transpose+flatten effectively flattens chunks into the batch | |
# dimension to be processed into the time-wise convolution. the | |
# chunks will later on be unflattened. | |
# -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad] | |
out = out.transpose(1, 2) | |
# -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad] | |
out = out.flatten(start_dim=0, end_dim=1) | |
# TODO: experiment around reflect padding, which is difficult | |
# because small chunks have too little time steps to reflect from | |
# let's keep backwards compat by pointing at the weights from the | |
# already declared Conv1d. | |
# | |
# still reusing the above example, the convolution will be applied, | |
# with the padding truncated on both ends. the following example | |
# shows the letter corresponding to the input frame on which the | |
# convolution was centered. | |
# | |
# as you can see, the sum of lengths of all chunks is equal to our | |
# input sequence length + `final_right_padding`. | |
# | |
# chunk #1: aaaa | |
# chunk #2: bbbb | |
# chunk #3: cc00 | |
# -> [batch_size * num_chunks, out_channels, chunk_size] | |
out = F.conv1d( | |
out, | |
weight=self.conv.weight, | |
bias=self.conv.bias, | |
stride=self.conv.stride, | |
padding=0, | |
dilation=self.conv.dilation, | |
groups=self.conv.groups, | |
) | |
# -> [batch_size * num_chunks, chunk_size, out_channels] | |
out = out.transpose(1, 2) | |
out = self.after_conv(out) | |
# -> [batch_size, num_chunks, chunk_size, out_channels] | |
out = torch.unflatten(out, dim=0, sizes=(batch_size, -1)) | |
# -> [batch_size, t + final_right_padding, out_channels] | |
out = torch.flatten(out, start_dim=1, end_dim=2) | |
# -> [batch_size, t, out_channels] | |
if final_right_padding > 0: | |
out = out[:, :-final_right_padding, :] | |
else: | |
out = self.layer_norm(x) | |
out = out.transpose(1, 2) | |
out = self.bottleneck(out) | |
out = self.conv(out) | |
if self.causal: | |
# chomp | |
out = out[..., : -self.padding] | |
out = out.transpose(1, 2) | |
out = self.after_conv(out) | |
if mask is not None: | |
out.masked_fill_(mask, 0.0) | |
return out | |
class ConmambaEncoderLayer(nn.Module): | |
"""This is an implementation of Conmamba encoder layer. | |
""" | |
def __init__( | |
self, | |
d_model, | |
d_ffn, | |
kernel_size=31, | |
activation=Swish, | |
bias=True, | |
dropout=0.0, | |
causal=False, | |
mamba_config=None | |
): | |
super().__init__() | |
assert mamba_config != None | |
bidirectional = mamba_config.pop('bidirectional') | |
if causal or (not bidirectional): | |
self.mamba = Mamba( | |
d_model=d_model, | |
**mamba_config | |
) | |
else: | |
self.mamba = BiMamba( | |
d_model=d_model, | |
bimamba_type='v2', | |
**mamba_config | |
) | |
mamba_config['bidirectional'] = bidirectional | |
self.convolution_module = ConvolutionModule( | |
d_model, kernel_size, bias, activation, dropout, causal=causal | |
) | |
self.ffn_module1 = nn.Sequential( | |
nn.LayerNorm(d_model), | |
PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
), | |
nn.Dropout(dropout), | |
) | |
self.ffn_module2 = nn.Sequential( | |
nn.LayerNorm(d_model), | |
PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
), | |
nn.Dropout(dropout), | |
) | |
self.norm1 = LayerNorm(d_model) | |
self.norm2 = LayerNorm(d_model) | |
self.drop = nn.Dropout(dropout) | |
def forward( | |
self, | |
x, | |
src_mask: Optional[torch.Tensor] = None, | |
src_key_padding_mask: Optional[torch.Tensor] = None, | |
pos_embs: torch.Tensor = None, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
): | |
conv_mask: Optional[torch.Tensor] = None | |
if src_key_padding_mask is not None: | |
conv_mask = src_key_padding_mask.unsqueeze(-1) | |
conv_mask = None | |
# ffn module | |
x = x + 0.5 * self.ffn_module1(x) | |
# mamba module | |
skip = x | |
x = self.norm1(x) | |
x = self.mamba(x) | |
x = x + skip | |
# convolution module | |
x = x + self.convolution_module( | |
x, conv_mask, dynchunktrain_config=dynchunktrain_config | |
) | |
# ffn module | |
x = self.norm2(x + 0.5 * self.ffn_module2(x)) | |
return x | |
class ConmambaEncoder(nn.Module): | |
"""This class implements the Conmamba encoder. | |
""" | |
def __init__( | |
self, | |
num_layers, | |
d_model, | |
d_ffn, | |
kernel_size=31, | |
activation=Swish, | |
bias=True, | |
dropout=0.0, | |
causal=False, | |
mamba_config=None | |
): | |
super().__init__() | |
print(f'dropout={str(dropout)} is not used in Mamba.') | |
self.layers = torch.nn.ModuleList( | |
[ | |
ConmambaEncoderLayer( | |
d_model=d_model, | |
d_ffn=d_ffn, | |
dropout=dropout, | |
activation=activation, | |
kernel_size=kernel_size, | |
bias=bias, | |
causal=causal, | |
mamba_config=mamba_config, | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.norm = LayerNorm(d_model, eps=1e-6) | |
def forward( | |
self, | |
src, | |
src_mask: Optional[torch.Tensor] = None, | |
src_key_padding_mask: Optional[torch.Tensor] = None, | |
pos_embs: Optional[torch.Tensor] = None, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
): | |
""" | |
Arguments | |
---------- | |
src : torch.Tensor | |
The sequence to the encoder layer. | |
src_mask : torch.Tensor, optional | |
The mask for the src sequence. | |
src_key_padding_mask : torch.Tensor, optional | |
The mask for the src keys per batch. | |
pos_embs: torch.Tensor, torch.nn.Module, | |
Module or tensor containing the input sequence positional embeddings | |
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) | |
where S is the sequence length, and E is the embedding dimension. | |
dynchunktrain_config: Optional[DynChunkTrainConfig] | |
Dynamic Chunk Training configuration object for streaming, | |
specifically involved here to apply Dynamic Chunk Convolution to the | |
convolution module. | |
""" | |
output = src | |
for enc_layer in self.layers: | |
output = enc_layer( | |
output, | |
src_mask=src_mask, | |
src_key_padding_mask=src_key_padding_mask, | |
pos_embs=pos_embs, | |
dynchunktrain_config=dynchunktrain_config, | |
) | |
output = self.norm(output) | |
return output, None | |
class MambaDecoderLayer(nn.Module): | |
"""This class implements the Mamba decoder layer. | |
""" | |
def __init__( | |
self, | |
d_model, | |
d_ffn, | |
activation=nn.ReLU, | |
dropout=0.0, | |
normalize_before=False, | |
mamba_config=None | |
): | |
super().__init__() | |
assert mamba_config != None | |
bidirectional = mamba_config.pop('bidirectional') | |
self.self_mamba = Mamba( | |
d_model=d_model, | |
**mamba_config | |
) | |
self.cross_mamba = Mamba( | |
d_model=d_model, | |
**mamba_config | |
) | |
mamba_config['bidirectional'] = bidirectional | |
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
) | |
# normalization layers | |
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.dropout1 = torch.nn.Dropout(dropout) | |
self.dropout2 = torch.nn.Dropout(dropout) | |
self.dropout3 = torch.nn.Dropout(dropout) | |
self.normalize_before = normalize_before | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask=None, | |
memory_mask=None, | |
tgt_key_padding_mask=None, | |
memory_key_padding_mask=None, | |
pos_embs_tgt=None, | |
pos_embs_src=None, | |
): | |
""" | |
Arguments | |
---------- | |
tgt: torch.Tensor | |
The sequence to the decoder layer (required). | |
memory: torch.Tensor | |
The sequence from the last layer of the encoder (required). | |
tgt_mask: torch.Tensor | |
The mask for the tgt sequence (optional). | |
memory_mask: torch.Tensor | |
The mask for the memory sequence (optional). | |
tgt_key_padding_mask: torch.Tensor | |
The mask for the tgt keys per batch (optional). | |
memory_key_padding_mask: torch.Tensor | |
The mask for the memory keys per batch (optional). | |
pos_embs_tgt: torch.Tensor | |
The positional embeddings for the target (optional). | |
pos_embs_src: torch.Tensor | |
The positional embeddings for the source (optional). | |
""" | |
if self.normalize_before: | |
tgt1 = self.norm1(tgt) | |
else: | |
tgt1 = tgt | |
# Mamba over the target sequence | |
tgt2 = self.self_mamba(tgt1) | |
# add & norm | |
tgt = tgt + self.dropout1(tgt2) | |
if not self.normalize_before: | |
tgt = self.norm1(tgt) | |
if self.normalize_before: | |
tgt1 = self.norm2(tgt) | |
else: | |
tgt1 = tgt | |
# Mamba over key=value + query | |
# and only take the last len(query) tokens | |
tgt2 = self.cross_mamba(torch.cat([memory, tgt1], dim=1))[:, -tgt1.shape[1]:] | |
# add & norm | |
tgt = tgt + self.dropout2(tgt2) | |
if not self.normalize_before: | |
tgt = self.norm2(tgt) | |
if self.normalize_before: | |
tgt1 = self.norm3(tgt) | |
else: | |
tgt1 = tgt | |
tgt2 = self.pos_ffn(tgt1) | |
# add & norm | |
tgt = tgt + self.dropout3(tgt2) | |
if not self.normalize_before: | |
tgt = self.norm3(tgt) | |
return tgt, None, None | |
class MambaDecoder(nn.Module): | |
"""This class implements the Mamba decoder. | |
""" | |
def __init__( | |
self, | |
num_layers, | |
d_model, | |
d_ffn, | |
activation=nn.ReLU, | |
dropout=0.0, | |
normalize_before=False, | |
mamba_config=None | |
): | |
super().__init__() | |
self.layers = torch.nn.ModuleList( | |
[ | |
MambaDecoderLayer( | |
d_model=d_model, | |
d_ffn=d_ffn, | |
activation=activation, | |
dropout=dropout, | |
normalize_before=normalize_before, | |
mamba_config=mamba_config | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask=None, | |
memory_mask=None, | |
tgt_key_padding_mask=None, | |
memory_key_padding_mask=None, | |
pos_embs_tgt=None, | |
pos_embs_src=None, | |
): | |
""" | |
Arguments | |
---------- | |
tgt : torch.Tensor | |
The sequence to the decoder layer (required). | |
memory : torch.Tensor | |
The sequence from the last layer of the encoder (required). | |
tgt_mask : torch.Tensor | |
The mask for the tgt sequence (optional). | |
memory_mask : torch.Tensor | |
The mask for the memory sequence (optional). | |
tgt_key_padding_mask : torch.Tensor | |
The mask for the tgt keys per batch (optional). | |
memory_key_padding_mask : torch.Tensor | |
The mask for the memory keys per batch (optional). | |
pos_embs_tgt : torch.Tensor | |
The positional embeddings for the target (optional). | |
pos_embs_src : torch.Tensor | |
The positional embeddings for the source (optional). | |
""" | |
output = tgt | |
for dec_layer in self.layers: | |
output, _, _ = dec_layer( | |
output, | |
memory, | |
tgt_mask=tgt_mask, | |
memory_mask=memory_mask, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
memory_key_padding_mask=memory_key_padding_mask, | |
pos_embs_tgt=pos_embs_tgt, | |
pos_embs_src=pos_embs_src, | |
) | |
output = self.norm(output) | |
return output, [None], [None] | |