Spaces:
Sleeping
Sleeping
"""Conformer implementation. | |
Authors | |
------- | |
* Jianyuan Zhong 2020 | |
* Samuele Cornell 2021 | |
* Sylvain de Langen 2023 | |
""" | |
import warnings | |
from dataclasses import dataclass | |
from typing import List, Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import speechbrain as sb | |
from speechbrain.nnet.activations import Swish | |
from speechbrain.nnet.attention import ( | |
MultiheadAttention, | |
PositionalwiseFeedForward, | |
RelPosMHAXL, | |
) | |
from speechbrain.nnet.hypermixing import HyperMixing | |
from speechbrain.nnet.normalization import LayerNorm | |
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig | |
class ConformerEncoderLayerStreamingContext: | |
"""Streaming metadata and state for a `ConformerEncoderLayer`. | |
The multi-head attention and Dynamic Chunk Convolution require to save some | |
left context that gets inserted as left padding. | |
See :class:`.ConvolutionModule` documentation for further details. | |
""" | |
mha_left_context_size: int | |
"""For this layer, specifies how many frames of inputs should be saved. | |
Usually, the same value is used across all layers, but this can be modified. | |
""" | |
mha_left_context: Optional[torch.Tensor] = None | |
"""Left context to insert at the left of the current chunk as inputs to the | |
multi-head attention. It can be `None` (if we're dealing with the first | |
chunk) or `<= mha_left_context_size` because for the first few chunks, not | |
enough left context may be available to pad. | |
""" | |
dcconv_left_context: Optional[torch.Tensor] = None | |
"""Left context to insert at the left of the convolution according to the | |
Dynamic Chunk Convolution method. | |
Unlike `mha_left_context`, here the amount of frames to keep is fixed and | |
inferred from the kernel size of the convolution module. | |
""" | |
class ConformerEncoderStreamingContext: | |
"""Streaming metadata and state for a `ConformerEncoder`.""" | |
dynchunktrain_config: DynChunkTrainConfig | |
"""Dynamic Chunk Training configuration holding chunk size and context size | |
information.""" | |
layers: List[ConformerEncoderLayerStreamingContext] | |
"""Streaming metadata and state for each layer of the encoder.""" | |
class ConvolutionModule(nn.Module): | |
"""This is an implementation of convolution module in Conformer. | |
Arguments | |
--------- | |
input_size : int | |
The expected size of the input embedding dimension. | |
kernel_size: int, optional | |
Kernel size of non-bottleneck convolutional layer. | |
bias: bool, optional | |
Whether to use bias in the non-bottleneck conv layer. | |
activation: torch.nn.Module | |
Activation function used after non-bottleneck conv layer. | |
dropout: float, optional | |
Dropout rate. | |
causal: bool, optional | |
Whether the convolution should be causal or not. | |
dilation: int, optional | |
Dilation factor for the non bottleneck conv layer. | |
Example | |
------- | |
>>> import torch | |
>>> x = torch.rand((8, 60, 512)) | |
>>> net = ConvolutionModule(512, 3) | |
>>> output = net(x) | |
>>> output.shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
input_size, | |
kernel_size=31, | |
bias=True, | |
activation=Swish, | |
dropout=0.0, | |
causal=False, | |
dilation=1, | |
): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.causal = causal | |
self.dilation = dilation | |
if self.causal: | |
self.padding = (kernel_size - 1) * 2 ** (dilation - 1) | |
else: | |
self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2 | |
self.layer_norm = nn.LayerNorm(input_size) | |
self.bottleneck = nn.Sequential( | |
# pointwise | |
nn.Conv1d( | |
input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias | |
), | |
nn.GLU(dim=1), | |
) | |
# depthwise | |
self.conv = nn.Conv1d( | |
input_size, | |
input_size, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=self.padding, | |
dilation=dilation, | |
groups=input_size, | |
bias=bias, | |
) | |
# BatchNorm in the original Conformer replaced with a LayerNorm due to | |
# https://github.com/speechbrain/speechbrain/pull/1329 | |
# see discussion | |
# https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884 | |
self.after_conv = nn.Sequential( | |
nn.LayerNorm(input_size), | |
activation(), | |
# pointwise | |
nn.Linear(input_size, input_size, bias=bias), | |
nn.Dropout(dropout), | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
): | |
"""Applies the convolution to an input tensor `x`. | |
Arguments | |
--------- | |
x: torch.Tensor | |
Input tensor to the convolution module. | |
mask: torch.Tensor, optional | |
Mask to be applied over the output of the convolution using | |
`masked_fill_`, if specified. | |
dynchunktrain_config: DynChunkTrainConfig, optional | |
If specified, makes the module support Dynamic Chunk Convolution | |
(DCConv) as implemented by | |
`Dynamic Chunk Convolution for Unified Streaming and Non-Streaming Conformer ASR <https://www.amazon.science/publications/dynamic-chunk-convolution-for-unified-streaming-and-non-streaming-conformer-asr>`_. | |
This allows masking future frames while preserving better accuracy | |
than a fully causal convolution, at a small speed cost. | |
This should only be used for training (or, if you know what you're | |
doing, for masked evaluation at inference time), as the forward | |
streaming function should be used at inference time. | |
Returns | |
------- | |
out: torch.Tensor | |
The output tensor. | |
""" | |
if dynchunktrain_config is not None: | |
# chances are chunking+causal is unintended; i don't know where it | |
# may make sense, but if it does to you, feel free to implement it. | |
assert ( | |
not self.causal | |
), "Chunked convolution not supported with causal padding" | |
assert ( | |
self.dilation == 1 | |
), "Current DynChunkTrain logic does not support dilation != 1" | |
# in a causal convolution, which is not the case here, an output | |
# frame would never be able to depend on a input frame from any | |
# point in the future. | |
# but with the dynamic chunk convolution, we instead use a "normal" | |
# convolution but where, for any output frame, the future beyond the | |
# "current" chunk gets masked. | |
# see the paper linked in the documentation for details. | |
chunk_size = dynchunktrain_config.chunk_size | |
batch_size = x.shape[0] | |
# determine the amount of padding we need to insert at the right of | |
# the last chunk so that all chunks end up with the same size. | |
if x.shape[1] % chunk_size != 0: | |
final_right_padding = chunk_size - (x.shape[1] % chunk_size) | |
else: | |
final_right_padding = 0 | |
# -> [batch_size, t, in_channels] | |
out = self.layer_norm(x) | |
# -> [batch_size, in_channels, t] for the CNN | |
out = out.transpose(1, 2) | |
# -> [batch_size, in_channels, t] (pointwise) | |
out = self.bottleneck(out) | |
# -> [batch_size, in_channels, lc+t+final_right_padding] | |
out = F.pad(out, (self.padding, final_right_padding), value=0) | |
# now, make chunks with left context. | |
# as a recap to what the above padding and this unfold do, consider | |
# each a/b/c letter represents a frame as part of chunks a, b, c. | |
# consider a chunk size of 4 and a kernel size of 5 (padding=2): | |
# | |
# input seq: 00aaaabbbbcc00 | |
# chunk #1: 00aaaa | |
# chunk #2: aabbbb | |
# chunk #3: bbcc00 | |
# | |
# a few remarks here: | |
# - the left padding gets inserted early so that the unfold logic | |
# works trivially | |
# - the right 0-padding got inserted as the number of time steps | |
# could not be evenly split in `chunk_size` chunks | |
# -> [batch_size, in_channels, num_chunks, lc+chunk_size] | |
out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size) | |
# as we manually disable padding in the convolution below, we insert | |
# right 0-padding to the chunks, e.g. reusing the above example: | |
# | |
# chunk #1: 00aaaa00 | |
# chunk #2: aabbbb00 | |
# chunk #3: bbcc0000 | |
# -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad] | |
out = F.pad(out, (0, self.padding), value=0) | |
# the transpose+flatten effectively flattens chunks into the batch | |
# dimension to be processed into the time-wise convolution. the | |
# chunks will later on be unflattened. | |
# -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad] | |
out = out.transpose(1, 2) | |
# -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad] | |
out = out.flatten(start_dim=0, end_dim=1) | |
# TODO: experiment around reflect padding, which is difficult | |
# because small chunks have too little time steps to reflect from | |
# let's keep backwards compat by pointing at the weights from the | |
# already declared Conv1d. | |
# | |
# still reusing the above example, the convolution will be applied, | |
# with the padding truncated on both ends. the following example | |
# shows the letter corresponding to the input frame on which the | |
# convolution was centered. | |
# | |
# as you can see, the sum of lengths of all chunks is equal to our | |
# input sequence length + `final_right_padding`. | |
# | |
# chunk #1: aaaa | |
# chunk #2: bbbb | |
# chunk #3: cc00 | |
# -> [batch_size * num_chunks, out_channels, chunk_size] | |
out = F.conv1d( | |
out, | |
weight=self.conv.weight, | |
bias=self.conv.bias, | |
stride=self.conv.stride, | |
padding=0, | |
dilation=self.conv.dilation, | |
groups=self.conv.groups, | |
) | |
# -> [batch_size * num_chunks, chunk_size, out_channels] | |
out = out.transpose(1, 2) | |
out = self.after_conv(out) | |
# -> [batch_size, num_chunks, chunk_size, out_channels] | |
out = torch.unflatten(out, dim=0, sizes=(batch_size, -1)) | |
# -> [batch_size, t + final_right_padding, out_channels] | |
out = torch.flatten(out, start_dim=1, end_dim=2) | |
# -> [batch_size, t, out_channels] | |
if final_right_padding > 0: | |
out = out[:, :-final_right_padding, :] | |
else: | |
out = self.layer_norm(x) | |
out = out.transpose(1, 2) | |
out = self.bottleneck(out) | |
out = self.conv(out) | |
if self.causal: | |
# chomp | |
out = out[..., : -self.padding] | |
out = out.transpose(1, 2) | |
out = self.after_conv(out) | |
if mask is not None: | |
out.masked_fill_(mask, 0.0) | |
return out | |
class ConformerEncoderLayer(nn.Module): | |
"""This is an implementation of Conformer encoder layer. | |
Arguments | |
--------- | |
d_model : int | |
The expected size of the input embedding. | |
d_ffn : int | |
Hidden size of self-attention Feed Forward layer. | |
nhead : int | |
Number of attention heads. | |
kernel_size : int, optional | |
Kernel size of convolution model. | |
kdim : int, optional | |
Dimension of the key. | |
vdim : int, optional | |
Dimension of the value. | |
activation: torch.nn.Module | |
Activation function used in each Conformer layer. | |
bias : bool, optional | |
Whether convolution module. | |
dropout : int, optional | |
Dropout for the encoder. | |
causal : bool, optional | |
Whether the convolutions should be causal or not. | |
attention_type : str, optional | |
type of attention layer, e.g. regularMHA for regular MultiHeadAttention. | |
Example | |
------- | |
>>> import torch | |
>>> x = torch.rand((8, 60, 512)) | |
>>> pos_embs = torch.rand((1, 2*60-1, 512)) | |
>>> net = ConformerEncoderLayer(d_ffn=512, nhead=8, d_model=512, kernel_size=3) | |
>>> output = net(x, pos_embs=pos_embs) | |
>>> output[0].shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
d_model, | |
d_ffn, | |
nhead, | |
kernel_size=31, | |
kdim=None, | |
vdim=None, | |
activation=Swish, | |
bias=True, | |
dropout=0.0, | |
causal=False, | |
attention_type="RelPosMHAXL", | |
): | |
super().__init__() | |
if attention_type == "regularMHA": | |
self.mha_layer = MultiheadAttention( | |
nhead=nhead, | |
d_model=d_model, | |
dropout=dropout, | |
kdim=kdim, | |
vdim=vdim, | |
) | |
elif attention_type == "RelPosMHAXL": | |
# transformerXL style positional encoding | |
self.mha_layer = RelPosMHAXL( | |
num_heads=nhead, | |
embed_dim=d_model, | |
dropout=dropout, | |
mask_pos_future=causal, | |
) | |
elif attention_type == "hypermixing": | |
self.mha_layer = HyperMixing( | |
input_output_dim=d_model, | |
hypernet_size=d_ffn, | |
tied=False, | |
num_heads=nhead, | |
fix_tm_hidden_size=False, | |
) | |
self.convolution_module = ConvolutionModule( | |
d_model, kernel_size, bias, activation, dropout, causal=causal | |
) | |
self.ffn_module1 = nn.Sequential( | |
nn.LayerNorm(d_model), | |
PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
), | |
nn.Dropout(dropout), | |
) | |
self.ffn_module2 = nn.Sequential( | |
nn.LayerNorm(d_model), | |
PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
), | |
nn.Dropout(dropout), | |
) | |
self.norm1 = LayerNorm(d_model) | |
self.norm2 = LayerNorm(d_model) | |
self.drop = nn.Dropout(dropout) | |
def forward( | |
self, | |
x, | |
src_mask: Optional[torch.Tensor] = None, | |
src_key_padding_mask: Optional[torch.Tensor] = None, | |
pos_embs: torch.Tensor = None, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
): | |
""" | |
Arguments | |
---------- | |
src : torch.Tensor | |
The sequence to the encoder layer. | |
src_mask : torch.Tensor, optional | |
The mask for the src sequence. | |
src_key_padding_mask : torch.Tensor, optional | |
The mask for the src keys per batch. | |
pos_embs: torch.Tensor, torch.nn.Module, optional | |
Module or tensor containing the input sequence positional embeddings | |
dynchunktrain_config: Optional[DynChunkTrainConfig] | |
Dynamic Chunk Training configuration object for streaming, | |
specifically involved here to apply Dynamic Chunk Convolution to | |
the convolution module. | |
""" | |
conv_mask: Optional[torch.Tensor] = None | |
if src_key_padding_mask is not None: | |
conv_mask = src_key_padding_mask.unsqueeze(-1) | |
# ffn module | |
x = x + 0.5 * self.ffn_module1(x) | |
# multi-head attention module | |
skip = x | |
x = self.norm1(x) | |
x, self_attn = self.mha_layer( | |
x, | |
x, | |
x, | |
attn_mask=src_mask, | |
key_padding_mask=src_key_padding_mask, | |
pos_embs=pos_embs, | |
) | |
x = x + skip | |
# convolution module | |
x = x + self.convolution_module( | |
x, conv_mask, dynchunktrain_config=dynchunktrain_config | |
) | |
# ffn module | |
x = self.norm2(x + 0.5 * self.ffn_module2(x)) | |
return x, self_attn | |
def forward_streaming( | |
self, | |
x, | |
context: ConformerEncoderLayerStreamingContext, | |
pos_embs: torch.Tensor = None, | |
): | |
"""Conformer layer streaming forward (typically for | |
DynamicChunkTraining-trained models), which is to be used at inference | |
time. Relies on a mutable context object as initialized by | |
`make_streaming_context` that should be used across chunks. | |
Invoked by `ConformerEncoder.forward_streaming`. | |
Arguments | |
--------- | |
x : torch.Tensor | |
Input tensor for this layer. Batching is supported as long as you | |
keep the context consistent. | |
context : ConformerEncoderStreamingContext | |
Mutable streaming context; the same object should be passed across | |
calls. | |
pos_embs : torch.Tensor, optional | |
Positional embeddings, if used. | |
Returns | |
------- | |
x : torch.Tensor | |
Output tensor. | |
self_attn : list | |
List of self attention values. | |
""" | |
orig_len = x.shape[-2] | |
# ffn module | |
x = x + 0.5 * self.ffn_module1(x) | |
# TODO: make the approach for MHA left context more efficient. | |
# currently, this saves the inputs to the MHA. | |
# the naive approach is suboptimal in a few ways, namely that the | |
# outputs for this left padding is being re-computed even though we | |
# discard them immediately after. | |
# left pad `x` with our MHA left context | |
if context.mha_left_context is not None: | |
x = torch.cat((context.mha_left_context, x), dim=1) | |
# compute new MHA left context for the next call to our function | |
if context.mha_left_context_size > 0: | |
context.mha_left_context = x[ | |
..., -context.mha_left_context_size :, : | |
] | |
# multi-head attention module | |
skip = x | |
x = self.norm1(x) | |
x, self_attn = self.mha_layer( | |
x, | |
x, | |
x, | |
attn_mask=None, | |
key_padding_mask=None, | |
pos_embs=pos_embs, | |
) | |
x = x + skip | |
# truncate outputs corresponding to the MHA left context (we only care | |
# about our chunk's outputs); see above to-do | |
x = x[..., -orig_len:, :] | |
if context.dcconv_left_context is not None: | |
x = torch.cat((context.dcconv_left_context, x), dim=1) | |
# compute new DCConv left context for the next call to our function | |
context.dcconv_left_context = x[ | |
..., -self.convolution_module.padding :, : | |
] | |
# convolution module | |
x = x + self.convolution_module(x) | |
# truncate outputs corresponding to the DCConv left context | |
x = x[..., -orig_len:, :] | |
# ffn module | |
x = self.norm2(x + 0.5 * self.ffn_module2(x)) | |
return x, self_attn | |
def make_streaming_context(self, mha_left_context_size: int): | |
"""Creates a blank streaming context for this encoding layer. | |
Arguments | |
--------- | |
mha_left_context_size : int | |
How many left frames should be saved and used as left context to the | |
current chunk when streaming | |
Returns | |
------- | |
ConformerEncoderLayerStreamingContext | |
""" | |
return ConformerEncoderLayerStreamingContext( | |
mha_left_context_size=mha_left_context_size | |
) | |
class ConformerEncoder(nn.Module): | |
"""This class implements the Conformer encoder. | |
Arguments | |
--------- | |
num_layers : int | |
Number of layers. | |
d_model : int | |
Embedding dimension size. | |
d_ffn : int | |
Hidden size of self-attention Feed Forward layer. | |
nhead : int | |
Number of attention heads. | |
kernel_size : int, optional | |
Kernel size of convolution model. | |
kdim : int, optional | |
Dimension of the key. | |
vdim : int, optional | |
Dimension of the value. | |
activation: torch.nn.Module | |
Activation function used in each Confomer layer. | |
bias : bool, optional | |
Whether convolution module. | |
dropout : int, optional | |
Dropout for the encoder. | |
causal: bool, optional | |
Whether the convolutions should be causal or not. | |
attention_type: str, optional | |
type of attention layer, e.g. regularMHA for regular MultiHeadAttention. | |
Example | |
------- | |
>>> import torch | |
>>> x = torch.rand((8, 60, 512)) | |
>>> pos_emb = torch.rand((1, 2*60-1, 512)) | |
>>> net = ConformerEncoder(1, 512, 512, 8) | |
>>> output, _ = net(x, pos_embs=pos_emb) | |
>>> output.shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
num_layers, | |
d_model, | |
d_ffn, | |
nhead, | |
kernel_size=31, | |
kdim=None, | |
vdim=None, | |
activation=Swish, | |
bias=True, | |
dropout=0.0, | |
causal=False, | |
attention_type="RelPosMHAXL", | |
): | |
super().__init__() | |
self.layers = torch.nn.ModuleList( | |
[ | |
ConformerEncoderLayer( | |
d_ffn=d_ffn, | |
nhead=nhead, | |
d_model=d_model, | |
kdim=kdim, | |
vdim=vdim, | |
dropout=dropout, | |
activation=activation, | |
kernel_size=kernel_size, | |
bias=bias, | |
causal=causal, | |
attention_type=attention_type, | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.norm = LayerNorm(d_model, eps=1e-6) | |
self.attention_type = attention_type | |
def forward( | |
self, | |
src, | |
src_mask: Optional[torch.Tensor] = None, | |
src_key_padding_mask: Optional[torch.Tensor] = None, | |
pos_embs: Optional[torch.Tensor] = None, | |
dynchunktrain_config: Optional[DynChunkTrainConfig] = None, | |
): | |
""" | |
Arguments | |
---------- | |
src : torch.Tensor | |
The sequence to the encoder layer. | |
src_mask : torch.Tensor, optional | |
The mask for the src sequence. | |
src_key_padding_mask : torch.Tensor, optional | |
The mask for the src keys per batch. | |
pos_embs: torch.Tensor, torch.nn.Module, | |
Module or tensor containing the input sequence positional embeddings | |
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) | |
where S is the sequence length, and E is the embedding dimension. | |
dynchunktrain_config: Optional[DynChunkTrainConfig] | |
Dynamic Chunk Training configuration object for streaming, | |
specifically involved here to apply Dynamic Chunk Convolution to the | |
convolution module. | |
""" | |
if self.attention_type == "RelPosMHAXL": | |
if pos_embs is None: | |
raise ValueError( | |
"The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory" | |
) | |
output = src | |
attention_lst = [] | |
for enc_layer in self.layers: | |
output, attention = enc_layer( | |
output, | |
src_mask=src_mask, | |
src_key_padding_mask=src_key_padding_mask, | |
pos_embs=pos_embs, | |
dynchunktrain_config=dynchunktrain_config, | |
) | |
attention_lst.append(attention) | |
output = self.norm(output) | |
return output, attention_lst | |
def forward_streaming( | |
self, | |
src: torch.Tensor, | |
context: ConformerEncoderStreamingContext, | |
pos_embs: Optional[torch.Tensor] = None, | |
): | |
"""Conformer streaming forward (typically for | |
DynamicChunkTraining-trained models), which is to be used at inference | |
time. Relies on a mutable context object as initialized by | |
`make_streaming_context` that should be used across chunks. | |
Arguments | |
--------- | |
src : torch.Tensor | |
Input tensor. Batching is supported as long as you keep the context | |
consistent. | |
context : ConformerEncoderStreamingContext | |
Mutable streaming context; the same object should be passed across | |
calls. | |
pos_embs : torch.Tensor, optional | |
Positional embeddings, if used. | |
Returns | |
------- | |
output : torch.Tensor | |
The output of the streaming conformer. | |
attention_lst : list | |
The attention values. | |
""" | |
if self.attention_type == "RelPosMHAXL": | |
if pos_embs is None: | |
raise ValueError( | |
"The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory" | |
) | |
output = src | |
attention_lst = [] | |
for i, enc_layer in enumerate(self.layers): | |
output, attention = enc_layer.forward_streaming( | |
output, pos_embs=pos_embs, context=context.layers[i] | |
) | |
attention_lst.append(attention) | |
output = self.norm(output) | |
return output, attention_lst | |
def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig): | |
"""Creates a blank streaming context for the encoder. | |
Arguments | |
--------- | |
dynchunktrain_config: Optional[DynChunkTrainConfig] | |
Dynamic Chunk Training configuration object for streaming | |
Returns | |
------- | |
ConformerEncoderStreamingContext | |
""" | |
return ConformerEncoderStreamingContext( | |
dynchunktrain_config=dynchunktrain_config, | |
layers=[ | |
layer.make_streaming_context( | |
mha_left_context_size=dynchunktrain_config.left_context_size_frames() | |
) | |
for layer in self.layers | |
], | |
) | |
class ConformerDecoderLayer(nn.Module): | |
"""This is an implementation of Conformer encoder layer. | |
Arguments | |
--------- | |
d_model : int | |
The expected size of the input embedding. | |
d_ffn : int | |
Hidden size of self-attention Feed Forward layer. | |
nhead : int | |
Number of attention heads. | |
kernel_size : int, optional | |
Kernel size of convolution model. | |
kdim : int, optional | |
Dimension of the key. | |
vdim : int, optional | |
Dimension of the value. | |
activation : torch.nn.Module, optional | |
Activation function used in each Conformer layer. | |
bias : bool, optional | |
Whether convolution module. | |
dropout : int, optional | |
Dropout for the encoder. | |
causal : bool, optional | |
Whether the convolutions should be causal or not. | |
attention_type : str, optional | |
type of attention layer, e.g. regularMHA for regular MultiHeadAttention. | |
Example | |
------- | |
>>> import torch | |
>>> x = torch.rand((8, 60, 512)) | |
>>> pos_embs = torch.rand((1, 2*60-1, 512)) | |
>>> net = ConformerEncoderLayer(d_ffn=512, nhead=8, d_model=512, kernel_size=3) | |
>>> output = net(x, pos_embs=pos_embs) | |
>>> output[0].shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
d_model, | |
d_ffn, | |
nhead, | |
kernel_size, | |
kdim=None, | |
vdim=None, | |
activation=Swish, | |
bias=True, | |
dropout=0.0, | |
causal=True, | |
attention_type="RelPosMHAXL", | |
): | |
super().__init__() | |
if not causal: | |
warnings.warn( | |
"Decoder is not causal, in most applications it should be causal, you have been warned !" | |
) | |
if attention_type == "regularMHA": | |
self.mha_layer = MultiheadAttention( | |
nhead=nhead, | |
d_model=d_model, | |
dropout=dropout, | |
kdim=kdim, | |
vdim=vdim, | |
) | |
elif attention_type == "RelPosMHAXL": | |
# transformerXL style positional encoding | |
self.mha_layer = RelPosMHAXL( | |
num_heads=nhead, | |
embed_dim=d_model, | |
dropout=dropout, | |
mask_pos_future=causal, | |
) | |
self.convolution_module = ConvolutionModule( | |
d_model, kernel_size, bias, activation, dropout, causal=causal | |
) | |
self.ffn_module1 = nn.Sequential( | |
nn.LayerNorm(d_model), | |
PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
), | |
nn.Dropout(dropout), | |
) | |
self.ffn_module2 = nn.Sequential( | |
nn.LayerNorm(d_model), | |
PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
), | |
nn.Dropout(dropout), | |
) | |
self.norm1 = LayerNorm(d_model) | |
self.norm2 = LayerNorm(d_model) | |
self.drop = nn.Dropout(dropout) | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask=None, | |
memory_mask=None, | |
tgt_key_padding_mask=None, | |
memory_key_padding_mask=None, | |
pos_embs_tgt=None, | |
pos_embs_src=None, | |
): | |
""" | |
Arguments | |
--------- | |
tgt: torch.Tensor | |
The sequence to the decoder layer. | |
memory: torch.Tensor | |
The sequence from the last layer of the encoder. | |
tgt_mask: torch.Tensor, optional, optional | |
The mask for the tgt sequence. | |
memory_mask: torch.Tensor, optional | |
The mask for the memory sequence. | |
tgt_key_padding_mask: torch.Tensor, optional | |
The mask for the tgt keys per batch. | |
memory_key_padding_mask: torch.Tensor, optional | |
The mask for the memory keys per batch. | |
pos_embs_tgt: torch.Tensor, torch.nn.Module, optional | |
Module or tensor containing the target sequence positional embeddings for each attention layer. | |
pos_embs_src: torch.Tensor, torch.nn.Module, optional | |
Module or tensor containing the source sequence positional embeddings for each attention layer. | |
Returns | |
------- | |
x: torch.Tensor | |
The output tensor | |
self_attn : torch.Tensor | |
self_attn : torch.Tensor | |
The self attention tensor | |
""" | |
# ffn module | |
tgt = tgt + 0.5 * self.ffn_module1(tgt) | |
# multi-head attention module | |
skip = tgt | |
x = self.norm1(tgt) | |
x, self_attn = self.mha_layer( | |
x, | |
memory, | |
memory, | |
attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask, | |
pos_embs=pos_embs_src, | |
) | |
x = x + skip | |
# convolution module | |
x = x + self.convolution_module(x) | |
# ffn module | |
x = self.norm2(x + 0.5 * self.ffn_module2(x)) | |
return x, self_attn, self_attn | |
class ConformerDecoder(nn.Module): | |
"""This class implements the Transformer decoder. | |
Arguments | |
--------- | |
num_layers: int | |
Number of layers. | |
nhead: int | |
Number of attention heads. | |
d_ffn: int | |
Hidden size of self-attention Feed Forward layer. | |
d_model: int | |
Embedding dimension size. | |
kdim: int, optional | |
Dimension for key. | |
vdim: int, optional | |
Dimension for value. | |
dropout: float, optional | |
Dropout rate. | |
activation: torch.nn.Module, optional | |
Activation function used after non-bottleneck conv layer. | |
kernel_size : int, optional | |
Kernel size of convolutional layer. | |
bias : bool, optional | |
Whether convolution module. | |
causal: bool, optional | |
Whether the convolutions should be causal or not. | |
attention_type: str, optional | |
type of attention layer, e.g. regularMHA for regular MultiHeadAttention. | |
Example | |
------- | |
>>> src = torch.rand((8, 60, 512)) | |
>>> tgt = torch.rand((8, 60, 512)) | |
>>> net = ConformerDecoder(1, 8, 1024, 512, attention_type="regularMHA") | |
>>> output, _, _ = net(tgt, src) | |
>>> output.shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
num_layers, | |
nhead, | |
d_ffn, | |
d_model, | |
kdim=None, | |
vdim=None, | |
dropout=0.0, | |
activation=Swish, | |
kernel_size=3, | |
bias=True, | |
causal=True, | |
attention_type="RelPosMHAXL", | |
): | |
super().__init__() | |
self.layers = torch.nn.ModuleList( | |
[ | |
ConformerDecoderLayer( | |
d_ffn=d_ffn, | |
nhead=nhead, | |
d_model=d_model, | |
kdim=kdim, | |
vdim=vdim, | |
dropout=dropout, | |
activation=activation, | |
kernel_size=kernel_size, | |
bias=bias, | |
causal=causal, | |
attention_type=attention_type, | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask=None, | |
memory_mask=None, | |
tgt_key_padding_mask=None, | |
memory_key_padding_mask=None, | |
pos_embs_tgt=None, | |
pos_embs_src=None, | |
): | |
""" | |
Arguments | |
--------- | |
tgt: torch.Tensor | |
The sequence to the decoder layer. | |
memory: torch.Tensor | |
The sequence from the last layer of the encoder. | |
tgt_mask: torch.Tensor, optional, optional | |
The mask for the tgt sequence. | |
memory_mask: torch.Tensor, optional | |
The mask for the memory sequence. | |
tgt_key_padding_mask : torch.Tensor, optional | |
The mask for the tgt keys per batch. | |
memory_key_padding_mask : torch.Tensor, optional | |
The mask for the memory keys per batch. | |
pos_embs_tgt: torch.Tensor, torch.nn.Module, optional | |
Module or tensor containing the target sequence positional embeddings for each attention layer. | |
pos_embs_src: torch.Tensor, torch.nn.Module, optional | |
Module or tensor containing the source sequence positional embeddings for each attention layer. | |
Returns | |
------- | |
output: torch.Tensor | |
Conformer decoder output. | |
self_attns : list | |
Location of self attentions. | |
multihead_attns : list | |
Location of multihead attentions. | |
""" | |
output = tgt | |
self_attns, multihead_attns = [], [] | |
for dec_layer in self.layers: | |
output, self_attn, multihead_attn = dec_layer( | |
output, | |
memory, | |
tgt_mask=tgt_mask, | |
memory_mask=memory_mask, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
memory_key_padding_mask=memory_key_padding_mask, | |
pos_embs_tgt=pos_embs_tgt, | |
pos_embs_src=pos_embs_src, | |
) | |
self_attns.append(self_attn) | |
multihead_attns.append(multihead_attn) | |
output = self.norm(output) | |
return output, self_attns, multihead_attns | |