jethrowang's picture
Upload 18 files
1423dc8 verified
"""ConMamba encoder and Mamba decoder implementation.
Authors
-------
* Xilin Jiang 2024
"""
import warnings
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import speechbrain as sb
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.attention import (
MultiheadAttention,
PositionalwiseFeedForward,
RelPosMHAXL,
)
from speechbrain.nnet.hypermixing import HyperMixing
from speechbrain.nnet.normalization import LayerNorm
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
# Mamba
from mamba_ssm import Mamba
from .mamba.bimamba import Mamba as BiMamba
class ConvolutionModule(nn.Module):
"""This is an implementation of convolution module in Conmamba.
"""
def __init__(
self,
input_size,
kernel_size=31,
bias=True,
activation=Swish,
dropout=0.0,
causal=False,
dilation=1,
):
super().__init__()
self.kernel_size = kernel_size
self.causal = causal
self.dilation = dilation
if self.causal:
self.padding = (kernel_size - 1) * 2 ** (dilation - 1)
else:
self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2
self.layer_norm = nn.LayerNorm(input_size)
self.bottleneck = nn.Sequential(
# pointwise
nn.Conv1d(
input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias
),
nn.GLU(dim=1),
)
# depthwise
self.conv = nn.Conv1d(
input_size,
input_size,
kernel_size=kernel_size,
stride=1,
padding=self.padding,
dilation=dilation,
groups=input_size,
bias=bias,
)
# BatchNorm in the original Conformer replaced with a LayerNorm due to
# https://github.com/speechbrain/speechbrain/pull/1329
# see discussion
# https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884
self.after_conv = nn.Sequential(
nn.LayerNorm(input_size),
activation(),
# pointwise
nn.Linear(input_size, input_size, bias=bias),
nn.Dropout(dropout),
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
):
"""Applies the convolution to an input tensor `x`.
"""
if dynchunktrain_config is not None:
# chances are chunking+causal is unintended; i don't know where it
# may make sense, but if it does to you, feel free to implement it.
assert (
not self.causal
), "Chunked convolution not supported with causal padding"
assert (
self.dilation == 1
), "Current DynChunkTrain logic does not support dilation != 1"
# in a causal convolution, which is not the case here, an output
# frame would never be able to depend on a input frame from any
# point in the future.
# but with the dynamic chunk convolution, we instead use a "normal"
# convolution but where, for any output frame, the future beyond the
# "current" chunk gets masked.
# see the paper linked in the documentation for details.
chunk_size = dynchunktrain_config.chunk_size
batch_size = x.shape[0]
# determine the amount of padding we need to insert at the right of
# the last chunk so that all chunks end up with the same size.
if x.shape[1] % chunk_size != 0:
final_right_padding = chunk_size - (x.shape[1] % chunk_size)
else:
final_right_padding = 0
# -> [batch_size, t, in_channels]
out = self.layer_norm(x)
# -> [batch_size, in_channels, t] for the CNN
out = out.transpose(1, 2)
# -> [batch_size, in_channels, t] (pointwise)
out = self.bottleneck(out)
# -> [batch_size, in_channels, lc+t+final_right_padding]
out = F.pad(out, (self.padding, final_right_padding), value=0)
# now, make chunks with left context.
# as a recap to what the above padding and this unfold do, consider
# each a/b/c letter represents a frame as part of chunks a, b, c.
# consider a chunk size of 4 and a kernel size of 5 (padding=2):
#
# input seq: 00aaaabbbbcc00
# chunk #1: 00aaaa
# chunk #2: aabbbb
# chunk #3: bbcc00
#
# a few remarks here:
# - the left padding gets inserted early so that the unfold logic
# works trivially
# - the right 0-padding got inserted as the number of time steps
# could not be evenly split in `chunk_size` chunks
# -> [batch_size, in_channels, num_chunks, lc+chunk_size]
out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size)
# as we manually disable padding in the convolution below, we insert
# right 0-padding to the chunks, e.g. reusing the above example:
#
# chunk #1: 00aaaa00
# chunk #2: aabbbb00
# chunk #3: bbcc0000
# -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad]
out = F.pad(out, (0, self.padding), value=0)
# the transpose+flatten effectively flattens chunks into the batch
# dimension to be processed into the time-wise convolution. the
# chunks will later on be unflattened.
# -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad]
out = out.transpose(1, 2)
# -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad]
out = out.flatten(start_dim=0, end_dim=1)
# TODO: experiment around reflect padding, which is difficult
# because small chunks have too little time steps to reflect from
# let's keep backwards compat by pointing at the weights from the
# already declared Conv1d.
#
# still reusing the above example, the convolution will be applied,
# with the padding truncated on both ends. the following example
# shows the letter corresponding to the input frame on which the
# convolution was centered.
#
# as you can see, the sum of lengths of all chunks is equal to our
# input sequence length + `final_right_padding`.
#
# chunk #1: aaaa
# chunk #2: bbbb
# chunk #3: cc00
# -> [batch_size * num_chunks, out_channels, chunk_size]
out = F.conv1d(
out,
weight=self.conv.weight,
bias=self.conv.bias,
stride=self.conv.stride,
padding=0,
dilation=self.conv.dilation,
groups=self.conv.groups,
)
# -> [batch_size * num_chunks, chunk_size, out_channels]
out = out.transpose(1, 2)
out = self.after_conv(out)
# -> [batch_size, num_chunks, chunk_size, out_channels]
out = torch.unflatten(out, dim=0, sizes=(batch_size, -1))
# -> [batch_size, t + final_right_padding, out_channels]
out = torch.flatten(out, start_dim=1, end_dim=2)
# -> [batch_size, t, out_channels]
if final_right_padding > 0:
out = out[:, :-final_right_padding, :]
else:
out = self.layer_norm(x)
out = out.transpose(1, 2)
out = self.bottleneck(out)
out = self.conv(out)
if self.causal:
# chomp
out = out[..., : -self.padding]
out = out.transpose(1, 2)
out = self.after_conv(out)
if mask is not None:
out.masked_fill_(mask, 0.0)
return out
class ConmambaEncoderLayer(nn.Module):
"""This is an implementation of Conmamba encoder layer.
"""
def __init__(
self,
d_model,
d_ffn,
kernel_size=31,
activation=Swish,
bias=True,
dropout=0.0,
causal=False,
mamba_config=None
):
super().__init__()
assert mamba_config != None
bidirectional = mamba_config.pop('bidirectional')
if causal or (not bidirectional):
self.mamba = Mamba(
d_model=d_model,
**mamba_config
)
else:
self.mamba = BiMamba(
d_model=d_model,
bimamba_type='v2',
**mamba_config
)
mamba_config['bidirectional'] = bidirectional
self.convolution_module = ConvolutionModule(
d_model, kernel_size, bias, activation, dropout, causal=causal
)
self.ffn_module1 = nn.Sequential(
nn.LayerNorm(d_model),
PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
dropout=dropout,
activation=activation,
),
nn.Dropout(dropout),
)
self.ffn_module2 = nn.Sequential(
nn.LayerNorm(d_model),
PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
dropout=dropout,
activation=activation,
),
nn.Dropout(dropout),
)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(
self,
x,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: torch.Tensor = None,
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
):
conv_mask: Optional[torch.Tensor] = None
if src_key_padding_mask is not None:
conv_mask = src_key_padding_mask.unsqueeze(-1)
conv_mask = None
# ffn module
x = x + 0.5 * self.ffn_module1(x)
# mamba module
skip = x
x = self.norm1(x)
x = self.mamba(x)
x = x + skip
# convolution module
x = x + self.convolution_module(
x, conv_mask, dynchunktrain_config=dynchunktrain_config
)
# ffn module
x = self.norm2(x + 0.5 * self.ffn_module2(x))
return x
class ConmambaEncoder(nn.Module):
"""This class implements the Conmamba encoder.
"""
def __init__(
self,
num_layers,
d_model,
d_ffn,
kernel_size=31,
activation=Swish,
bias=True,
dropout=0.0,
causal=False,
mamba_config=None
):
super().__init__()
print(f'dropout={str(dropout)} is not used in Mamba.')
self.layers = torch.nn.ModuleList(
[
ConmambaEncoderLayer(
d_model=d_model,
d_ffn=d_ffn,
dropout=dropout,
activation=activation,
kernel_size=kernel_size,
bias=bias,
causal=causal,
mamba_config=mamba_config,
)
for i in range(num_layers)
]
)
self.norm = LayerNorm(d_model, eps=1e-6)
def forward(
self,
src,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
):
"""
Arguments
----------
src : torch.Tensor
The sequence to the encoder layer.
src_mask : torch.Tensor, optional
The mask for the src sequence.
src_key_padding_mask : torch.Tensor, optional
The mask for the src keys per batch.
pos_embs: torch.Tensor, torch.nn.Module,
Module or tensor containing the input sequence positional embeddings
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
where S is the sequence length, and E is the embedding dimension.
dynchunktrain_config: Optional[DynChunkTrainConfig]
Dynamic Chunk Training configuration object for streaming,
specifically involved here to apply Dynamic Chunk Convolution to the
convolution module.
"""
output = src
for enc_layer in self.layers:
output = enc_layer(
output,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
pos_embs=pos_embs,
dynchunktrain_config=dynchunktrain_config,
)
output = self.norm(output)
return output, None
class MambaDecoderLayer(nn.Module):
"""This class implements the Mamba decoder layer.
"""
def __init__(
self,
d_model,
d_ffn,
activation=nn.ReLU,
dropout=0.0,
normalize_before=False,
mamba_config=None
):
super().__init__()
assert mamba_config != None
bidirectional = mamba_config.pop('bidirectional')
self.self_mamba = Mamba(
d_model=d_model,
**mamba_config
)
self.cross_mamba = Mamba(
d_model=d_model,
**mamba_config
)
mamba_config['bidirectional'] = bidirectional
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
dropout=dropout,
activation=activation,
)
# normalization layers
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.dropout3 = torch.nn.Dropout(dropout)
self.normalize_before = normalize_before
def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
pos_embs_tgt=None,
pos_embs_src=None,
):
"""
Arguments
----------
tgt: torch.Tensor
The sequence to the decoder layer (required).
memory: torch.Tensor
The sequence from the last layer of the encoder (required).
tgt_mask: torch.Tensor
The mask for the tgt sequence (optional).
memory_mask: torch.Tensor
The mask for the memory sequence (optional).
tgt_key_padding_mask: torch.Tensor
The mask for the tgt keys per batch (optional).
memory_key_padding_mask: torch.Tensor
The mask for the memory keys per batch (optional).
pos_embs_tgt: torch.Tensor
The positional embeddings for the target (optional).
pos_embs_src: torch.Tensor
The positional embeddings for the source (optional).
"""
if self.normalize_before:
tgt1 = self.norm1(tgt)
else:
tgt1 = tgt
# Mamba over the target sequence
tgt2 = self.self_mamba(tgt1)
# add & norm
tgt = tgt + self.dropout1(tgt2)
if not self.normalize_before:
tgt = self.norm1(tgt)
if self.normalize_before:
tgt1 = self.norm2(tgt)
else:
tgt1 = tgt
# Mamba over key=value + query
# and only take the last len(query) tokens
tgt2 = self.cross_mamba(torch.cat([memory, tgt1], dim=1))[:, -tgt1.shape[1]:]
# add & norm
tgt = tgt + self.dropout2(tgt2)
if not self.normalize_before:
tgt = self.norm2(tgt)
if self.normalize_before:
tgt1 = self.norm3(tgt)
else:
tgt1 = tgt
tgt2 = self.pos_ffn(tgt1)
# add & norm
tgt = tgt + self.dropout3(tgt2)
if not self.normalize_before:
tgt = self.norm3(tgt)
return tgt, None, None
class MambaDecoder(nn.Module):
"""This class implements the Mamba decoder.
"""
def __init__(
self,
num_layers,
d_model,
d_ffn,
activation=nn.ReLU,
dropout=0.0,
normalize_before=False,
mamba_config=None
):
super().__init__()
self.layers = torch.nn.ModuleList(
[
MambaDecoderLayer(
d_model=d_model,
d_ffn=d_ffn,
activation=activation,
dropout=dropout,
normalize_before=normalize_before,
mamba_config=mamba_config
)
for _ in range(num_layers)
]
)
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
pos_embs_tgt=None,
pos_embs_src=None,
):
"""
Arguments
----------
tgt : torch.Tensor
The sequence to the decoder layer (required).
memory : torch.Tensor
The sequence from the last layer of the encoder (required).
tgt_mask : torch.Tensor
The mask for the tgt sequence (optional).
memory_mask : torch.Tensor
The mask for the memory sequence (optional).
tgt_key_padding_mask : torch.Tensor
The mask for the tgt keys per batch (optional).
memory_key_padding_mask : torch.Tensor
The mask for the memory keys per batch (optional).
pos_embs_tgt : torch.Tensor
The positional embeddings for the target (optional).
pos_embs_src : torch.Tensor
The positional embeddings for the source (optional).
"""
output = tgt
for dec_layer in self.layers:
output, _, _ = dec_layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos_embs_tgt=pos_embs_tgt,
pos_embs_src=pos_embs_src,
)
output = self.norm(output)
return output, [None], [None]