Spaces:
Sleeping
Sleeping
Upload 18 files
Browse files- model/.DS_Store +0 -0
- model/csp_tiny_layer.py +86 -0
- model/mamba_hf.py +41 -0
- model/modules/.DS_Store +0 -0
- model/modules/Conformer.py +1094 -0
- model/modules/Conmamba.py +607 -0
- model/modules/Transformer.py +1085 -0
- model/modules/TransformerASR.py +682 -0
- model/modules/__init__.py +0 -0
- model/modules/mamba/.DS_Store +0 -0
- model/modules/mamba/__init__.py +0 -0
- model/modules/mamba/bimamba.py +465 -0
- model/modules/mamba/mamba_blocks.py +252 -0
- model/modules/mamba/selective_scan_interface.py +714 -0
- model/patchify.py +20 -0
- model/sinc_conv.py +471 -0
- model/tiny_block.py +31 -0
- model/tinyvad.py +62 -0
model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model/csp_tiny_layer.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .tiny_block import TinyBlock
|
4 |
+
from transformers import MambaConfig, MambaModel
|
5 |
+
# from .conmamba import ConMamba
|
6 |
+
|
7 |
+
class CSPTinyLayer(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels, num_blocks, ssm=False):
|
9 |
+
super(CSPTinyLayer, self).__init__()
|
10 |
+
|
11 |
+
self.ssm = ssm
|
12 |
+
|
13 |
+
# Split channels
|
14 |
+
self.split_channels = in_channels // 2
|
15 |
+
|
16 |
+
if self.ssm:
|
17 |
+
# Mamba Blocks
|
18 |
+
configuration = MambaConfig(vocab_size=0, hidden_size=self.split_channels, num_hidden_layers=num_blocks)
|
19 |
+
self.mamba_blocks = MambaModel(configuration)
|
20 |
+
|
21 |
+
# mamba_config = {
|
22 |
+
# 'd_state': self.split_channels,
|
23 |
+
# 'expand': 2,
|
24 |
+
# 'd_conv': 4,
|
25 |
+
# 'bidirectional': True
|
26 |
+
# }
|
27 |
+
# self.mamba_blocks = ConMamba(
|
28 |
+
# num_blocks=num_blocks,
|
29 |
+
# channels=self.split_channels,
|
30 |
+
# height=8,
|
31 |
+
# width=8,
|
32 |
+
# mamba_config=mamba_config
|
33 |
+
# )
|
34 |
+
|
35 |
+
else:
|
36 |
+
# TinyBlocks
|
37 |
+
self.tiny_blocks = nn.Sequential(
|
38 |
+
*[TinyBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)]
|
39 |
+
)
|
40 |
+
|
41 |
+
# Transition layer to adjust channel dimensions
|
42 |
+
self.transition = nn.Sequential(
|
43 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
|
44 |
+
nn.BatchNorm2d(out_channels),
|
45 |
+
nn.ReLU(inplace=True)
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
# Split input into two parts
|
50 |
+
p1 = x[:, :self.split_channels, :, :]
|
51 |
+
p2 = x[:, self.split_channels:, :, :]
|
52 |
+
|
53 |
+
if self.ssm:
|
54 |
+
# Reshape to fit Mamba
|
55 |
+
B, C, H, W = p2.shape
|
56 |
+
p2 = p2.permute(0, 2, 3, 1) # [B, H, W, C]
|
57 |
+
p2 = p2.reshape(B, H * W, C) # [B, L, C], L = H * W
|
58 |
+
|
59 |
+
# Process p2 through MambaBlocks
|
60 |
+
p2_out = self.mamba_blocks(inputs_embeds=p2).last_hidden_state
|
61 |
+
|
62 |
+
# p2_out = self.mamba_blocks(p2)
|
63 |
+
|
64 |
+
# Reshape back to original dimension
|
65 |
+
p2_out = p2_out.reshape(B, H, W, -1)
|
66 |
+
p2_out = p2_out.permute(0, 3, 1, 2) # [B, C, H, W]
|
67 |
+
else:
|
68 |
+
# Process p2 through TinyBlocks
|
69 |
+
p2_out = self.tiny_blocks(p2)
|
70 |
+
|
71 |
+
# Concatenate p1 and processed p2
|
72 |
+
concatenated = torch.cat((p1, p2_out), dim=1)
|
73 |
+
|
74 |
+
# Apply transition layer
|
75 |
+
out = self.transition(concatenated)
|
76 |
+
return out
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
80 |
+
print(f"Using device: {device}")
|
81 |
+
|
82 |
+
model = CSPTinyLayer(32, 32, 2, True).to(device)
|
83 |
+
print(model)
|
84 |
+
dummy_input = torch.randn(256, 32, 8, 8).to(device)
|
85 |
+
output = model(dummy_input)
|
86 |
+
print(output.shape)
|
model/mamba_hf.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import MambaConfig, MambaModel, Mamba2Config, Mamba2Model
|
3 |
+
|
4 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
5 |
+
if torch.cuda.is_available():
|
6 |
+
print(f"CUDA device: {torch.cuda.get_device_name()}")
|
7 |
+
print(f"CUDA version: {torch.version.cuda}")
|
8 |
+
|
9 |
+
batch, channel, height, width = 256, 16, 8, 8
|
10 |
+
x = torch.randn(batch, channel, height, width).to("cuda")
|
11 |
+
print(f'x: {x.shape}')
|
12 |
+
|
13 |
+
B, C, H, W = x.shape
|
14 |
+
x = x.permute(0, 2, 3, 1) # [B, H, W, C]
|
15 |
+
print(f'Permuted x: {x.shape}')
|
16 |
+
|
17 |
+
x = x.reshape(B, H * W, C) # [B, L, C], L = H * W
|
18 |
+
print(f'Reshaped x: {x.shape}')
|
19 |
+
|
20 |
+
# Initializing a Mamba configuration
|
21 |
+
configuration = MambaConfig(vocab_size=0, hidden_size=channel, num_hidden_layers=2)
|
22 |
+
# configuration = Mamba2Config(hidden_size=channel)
|
23 |
+
|
24 |
+
# Initializing a model (with random weights) from the configuration
|
25 |
+
model = MambaModel(configuration).to("cuda")
|
26 |
+
# model = Mamba2Model(configuration).to("cuda")
|
27 |
+
print(f'Model: {model}')
|
28 |
+
|
29 |
+
# Accessing the model configuration
|
30 |
+
configuration = model.config
|
31 |
+
print(f'Configuration: {configuration}')
|
32 |
+
|
33 |
+
# y = model(inputs_embeds=x).last_hidden_state
|
34 |
+
y = model(inputs_embeds=x, return_dict=True)[0]
|
35 |
+
print(f'y: {y.shape}')
|
36 |
+
|
37 |
+
y = y.reshape(B, H, W, -1)
|
38 |
+
print(f'Reshaped y: {y.shape}')
|
39 |
+
|
40 |
+
y = y.permute(0, 3, 1, 2) # [B, C, H, W]
|
41 |
+
print(f'Permuted y: {y.shape}')
|
model/modules/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model/modules/Conformer.py
ADDED
@@ -0,0 +1,1094 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Conformer implementation.
|
2 |
+
|
3 |
+
Authors
|
4 |
+
-------
|
5 |
+
* Jianyuan Zhong 2020
|
6 |
+
* Samuele Cornell 2021
|
7 |
+
* Sylvain de Langen 2023
|
8 |
+
"""
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from typing import List, Optional
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
import speechbrain as sb
|
19 |
+
from speechbrain.nnet.activations import Swish
|
20 |
+
from speechbrain.nnet.attention import (
|
21 |
+
MultiheadAttention,
|
22 |
+
PositionalwiseFeedForward,
|
23 |
+
RelPosMHAXL,
|
24 |
+
)
|
25 |
+
from speechbrain.nnet.hypermixing import HyperMixing
|
26 |
+
from speechbrain.nnet.normalization import LayerNorm
|
27 |
+
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class ConformerEncoderLayerStreamingContext:
|
32 |
+
"""Streaming metadata and state for a `ConformerEncoderLayer`.
|
33 |
+
|
34 |
+
The multi-head attention and Dynamic Chunk Convolution require to save some
|
35 |
+
left context that gets inserted as left padding.
|
36 |
+
|
37 |
+
See :class:`.ConvolutionModule` documentation for further details.
|
38 |
+
"""
|
39 |
+
|
40 |
+
mha_left_context_size: int
|
41 |
+
"""For this layer, specifies how many frames of inputs should be saved.
|
42 |
+
Usually, the same value is used across all layers, but this can be modified.
|
43 |
+
"""
|
44 |
+
|
45 |
+
mha_left_context: Optional[torch.Tensor] = None
|
46 |
+
"""Left context to insert at the left of the current chunk as inputs to the
|
47 |
+
multi-head attention. It can be `None` (if we're dealing with the first
|
48 |
+
chunk) or `<= mha_left_context_size` because for the first few chunks, not
|
49 |
+
enough left context may be available to pad.
|
50 |
+
"""
|
51 |
+
|
52 |
+
dcconv_left_context: Optional[torch.Tensor] = None
|
53 |
+
"""Left context to insert at the left of the convolution according to the
|
54 |
+
Dynamic Chunk Convolution method.
|
55 |
+
|
56 |
+
Unlike `mha_left_context`, here the amount of frames to keep is fixed and
|
57 |
+
inferred from the kernel size of the convolution module.
|
58 |
+
"""
|
59 |
+
|
60 |
+
|
61 |
+
@dataclass
|
62 |
+
class ConformerEncoderStreamingContext:
|
63 |
+
"""Streaming metadata and state for a `ConformerEncoder`."""
|
64 |
+
|
65 |
+
dynchunktrain_config: DynChunkTrainConfig
|
66 |
+
"""Dynamic Chunk Training configuration holding chunk size and context size
|
67 |
+
information."""
|
68 |
+
|
69 |
+
layers: List[ConformerEncoderLayerStreamingContext]
|
70 |
+
"""Streaming metadata and state for each layer of the encoder."""
|
71 |
+
|
72 |
+
|
73 |
+
class ConvolutionModule(nn.Module):
|
74 |
+
"""This is an implementation of convolution module in Conformer.
|
75 |
+
|
76 |
+
Arguments
|
77 |
+
---------
|
78 |
+
input_size : int
|
79 |
+
The expected size of the input embedding dimension.
|
80 |
+
kernel_size: int, optional
|
81 |
+
Kernel size of non-bottleneck convolutional layer.
|
82 |
+
bias: bool, optional
|
83 |
+
Whether to use bias in the non-bottleneck conv layer.
|
84 |
+
activation: torch.nn.Module
|
85 |
+
Activation function used after non-bottleneck conv layer.
|
86 |
+
dropout: float, optional
|
87 |
+
Dropout rate.
|
88 |
+
causal: bool, optional
|
89 |
+
Whether the convolution should be causal or not.
|
90 |
+
dilation: int, optional
|
91 |
+
Dilation factor for the non bottleneck conv layer.
|
92 |
+
|
93 |
+
Example
|
94 |
+
-------
|
95 |
+
>>> import torch
|
96 |
+
>>> x = torch.rand((8, 60, 512))
|
97 |
+
>>> net = ConvolutionModule(512, 3)
|
98 |
+
>>> output = net(x)
|
99 |
+
>>> output.shape
|
100 |
+
torch.Size([8, 60, 512])
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
input_size,
|
106 |
+
kernel_size=31,
|
107 |
+
bias=True,
|
108 |
+
activation=Swish,
|
109 |
+
dropout=0.0,
|
110 |
+
causal=False,
|
111 |
+
dilation=1,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
self.kernel_size = kernel_size
|
116 |
+
self.causal = causal
|
117 |
+
self.dilation = dilation
|
118 |
+
|
119 |
+
if self.causal:
|
120 |
+
self.padding = (kernel_size - 1) * 2 ** (dilation - 1)
|
121 |
+
else:
|
122 |
+
self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2
|
123 |
+
|
124 |
+
self.layer_norm = nn.LayerNorm(input_size)
|
125 |
+
self.bottleneck = nn.Sequential(
|
126 |
+
# pointwise
|
127 |
+
nn.Conv1d(
|
128 |
+
input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias
|
129 |
+
),
|
130 |
+
nn.GLU(dim=1),
|
131 |
+
)
|
132 |
+
# depthwise
|
133 |
+
self.conv = nn.Conv1d(
|
134 |
+
input_size,
|
135 |
+
input_size,
|
136 |
+
kernel_size=kernel_size,
|
137 |
+
stride=1,
|
138 |
+
padding=self.padding,
|
139 |
+
dilation=dilation,
|
140 |
+
groups=input_size,
|
141 |
+
bias=bias,
|
142 |
+
)
|
143 |
+
|
144 |
+
# BatchNorm in the original Conformer replaced with a LayerNorm due to
|
145 |
+
# https://github.com/speechbrain/speechbrain/pull/1329
|
146 |
+
# see discussion
|
147 |
+
# https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884
|
148 |
+
|
149 |
+
self.after_conv = nn.Sequential(
|
150 |
+
nn.LayerNorm(input_size),
|
151 |
+
activation(),
|
152 |
+
# pointwise
|
153 |
+
nn.Linear(input_size, input_size, bias=bias),
|
154 |
+
nn.Dropout(dropout),
|
155 |
+
)
|
156 |
+
|
157 |
+
def forward(
|
158 |
+
self,
|
159 |
+
x: torch.Tensor,
|
160 |
+
mask: Optional[torch.Tensor] = None,
|
161 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
162 |
+
):
|
163 |
+
"""Applies the convolution to an input tensor `x`.
|
164 |
+
|
165 |
+
Arguments
|
166 |
+
---------
|
167 |
+
x: torch.Tensor
|
168 |
+
Input tensor to the convolution module.
|
169 |
+
mask: torch.Tensor, optional
|
170 |
+
Mask to be applied over the output of the convolution using
|
171 |
+
`masked_fill_`, if specified.
|
172 |
+
dynchunktrain_config: DynChunkTrainConfig, optional
|
173 |
+
If specified, makes the module support Dynamic Chunk Convolution
|
174 |
+
(DCConv) as implemented by
|
175 |
+
`Dynamic Chunk Convolution for Unified Streaming and Non-Streaming Conformer ASR <https://www.amazon.science/publications/dynamic-chunk-convolution-for-unified-streaming-and-non-streaming-conformer-asr>`_.
|
176 |
+
This allows masking future frames while preserving better accuracy
|
177 |
+
than a fully causal convolution, at a small speed cost.
|
178 |
+
This should only be used for training (or, if you know what you're
|
179 |
+
doing, for masked evaluation at inference time), as the forward
|
180 |
+
streaming function should be used at inference time.
|
181 |
+
|
182 |
+
Returns
|
183 |
+
-------
|
184 |
+
out: torch.Tensor
|
185 |
+
The output tensor.
|
186 |
+
"""
|
187 |
+
|
188 |
+
if dynchunktrain_config is not None:
|
189 |
+
# chances are chunking+causal is unintended; i don't know where it
|
190 |
+
# may make sense, but if it does to you, feel free to implement it.
|
191 |
+
assert (
|
192 |
+
not self.causal
|
193 |
+
), "Chunked convolution not supported with causal padding"
|
194 |
+
|
195 |
+
assert (
|
196 |
+
self.dilation == 1
|
197 |
+
), "Current DynChunkTrain logic does not support dilation != 1"
|
198 |
+
|
199 |
+
# in a causal convolution, which is not the case here, an output
|
200 |
+
# frame would never be able to depend on a input frame from any
|
201 |
+
# point in the future.
|
202 |
+
|
203 |
+
# but with the dynamic chunk convolution, we instead use a "normal"
|
204 |
+
# convolution but where, for any output frame, the future beyond the
|
205 |
+
# "current" chunk gets masked.
|
206 |
+
# see the paper linked in the documentation for details.
|
207 |
+
|
208 |
+
chunk_size = dynchunktrain_config.chunk_size
|
209 |
+
batch_size = x.shape[0]
|
210 |
+
|
211 |
+
# determine the amount of padding we need to insert at the right of
|
212 |
+
# the last chunk so that all chunks end up with the same size.
|
213 |
+
if x.shape[1] % chunk_size != 0:
|
214 |
+
final_right_padding = chunk_size - (x.shape[1] % chunk_size)
|
215 |
+
else:
|
216 |
+
final_right_padding = 0
|
217 |
+
|
218 |
+
# -> [batch_size, t, in_channels]
|
219 |
+
out = self.layer_norm(x)
|
220 |
+
|
221 |
+
# -> [batch_size, in_channels, t] for the CNN
|
222 |
+
out = out.transpose(1, 2)
|
223 |
+
|
224 |
+
# -> [batch_size, in_channels, t] (pointwise)
|
225 |
+
out = self.bottleneck(out)
|
226 |
+
|
227 |
+
# -> [batch_size, in_channels, lc+t+final_right_padding]
|
228 |
+
out = F.pad(out, (self.padding, final_right_padding), value=0)
|
229 |
+
|
230 |
+
# now, make chunks with left context.
|
231 |
+
# as a recap to what the above padding and this unfold do, consider
|
232 |
+
# each a/b/c letter represents a frame as part of chunks a, b, c.
|
233 |
+
# consider a chunk size of 4 and a kernel size of 5 (padding=2):
|
234 |
+
#
|
235 |
+
# input seq: 00aaaabbbbcc00
|
236 |
+
# chunk #1: 00aaaa
|
237 |
+
# chunk #2: aabbbb
|
238 |
+
# chunk #3: bbcc00
|
239 |
+
#
|
240 |
+
# a few remarks here:
|
241 |
+
# - the left padding gets inserted early so that the unfold logic
|
242 |
+
# works trivially
|
243 |
+
# - the right 0-padding got inserted as the number of time steps
|
244 |
+
# could not be evenly split in `chunk_size` chunks
|
245 |
+
|
246 |
+
# -> [batch_size, in_channels, num_chunks, lc+chunk_size]
|
247 |
+
out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size)
|
248 |
+
|
249 |
+
# as we manually disable padding in the convolution below, we insert
|
250 |
+
# right 0-padding to the chunks, e.g. reusing the above example:
|
251 |
+
#
|
252 |
+
# chunk #1: 00aaaa00
|
253 |
+
# chunk #2: aabbbb00
|
254 |
+
# chunk #3: bbcc0000
|
255 |
+
|
256 |
+
# -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad]
|
257 |
+
out = F.pad(out, (0, self.padding), value=0)
|
258 |
+
|
259 |
+
# the transpose+flatten effectively flattens chunks into the batch
|
260 |
+
# dimension to be processed into the time-wise convolution. the
|
261 |
+
# chunks will later on be unflattened.
|
262 |
+
|
263 |
+
# -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad]
|
264 |
+
out = out.transpose(1, 2)
|
265 |
+
|
266 |
+
# -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad]
|
267 |
+
out = out.flatten(start_dim=0, end_dim=1)
|
268 |
+
|
269 |
+
# TODO: experiment around reflect padding, which is difficult
|
270 |
+
# because small chunks have too little time steps to reflect from
|
271 |
+
|
272 |
+
# let's keep backwards compat by pointing at the weights from the
|
273 |
+
# already declared Conv1d.
|
274 |
+
#
|
275 |
+
# still reusing the above example, the convolution will be applied,
|
276 |
+
# with the padding truncated on both ends. the following example
|
277 |
+
# shows the letter corresponding to the input frame on which the
|
278 |
+
# convolution was centered.
|
279 |
+
#
|
280 |
+
# as you can see, the sum of lengths of all chunks is equal to our
|
281 |
+
# input sequence length + `final_right_padding`.
|
282 |
+
#
|
283 |
+
# chunk #1: aaaa
|
284 |
+
# chunk #2: bbbb
|
285 |
+
# chunk #3: cc00
|
286 |
+
|
287 |
+
# -> [batch_size * num_chunks, out_channels, chunk_size]
|
288 |
+
out = F.conv1d(
|
289 |
+
out,
|
290 |
+
weight=self.conv.weight,
|
291 |
+
bias=self.conv.bias,
|
292 |
+
stride=self.conv.stride,
|
293 |
+
padding=0,
|
294 |
+
dilation=self.conv.dilation,
|
295 |
+
groups=self.conv.groups,
|
296 |
+
)
|
297 |
+
|
298 |
+
# -> [batch_size * num_chunks, chunk_size, out_channels]
|
299 |
+
out = out.transpose(1, 2)
|
300 |
+
|
301 |
+
out = self.after_conv(out)
|
302 |
+
|
303 |
+
# -> [batch_size, num_chunks, chunk_size, out_channels]
|
304 |
+
out = torch.unflatten(out, dim=0, sizes=(batch_size, -1))
|
305 |
+
|
306 |
+
# -> [batch_size, t + final_right_padding, out_channels]
|
307 |
+
out = torch.flatten(out, start_dim=1, end_dim=2)
|
308 |
+
|
309 |
+
# -> [batch_size, t, out_channels]
|
310 |
+
if final_right_padding > 0:
|
311 |
+
out = out[:, :-final_right_padding, :]
|
312 |
+
else:
|
313 |
+
out = self.layer_norm(x)
|
314 |
+
out = out.transpose(1, 2)
|
315 |
+
out = self.bottleneck(out)
|
316 |
+
out = self.conv(out)
|
317 |
+
|
318 |
+
if self.causal:
|
319 |
+
# chomp
|
320 |
+
out = out[..., : -self.padding]
|
321 |
+
|
322 |
+
out = out.transpose(1, 2)
|
323 |
+
out = self.after_conv(out)
|
324 |
+
|
325 |
+
if mask is not None:
|
326 |
+
out.masked_fill_(mask, 0.0)
|
327 |
+
|
328 |
+
return out
|
329 |
+
|
330 |
+
|
331 |
+
class ConformerEncoderLayer(nn.Module):
|
332 |
+
"""This is an implementation of Conformer encoder layer.
|
333 |
+
|
334 |
+
Arguments
|
335 |
+
---------
|
336 |
+
d_model : int
|
337 |
+
The expected size of the input embedding.
|
338 |
+
d_ffn : int
|
339 |
+
Hidden size of self-attention Feed Forward layer.
|
340 |
+
nhead : int
|
341 |
+
Number of attention heads.
|
342 |
+
kernel_size : int, optional
|
343 |
+
Kernel size of convolution model.
|
344 |
+
kdim : int, optional
|
345 |
+
Dimension of the key.
|
346 |
+
vdim : int, optional
|
347 |
+
Dimension of the value.
|
348 |
+
activation: torch.nn.Module
|
349 |
+
Activation function used in each Conformer layer.
|
350 |
+
bias : bool, optional
|
351 |
+
Whether convolution module.
|
352 |
+
dropout : int, optional
|
353 |
+
Dropout for the encoder.
|
354 |
+
causal : bool, optional
|
355 |
+
Whether the convolutions should be causal or not.
|
356 |
+
attention_type : str, optional
|
357 |
+
type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
|
358 |
+
|
359 |
+
Example
|
360 |
+
-------
|
361 |
+
>>> import torch
|
362 |
+
>>> x = torch.rand((8, 60, 512))
|
363 |
+
>>> pos_embs = torch.rand((1, 2*60-1, 512))
|
364 |
+
>>> net = ConformerEncoderLayer(d_ffn=512, nhead=8, d_model=512, kernel_size=3)
|
365 |
+
>>> output = net(x, pos_embs=pos_embs)
|
366 |
+
>>> output[0].shape
|
367 |
+
torch.Size([8, 60, 512])
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
d_model,
|
373 |
+
d_ffn,
|
374 |
+
nhead,
|
375 |
+
kernel_size=31,
|
376 |
+
kdim=None,
|
377 |
+
vdim=None,
|
378 |
+
activation=Swish,
|
379 |
+
bias=True,
|
380 |
+
dropout=0.0,
|
381 |
+
causal=False,
|
382 |
+
attention_type="RelPosMHAXL",
|
383 |
+
):
|
384 |
+
super().__init__()
|
385 |
+
|
386 |
+
if attention_type == "regularMHA":
|
387 |
+
self.mha_layer = MultiheadAttention(
|
388 |
+
nhead=nhead,
|
389 |
+
d_model=d_model,
|
390 |
+
dropout=dropout,
|
391 |
+
kdim=kdim,
|
392 |
+
vdim=vdim,
|
393 |
+
)
|
394 |
+
elif attention_type == "RelPosMHAXL":
|
395 |
+
# transformerXL style positional encoding
|
396 |
+
self.mha_layer = RelPosMHAXL(
|
397 |
+
num_heads=nhead,
|
398 |
+
embed_dim=d_model,
|
399 |
+
dropout=dropout,
|
400 |
+
mask_pos_future=causal,
|
401 |
+
)
|
402 |
+
elif attention_type == "hypermixing":
|
403 |
+
self.mha_layer = HyperMixing(
|
404 |
+
input_output_dim=d_model,
|
405 |
+
hypernet_size=d_ffn,
|
406 |
+
tied=False,
|
407 |
+
num_heads=nhead,
|
408 |
+
fix_tm_hidden_size=False,
|
409 |
+
)
|
410 |
+
|
411 |
+
self.convolution_module = ConvolutionModule(
|
412 |
+
d_model, kernel_size, bias, activation, dropout, causal=causal
|
413 |
+
)
|
414 |
+
|
415 |
+
self.ffn_module1 = nn.Sequential(
|
416 |
+
nn.LayerNorm(d_model),
|
417 |
+
PositionalwiseFeedForward(
|
418 |
+
d_ffn=d_ffn,
|
419 |
+
input_size=d_model,
|
420 |
+
dropout=dropout,
|
421 |
+
activation=activation,
|
422 |
+
),
|
423 |
+
nn.Dropout(dropout),
|
424 |
+
)
|
425 |
+
|
426 |
+
self.ffn_module2 = nn.Sequential(
|
427 |
+
nn.LayerNorm(d_model),
|
428 |
+
PositionalwiseFeedForward(
|
429 |
+
d_ffn=d_ffn,
|
430 |
+
input_size=d_model,
|
431 |
+
dropout=dropout,
|
432 |
+
activation=activation,
|
433 |
+
),
|
434 |
+
nn.Dropout(dropout),
|
435 |
+
)
|
436 |
+
|
437 |
+
self.norm1 = LayerNorm(d_model)
|
438 |
+
self.norm2 = LayerNorm(d_model)
|
439 |
+
self.drop = nn.Dropout(dropout)
|
440 |
+
|
441 |
+
def forward(
|
442 |
+
self,
|
443 |
+
x,
|
444 |
+
src_mask: Optional[torch.Tensor] = None,
|
445 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
446 |
+
pos_embs: torch.Tensor = None,
|
447 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
448 |
+
):
|
449 |
+
"""
|
450 |
+
Arguments
|
451 |
+
----------
|
452 |
+
src : torch.Tensor
|
453 |
+
The sequence to the encoder layer.
|
454 |
+
src_mask : torch.Tensor, optional
|
455 |
+
The mask for the src sequence.
|
456 |
+
src_key_padding_mask : torch.Tensor, optional
|
457 |
+
The mask for the src keys per batch.
|
458 |
+
pos_embs: torch.Tensor, torch.nn.Module, optional
|
459 |
+
Module or tensor containing the input sequence positional embeddings
|
460 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig]
|
461 |
+
Dynamic Chunk Training configuration object for streaming,
|
462 |
+
specifically involved here to apply Dynamic Chunk Convolution to
|
463 |
+
the convolution module.
|
464 |
+
"""
|
465 |
+
conv_mask: Optional[torch.Tensor] = None
|
466 |
+
if src_key_padding_mask is not None:
|
467 |
+
conv_mask = src_key_padding_mask.unsqueeze(-1)
|
468 |
+
# ffn module
|
469 |
+
x = x + 0.5 * self.ffn_module1(x)
|
470 |
+
# multi-head attention module
|
471 |
+
skip = x
|
472 |
+
x = self.norm1(x)
|
473 |
+
|
474 |
+
x, self_attn = self.mha_layer(
|
475 |
+
x,
|
476 |
+
x,
|
477 |
+
x,
|
478 |
+
attn_mask=src_mask,
|
479 |
+
key_padding_mask=src_key_padding_mask,
|
480 |
+
pos_embs=pos_embs,
|
481 |
+
)
|
482 |
+
x = x + skip
|
483 |
+
# convolution module
|
484 |
+
x = x + self.convolution_module(
|
485 |
+
x, conv_mask, dynchunktrain_config=dynchunktrain_config
|
486 |
+
)
|
487 |
+
# ffn module
|
488 |
+
x = self.norm2(x + 0.5 * self.ffn_module2(x))
|
489 |
+
return x, self_attn
|
490 |
+
|
491 |
+
def forward_streaming(
|
492 |
+
self,
|
493 |
+
x,
|
494 |
+
context: ConformerEncoderLayerStreamingContext,
|
495 |
+
pos_embs: torch.Tensor = None,
|
496 |
+
):
|
497 |
+
"""Conformer layer streaming forward (typically for
|
498 |
+
DynamicChunkTraining-trained models), which is to be used at inference
|
499 |
+
time. Relies on a mutable context object as initialized by
|
500 |
+
`make_streaming_context` that should be used across chunks.
|
501 |
+
Invoked by `ConformerEncoder.forward_streaming`.
|
502 |
+
|
503 |
+
Arguments
|
504 |
+
---------
|
505 |
+
x : torch.Tensor
|
506 |
+
Input tensor for this layer. Batching is supported as long as you
|
507 |
+
keep the context consistent.
|
508 |
+
context : ConformerEncoderStreamingContext
|
509 |
+
Mutable streaming context; the same object should be passed across
|
510 |
+
calls.
|
511 |
+
pos_embs : torch.Tensor, optional
|
512 |
+
Positional embeddings, if used.
|
513 |
+
|
514 |
+
Returns
|
515 |
+
-------
|
516 |
+
x : torch.Tensor
|
517 |
+
Output tensor.
|
518 |
+
self_attn : list
|
519 |
+
List of self attention values.
|
520 |
+
"""
|
521 |
+
|
522 |
+
orig_len = x.shape[-2]
|
523 |
+
# ffn module
|
524 |
+
x = x + 0.5 * self.ffn_module1(x)
|
525 |
+
|
526 |
+
# TODO: make the approach for MHA left context more efficient.
|
527 |
+
# currently, this saves the inputs to the MHA.
|
528 |
+
# the naive approach is suboptimal in a few ways, namely that the
|
529 |
+
# outputs for this left padding is being re-computed even though we
|
530 |
+
# discard them immediately after.
|
531 |
+
|
532 |
+
# left pad `x` with our MHA left context
|
533 |
+
if context.mha_left_context is not None:
|
534 |
+
x = torch.cat((context.mha_left_context, x), dim=1)
|
535 |
+
|
536 |
+
# compute new MHA left context for the next call to our function
|
537 |
+
if context.mha_left_context_size > 0:
|
538 |
+
context.mha_left_context = x[
|
539 |
+
..., -context.mha_left_context_size :, :
|
540 |
+
]
|
541 |
+
|
542 |
+
# multi-head attention module
|
543 |
+
skip = x
|
544 |
+
x = self.norm1(x)
|
545 |
+
|
546 |
+
x, self_attn = self.mha_layer(
|
547 |
+
x,
|
548 |
+
x,
|
549 |
+
x,
|
550 |
+
attn_mask=None,
|
551 |
+
key_padding_mask=None,
|
552 |
+
pos_embs=pos_embs,
|
553 |
+
)
|
554 |
+
x = x + skip
|
555 |
+
|
556 |
+
# truncate outputs corresponding to the MHA left context (we only care
|
557 |
+
# about our chunk's outputs); see above to-do
|
558 |
+
x = x[..., -orig_len:, :]
|
559 |
+
|
560 |
+
if context.dcconv_left_context is not None:
|
561 |
+
x = torch.cat((context.dcconv_left_context, x), dim=1)
|
562 |
+
|
563 |
+
# compute new DCConv left context for the next call to our function
|
564 |
+
context.dcconv_left_context = x[
|
565 |
+
..., -self.convolution_module.padding :, :
|
566 |
+
]
|
567 |
+
|
568 |
+
# convolution module
|
569 |
+
x = x + self.convolution_module(x)
|
570 |
+
|
571 |
+
# truncate outputs corresponding to the DCConv left context
|
572 |
+
x = x[..., -orig_len:, :]
|
573 |
+
|
574 |
+
# ffn module
|
575 |
+
x = self.norm2(x + 0.5 * self.ffn_module2(x))
|
576 |
+
return x, self_attn
|
577 |
+
|
578 |
+
def make_streaming_context(self, mha_left_context_size: int):
|
579 |
+
"""Creates a blank streaming context for this encoding layer.
|
580 |
+
|
581 |
+
Arguments
|
582 |
+
---------
|
583 |
+
mha_left_context_size : int
|
584 |
+
How many left frames should be saved and used as left context to the
|
585 |
+
current chunk when streaming
|
586 |
+
|
587 |
+
Returns
|
588 |
+
-------
|
589 |
+
ConformerEncoderLayerStreamingContext
|
590 |
+
"""
|
591 |
+
return ConformerEncoderLayerStreamingContext(
|
592 |
+
mha_left_context_size=mha_left_context_size
|
593 |
+
)
|
594 |
+
|
595 |
+
|
596 |
+
class ConformerEncoder(nn.Module):
|
597 |
+
"""This class implements the Conformer encoder.
|
598 |
+
|
599 |
+
Arguments
|
600 |
+
---------
|
601 |
+
num_layers : int
|
602 |
+
Number of layers.
|
603 |
+
d_model : int
|
604 |
+
Embedding dimension size.
|
605 |
+
d_ffn : int
|
606 |
+
Hidden size of self-attention Feed Forward layer.
|
607 |
+
nhead : int
|
608 |
+
Number of attention heads.
|
609 |
+
kernel_size : int, optional
|
610 |
+
Kernel size of convolution model.
|
611 |
+
kdim : int, optional
|
612 |
+
Dimension of the key.
|
613 |
+
vdim : int, optional
|
614 |
+
Dimension of the value.
|
615 |
+
activation: torch.nn.Module
|
616 |
+
Activation function used in each Confomer layer.
|
617 |
+
bias : bool, optional
|
618 |
+
Whether convolution module.
|
619 |
+
dropout : int, optional
|
620 |
+
Dropout for the encoder.
|
621 |
+
causal: bool, optional
|
622 |
+
Whether the convolutions should be causal or not.
|
623 |
+
attention_type: str, optional
|
624 |
+
type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
|
625 |
+
|
626 |
+
|
627 |
+
Example
|
628 |
+
-------
|
629 |
+
>>> import torch
|
630 |
+
>>> x = torch.rand((8, 60, 512))
|
631 |
+
>>> pos_emb = torch.rand((1, 2*60-1, 512))
|
632 |
+
>>> net = ConformerEncoder(1, 512, 512, 8)
|
633 |
+
>>> output, _ = net(x, pos_embs=pos_emb)
|
634 |
+
>>> output.shape
|
635 |
+
torch.Size([8, 60, 512])
|
636 |
+
"""
|
637 |
+
|
638 |
+
def __init__(
|
639 |
+
self,
|
640 |
+
num_layers,
|
641 |
+
d_model,
|
642 |
+
d_ffn,
|
643 |
+
nhead,
|
644 |
+
kernel_size=31,
|
645 |
+
kdim=None,
|
646 |
+
vdim=None,
|
647 |
+
activation=Swish,
|
648 |
+
bias=True,
|
649 |
+
dropout=0.0,
|
650 |
+
causal=False,
|
651 |
+
attention_type="RelPosMHAXL",
|
652 |
+
):
|
653 |
+
super().__init__()
|
654 |
+
|
655 |
+
self.layers = torch.nn.ModuleList(
|
656 |
+
[
|
657 |
+
ConformerEncoderLayer(
|
658 |
+
d_ffn=d_ffn,
|
659 |
+
nhead=nhead,
|
660 |
+
d_model=d_model,
|
661 |
+
kdim=kdim,
|
662 |
+
vdim=vdim,
|
663 |
+
dropout=dropout,
|
664 |
+
activation=activation,
|
665 |
+
kernel_size=kernel_size,
|
666 |
+
bias=bias,
|
667 |
+
causal=causal,
|
668 |
+
attention_type=attention_type,
|
669 |
+
)
|
670 |
+
for i in range(num_layers)
|
671 |
+
]
|
672 |
+
)
|
673 |
+
self.norm = LayerNorm(d_model, eps=1e-6)
|
674 |
+
self.attention_type = attention_type
|
675 |
+
|
676 |
+
def forward(
|
677 |
+
self,
|
678 |
+
src,
|
679 |
+
src_mask: Optional[torch.Tensor] = None,
|
680 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
681 |
+
pos_embs: Optional[torch.Tensor] = None,
|
682 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
683 |
+
):
|
684 |
+
"""
|
685 |
+
Arguments
|
686 |
+
----------
|
687 |
+
src : torch.Tensor
|
688 |
+
The sequence to the encoder layer.
|
689 |
+
src_mask : torch.Tensor, optional
|
690 |
+
The mask for the src sequence.
|
691 |
+
src_key_padding_mask : torch.Tensor, optional
|
692 |
+
The mask for the src keys per batch.
|
693 |
+
pos_embs: torch.Tensor, torch.nn.Module,
|
694 |
+
Module or tensor containing the input sequence positional embeddings
|
695 |
+
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
|
696 |
+
where S is the sequence length, and E is the embedding dimension.
|
697 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig]
|
698 |
+
Dynamic Chunk Training configuration object for streaming,
|
699 |
+
specifically involved here to apply Dynamic Chunk Convolution to the
|
700 |
+
convolution module.
|
701 |
+
"""
|
702 |
+
if self.attention_type == "RelPosMHAXL":
|
703 |
+
if pos_embs is None:
|
704 |
+
raise ValueError(
|
705 |
+
"The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
|
706 |
+
)
|
707 |
+
|
708 |
+
output = src
|
709 |
+
attention_lst = []
|
710 |
+
for enc_layer in self.layers:
|
711 |
+
output, attention = enc_layer(
|
712 |
+
output,
|
713 |
+
src_mask=src_mask,
|
714 |
+
src_key_padding_mask=src_key_padding_mask,
|
715 |
+
pos_embs=pos_embs,
|
716 |
+
dynchunktrain_config=dynchunktrain_config,
|
717 |
+
)
|
718 |
+
attention_lst.append(attention)
|
719 |
+
output = self.norm(output)
|
720 |
+
|
721 |
+
return output, attention_lst
|
722 |
+
|
723 |
+
def forward_streaming(
|
724 |
+
self,
|
725 |
+
src: torch.Tensor,
|
726 |
+
context: ConformerEncoderStreamingContext,
|
727 |
+
pos_embs: Optional[torch.Tensor] = None,
|
728 |
+
):
|
729 |
+
"""Conformer streaming forward (typically for
|
730 |
+
DynamicChunkTraining-trained models), which is to be used at inference
|
731 |
+
time. Relies on a mutable context object as initialized by
|
732 |
+
`make_streaming_context` that should be used across chunks.
|
733 |
+
|
734 |
+
Arguments
|
735 |
+
---------
|
736 |
+
src : torch.Tensor
|
737 |
+
Input tensor. Batching is supported as long as you keep the context
|
738 |
+
consistent.
|
739 |
+
context : ConformerEncoderStreamingContext
|
740 |
+
Mutable streaming context; the same object should be passed across
|
741 |
+
calls.
|
742 |
+
pos_embs : torch.Tensor, optional
|
743 |
+
Positional embeddings, if used.
|
744 |
+
|
745 |
+
Returns
|
746 |
+
-------
|
747 |
+
output : torch.Tensor
|
748 |
+
The output of the streaming conformer.
|
749 |
+
attention_lst : list
|
750 |
+
The attention values.
|
751 |
+
"""
|
752 |
+
|
753 |
+
if self.attention_type == "RelPosMHAXL":
|
754 |
+
if pos_embs is None:
|
755 |
+
raise ValueError(
|
756 |
+
"The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
|
757 |
+
)
|
758 |
+
|
759 |
+
output = src
|
760 |
+
attention_lst = []
|
761 |
+
for i, enc_layer in enumerate(self.layers):
|
762 |
+
output, attention = enc_layer.forward_streaming(
|
763 |
+
output, pos_embs=pos_embs, context=context.layers[i]
|
764 |
+
)
|
765 |
+
attention_lst.append(attention)
|
766 |
+
output = self.norm(output)
|
767 |
+
|
768 |
+
return output, attention_lst
|
769 |
+
|
770 |
+
def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig):
|
771 |
+
"""Creates a blank streaming context for the encoder.
|
772 |
+
|
773 |
+
Arguments
|
774 |
+
---------
|
775 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig]
|
776 |
+
Dynamic Chunk Training configuration object for streaming
|
777 |
+
|
778 |
+
Returns
|
779 |
+
-------
|
780 |
+
ConformerEncoderStreamingContext
|
781 |
+
"""
|
782 |
+
return ConformerEncoderStreamingContext(
|
783 |
+
dynchunktrain_config=dynchunktrain_config,
|
784 |
+
layers=[
|
785 |
+
layer.make_streaming_context(
|
786 |
+
mha_left_context_size=dynchunktrain_config.left_context_size_frames()
|
787 |
+
)
|
788 |
+
for layer in self.layers
|
789 |
+
],
|
790 |
+
)
|
791 |
+
|
792 |
+
|
793 |
+
class ConformerDecoderLayer(nn.Module):
|
794 |
+
"""This is an implementation of Conformer encoder layer.
|
795 |
+
|
796 |
+
Arguments
|
797 |
+
---------
|
798 |
+
d_model : int
|
799 |
+
The expected size of the input embedding.
|
800 |
+
d_ffn : int
|
801 |
+
Hidden size of self-attention Feed Forward layer.
|
802 |
+
nhead : int
|
803 |
+
Number of attention heads.
|
804 |
+
kernel_size : int, optional
|
805 |
+
Kernel size of convolution model.
|
806 |
+
kdim : int, optional
|
807 |
+
Dimension of the key.
|
808 |
+
vdim : int, optional
|
809 |
+
Dimension of the value.
|
810 |
+
activation : torch.nn.Module, optional
|
811 |
+
Activation function used in each Conformer layer.
|
812 |
+
bias : bool, optional
|
813 |
+
Whether convolution module.
|
814 |
+
dropout : int, optional
|
815 |
+
Dropout for the encoder.
|
816 |
+
causal : bool, optional
|
817 |
+
Whether the convolutions should be causal or not.
|
818 |
+
attention_type : str, optional
|
819 |
+
type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
|
820 |
+
|
821 |
+
Example
|
822 |
+
-------
|
823 |
+
>>> import torch
|
824 |
+
>>> x = torch.rand((8, 60, 512))
|
825 |
+
>>> pos_embs = torch.rand((1, 2*60-1, 512))
|
826 |
+
>>> net = ConformerEncoderLayer(d_ffn=512, nhead=8, d_model=512, kernel_size=3)
|
827 |
+
>>> output = net(x, pos_embs=pos_embs)
|
828 |
+
>>> output[0].shape
|
829 |
+
torch.Size([8, 60, 512])
|
830 |
+
"""
|
831 |
+
|
832 |
+
def __init__(
|
833 |
+
self,
|
834 |
+
d_model,
|
835 |
+
d_ffn,
|
836 |
+
nhead,
|
837 |
+
kernel_size,
|
838 |
+
kdim=None,
|
839 |
+
vdim=None,
|
840 |
+
activation=Swish,
|
841 |
+
bias=True,
|
842 |
+
dropout=0.0,
|
843 |
+
causal=True,
|
844 |
+
attention_type="RelPosMHAXL",
|
845 |
+
):
|
846 |
+
super().__init__()
|
847 |
+
|
848 |
+
if not causal:
|
849 |
+
warnings.warn(
|
850 |
+
"Decoder is not causal, in most applications it should be causal, you have been warned !"
|
851 |
+
)
|
852 |
+
|
853 |
+
if attention_type == "regularMHA":
|
854 |
+
self.mha_layer = MultiheadAttention(
|
855 |
+
nhead=nhead,
|
856 |
+
d_model=d_model,
|
857 |
+
dropout=dropout,
|
858 |
+
kdim=kdim,
|
859 |
+
vdim=vdim,
|
860 |
+
)
|
861 |
+
elif attention_type == "RelPosMHAXL":
|
862 |
+
# transformerXL style positional encoding
|
863 |
+
self.mha_layer = RelPosMHAXL(
|
864 |
+
num_heads=nhead,
|
865 |
+
embed_dim=d_model,
|
866 |
+
dropout=dropout,
|
867 |
+
mask_pos_future=causal,
|
868 |
+
)
|
869 |
+
|
870 |
+
self.convolution_module = ConvolutionModule(
|
871 |
+
d_model, kernel_size, bias, activation, dropout, causal=causal
|
872 |
+
)
|
873 |
+
|
874 |
+
self.ffn_module1 = nn.Sequential(
|
875 |
+
nn.LayerNorm(d_model),
|
876 |
+
PositionalwiseFeedForward(
|
877 |
+
d_ffn=d_ffn,
|
878 |
+
input_size=d_model,
|
879 |
+
dropout=dropout,
|
880 |
+
activation=activation,
|
881 |
+
),
|
882 |
+
nn.Dropout(dropout),
|
883 |
+
)
|
884 |
+
|
885 |
+
self.ffn_module2 = nn.Sequential(
|
886 |
+
nn.LayerNorm(d_model),
|
887 |
+
PositionalwiseFeedForward(
|
888 |
+
d_ffn=d_ffn,
|
889 |
+
input_size=d_model,
|
890 |
+
dropout=dropout,
|
891 |
+
activation=activation,
|
892 |
+
),
|
893 |
+
nn.Dropout(dropout),
|
894 |
+
)
|
895 |
+
|
896 |
+
self.norm1 = LayerNorm(d_model)
|
897 |
+
self.norm2 = LayerNorm(d_model)
|
898 |
+
self.drop = nn.Dropout(dropout)
|
899 |
+
|
900 |
+
def forward(
|
901 |
+
self,
|
902 |
+
tgt,
|
903 |
+
memory,
|
904 |
+
tgt_mask=None,
|
905 |
+
memory_mask=None,
|
906 |
+
tgt_key_padding_mask=None,
|
907 |
+
memory_key_padding_mask=None,
|
908 |
+
pos_embs_tgt=None,
|
909 |
+
pos_embs_src=None,
|
910 |
+
):
|
911 |
+
"""
|
912 |
+
Arguments
|
913 |
+
---------
|
914 |
+
tgt: torch.Tensor
|
915 |
+
The sequence to the decoder layer.
|
916 |
+
memory: torch.Tensor
|
917 |
+
The sequence from the last layer of the encoder.
|
918 |
+
tgt_mask: torch.Tensor, optional, optional
|
919 |
+
The mask for the tgt sequence.
|
920 |
+
memory_mask: torch.Tensor, optional
|
921 |
+
The mask for the memory sequence.
|
922 |
+
tgt_key_padding_mask: torch.Tensor, optional
|
923 |
+
The mask for the tgt keys per batch.
|
924 |
+
memory_key_padding_mask: torch.Tensor, optional
|
925 |
+
The mask for the memory keys per batch.
|
926 |
+
pos_embs_tgt: torch.Tensor, torch.nn.Module, optional
|
927 |
+
Module or tensor containing the target sequence positional embeddings for each attention layer.
|
928 |
+
pos_embs_src: torch.Tensor, torch.nn.Module, optional
|
929 |
+
Module or tensor containing the source sequence positional embeddings for each attention layer.
|
930 |
+
|
931 |
+
Returns
|
932 |
+
-------
|
933 |
+
x: torch.Tensor
|
934 |
+
The output tensor
|
935 |
+
self_attn : torch.Tensor
|
936 |
+
self_attn : torch.Tensor
|
937 |
+
The self attention tensor
|
938 |
+
"""
|
939 |
+
# ffn module
|
940 |
+
tgt = tgt + 0.5 * self.ffn_module1(tgt)
|
941 |
+
# multi-head attention module
|
942 |
+
skip = tgt
|
943 |
+
x = self.norm1(tgt)
|
944 |
+
x, self_attn = self.mha_layer(
|
945 |
+
x,
|
946 |
+
memory,
|
947 |
+
memory,
|
948 |
+
attn_mask=memory_mask,
|
949 |
+
key_padding_mask=memory_key_padding_mask,
|
950 |
+
pos_embs=pos_embs_src,
|
951 |
+
)
|
952 |
+
x = x + skip
|
953 |
+
# convolution module
|
954 |
+
x = x + self.convolution_module(x)
|
955 |
+
# ffn module
|
956 |
+
x = self.norm2(x + 0.5 * self.ffn_module2(x))
|
957 |
+
return x, self_attn, self_attn
|
958 |
+
|
959 |
+
|
960 |
+
class ConformerDecoder(nn.Module):
|
961 |
+
"""This class implements the Transformer decoder.
|
962 |
+
|
963 |
+
Arguments
|
964 |
+
---------
|
965 |
+
num_layers: int
|
966 |
+
Number of layers.
|
967 |
+
nhead: int
|
968 |
+
Number of attention heads.
|
969 |
+
d_ffn: int
|
970 |
+
Hidden size of self-attention Feed Forward layer.
|
971 |
+
d_model: int
|
972 |
+
Embedding dimension size.
|
973 |
+
kdim: int, optional
|
974 |
+
Dimension for key.
|
975 |
+
vdim: int, optional
|
976 |
+
Dimension for value.
|
977 |
+
dropout: float, optional
|
978 |
+
Dropout rate.
|
979 |
+
activation: torch.nn.Module, optional
|
980 |
+
Activation function used after non-bottleneck conv layer.
|
981 |
+
kernel_size : int, optional
|
982 |
+
Kernel size of convolutional layer.
|
983 |
+
bias : bool, optional
|
984 |
+
Whether convolution module.
|
985 |
+
causal: bool, optional
|
986 |
+
Whether the convolutions should be causal or not.
|
987 |
+
attention_type: str, optional
|
988 |
+
type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
|
989 |
+
|
990 |
+
|
991 |
+
Example
|
992 |
+
-------
|
993 |
+
>>> src = torch.rand((8, 60, 512))
|
994 |
+
>>> tgt = torch.rand((8, 60, 512))
|
995 |
+
>>> net = ConformerDecoder(1, 8, 1024, 512, attention_type="regularMHA")
|
996 |
+
>>> output, _, _ = net(tgt, src)
|
997 |
+
>>> output.shape
|
998 |
+
torch.Size([8, 60, 512])
|
999 |
+
"""
|
1000 |
+
|
1001 |
+
def __init__(
|
1002 |
+
self,
|
1003 |
+
num_layers,
|
1004 |
+
nhead,
|
1005 |
+
d_ffn,
|
1006 |
+
d_model,
|
1007 |
+
kdim=None,
|
1008 |
+
vdim=None,
|
1009 |
+
dropout=0.0,
|
1010 |
+
activation=Swish,
|
1011 |
+
kernel_size=3,
|
1012 |
+
bias=True,
|
1013 |
+
causal=True,
|
1014 |
+
attention_type="RelPosMHAXL",
|
1015 |
+
):
|
1016 |
+
super().__init__()
|
1017 |
+
self.layers = torch.nn.ModuleList(
|
1018 |
+
[
|
1019 |
+
ConformerDecoderLayer(
|
1020 |
+
d_ffn=d_ffn,
|
1021 |
+
nhead=nhead,
|
1022 |
+
d_model=d_model,
|
1023 |
+
kdim=kdim,
|
1024 |
+
vdim=vdim,
|
1025 |
+
dropout=dropout,
|
1026 |
+
activation=activation,
|
1027 |
+
kernel_size=kernel_size,
|
1028 |
+
bias=bias,
|
1029 |
+
causal=causal,
|
1030 |
+
attention_type=attention_type,
|
1031 |
+
)
|
1032 |
+
for _ in range(num_layers)
|
1033 |
+
]
|
1034 |
+
)
|
1035 |
+
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
1036 |
+
|
1037 |
+
def forward(
|
1038 |
+
self,
|
1039 |
+
tgt,
|
1040 |
+
memory,
|
1041 |
+
tgt_mask=None,
|
1042 |
+
memory_mask=None,
|
1043 |
+
tgt_key_padding_mask=None,
|
1044 |
+
memory_key_padding_mask=None,
|
1045 |
+
pos_embs_tgt=None,
|
1046 |
+
pos_embs_src=None,
|
1047 |
+
):
|
1048 |
+
"""
|
1049 |
+
Arguments
|
1050 |
+
---------
|
1051 |
+
tgt: torch.Tensor
|
1052 |
+
The sequence to the decoder layer.
|
1053 |
+
memory: torch.Tensor
|
1054 |
+
The sequence from the last layer of the encoder.
|
1055 |
+
tgt_mask: torch.Tensor, optional, optional
|
1056 |
+
The mask for the tgt sequence.
|
1057 |
+
memory_mask: torch.Tensor, optional
|
1058 |
+
The mask for the memory sequence.
|
1059 |
+
tgt_key_padding_mask : torch.Tensor, optional
|
1060 |
+
The mask for the tgt keys per batch.
|
1061 |
+
memory_key_padding_mask : torch.Tensor, optional
|
1062 |
+
The mask for the memory keys per batch.
|
1063 |
+
pos_embs_tgt: torch.Tensor, torch.nn.Module, optional
|
1064 |
+
Module or tensor containing the target sequence positional embeddings for each attention layer.
|
1065 |
+
pos_embs_src: torch.Tensor, torch.nn.Module, optional
|
1066 |
+
Module or tensor containing the source sequence positional embeddings for each attention layer.
|
1067 |
+
|
1068 |
+
Returns
|
1069 |
+
-------
|
1070 |
+
output: torch.Tensor
|
1071 |
+
Conformer decoder output.
|
1072 |
+
self_attns : list
|
1073 |
+
Location of self attentions.
|
1074 |
+
multihead_attns : list
|
1075 |
+
Location of multihead attentions.
|
1076 |
+
"""
|
1077 |
+
output = tgt
|
1078 |
+
self_attns, multihead_attns = [], []
|
1079 |
+
for dec_layer in self.layers:
|
1080 |
+
output, self_attn, multihead_attn = dec_layer(
|
1081 |
+
output,
|
1082 |
+
memory,
|
1083 |
+
tgt_mask=tgt_mask,
|
1084 |
+
memory_mask=memory_mask,
|
1085 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
1086 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
1087 |
+
pos_embs_tgt=pos_embs_tgt,
|
1088 |
+
pos_embs_src=pos_embs_src,
|
1089 |
+
)
|
1090 |
+
self_attns.append(self_attn)
|
1091 |
+
multihead_attns.append(multihead_attn)
|
1092 |
+
output = self.norm(output)
|
1093 |
+
|
1094 |
+
return output, self_attns, multihead_attns
|
model/modules/Conmamba.py
ADDED
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""ConMamba encoder and Mamba decoder implementation.
|
2 |
+
|
3 |
+
Authors
|
4 |
+
-------
|
5 |
+
* Xilin Jiang 2024
|
6 |
+
"""
|
7 |
+
|
8 |
+
import warnings
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
import speechbrain as sb
|
17 |
+
from speechbrain.nnet.activations import Swish
|
18 |
+
from speechbrain.nnet.attention import (
|
19 |
+
MultiheadAttention,
|
20 |
+
PositionalwiseFeedForward,
|
21 |
+
RelPosMHAXL,
|
22 |
+
)
|
23 |
+
from speechbrain.nnet.hypermixing import HyperMixing
|
24 |
+
from speechbrain.nnet.normalization import LayerNorm
|
25 |
+
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
|
26 |
+
|
27 |
+
# Mamba
|
28 |
+
from mamba_ssm import Mamba
|
29 |
+
from .mamba.bimamba import Mamba as BiMamba
|
30 |
+
|
31 |
+
|
32 |
+
class ConvolutionModule(nn.Module):
|
33 |
+
"""This is an implementation of convolution module in Conmamba.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
input_size,
|
39 |
+
kernel_size=31,
|
40 |
+
bias=True,
|
41 |
+
activation=Swish,
|
42 |
+
dropout=0.0,
|
43 |
+
causal=False,
|
44 |
+
dilation=1,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.kernel_size = kernel_size
|
49 |
+
self.causal = causal
|
50 |
+
self.dilation = dilation
|
51 |
+
|
52 |
+
if self.causal:
|
53 |
+
self.padding = (kernel_size - 1) * 2 ** (dilation - 1)
|
54 |
+
else:
|
55 |
+
self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2
|
56 |
+
|
57 |
+
self.layer_norm = nn.LayerNorm(input_size)
|
58 |
+
self.bottleneck = nn.Sequential(
|
59 |
+
# pointwise
|
60 |
+
nn.Conv1d(
|
61 |
+
input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias
|
62 |
+
),
|
63 |
+
nn.GLU(dim=1),
|
64 |
+
)
|
65 |
+
# depthwise
|
66 |
+
self.conv = nn.Conv1d(
|
67 |
+
input_size,
|
68 |
+
input_size,
|
69 |
+
kernel_size=kernel_size,
|
70 |
+
stride=1,
|
71 |
+
padding=self.padding,
|
72 |
+
dilation=dilation,
|
73 |
+
groups=input_size,
|
74 |
+
bias=bias,
|
75 |
+
)
|
76 |
+
|
77 |
+
# BatchNorm in the original Conformer replaced with a LayerNorm due to
|
78 |
+
# https://github.com/speechbrain/speechbrain/pull/1329
|
79 |
+
# see discussion
|
80 |
+
# https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884
|
81 |
+
|
82 |
+
self.after_conv = nn.Sequential(
|
83 |
+
nn.LayerNorm(input_size),
|
84 |
+
activation(),
|
85 |
+
# pointwise
|
86 |
+
nn.Linear(input_size, input_size, bias=bias),
|
87 |
+
nn.Dropout(dropout),
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self,
|
92 |
+
x: torch.Tensor,
|
93 |
+
mask: Optional[torch.Tensor] = None,
|
94 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
95 |
+
):
|
96 |
+
"""Applies the convolution to an input tensor `x`.
|
97 |
+
"""
|
98 |
+
|
99 |
+
if dynchunktrain_config is not None:
|
100 |
+
# chances are chunking+causal is unintended; i don't know where it
|
101 |
+
# may make sense, but if it does to you, feel free to implement it.
|
102 |
+
assert (
|
103 |
+
not self.causal
|
104 |
+
), "Chunked convolution not supported with causal padding"
|
105 |
+
|
106 |
+
assert (
|
107 |
+
self.dilation == 1
|
108 |
+
), "Current DynChunkTrain logic does not support dilation != 1"
|
109 |
+
|
110 |
+
# in a causal convolution, which is not the case here, an output
|
111 |
+
# frame would never be able to depend on a input frame from any
|
112 |
+
# point in the future.
|
113 |
+
|
114 |
+
# but with the dynamic chunk convolution, we instead use a "normal"
|
115 |
+
# convolution but where, for any output frame, the future beyond the
|
116 |
+
# "current" chunk gets masked.
|
117 |
+
# see the paper linked in the documentation for details.
|
118 |
+
|
119 |
+
chunk_size = dynchunktrain_config.chunk_size
|
120 |
+
batch_size = x.shape[0]
|
121 |
+
|
122 |
+
# determine the amount of padding we need to insert at the right of
|
123 |
+
# the last chunk so that all chunks end up with the same size.
|
124 |
+
if x.shape[1] % chunk_size != 0:
|
125 |
+
final_right_padding = chunk_size - (x.shape[1] % chunk_size)
|
126 |
+
else:
|
127 |
+
final_right_padding = 0
|
128 |
+
|
129 |
+
# -> [batch_size, t, in_channels]
|
130 |
+
out = self.layer_norm(x)
|
131 |
+
|
132 |
+
# -> [batch_size, in_channels, t] for the CNN
|
133 |
+
out = out.transpose(1, 2)
|
134 |
+
|
135 |
+
# -> [batch_size, in_channels, t] (pointwise)
|
136 |
+
out = self.bottleneck(out)
|
137 |
+
|
138 |
+
# -> [batch_size, in_channels, lc+t+final_right_padding]
|
139 |
+
out = F.pad(out, (self.padding, final_right_padding), value=0)
|
140 |
+
|
141 |
+
# now, make chunks with left context.
|
142 |
+
# as a recap to what the above padding and this unfold do, consider
|
143 |
+
# each a/b/c letter represents a frame as part of chunks a, b, c.
|
144 |
+
# consider a chunk size of 4 and a kernel size of 5 (padding=2):
|
145 |
+
#
|
146 |
+
# input seq: 00aaaabbbbcc00
|
147 |
+
# chunk #1: 00aaaa
|
148 |
+
# chunk #2: aabbbb
|
149 |
+
# chunk #3: bbcc00
|
150 |
+
#
|
151 |
+
# a few remarks here:
|
152 |
+
# - the left padding gets inserted early so that the unfold logic
|
153 |
+
# works trivially
|
154 |
+
# - the right 0-padding got inserted as the number of time steps
|
155 |
+
# could not be evenly split in `chunk_size` chunks
|
156 |
+
|
157 |
+
# -> [batch_size, in_channels, num_chunks, lc+chunk_size]
|
158 |
+
out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size)
|
159 |
+
|
160 |
+
# as we manually disable padding in the convolution below, we insert
|
161 |
+
# right 0-padding to the chunks, e.g. reusing the above example:
|
162 |
+
#
|
163 |
+
# chunk #1: 00aaaa00
|
164 |
+
# chunk #2: aabbbb00
|
165 |
+
# chunk #3: bbcc0000
|
166 |
+
|
167 |
+
# -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad]
|
168 |
+
out = F.pad(out, (0, self.padding), value=0)
|
169 |
+
|
170 |
+
# the transpose+flatten effectively flattens chunks into the batch
|
171 |
+
# dimension to be processed into the time-wise convolution. the
|
172 |
+
# chunks will later on be unflattened.
|
173 |
+
|
174 |
+
# -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad]
|
175 |
+
out = out.transpose(1, 2)
|
176 |
+
|
177 |
+
# -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad]
|
178 |
+
out = out.flatten(start_dim=0, end_dim=1)
|
179 |
+
|
180 |
+
# TODO: experiment around reflect padding, which is difficult
|
181 |
+
# because small chunks have too little time steps to reflect from
|
182 |
+
|
183 |
+
# let's keep backwards compat by pointing at the weights from the
|
184 |
+
# already declared Conv1d.
|
185 |
+
#
|
186 |
+
# still reusing the above example, the convolution will be applied,
|
187 |
+
# with the padding truncated on both ends. the following example
|
188 |
+
# shows the letter corresponding to the input frame on which the
|
189 |
+
# convolution was centered.
|
190 |
+
#
|
191 |
+
# as you can see, the sum of lengths of all chunks is equal to our
|
192 |
+
# input sequence length + `final_right_padding`.
|
193 |
+
#
|
194 |
+
# chunk #1: aaaa
|
195 |
+
# chunk #2: bbbb
|
196 |
+
# chunk #3: cc00
|
197 |
+
|
198 |
+
# -> [batch_size * num_chunks, out_channels, chunk_size]
|
199 |
+
out = F.conv1d(
|
200 |
+
out,
|
201 |
+
weight=self.conv.weight,
|
202 |
+
bias=self.conv.bias,
|
203 |
+
stride=self.conv.stride,
|
204 |
+
padding=0,
|
205 |
+
dilation=self.conv.dilation,
|
206 |
+
groups=self.conv.groups,
|
207 |
+
)
|
208 |
+
|
209 |
+
# -> [batch_size * num_chunks, chunk_size, out_channels]
|
210 |
+
out = out.transpose(1, 2)
|
211 |
+
|
212 |
+
out = self.after_conv(out)
|
213 |
+
|
214 |
+
# -> [batch_size, num_chunks, chunk_size, out_channels]
|
215 |
+
out = torch.unflatten(out, dim=0, sizes=(batch_size, -1))
|
216 |
+
|
217 |
+
# -> [batch_size, t + final_right_padding, out_channels]
|
218 |
+
out = torch.flatten(out, start_dim=1, end_dim=2)
|
219 |
+
|
220 |
+
# -> [batch_size, t, out_channels]
|
221 |
+
if final_right_padding > 0:
|
222 |
+
out = out[:, :-final_right_padding, :]
|
223 |
+
else:
|
224 |
+
out = self.layer_norm(x)
|
225 |
+
out = out.transpose(1, 2)
|
226 |
+
out = self.bottleneck(out)
|
227 |
+
out = self.conv(out)
|
228 |
+
|
229 |
+
if self.causal:
|
230 |
+
# chomp
|
231 |
+
out = out[..., : -self.padding]
|
232 |
+
|
233 |
+
out = out.transpose(1, 2)
|
234 |
+
out = self.after_conv(out)
|
235 |
+
|
236 |
+
if mask is not None:
|
237 |
+
out.masked_fill_(mask, 0.0)
|
238 |
+
|
239 |
+
return out
|
240 |
+
|
241 |
+
|
242 |
+
class ConmambaEncoderLayer(nn.Module):
|
243 |
+
"""This is an implementation of Conmamba encoder layer.
|
244 |
+
"""
|
245 |
+
|
246 |
+
def __init__(
|
247 |
+
self,
|
248 |
+
d_model,
|
249 |
+
d_ffn,
|
250 |
+
kernel_size=31,
|
251 |
+
activation=Swish,
|
252 |
+
bias=True,
|
253 |
+
dropout=0.0,
|
254 |
+
causal=False,
|
255 |
+
mamba_config=None
|
256 |
+
):
|
257 |
+
super().__init__()
|
258 |
+
assert mamba_config != None
|
259 |
+
|
260 |
+
bidirectional = mamba_config.pop('bidirectional')
|
261 |
+
if causal or (not bidirectional):
|
262 |
+
self.mamba = Mamba(
|
263 |
+
d_model=d_model,
|
264 |
+
**mamba_config
|
265 |
+
)
|
266 |
+
else:
|
267 |
+
self.mamba = BiMamba(
|
268 |
+
d_model=d_model,
|
269 |
+
bimamba_type='v2',
|
270 |
+
**mamba_config
|
271 |
+
)
|
272 |
+
mamba_config['bidirectional'] = bidirectional
|
273 |
+
|
274 |
+
self.convolution_module = ConvolutionModule(
|
275 |
+
d_model, kernel_size, bias, activation, dropout, causal=causal
|
276 |
+
)
|
277 |
+
|
278 |
+
self.ffn_module1 = nn.Sequential(
|
279 |
+
nn.LayerNorm(d_model),
|
280 |
+
PositionalwiseFeedForward(
|
281 |
+
d_ffn=d_ffn,
|
282 |
+
input_size=d_model,
|
283 |
+
dropout=dropout,
|
284 |
+
activation=activation,
|
285 |
+
),
|
286 |
+
nn.Dropout(dropout),
|
287 |
+
)
|
288 |
+
|
289 |
+
self.ffn_module2 = nn.Sequential(
|
290 |
+
nn.LayerNorm(d_model),
|
291 |
+
PositionalwiseFeedForward(
|
292 |
+
d_ffn=d_ffn,
|
293 |
+
input_size=d_model,
|
294 |
+
dropout=dropout,
|
295 |
+
activation=activation,
|
296 |
+
),
|
297 |
+
nn.Dropout(dropout),
|
298 |
+
)
|
299 |
+
|
300 |
+
self.norm1 = LayerNorm(d_model)
|
301 |
+
self.norm2 = LayerNorm(d_model)
|
302 |
+
self.drop = nn.Dropout(dropout)
|
303 |
+
|
304 |
+
def forward(
|
305 |
+
self,
|
306 |
+
x,
|
307 |
+
src_mask: Optional[torch.Tensor] = None,
|
308 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
309 |
+
pos_embs: torch.Tensor = None,
|
310 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
311 |
+
):
|
312 |
+
conv_mask: Optional[torch.Tensor] = None
|
313 |
+
if src_key_padding_mask is not None:
|
314 |
+
conv_mask = src_key_padding_mask.unsqueeze(-1)
|
315 |
+
|
316 |
+
conv_mask = None
|
317 |
+
|
318 |
+
# ffn module
|
319 |
+
x = x + 0.5 * self.ffn_module1(x)
|
320 |
+
# mamba module
|
321 |
+
skip = x
|
322 |
+
x = self.norm1(x)
|
323 |
+
x = self.mamba(x)
|
324 |
+
x = x + skip
|
325 |
+
# convolution module
|
326 |
+
x = x + self.convolution_module(
|
327 |
+
x, conv_mask, dynchunktrain_config=dynchunktrain_config
|
328 |
+
)
|
329 |
+
# ffn module
|
330 |
+
x = self.norm2(x + 0.5 * self.ffn_module2(x))
|
331 |
+
return x
|
332 |
+
|
333 |
+
|
334 |
+
class ConmambaEncoder(nn.Module):
|
335 |
+
"""This class implements the Conmamba encoder.
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
num_layers,
|
341 |
+
d_model,
|
342 |
+
d_ffn,
|
343 |
+
kernel_size=31,
|
344 |
+
activation=Swish,
|
345 |
+
bias=True,
|
346 |
+
dropout=0.0,
|
347 |
+
causal=False,
|
348 |
+
mamba_config=None
|
349 |
+
):
|
350 |
+
super().__init__()
|
351 |
+
print(f'dropout={str(dropout)} is not used in Mamba.')
|
352 |
+
|
353 |
+
self.layers = torch.nn.ModuleList(
|
354 |
+
[
|
355 |
+
ConmambaEncoderLayer(
|
356 |
+
d_model=d_model,
|
357 |
+
d_ffn=d_ffn,
|
358 |
+
dropout=dropout,
|
359 |
+
activation=activation,
|
360 |
+
kernel_size=kernel_size,
|
361 |
+
bias=bias,
|
362 |
+
causal=causal,
|
363 |
+
mamba_config=mamba_config,
|
364 |
+
)
|
365 |
+
for i in range(num_layers)
|
366 |
+
]
|
367 |
+
)
|
368 |
+
self.norm = LayerNorm(d_model, eps=1e-6)
|
369 |
+
|
370 |
+
def forward(
|
371 |
+
self,
|
372 |
+
src,
|
373 |
+
src_mask: Optional[torch.Tensor] = None,
|
374 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
375 |
+
pos_embs: Optional[torch.Tensor] = None,
|
376 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
377 |
+
):
|
378 |
+
"""
|
379 |
+
Arguments
|
380 |
+
----------
|
381 |
+
src : torch.Tensor
|
382 |
+
The sequence to the encoder layer.
|
383 |
+
src_mask : torch.Tensor, optional
|
384 |
+
The mask for the src sequence.
|
385 |
+
src_key_padding_mask : torch.Tensor, optional
|
386 |
+
The mask for the src keys per batch.
|
387 |
+
pos_embs: torch.Tensor, torch.nn.Module,
|
388 |
+
Module or tensor containing the input sequence positional embeddings
|
389 |
+
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
|
390 |
+
where S is the sequence length, and E is the embedding dimension.
|
391 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig]
|
392 |
+
Dynamic Chunk Training configuration object for streaming,
|
393 |
+
specifically involved here to apply Dynamic Chunk Convolution to the
|
394 |
+
convolution module.
|
395 |
+
"""
|
396 |
+
|
397 |
+
output = src
|
398 |
+
for enc_layer in self.layers:
|
399 |
+
output = enc_layer(
|
400 |
+
output,
|
401 |
+
src_mask=src_mask,
|
402 |
+
src_key_padding_mask=src_key_padding_mask,
|
403 |
+
pos_embs=pos_embs,
|
404 |
+
dynchunktrain_config=dynchunktrain_config,
|
405 |
+
)
|
406 |
+
output = self.norm(output)
|
407 |
+
|
408 |
+
return output, None
|
409 |
+
|
410 |
+
|
411 |
+
class MambaDecoderLayer(nn.Module):
|
412 |
+
"""This class implements the Mamba decoder layer.
|
413 |
+
"""
|
414 |
+
|
415 |
+
def __init__(
|
416 |
+
self,
|
417 |
+
d_model,
|
418 |
+
d_ffn,
|
419 |
+
activation=nn.ReLU,
|
420 |
+
dropout=0.0,
|
421 |
+
normalize_before=False,
|
422 |
+
mamba_config=None
|
423 |
+
):
|
424 |
+
super().__init__()
|
425 |
+
|
426 |
+
assert mamba_config != None
|
427 |
+
|
428 |
+
bidirectional = mamba_config.pop('bidirectional')
|
429 |
+
|
430 |
+
self.self_mamba = Mamba(
|
431 |
+
d_model=d_model,
|
432 |
+
**mamba_config
|
433 |
+
)
|
434 |
+
|
435 |
+
self.cross_mamba = Mamba(
|
436 |
+
d_model=d_model,
|
437 |
+
**mamba_config
|
438 |
+
)
|
439 |
+
|
440 |
+
mamba_config['bidirectional'] = bidirectional
|
441 |
+
|
442 |
+
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
|
443 |
+
d_ffn=d_ffn,
|
444 |
+
input_size=d_model,
|
445 |
+
dropout=dropout,
|
446 |
+
activation=activation,
|
447 |
+
)
|
448 |
+
|
449 |
+
# normalization layers
|
450 |
+
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
451 |
+
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
452 |
+
self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
453 |
+
self.dropout1 = torch.nn.Dropout(dropout)
|
454 |
+
self.dropout2 = torch.nn.Dropout(dropout)
|
455 |
+
self.dropout3 = torch.nn.Dropout(dropout)
|
456 |
+
|
457 |
+
self.normalize_before = normalize_before
|
458 |
+
|
459 |
+
def forward(
|
460 |
+
self,
|
461 |
+
tgt,
|
462 |
+
memory,
|
463 |
+
tgt_mask=None,
|
464 |
+
memory_mask=None,
|
465 |
+
tgt_key_padding_mask=None,
|
466 |
+
memory_key_padding_mask=None,
|
467 |
+
pos_embs_tgt=None,
|
468 |
+
pos_embs_src=None,
|
469 |
+
):
|
470 |
+
"""
|
471 |
+
Arguments
|
472 |
+
----------
|
473 |
+
tgt: torch.Tensor
|
474 |
+
The sequence to the decoder layer (required).
|
475 |
+
memory: torch.Tensor
|
476 |
+
The sequence from the last layer of the encoder (required).
|
477 |
+
tgt_mask: torch.Tensor
|
478 |
+
The mask for the tgt sequence (optional).
|
479 |
+
memory_mask: torch.Tensor
|
480 |
+
The mask for the memory sequence (optional).
|
481 |
+
tgt_key_padding_mask: torch.Tensor
|
482 |
+
The mask for the tgt keys per batch (optional).
|
483 |
+
memory_key_padding_mask: torch.Tensor
|
484 |
+
The mask for the memory keys per batch (optional).
|
485 |
+
pos_embs_tgt: torch.Tensor
|
486 |
+
The positional embeddings for the target (optional).
|
487 |
+
pos_embs_src: torch.Tensor
|
488 |
+
The positional embeddings for the source (optional).
|
489 |
+
"""
|
490 |
+
if self.normalize_before:
|
491 |
+
tgt1 = self.norm1(tgt)
|
492 |
+
else:
|
493 |
+
tgt1 = tgt
|
494 |
+
|
495 |
+
# Mamba over the target sequence
|
496 |
+
tgt2 = self.self_mamba(tgt1)
|
497 |
+
|
498 |
+
# add & norm
|
499 |
+
tgt = tgt + self.dropout1(tgt2)
|
500 |
+
if not self.normalize_before:
|
501 |
+
tgt = self.norm1(tgt)
|
502 |
+
|
503 |
+
if self.normalize_before:
|
504 |
+
tgt1 = self.norm2(tgt)
|
505 |
+
else:
|
506 |
+
tgt1 = tgt
|
507 |
+
|
508 |
+
# Mamba over key=value + query
|
509 |
+
# and only take the last len(query) tokens
|
510 |
+
tgt2 = self.cross_mamba(torch.cat([memory, tgt1], dim=1))[:, -tgt1.shape[1]:]
|
511 |
+
|
512 |
+
# add & norm
|
513 |
+
tgt = tgt + self.dropout2(tgt2)
|
514 |
+
if not self.normalize_before:
|
515 |
+
tgt = self.norm2(tgt)
|
516 |
+
|
517 |
+
if self.normalize_before:
|
518 |
+
tgt1 = self.norm3(tgt)
|
519 |
+
else:
|
520 |
+
tgt1 = tgt
|
521 |
+
|
522 |
+
tgt2 = self.pos_ffn(tgt1)
|
523 |
+
|
524 |
+
# add & norm
|
525 |
+
tgt = tgt + self.dropout3(tgt2)
|
526 |
+
if not self.normalize_before:
|
527 |
+
tgt = self.norm3(tgt)
|
528 |
+
|
529 |
+
return tgt, None, None
|
530 |
+
|
531 |
+
|
532 |
+
class MambaDecoder(nn.Module):
|
533 |
+
"""This class implements the Mamba decoder.
|
534 |
+
"""
|
535 |
+
|
536 |
+
def __init__(
|
537 |
+
self,
|
538 |
+
num_layers,
|
539 |
+
d_model,
|
540 |
+
d_ffn,
|
541 |
+
activation=nn.ReLU,
|
542 |
+
dropout=0.0,
|
543 |
+
normalize_before=False,
|
544 |
+
mamba_config=None
|
545 |
+
):
|
546 |
+
super().__init__()
|
547 |
+
self.layers = torch.nn.ModuleList(
|
548 |
+
[
|
549 |
+
MambaDecoderLayer(
|
550 |
+
d_model=d_model,
|
551 |
+
d_ffn=d_ffn,
|
552 |
+
activation=activation,
|
553 |
+
dropout=dropout,
|
554 |
+
normalize_before=normalize_before,
|
555 |
+
mamba_config=mamba_config
|
556 |
+
)
|
557 |
+
for _ in range(num_layers)
|
558 |
+
]
|
559 |
+
)
|
560 |
+
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
561 |
+
|
562 |
+
def forward(
|
563 |
+
self,
|
564 |
+
tgt,
|
565 |
+
memory,
|
566 |
+
tgt_mask=None,
|
567 |
+
memory_mask=None,
|
568 |
+
tgt_key_padding_mask=None,
|
569 |
+
memory_key_padding_mask=None,
|
570 |
+
pos_embs_tgt=None,
|
571 |
+
pos_embs_src=None,
|
572 |
+
):
|
573 |
+
"""
|
574 |
+
Arguments
|
575 |
+
----------
|
576 |
+
tgt : torch.Tensor
|
577 |
+
The sequence to the decoder layer (required).
|
578 |
+
memory : torch.Tensor
|
579 |
+
The sequence from the last layer of the encoder (required).
|
580 |
+
tgt_mask : torch.Tensor
|
581 |
+
The mask for the tgt sequence (optional).
|
582 |
+
memory_mask : torch.Tensor
|
583 |
+
The mask for the memory sequence (optional).
|
584 |
+
tgt_key_padding_mask : torch.Tensor
|
585 |
+
The mask for the tgt keys per batch (optional).
|
586 |
+
memory_key_padding_mask : torch.Tensor
|
587 |
+
The mask for the memory keys per batch (optional).
|
588 |
+
pos_embs_tgt : torch.Tensor
|
589 |
+
The positional embeddings for the target (optional).
|
590 |
+
pos_embs_src : torch.Tensor
|
591 |
+
The positional embeddings for the source (optional).
|
592 |
+
"""
|
593 |
+
output = tgt
|
594 |
+
for dec_layer in self.layers:
|
595 |
+
output, _, _ = dec_layer(
|
596 |
+
output,
|
597 |
+
memory,
|
598 |
+
tgt_mask=tgt_mask,
|
599 |
+
memory_mask=memory_mask,
|
600 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
601 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
602 |
+
pos_embs_tgt=pos_embs_tgt,
|
603 |
+
pos_embs_src=pos_embs_src,
|
604 |
+
)
|
605 |
+
output = self.norm(output)
|
606 |
+
|
607 |
+
return output, [None], [None]
|
model/modules/Transformer.py
ADDED
@@ -0,0 +1,1085 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Added ConMamba and Mamba
|
2 |
+
|
3 |
+
Authors
|
4 |
+
* Xilin Jiang 2024
|
5 |
+
"""
|
6 |
+
|
7 |
+
"""Transformer implementation in the SpeechBrain style.
|
8 |
+
|
9 |
+
Authors
|
10 |
+
* Jianyuan Zhong 2020
|
11 |
+
* Samuele Cornell 2021
|
12 |
+
"""
|
13 |
+
|
14 |
+
import math
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
import speechbrain as sb
|
22 |
+
from speechbrain.nnet.activations import Swish
|
23 |
+
from speechbrain.nnet.attention import RelPosEncXL
|
24 |
+
from speechbrain.nnet.CNN import Conv1d
|
25 |
+
|
26 |
+
from modules.Conformer import ConformerEncoder
|
27 |
+
from modules.Conmamba import ConmambaEncoder, MambaDecoder
|
28 |
+
|
29 |
+
|
30 |
+
class TransformerInterface(nn.Module):
|
31 |
+
"""This is an interface for transformer model.
|
32 |
+
Users can modify the attributes and define the forward function as
|
33 |
+
needed according to their own tasks.
|
34 |
+
The architecture is based on the paper "Attention Is All You Need":
|
35 |
+
https://arxiv.org/pdf/1706.03762.pdf
|
36 |
+
|
37 |
+
Arguments
|
38 |
+
---------
|
39 |
+
d_model: int
|
40 |
+
The number of expected features in the encoder/decoder inputs (default=512).
|
41 |
+
nhead: int
|
42 |
+
The number of heads in the multi-head attention models (default=8).
|
43 |
+
num_encoder_layers: int, optional
|
44 |
+
The number of encoder layers in1ì the encoder.
|
45 |
+
num_decoder_layers: int, optional
|
46 |
+
The number of decoder layers in the decoder.
|
47 |
+
d_ffn: int, optional
|
48 |
+
The dimension of the feedforward network model hidden layer.
|
49 |
+
dropout: int, optional
|
50 |
+
The dropout value.
|
51 |
+
activation: torch.nn.Module, optional
|
52 |
+
The activation function for Feed-Forward Network layer,
|
53 |
+
e.g., relu or gelu or swish.
|
54 |
+
custom_src_module: torch.nn.Module, optional
|
55 |
+
Module that processes the src features to expected feature dim.
|
56 |
+
custom_tgt_module: torch.nn.Module, optional
|
57 |
+
Module that processes the src features to expected feature dim.
|
58 |
+
positional_encoding: str, optional
|
59 |
+
Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings.
|
60 |
+
normalize_before: bool, optional
|
61 |
+
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
|
62 |
+
Defaults to True as this was shown to lead to better performance and training stability.
|
63 |
+
kernel_size: int, optional
|
64 |
+
Kernel size in convolutional layers when Conformer is used.
|
65 |
+
bias: bool, optional
|
66 |
+
Whether to use bias in Conformer convolutional layers.
|
67 |
+
encoder_module: str, optional
|
68 |
+
Choose between Branchformer, Conformer, ConMamba, and Transformer for the encoder.
|
69 |
+
decoder_module: str, optional
|
70 |
+
Choose between Mamba and Transformer for the decoder.
|
71 |
+
conformer_activation: torch.nn.Module, optional
|
72 |
+
Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.
|
73 |
+
branchformer_activation: torch.nn.Module, optional
|
74 |
+
Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.
|
75 |
+
attention_type: str, optional
|
76 |
+
Type of attention layer used in all Transformer or Conformer layers.
|
77 |
+
e.g. regularMHA or RelPosMHA.
|
78 |
+
max_length: int, optional
|
79 |
+
Max length for the target and source sequence in input.
|
80 |
+
Used for positional encodings.
|
81 |
+
causal: bool, optional
|
82 |
+
Whether the encoder should be causal or not (the decoder is always causal).
|
83 |
+
If causal the Conformer convolutional layer is causal.
|
84 |
+
encoder_kdim: int, optional
|
85 |
+
Dimension of the key for the encoder.
|
86 |
+
encoder_vdim: int, optional
|
87 |
+
Dimension of the value for the encoder.
|
88 |
+
decoder_kdim: int, optional
|
89 |
+
Dimension of the key for the decoder.
|
90 |
+
decoder_vdim: int, optional
|
91 |
+
Dimension of the value for the decoder.
|
92 |
+
csgu_linear_units: int, optional
|
93 |
+
Number of neurons in the hidden linear units of the CSGU Module.
|
94 |
+
-> Branchformer
|
95 |
+
gate_activation: torch.nn.Module, optional
|
96 |
+
Activation function used at the gate of the CSGU module.
|
97 |
+
-> Branchformer
|
98 |
+
use_linear_after_conv: bool, optional
|
99 |
+
If True, will apply a linear transformation of size input_size//2.
|
100 |
+
-> Branchformer
|
101 |
+
mamba_config: dict, optional
|
102 |
+
Mamba parameters if encoder_module or decoder_module is Mamba or ConMamba
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
d_model=512,
|
108 |
+
nhead=8,
|
109 |
+
num_encoder_layers=6,
|
110 |
+
num_decoder_layers=6,
|
111 |
+
d_ffn=2048,
|
112 |
+
dropout=0.1,
|
113 |
+
activation=nn.ReLU,
|
114 |
+
custom_src_module=None,
|
115 |
+
custom_tgt_module=None,
|
116 |
+
positional_encoding="fixed_abs_sine",
|
117 |
+
normalize_before=True,
|
118 |
+
kernel_size: Optional[int] = 31,
|
119 |
+
bias: Optional[bool] = True,
|
120 |
+
encoder_module: Optional[str] = "transformer",
|
121 |
+
decoder_module: Optional[str] = "transformer",
|
122 |
+
conformer_activation: Optional[nn.Module] = Swish,
|
123 |
+
branchformer_activation: Optional[nn.Module] = nn.GELU,
|
124 |
+
attention_type: Optional[str] = "regularMHA",
|
125 |
+
max_length: Optional[int] = 2500,
|
126 |
+
causal: Optional[bool] = False,
|
127 |
+
encoder_kdim: Optional[int] = None,
|
128 |
+
encoder_vdim: Optional[int] = None,
|
129 |
+
decoder_kdim: Optional[int] = None,
|
130 |
+
decoder_vdim: Optional[int] = None,
|
131 |
+
csgu_linear_units: Optional[int] = 3072,
|
132 |
+
gate_activation: Optional[nn.Module] = nn.Identity,
|
133 |
+
use_linear_after_conv: Optional[bool] = False,
|
134 |
+
mamba_config=None
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
self.causal = causal
|
138 |
+
self.attention_type = attention_type
|
139 |
+
self.positional_encoding_type = positional_encoding
|
140 |
+
self.encoder_kdim = encoder_kdim
|
141 |
+
self.encoder_vdim = encoder_vdim
|
142 |
+
self.decoder_kdim = decoder_kdim
|
143 |
+
self.decoder_vdim = decoder_vdim
|
144 |
+
|
145 |
+
assert attention_type in ["regularMHA", "RelPosMHAXL", "hypermixing"]
|
146 |
+
assert positional_encoding in ["fixed_abs_sine", None]
|
147 |
+
|
148 |
+
assert (
|
149 |
+
num_encoder_layers + num_decoder_layers > 0
|
150 |
+
), "number of encoder layers and number of decoder layers cannot both be 0!"
|
151 |
+
|
152 |
+
if positional_encoding == "fixed_abs_sine":
|
153 |
+
self.positional_encoding = PositionalEncoding(d_model, max_length)
|
154 |
+
elif positional_encoding is None:
|
155 |
+
pass
|
156 |
+
# no positional encodings
|
157 |
+
|
158 |
+
# overrides any other pos_embedding
|
159 |
+
if attention_type == "RelPosMHAXL":
|
160 |
+
self.positional_encoding = RelPosEncXL(d_model)
|
161 |
+
self.positional_encoding_decoder = PositionalEncoding(
|
162 |
+
d_model, max_length
|
163 |
+
)
|
164 |
+
|
165 |
+
# initialize the encoder
|
166 |
+
if num_encoder_layers > 0:
|
167 |
+
if custom_src_module is not None:
|
168 |
+
self.custom_src_module = custom_src_module(d_model)
|
169 |
+
if encoder_module == "transformer":
|
170 |
+
self.encoder = TransformerEncoder(
|
171 |
+
nhead=nhead,
|
172 |
+
num_layers=num_encoder_layers,
|
173 |
+
d_ffn=d_ffn,
|
174 |
+
d_model=d_model,
|
175 |
+
dropout=dropout,
|
176 |
+
activation=activation,
|
177 |
+
normalize_before=normalize_before,
|
178 |
+
causal=self.causal,
|
179 |
+
attention_type=self.attention_type,
|
180 |
+
kdim=self.encoder_kdim,
|
181 |
+
vdim=self.encoder_vdim,
|
182 |
+
)
|
183 |
+
elif encoder_module == "conformer":
|
184 |
+
self.encoder = ConformerEncoder(
|
185 |
+
nhead=nhead,
|
186 |
+
num_layers=num_encoder_layers,
|
187 |
+
d_ffn=d_ffn,
|
188 |
+
d_model=d_model,
|
189 |
+
dropout=dropout,
|
190 |
+
activation=conformer_activation,
|
191 |
+
kernel_size=kernel_size,
|
192 |
+
bias=bias,
|
193 |
+
causal=self.causal,
|
194 |
+
attention_type=self.attention_type,
|
195 |
+
)
|
196 |
+
assert (
|
197 |
+
normalize_before
|
198 |
+
), "normalize_before must be True for Conformer"
|
199 |
+
|
200 |
+
assert (
|
201 |
+
conformer_activation is not None
|
202 |
+
), "conformer_activation must not be None"
|
203 |
+
elif encoder_module == "branchformer":
|
204 |
+
self.encoder = BranchformerEncoder(
|
205 |
+
nhead=nhead,
|
206 |
+
num_layers=num_encoder_layers,
|
207 |
+
d_model=d_model,
|
208 |
+
dropout=dropout,
|
209 |
+
activation=branchformer_activation,
|
210 |
+
kernel_size=kernel_size,
|
211 |
+
attention_type=self.attention_type,
|
212 |
+
csgu_linear_units=csgu_linear_units,
|
213 |
+
gate_activation=gate_activation,
|
214 |
+
use_linear_after_conv=use_linear_after_conv,
|
215 |
+
)
|
216 |
+
elif encoder_module == "conmamba":
|
217 |
+
self.encoder = ConmambaEncoder(
|
218 |
+
num_layers=num_encoder_layers,
|
219 |
+
d_model=d_model,
|
220 |
+
d_ffn=d_ffn,
|
221 |
+
dropout=dropout,
|
222 |
+
activation=branchformer_activation,
|
223 |
+
kernel_size=kernel_size,
|
224 |
+
bias=bias,
|
225 |
+
causal=self.causal,
|
226 |
+
mamba_config=mamba_config
|
227 |
+
)
|
228 |
+
assert (
|
229 |
+
normalize_before
|
230 |
+
), "normalize_before must be True for Conmamba"
|
231 |
+
|
232 |
+
assert (
|
233 |
+
conformer_activation is not None
|
234 |
+
), "conformer_activation must not be None"
|
235 |
+
|
236 |
+
# initialize the decoder
|
237 |
+
if num_decoder_layers > 0:
|
238 |
+
if custom_tgt_module is not None:
|
239 |
+
self.custom_tgt_module = custom_tgt_module(d_model)
|
240 |
+
if decoder_module == 'transformer':
|
241 |
+
self.decoder = TransformerDecoder(
|
242 |
+
num_layers=num_decoder_layers,
|
243 |
+
nhead=nhead,
|
244 |
+
d_ffn=d_ffn,
|
245 |
+
d_model=d_model,
|
246 |
+
dropout=dropout,
|
247 |
+
activation=activation,
|
248 |
+
normalize_before=normalize_before,
|
249 |
+
causal=True,
|
250 |
+
attention_type="regularMHA", # always use regular attention in decoder
|
251 |
+
kdim=self.decoder_kdim,
|
252 |
+
vdim=self.decoder_vdim,
|
253 |
+
)
|
254 |
+
elif decoder_module in ['mamba']:
|
255 |
+
self.decoder = MambaDecoder(
|
256 |
+
num_layers=num_decoder_layers,
|
257 |
+
d_ffn=d_ffn,
|
258 |
+
d_model=d_model,
|
259 |
+
activation=activation,
|
260 |
+
dropout=dropout,
|
261 |
+
normalize_before=normalize_before,
|
262 |
+
mamba_config=mamba_config
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
raise NotImplementedError(decoder_module)
|
266 |
+
|
267 |
+
def forward(self, **kwags):
|
268 |
+
"""Users should modify this function according to their own tasks."""
|
269 |
+
raise NotImplementedError
|
270 |
+
|
271 |
+
|
272 |
+
class PositionalEncoding(nn.Module):
|
273 |
+
"""This class implements the absolute sinusoidal positional encoding function.
|
274 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
275 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
276 |
+
|
277 |
+
Arguments
|
278 |
+
---------
|
279 |
+
input_size: int
|
280 |
+
Embedding dimension.
|
281 |
+
max_len : int, optional
|
282 |
+
Max length of the input sequences (default 2500).
|
283 |
+
|
284 |
+
Example
|
285 |
+
-------
|
286 |
+
>>> a = torch.rand((8, 120, 512))
|
287 |
+
>>> enc = PositionalEncoding(input_size=a.shape[-1])
|
288 |
+
>>> b = enc(a)
|
289 |
+
>>> b.shape
|
290 |
+
torch.Size([1, 120, 512])
|
291 |
+
"""
|
292 |
+
|
293 |
+
def __init__(self, input_size, max_len=2500):
|
294 |
+
super().__init__()
|
295 |
+
if input_size % 2 != 0:
|
296 |
+
raise ValueError(
|
297 |
+
f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
|
298 |
+
)
|
299 |
+
self.max_len = max_len
|
300 |
+
pe = torch.zeros(self.max_len, input_size, requires_grad=False)
|
301 |
+
positions = torch.arange(0, self.max_len).unsqueeze(1).float()
|
302 |
+
denominator = torch.exp(
|
303 |
+
torch.arange(0, input_size, 2).float()
|
304 |
+
* -(math.log(10000.0) / input_size)
|
305 |
+
)
|
306 |
+
|
307 |
+
pe[:, 0::2] = torch.sin(positions * denominator)
|
308 |
+
pe[:, 1::2] = torch.cos(positions * denominator)
|
309 |
+
pe = pe.unsqueeze(0)
|
310 |
+
self.register_buffer("pe", pe)
|
311 |
+
|
312 |
+
def forward(self, x):
|
313 |
+
"""
|
314 |
+
Arguments
|
315 |
+
---------
|
316 |
+
x : torch.Tensor
|
317 |
+
Input feature shape (batch, time, fea)
|
318 |
+
|
319 |
+
Returns
|
320 |
+
-------
|
321 |
+
The positional encoding.
|
322 |
+
"""
|
323 |
+
return self.pe[:, : x.size(1)].clone().detach()
|
324 |
+
|
325 |
+
|
326 |
+
class TransformerEncoderLayer(nn.Module):
|
327 |
+
"""This is an implementation of self-attention encoder layer.
|
328 |
+
|
329 |
+
Arguments
|
330 |
+
---------
|
331 |
+
d_ffn: int, optional
|
332 |
+
The dimension of the feedforward network model hidden layer.
|
333 |
+
nhead: int
|
334 |
+
The number of heads in the multi-head attention models (default=8).
|
335 |
+
d_model: int
|
336 |
+
The number of expected features in the encoder/decoder inputs (default=512).
|
337 |
+
kdim: int, optional
|
338 |
+
Dimension of the key.
|
339 |
+
vdim: int, optional
|
340 |
+
Dimension of the value.
|
341 |
+
dropout: int, optional
|
342 |
+
The dropout value.
|
343 |
+
activation: torch.nn.Module, optional
|
344 |
+
The activation function for Feed-Forward Network layer,
|
345 |
+
e.g., relu or gelu or swish.
|
346 |
+
normalize_before: bool, optional
|
347 |
+
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
|
348 |
+
Defaults to True as this was shown to lead to better performance and training stability.
|
349 |
+
attention_type: str, optional
|
350 |
+
Type of attention layer used in all Transformer or Conformer layers.
|
351 |
+
e.g. regularMHA or RelPosMHA.
|
352 |
+
ffn_type: str
|
353 |
+
type of ffn: regularFFN/1dcnn
|
354 |
+
ffn_cnn_kernel_size_list: list of int
|
355 |
+
kernel size of 2 1d-convs if ffn_type is 1dcnn
|
356 |
+
causal: bool, optional
|
357 |
+
Whether the encoder should be causal or not (the decoder is always causal).
|
358 |
+
If causal the Conformer convolutional layer is causal.
|
359 |
+
|
360 |
+
Example
|
361 |
+
-------
|
362 |
+
>>> import torch
|
363 |
+
>>> x = torch.rand((8, 60, 512))
|
364 |
+
>>> net = TransformerEncoderLayer(512, 8, d_model=512)
|
365 |
+
>>> output = net(x)
|
366 |
+
>>> output[0].shape
|
367 |
+
torch.Size([8, 60, 512])
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
d_ffn,
|
373 |
+
nhead,
|
374 |
+
d_model,
|
375 |
+
kdim=None,
|
376 |
+
vdim=None,
|
377 |
+
dropout=0.0,
|
378 |
+
activation=nn.ReLU,
|
379 |
+
normalize_before=False,
|
380 |
+
attention_type="regularMHA",
|
381 |
+
ffn_type="regularFFN",
|
382 |
+
ffn_cnn_kernel_size_list=[3, 3],
|
383 |
+
causal=False,
|
384 |
+
):
|
385 |
+
super().__init__()
|
386 |
+
|
387 |
+
if attention_type == "regularMHA":
|
388 |
+
self.self_att = sb.nnet.attention.MultiheadAttention(
|
389 |
+
nhead=nhead,
|
390 |
+
d_model=d_model,
|
391 |
+
dropout=dropout,
|
392 |
+
kdim=kdim,
|
393 |
+
vdim=vdim,
|
394 |
+
)
|
395 |
+
|
396 |
+
elif attention_type == "RelPosMHAXL":
|
397 |
+
self.self_att = sb.nnet.attention.RelPosMHAXL(
|
398 |
+
d_model, nhead, dropout, mask_pos_future=causal
|
399 |
+
)
|
400 |
+
elif attention_type == "hypermixing":
|
401 |
+
self.self_att = sb.nnet.hypermixing.HyperMixing(
|
402 |
+
input_output_dim=d_model,
|
403 |
+
hypernet_size=d_ffn,
|
404 |
+
tied=False,
|
405 |
+
num_heads=nhead,
|
406 |
+
fix_tm_hidden_size=False,
|
407 |
+
)
|
408 |
+
|
409 |
+
if ffn_type == "regularFFN":
|
410 |
+
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
|
411 |
+
d_ffn=d_ffn,
|
412 |
+
input_size=d_model,
|
413 |
+
dropout=dropout,
|
414 |
+
activation=activation,
|
415 |
+
)
|
416 |
+
elif ffn_type == "1dcnn":
|
417 |
+
self.pos_ffn = nn.Sequential(
|
418 |
+
Conv1d(
|
419 |
+
in_channels=d_model,
|
420 |
+
out_channels=d_ffn,
|
421 |
+
kernel_size=ffn_cnn_kernel_size_list[0],
|
422 |
+
padding="causal" if causal else "same",
|
423 |
+
),
|
424 |
+
nn.ReLU(),
|
425 |
+
Conv1d(
|
426 |
+
in_channels=d_ffn,
|
427 |
+
out_channels=d_model,
|
428 |
+
kernel_size=ffn_cnn_kernel_size_list[1],
|
429 |
+
padding="causal" if causal else "same",
|
430 |
+
),
|
431 |
+
)
|
432 |
+
|
433 |
+
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
434 |
+
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
435 |
+
self.dropout1 = torch.nn.Dropout(dropout)
|
436 |
+
self.dropout2 = torch.nn.Dropout(dropout)
|
437 |
+
|
438 |
+
self.normalize_before = normalize_before
|
439 |
+
self.pos_ffn_type = ffn_type
|
440 |
+
|
441 |
+
def forward(
|
442 |
+
self,
|
443 |
+
src,
|
444 |
+
src_mask: Optional[torch.Tensor] = None,
|
445 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
446 |
+
pos_embs: Optional[torch.Tensor] = None,
|
447 |
+
):
|
448 |
+
"""
|
449 |
+
Arguments
|
450 |
+
---------
|
451 |
+
src : torch.Tensor
|
452 |
+
The sequence to the encoder layer.
|
453 |
+
src_mask : torch.Tensor
|
454 |
+
The mask for the src query for each example in the batch.
|
455 |
+
src_key_padding_mask : torch.Tensor, optional
|
456 |
+
The mask for the src keys for each example in the batch.
|
457 |
+
pos_embs: torch.Tensor, optional
|
458 |
+
The positional embeddings tensor.
|
459 |
+
|
460 |
+
Returns
|
461 |
+
-------
|
462 |
+
output : torch.Tensor
|
463 |
+
The output of the transformer encoder layer.
|
464 |
+
"""
|
465 |
+
|
466 |
+
if self.normalize_before:
|
467 |
+
src1 = self.norm1(src)
|
468 |
+
else:
|
469 |
+
src1 = src
|
470 |
+
|
471 |
+
output, self_attn = self.self_att(
|
472 |
+
src1,
|
473 |
+
src1,
|
474 |
+
src1,
|
475 |
+
attn_mask=src_mask,
|
476 |
+
key_padding_mask=src_key_padding_mask,
|
477 |
+
pos_embs=pos_embs,
|
478 |
+
)
|
479 |
+
|
480 |
+
# add & norm
|
481 |
+
src = src + self.dropout1(output)
|
482 |
+
if not self.normalize_before:
|
483 |
+
src = self.norm1(src)
|
484 |
+
|
485 |
+
if self.normalize_before:
|
486 |
+
src1 = self.norm2(src)
|
487 |
+
else:
|
488 |
+
src1 = src
|
489 |
+
output = self.pos_ffn(src1)
|
490 |
+
|
491 |
+
# add & norm
|
492 |
+
output = src + self.dropout2(output)
|
493 |
+
if not self.normalize_before:
|
494 |
+
output = self.norm2(output)
|
495 |
+
return output, self_attn
|
496 |
+
|
497 |
+
|
498 |
+
class TransformerEncoder(nn.Module):
|
499 |
+
"""This class implements the transformer encoder.
|
500 |
+
|
501 |
+
Arguments
|
502 |
+
---------
|
503 |
+
num_layers : int
|
504 |
+
Number of transformer layers to include.
|
505 |
+
nhead : int
|
506 |
+
Number of attention heads.
|
507 |
+
d_ffn : int
|
508 |
+
Hidden size of self-attention Feed Forward layer.
|
509 |
+
input_shape : tuple
|
510 |
+
Expected shape of the input.
|
511 |
+
d_model : int
|
512 |
+
The dimension of the input embedding.
|
513 |
+
kdim : int
|
514 |
+
Dimension for key (Optional).
|
515 |
+
vdim : int
|
516 |
+
Dimension for value (Optional).
|
517 |
+
dropout : float
|
518 |
+
Dropout for the encoder (Optional).
|
519 |
+
activation: torch.nn.Module, optional
|
520 |
+
The activation function for Feed-Forward Network layer,
|
521 |
+
e.g., relu or gelu or swish.
|
522 |
+
normalize_before: bool, optional
|
523 |
+
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
|
524 |
+
Defaults to True as this was shown to lead to better performance and training stability.
|
525 |
+
causal: bool, optional
|
526 |
+
Whether the encoder should be causal or not (the decoder is always causal).
|
527 |
+
If causal the Conformer convolutional layer is causal.
|
528 |
+
layerdrop_prob: float
|
529 |
+
The probability to drop an entire layer
|
530 |
+
attention_type: str, optional
|
531 |
+
Type of attention layer used in all Transformer or Conformer layers.
|
532 |
+
e.g. regularMHA or RelPosMHA.
|
533 |
+
ffn_type: str
|
534 |
+
type of ffn: regularFFN/1dcnn
|
535 |
+
ffn_cnn_kernel_size_list: list of int
|
536 |
+
conv kernel size of 2 1d-convs if ffn_type is 1dcnn
|
537 |
+
|
538 |
+
Example
|
539 |
+
-------
|
540 |
+
>>> import torch
|
541 |
+
>>> x = torch.rand((8, 60, 512))
|
542 |
+
>>> net = TransformerEncoder(1, 8, 512, d_model=512)
|
543 |
+
>>> output, _ = net(x)
|
544 |
+
>>> output.shape
|
545 |
+
torch.Size([8, 60, 512])
|
546 |
+
"""
|
547 |
+
|
548 |
+
def __init__(
|
549 |
+
self,
|
550 |
+
num_layers,
|
551 |
+
nhead,
|
552 |
+
d_ffn,
|
553 |
+
input_shape=None,
|
554 |
+
d_model=None,
|
555 |
+
kdim=None,
|
556 |
+
vdim=None,
|
557 |
+
dropout=0.0,
|
558 |
+
activation=nn.ReLU,
|
559 |
+
normalize_before=False,
|
560 |
+
causal=False,
|
561 |
+
layerdrop_prob=0.0,
|
562 |
+
attention_type="regularMHA",
|
563 |
+
ffn_type="regularFFN",
|
564 |
+
ffn_cnn_kernel_size_list=[3, 3],
|
565 |
+
):
|
566 |
+
super().__init__()
|
567 |
+
|
568 |
+
self.layers = torch.nn.ModuleList(
|
569 |
+
[
|
570 |
+
TransformerEncoderLayer(
|
571 |
+
d_ffn=d_ffn,
|
572 |
+
nhead=nhead,
|
573 |
+
d_model=d_model,
|
574 |
+
kdim=kdim,
|
575 |
+
vdim=vdim,
|
576 |
+
dropout=dropout,
|
577 |
+
activation=activation,
|
578 |
+
normalize_before=normalize_before,
|
579 |
+
causal=causal,
|
580 |
+
attention_type=attention_type,
|
581 |
+
ffn_type=ffn_type,
|
582 |
+
ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
|
583 |
+
)
|
584 |
+
for i in range(num_layers)
|
585 |
+
]
|
586 |
+
)
|
587 |
+
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
588 |
+
self.layerdrop_prob = layerdrop_prob
|
589 |
+
self.rng = np.random.default_rng()
|
590 |
+
|
591 |
+
def forward(
|
592 |
+
self,
|
593 |
+
src,
|
594 |
+
src_mask: Optional[torch.Tensor] = None,
|
595 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
596 |
+
pos_embs: Optional[torch.Tensor] = None,
|
597 |
+
dynchunktrain_config=None,
|
598 |
+
):
|
599 |
+
"""
|
600 |
+
Arguments
|
601 |
+
---------
|
602 |
+
src : torch.Tensor
|
603 |
+
The sequence to the encoder layer (required).
|
604 |
+
src_mask : torch.Tensor
|
605 |
+
The mask for the src sequence (optional).
|
606 |
+
src_key_padding_mask : torch.Tensor
|
607 |
+
The mask for the src keys per batch (optional).
|
608 |
+
pos_embs : torch.Tensor
|
609 |
+
The positional embedding tensor
|
610 |
+
dynchunktrain_config : config
|
611 |
+
Not supported for this encoder.
|
612 |
+
|
613 |
+
Returns
|
614 |
+
-------
|
615 |
+
output : torch.Tensor
|
616 |
+
The output of the transformer.
|
617 |
+
attention_lst : list
|
618 |
+
The attention values.
|
619 |
+
"""
|
620 |
+
assert (
|
621 |
+
dynchunktrain_config is None
|
622 |
+
), "Dynamic Chunk Training unsupported for this encoder"
|
623 |
+
|
624 |
+
output = src
|
625 |
+
if self.layerdrop_prob > 0.0:
|
626 |
+
keep_probs = self.rng.random(len(self.layers))
|
627 |
+
else:
|
628 |
+
keep_probs = None
|
629 |
+
attention_lst = []
|
630 |
+
for i, enc_layer in enumerate(self.layers):
|
631 |
+
if (
|
632 |
+
not self.training
|
633 |
+
or self.layerdrop_prob == 0.0
|
634 |
+
or keep_probs[i] > self.layerdrop_prob
|
635 |
+
):
|
636 |
+
output, attention = enc_layer(
|
637 |
+
output,
|
638 |
+
src_mask=src_mask,
|
639 |
+
src_key_padding_mask=src_key_padding_mask,
|
640 |
+
pos_embs=pos_embs,
|
641 |
+
)
|
642 |
+
|
643 |
+
attention_lst.append(attention)
|
644 |
+
output = self.norm(output)
|
645 |
+
return output, attention_lst
|
646 |
+
|
647 |
+
|
648 |
+
class TransformerDecoderLayer(nn.Module):
|
649 |
+
"""This class implements the self-attention decoder layer.
|
650 |
+
|
651 |
+
Arguments
|
652 |
+
---------
|
653 |
+
d_ffn : int
|
654 |
+
Hidden size of self-attention Feed Forward layer.
|
655 |
+
nhead : int
|
656 |
+
Number of attention heads.
|
657 |
+
d_model : int
|
658 |
+
Dimension of the model.
|
659 |
+
kdim : int
|
660 |
+
Dimension for key (optional).
|
661 |
+
vdim : int
|
662 |
+
Dimension for value (optional).
|
663 |
+
dropout : float
|
664 |
+
Dropout for the decoder (optional).
|
665 |
+
activation : Callable
|
666 |
+
Function to use between layers, default nn.ReLU
|
667 |
+
normalize_before : bool
|
668 |
+
Whether to normalize before layers.
|
669 |
+
attention_type : str
|
670 |
+
Type of attention to use, "regularMHA" or "RelPosMHAXL"
|
671 |
+
causal : bool
|
672 |
+
Whether to mask future positions.
|
673 |
+
|
674 |
+
Example
|
675 |
+
-------
|
676 |
+
>>> src = torch.rand((8, 60, 512))
|
677 |
+
>>> tgt = torch.rand((8, 60, 512))
|
678 |
+
>>> net = TransformerDecoderLayer(1024, 8, d_model=512)
|
679 |
+
>>> output, self_attn, multihead_attn = net(src, tgt)
|
680 |
+
>>> output.shape
|
681 |
+
torch.Size([8, 60, 512])
|
682 |
+
"""
|
683 |
+
|
684 |
+
def __init__(
|
685 |
+
self,
|
686 |
+
d_ffn,
|
687 |
+
nhead,
|
688 |
+
d_model,
|
689 |
+
kdim=None,
|
690 |
+
vdim=None,
|
691 |
+
dropout=0.0,
|
692 |
+
activation=nn.ReLU,
|
693 |
+
normalize_before=False,
|
694 |
+
attention_type="regularMHA",
|
695 |
+
causal=None,
|
696 |
+
):
|
697 |
+
super().__init__()
|
698 |
+
self.nhead = nhead
|
699 |
+
|
700 |
+
if attention_type == "regularMHA":
|
701 |
+
self.self_attn = sb.nnet.attention.MultiheadAttention(
|
702 |
+
nhead=nhead,
|
703 |
+
d_model=d_model,
|
704 |
+
kdim=kdim,
|
705 |
+
vdim=vdim,
|
706 |
+
dropout=dropout,
|
707 |
+
)
|
708 |
+
self.multihead_attn = sb.nnet.attention.MultiheadAttention(
|
709 |
+
nhead=nhead,
|
710 |
+
d_model=d_model,
|
711 |
+
kdim=kdim,
|
712 |
+
vdim=vdim,
|
713 |
+
dropout=dropout,
|
714 |
+
)
|
715 |
+
|
716 |
+
elif attention_type == "RelPosMHAXL":
|
717 |
+
self.self_attn = sb.nnet.attention.RelPosMHAXL(
|
718 |
+
d_model, nhead, dropout, mask_pos_future=causal
|
719 |
+
)
|
720 |
+
self.multihead_attn = sb.nnet.attention.RelPosMHAXL(
|
721 |
+
d_model, nhead, dropout, mask_pos_future=causal
|
722 |
+
)
|
723 |
+
|
724 |
+
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
|
725 |
+
d_ffn=d_ffn,
|
726 |
+
input_size=d_model,
|
727 |
+
dropout=dropout,
|
728 |
+
activation=activation,
|
729 |
+
)
|
730 |
+
|
731 |
+
# normalization layers
|
732 |
+
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
733 |
+
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
734 |
+
self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
735 |
+
self.dropout1 = torch.nn.Dropout(dropout)
|
736 |
+
self.dropout2 = torch.nn.Dropout(dropout)
|
737 |
+
self.dropout3 = torch.nn.Dropout(dropout)
|
738 |
+
|
739 |
+
self.normalize_before = normalize_before
|
740 |
+
|
741 |
+
def forward(
|
742 |
+
self,
|
743 |
+
tgt,
|
744 |
+
memory,
|
745 |
+
tgt_mask=None,
|
746 |
+
memory_mask=None,
|
747 |
+
tgt_key_padding_mask=None,
|
748 |
+
memory_key_padding_mask=None,
|
749 |
+
pos_embs_tgt=None,
|
750 |
+
pos_embs_src=None,
|
751 |
+
):
|
752 |
+
"""
|
753 |
+
Arguments
|
754 |
+
----------
|
755 |
+
tgt: torch.Tensor
|
756 |
+
The sequence to the decoder layer (required).
|
757 |
+
memory: torch.Tensor
|
758 |
+
The sequence from the last layer of the encoder (required).
|
759 |
+
tgt_mask: torch.Tensor
|
760 |
+
The mask for the tgt sequence (optional).
|
761 |
+
memory_mask: torch.Tensor
|
762 |
+
The mask for the memory sequence (optional).
|
763 |
+
tgt_key_padding_mask: torch.Tensor
|
764 |
+
The mask for the tgt keys per batch (optional).
|
765 |
+
memory_key_padding_mask: torch.Tensor
|
766 |
+
The mask for the memory keys per batch (optional).
|
767 |
+
pos_embs_tgt: torch.Tensor
|
768 |
+
The positional embeddings for the target (optional).
|
769 |
+
pos_embs_src: torch.Tensor
|
770 |
+
The positional embeddings for the source (optional).
|
771 |
+
"""
|
772 |
+
if self.normalize_before:
|
773 |
+
tgt1 = self.norm1(tgt)
|
774 |
+
else:
|
775 |
+
tgt1 = tgt
|
776 |
+
|
777 |
+
# self-attention over the target sequence
|
778 |
+
tgt2, self_attn = self.self_attn(
|
779 |
+
query=tgt1,
|
780 |
+
key=tgt1,
|
781 |
+
value=tgt1,
|
782 |
+
attn_mask=tgt_mask,
|
783 |
+
key_padding_mask=tgt_key_padding_mask,
|
784 |
+
pos_embs=pos_embs_tgt,
|
785 |
+
)
|
786 |
+
|
787 |
+
# add & norm
|
788 |
+
tgt = tgt + self.dropout1(tgt2)
|
789 |
+
if not self.normalize_before:
|
790 |
+
tgt = self.norm1(tgt)
|
791 |
+
|
792 |
+
if self.normalize_before:
|
793 |
+
tgt1 = self.norm2(tgt)
|
794 |
+
else:
|
795 |
+
tgt1 = tgt
|
796 |
+
|
797 |
+
# multi-head attention over the target sequence and encoder states
|
798 |
+
|
799 |
+
tgt2, multihead_attention = self.multihead_attn(
|
800 |
+
query=tgt1,
|
801 |
+
key=memory,
|
802 |
+
value=memory,
|
803 |
+
attn_mask=memory_mask,
|
804 |
+
key_padding_mask=memory_key_padding_mask,
|
805 |
+
pos_embs=pos_embs_src,
|
806 |
+
)
|
807 |
+
|
808 |
+
# add & norm
|
809 |
+
tgt = tgt + self.dropout2(tgt2)
|
810 |
+
if not self.normalize_before:
|
811 |
+
tgt = self.norm2(tgt)
|
812 |
+
|
813 |
+
if self.normalize_before:
|
814 |
+
tgt1 = self.norm3(tgt)
|
815 |
+
else:
|
816 |
+
tgt1 = tgt
|
817 |
+
|
818 |
+
tgt2 = self.pos_ffn(tgt1)
|
819 |
+
|
820 |
+
# add & norm
|
821 |
+
tgt = tgt + self.dropout3(tgt2)
|
822 |
+
if not self.normalize_before:
|
823 |
+
tgt = self.norm3(tgt)
|
824 |
+
|
825 |
+
return tgt, self_attn, multihead_attention
|
826 |
+
|
827 |
+
|
828 |
+
class TransformerDecoder(nn.Module):
|
829 |
+
"""This class implements the Transformer decoder.
|
830 |
+
|
831 |
+
Arguments
|
832 |
+
---------
|
833 |
+
num_layers : int
|
834 |
+
Number of transformer layers for the decoder.
|
835 |
+
nhead : int
|
836 |
+
Number of attention heads.
|
837 |
+
d_ffn : int
|
838 |
+
Hidden size of self-attention Feed Forward layer.
|
839 |
+
d_model : int
|
840 |
+
Dimension of the model.
|
841 |
+
kdim : int, optional
|
842 |
+
Dimension for key (Optional).
|
843 |
+
vdim : int, optional
|
844 |
+
Dimension for value (Optional).
|
845 |
+
dropout : float, optional
|
846 |
+
Dropout for the decoder (Optional).
|
847 |
+
activation : Callable
|
848 |
+
The function to apply between layers, default nn.ReLU
|
849 |
+
normalize_before : bool
|
850 |
+
Whether to normalize before layers.
|
851 |
+
causal : bool
|
852 |
+
Whether to allow future information in decoding.
|
853 |
+
attention_type : str
|
854 |
+
Type of attention to use, "regularMHA" or "RelPosMHAXL"
|
855 |
+
|
856 |
+
Example
|
857 |
+
-------
|
858 |
+
>>> src = torch.rand((8, 60, 512))
|
859 |
+
>>> tgt = torch.rand((8, 60, 512))
|
860 |
+
>>> net = TransformerDecoder(1, 8, 1024, d_model=512)
|
861 |
+
>>> output, _, _ = net(src, tgt)
|
862 |
+
>>> output.shape
|
863 |
+
torch.Size([8, 60, 512])
|
864 |
+
"""
|
865 |
+
|
866 |
+
def __init__(
|
867 |
+
self,
|
868 |
+
num_layers,
|
869 |
+
nhead,
|
870 |
+
d_ffn,
|
871 |
+
d_model,
|
872 |
+
kdim=None,
|
873 |
+
vdim=None,
|
874 |
+
dropout=0.0,
|
875 |
+
activation=nn.ReLU,
|
876 |
+
normalize_before=False,
|
877 |
+
causal=False,
|
878 |
+
attention_type="regularMHA",
|
879 |
+
):
|
880 |
+
super().__init__()
|
881 |
+
self.layers = torch.nn.ModuleList(
|
882 |
+
[
|
883 |
+
TransformerDecoderLayer(
|
884 |
+
d_ffn=d_ffn,
|
885 |
+
nhead=nhead,
|
886 |
+
d_model=d_model,
|
887 |
+
kdim=kdim,
|
888 |
+
vdim=vdim,
|
889 |
+
dropout=dropout,
|
890 |
+
activation=activation,
|
891 |
+
normalize_before=normalize_before,
|
892 |
+
causal=causal,
|
893 |
+
attention_type=attention_type,
|
894 |
+
)
|
895 |
+
for _ in range(num_layers)
|
896 |
+
]
|
897 |
+
)
|
898 |
+
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
899 |
+
|
900 |
+
def forward(
|
901 |
+
self,
|
902 |
+
tgt,
|
903 |
+
memory,
|
904 |
+
tgt_mask=None,
|
905 |
+
memory_mask=None,
|
906 |
+
tgt_key_padding_mask=None,
|
907 |
+
memory_key_padding_mask=None,
|
908 |
+
pos_embs_tgt=None,
|
909 |
+
pos_embs_src=None,
|
910 |
+
):
|
911 |
+
"""
|
912 |
+
Arguments
|
913 |
+
----------
|
914 |
+
tgt : torch.Tensor
|
915 |
+
The sequence to the decoder layer (required).
|
916 |
+
memory : torch.Tensor
|
917 |
+
The sequence from the last layer of the encoder (required).
|
918 |
+
tgt_mask : torch.Tensor
|
919 |
+
The mask for the tgt sequence (optional).
|
920 |
+
memory_mask : torch.Tensor
|
921 |
+
The mask for the memory sequence (optional).
|
922 |
+
tgt_key_padding_mask : torch.Tensor
|
923 |
+
The mask for the tgt keys per batch (optional).
|
924 |
+
memory_key_padding_mask : torch.Tensor
|
925 |
+
The mask for the memory keys per batch (optional).
|
926 |
+
pos_embs_tgt : torch.Tensor
|
927 |
+
The positional embeddings for the target (optional).
|
928 |
+
pos_embs_src : torch.Tensor
|
929 |
+
The positional embeddings for the source (optional).
|
930 |
+
"""
|
931 |
+
output = tgt
|
932 |
+
self_attns, multihead_attns = [], []
|
933 |
+
for dec_layer in self.layers:
|
934 |
+
output, self_attn, multihead_attn = dec_layer(
|
935 |
+
output,
|
936 |
+
memory,
|
937 |
+
tgt_mask=tgt_mask,
|
938 |
+
memory_mask=memory_mask,
|
939 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
940 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
941 |
+
pos_embs_tgt=pos_embs_tgt,
|
942 |
+
pos_embs_src=pos_embs_src,
|
943 |
+
)
|
944 |
+
self_attns.append(self_attn)
|
945 |
+
multihead_attns.append(multihead_attn)
|
946 |
+
output = self.norm(output)
|
947 |
+
|
948 |
+
return output, self_attns, multihead_attns
|
949 |
+
|
950 |
+
|
951 |
+
class NormalizedEmbedding(nn.Module):
|
952 |
+
"""This class implements the normalized embedding layer for the transformer.
|
953 |
+
Since the dot product of the self-attention is always normalized by sqrt(d_model)
|
954 |
+
and the final linear projection for prediction shares weight with the embedding layer,
|
955 |
+
we multiply the output of the embedding by sqrt(d_model).
|
956 |
+
|
957 |
+
Arguments
|
958 |
+
---------
|
959 |
+
d_model: int
|
960 |
+
The number of expected features in the encoder/decoder inputs (default=512).
|
961 |
+
vocab: int
|
962 |
+
The vocab size.
|
963 |
+
|
964 |
+
Example
|
965 |
+
-------
|
966 |
+
>>> emb = NormalizedEmbedding(512, 1000)
|
967 |
+
>>> trg = torch.randint(0, 999, (8, 50))
|
968 |
+
>>> emb_fea = emb(trg)
|
969 |
+
"""
|
970 |
+
|
971 |
+
def __init__(self, d_model, vocab):
|
972 |
+
super().__init__()
|
973 |
+
self.emb = sb.nnet.embedding.Embedding(
|
974 |
+
num_embeddings=vocab, embedding_dim=d_model, blank_id=0
|
975 |
+
)
|
976 |
+
self.d_model = d_model
|
977 |
+
|
978 |
+
def forward(self, x):
|
979 |
+
"""Processes the input tensor x and returns an output tensor."""
|
980 |
+
return self.emb(x) * math.sqrt(self.d_model)
|
981 |
+
|
982 |
+
|
983 |
+
def get_key_padding_mask(padded_input, pad_idx):
|
984 |
+
"""Creates a binary mask to prevent attention to padded locations.
|
985 |
+
We suggest using ``get_mask_from_lengths`` instead of this function.
|
986 |
+
|
987 |
+
Arguments
|
988 |
+
---------
|
989 |
+
padded_input: torch.Tensor
|
990 |
+
Padded input.
|
991 |
+
pad_idx: int
|
992 |
+
idx for padding element.
|
993 |
+
|
994 |
+
Returns
|
995 |
+
-------
|
996 |
+
key_padded_mask: torch.Tensor
|
997 |
+
Binary mask to prevent attention to padding.
|
998 |
+
|
999 |
+
Example
|
1000 |
+
-------
|
1001 |
+
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
|
1002 |
+
>>> get_key_padding_mask(a, pad_idx=0)
|
1003 |
+
tensor([[False, False, True],
|
1004 |
+
[False, False, True],
|
1005 |
+
[False, False, True]])
|
1006 |
+
"""
|
1007 |
+
if len(padded_input.shape) == 4:
|
1008 |
+
bz, time, ch1, ch2 = padded_input.shape
|
1009 |
+
padded_input = padded_input.reshape(bz, time, ch1 * ch2)
|
1010 |
+
|
1011 |
+
key_padded_mask = padded_input.eq(pad_idx).to(padded_input.device)
|
1012 |
+
|
1013 |
+
# if the input is more than 2d, mask the locations where they are silence
|
1014 |
+
# across all channels
|
1015 |
+
if len(padded_input.shape) > 2:
|
1016 |
+
key_padded_mask = key_padded_mask.float().prod(dim=-1).bool()
|
1017 |
+
return key_padded_mask.detach()
|
1018 |
+
|
1019 |
+
return key_padded_mask.detach()
|
1020 |
+
|
1021 |
+
|
1022 |
+
def get_lookahead_mask(padded_input):
|
1023 |
+
"""Creates a binary mask for each sequence which masks future frames.
|
1024 |
+
|
1025 |
+
Arguments
|
1026 |
+
---------
|
1027 |
+
padded_input: torch.Tensor
|
1028 |
+
Padded input tensor.
|
1029 |
+
|
1030 |
+
Returns
|
1031 |
+
-------
|
1032 |
+
mask : torch.Tensor
|
1033 |
+
Binary mask for masking future frames.
|
1034 |
+
|
1035 |
+
Example
|
1036 |
+
-------
|
1037 |
+
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
|
1038 |
+
>>> get_lookahead_mask(a)
|
1039 |
+
tensor([[0., -inf, -inf],
|
1040 |
+
[0., 0., -inf],
|
1041 |
+
[0., 0., 0.]])
|
1042 |
+
"""
|
1043 |
+
seq_len = padded_input.shape[1]
|
1044 |
+
mask = (
|
1045 |
+
torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device))
|
1046 |
+
== 1
|
1047 |
+
).transpose(0, 1)
|
1048 |
+
mask = (
|
1049 |
+
mask.float()
|
1050 |
+
.masked_fill(mask == 0, float("-inf"))
|
1051 |
+
.masked_fill(mask == 1, float(0.0))
|
1052 |
+
)
|
1053 |
+
return mask.detach().to(padded_input.device)
|
1054 |
+
|
1055 |
+
|
1056 |
+
def get_mask_from_lengths(lengths, max_len=None):
|
1057 |
+
"""Creates a binary mask from sequence lengths
|
1058 |
+
|
1059 |
+
Arguments
|
1060 |
+
---------
|
1061 |
+
lengths: torch.Tensor
|
1062 |
+
A tensor of sequence lengths
|
1063 |
+
max_len: int (Optional)
|
1064 |
+
Maximum sequence length, defaults to None.
|
1065 |
+
|
1066 |
+
Returns
|
1067 |
+
-------
|
1068 |
+
mask: torch.Tensor
|
1069 |
+
the mask where padded elements are set to True.
|
1070 |
+
Then one can use tensor.masked_fill_(mask, 0) for the masking.
|
1071 |
+
|
1072 |
+
Example
|
1073 |
+
-------
|
1074 |
+
>>> lengths = torch.tensor([3, 2, 4])
|
1075 |
+
>>> get_mask_from_lengths(lengths)
|
1076 |
+
tensor([[False, False, False, True],
|
1077 |
+
[False, False, True, True],
|
1078 |
+
[False, False, False, False]])
|
1079 |
+
"""
|
1080 |
+
if max_len is None:
|
1081 |
+
max_len = torch.max(lengths).item()
|
1082 |
+
seq_range = torch.arange(
|
1083 |
+
max_len, device=lengths.device, dtype=lengths.dtype
|
1084 |
+
)
|
1085 |
+
return ~(seq_range.unsqueeze(0) < lengths.unsqueeze(1))
|
model/modules/TransformerASR.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Added ConMamba and Mamba
|
2 |
+
|
3 |
+
Authors
|
4 |
+
* Xilin Jiang 2024
|
5 |
+
"""
|
6 |
+
|
7 |
+
"""Transformer for ASR in the SpeechBrain style.
|
8 |
+
|
9 |
+
Authors
|
10 |
+
* Jianyuan Zhong 2020
|
11 |
+
* Titouan Parcollet 2024
|
12 |
+
* Luca Della Libera 2024
|
13 |
+
"""
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Optional
|
17 |
+
|
18 |
+
import torch # noqa 42
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from speechbrain.dataio.dataio import length_to_mask
|
22 |
+
from modules.Transformer import (
|
23 |
+
NormalizedEmbedding,
|
24 |
+
TransformerInterface,
|
25 |
+
get_key_padding_mask,
|
26 |
+
get_lookahead_mask,
|
27 |
+
)
|
28 |
+
from speechbrain.nnet.activations import Swish
|
29 |
+
from speechbrain.nnet.containers import ModuleList
|
30 |
+
from speechbrain.nnet.linear import Linear
|
31 |
+
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class TransformerASRStreamingContext:
|
36 |
+
"""Streaming metadata and state for a `TransformerASR` instance."""
|
37 |
+
|
38 |
+
dynchunktrain_config: DynChunkTrainConfig
|
39 |
+
"""Dynamic Chunk Training configuration holding chunk size and context size
|
40 |
+
information."""
|
41 |
+
|
42 |
+
encoder_context: Any
|
43 |
+
"""Opaque encoder context information. It is constructed by the encoder's
|
44 |
+
`make_streaming_context` method and is passed to the encoder when using
|
45 |
+
`encode_streaming`.
|
46 |
+
"""
|
47 |
+
|
48 |
+
|
49 |
+
def make_transformer_src_mask(
|
50 |
+
src: torch.Tensor,
|
51 |
+
causal: bool = False,
|
52 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
53 |
+
) -> Optional[torch.Tensor]:
|
54 |
+
"""Prepare the source transformer mask that restricts which frames can
|
55 |
+
attend to which frames depending on causal or other simple restricted
|
56 |
+
attention methods.
|
57 |
+
|
58 |
+
Arguments
|
59 |
+
---------
|
60 |
+
src: torch.Tensor
|
61 |
+
The source tensor to build a mask from. The contents of the tensor are
|
62 |
+
not actually used currently; only its shape and other metadata (e.g.
|
63 |
+
device).
|
64 |
+
causal: bool
|
65 |
+
Whether strict causality shall be used. Frames will not be able to
|
66 |
+
attend to any future frame.
|
67 |
+
dynchunktrain_config: DynChunkTrainConfig, optional
|
68 |
+
Dynamic Chunk Training configuration. This implements a simple form of
|
69 |
+
chunkwise attention. Incompatible with `causal`.
|
70 |
+
|
71 |
+
Returns
|
72 |
+
-------
|
73 |
+
torch.Tensor
|
74 |
+
A boolean mask Tensor of shape (timesteps, timesteps).
|
75 |
+
"""
|
76 |
+
if causal:
|
77 |
+
assert dynchunktrain_config is None
|
78 |
+
return get_lookahead_mask(src)
|
79 |
+
|
80 |
+
if dynchunktrain_config is None:
|
81 |
+
return
|
82 |
+
|
83 |
+
# The following is not really the sole source used to implement this,
|
84 |
+
# but it helps introduce the concept.
|
85 |
+
# ref: Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
|
86 |
+
# https://arxiv.org/pdf/2012.05481.pdf
|
87 |
+
timesteps = src.size(1)
|
88 |
+
|
89 |
+
# Mask the future at the right of each chunk
|
90 |
+
chunk_size = dynchunktrain_config.chunk_size
|
91 |
+
num_chunks = timesteps // chunk_size
|
92 |
+
timestep_idx = torch.arange(timesteps, device=src.device)
|
93 |
+
mask_idx = torch.arange(
|
94 |
+
chunk_size, chunk_size * (num_chunks + 2), chunk_size, device=src.device
|
95 |
+
).repeat_interleave(chunk_size)[:timesteps]
|
96 |
+
src_mask = timestep_idx[None] >= mask_idx[:, None]
|
97 |
+
|
98 |
+
# Mask the past at the left of each chunk (accounting for left context)
|
99 |
+
# only relevant if using left context
|
100 |
+
if not dynchunktrain_config.is_infinite_left_context():
|
101 |
+
num_left_chunks = dynchunktrain_config.left_context_size
|
102 |
+
mask_idx -= chunk_size * (num_left_chunks + 1)
|
103 |
+
src_mask += timestep_idx[None] < mask_idx[:, None]
|
104 |
+
|
105 |
+
return src_mask
|
106 |
+
|
107 |
+
|
108 |
+
def make_transformer_src_tgt_masks(
|
109 |
+
src,
|
110 |
+
tgt=None,
|
111 |
+
wav_len=None,
|
112 |
+
pad_idx=0,
|
113 |
+
causal: bool = False,
|
114 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
115 |
+
):
|
116 |
+
"""This function generates masks for training the transformer model,
|
117 |
+
opinionated for an ASR context with encoding masks and, optionally, decoding
|
118 |
+
masks (if specifying `tgt`).
|
119 |
+
|
120 |
+
Arguments
|
121 |
+
---------
|
122 |
+
src : torch.Tensor
|
123 |
+
The sequence to the encoder (required).
|
124 |
+
tgt : torch.Tensor
|
125 |
+
The sequence to the decoder.
|
126 |
+
wav_len : torch.Tensor
|
127 |
+
The lengths of the inputs.
|
128 |
+
pad_idx : int
|
129 |
+
The index for <pad> token (default=0).
|
130 |
+
causal: bool
|
131 |
+
Whether strict causality shall be used. See `make_asr_src_mask`
|
132 |
+
dynchunktrain_config: DynChunkTrainConfig, optional
|
133 |
+
Dynamic Chunk Training configuration. See `make_asr_src_mask`
|
134 |
+
|
135 |
+
Returns
|
136 |
+
-------
|
137 |
+
src_key_padding_mask : torch.Tensor
|
138 |
+
Key padding mask for ignoring padding
|
139 |
+
tgt_key_padding_mask : torch.Tensor
|
140 |
+
Key padding mask for ignoring padding
|
141 |
+
src_mask : torch.Tensor
|
142 |
+
Mask for ignoring invalid (e.g. future) timesteps
|
143 |
+
tgt_mask : torch.Tensor
|
144 |
+
Mask for ignoring invalid (e.g. future) timesteps
|
145 |
+
"""
|
146 |
+
src_key_padding_mask = None
|
147 |
+
|
148 |
+
# mask out audio beyond the length of audio for each batch
|
149 |
+
if wav_len is not None:
|
150 |
+
abs_len = torch.round(wav_len * src.shape[1])
|
151 |
+
src_key_padding_mask = ~length_to_mask(abs_len).bool()
|
152 |
+
|
153 |
+
# mask out the source
|
154 |
+
src_mask = make_transformer_src_mask(
|
155 |
+
src, causal=causal, dynchunktrain_config=dynchunktrain_config
|
156 |
+
)
|
157 |
+
|
158 |
+
# If no decoder in the transformer...
|
159 |
+
if tgt is not None:
|
160 |
+
tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx)
|
161 |
+
tgt_mask = get_lookahead_mask(tgt)
|
162 |
+
else:
|
163 |
+
tgt_key_padding_mask = None
|
164 |
+
tgt_mask = None
|
165 |
+
|
166 |
+
return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
|
167 |
+
|
168 |
+
|
169 |
+
class TransformerASR(TransformerInterface):
|
170 |
+
"""This is an implementation of transformer model for ASR.
|
171 |
+
|
172 |
+
The architecture is based on the paper "Attention Is All You Need":
|
173 |
+
https://arxiv.org/pdf/1706.03762.pdf
|
174 |
+
|
175 |
+
Arguments
|
176 |
+
---------
|
177 |
+
tgt_vocab: int
|
178 |
+
Size of vocabulary.
|
179 |
+
input_size: int
|
180 |
+
Input feature size.
|
181 |
+
d_model : int, optional
|
182 |
+
Embedding dimension size.
|
183 |
+
(default=512).
|
184 |
+
nhead : int, optional
|
185 |
+
The number of heads in the multi-head attention models (default=8).
|
186 |
+
num_encoder_layers : int, optional
|
187 |
+
The number of sub-encoder-layers in the encoder (default=6).
|
188 |
+
num_decoder_layers : int, optional
|
189 |
+
The number of sub-decoder-layers in the decoder (default=6).
|
190 |
+
d_ffn : int, optional
|
191 |
+
The dimension of the feedforward network model (default=2048).
|
192 |
+
dropout : int, optional
|
193 |
+
The dropout value (default=0.1).
|
194 |
+
activation : torch.nn.Module, optional
|
195 |
+
The activation function of FFN layers.
|
196 |
+
Recommended: relu or gelu (default=relu).
|
197 |
+
positional_encoding: str, optional
|
198 |
+
Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings.
|
199 |
+
normalize_before: bool, optional
|
200 |
+
Whether normalization should be applied before or after MHA or FFN in Transformer layers.
|
201 |
+
Defaults to True as this was shown to lead to better performance and training stability.
|
202 |
+
kernel_size: int, optional
|
203 |
+
Kernel size in convolutional layers when Conformer is used.
|
204 |
+
bias: bool, optional
|
205 |
+
Whether to use bias in Conformer convolutional layers.
|
206 |
+
encoder_module: str, optional
|
207 |
+
Choose between Branchformer, Conformer, ConMamba, and Transformer for the encoder.
|
208 |
+
decoder_module: str, optional
|
209 |
+
Choose between Mamba and Transformer for the decoder.
|
210 |
+
decoder_module: str, optional
|
211 |
+
Choose between Transformer and Mamba for the decoder.
|
212 |
+
conformer_activation: torch.nn.Module, optional
|
213 |
+
Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.
|
214 |
+
branchformer_activation: torch.nn.Module, optional
|
215 |
+
Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.
|
216 |
+
attention_type: str, optional
|
217 |
+
Type of attention layer used in all Transformer or Conformer layers.
|
218 |
+
e.g. regularMHA or RelPosMHA.
|
219 |
+
max_length: int, optional
|
220 |
+
Max length for the target and source sequence in input.
|
221 |
+
Used for positional encodings.
|
222 |
+
causal: bool, optional
|
223 |
+
Whether the encoder should be causal or not (the decoder is always causal).
|
224 |
+
If causal the Conformer convolutional layer is causal.
|
225 |
+
csgu_linear_units: int, optional
|
226 |
+
Number of neurons in the hidden linear units of the CSGU Module.
|
227 |
+
-> Branchformer
|
228 |
+
gate_activation: torch.nn.Module, optional
|
229 |
+
Activation function used at the gate of the CSGU module.
|
230 |
+
-> Branchformer
|
231 |
+
use_linear_after_conv: bool, optional
|
232 |
+
If True, will apply a linear transformation of size input_size//2.
|
233 |
+
-> Branchformer
|
234 |
+
mamba_config: dict, optional
|
235 |
+
Mamba parameters if encoder_module or decoder_module is Mamba or ConMamba
|
236 |
+
|
237 |
+
Example
|
238 |
+
-------
|
239 |
+
>>> src = torch.rand([8, 120, 512])
|
240 |
+
>>> tgt = torch.randint(0, 720, [8, 120])
|
241 |
+
>>> net = TransformerASR(
|
242 |
+
... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
|
243 |
+
... )
|
244 |
+
>>> enc_out, dec_out = net.forward(src, tgt)
|
245 |
+
>>> enc_out.shape
|
246 |
+
torch.Size([8, 120, 512])
|
247 |
+
>>> dec_out.shape
|
248 |
+
torch.Size([8, 120, 512])
|
249 |
+
"""
|
250 |
+
|
251 |
+
def __init__(
|
252 |
+
self,
|
253 |
+
tgt_vocab,
|
254 |
+
input_size,
|
255 |
+
d_model=512,
|
256 |
+
nhead=8,
|
257 |
+
num_encoder_layers=6,
|
258 |
+
num_decoder_layers=6,
|
259 |
+
d_ffn=2048,
|
260 |
+
dropout=0.1,
|
261 |
+
activation=nn.ReLU,
|
262 |
+
positional_encoding="fixed_abs_sine",
|
263 |
+
normalize_before=False,
|
264 |
+
kernel_size: Optional[int] = 31,
|
265 |
+
bias: Optional[bool] = True,
|
266 |
+
encoder_module: Optional[str] = "transformer",
|
267 |
+
decoder_module: Optional[str] = "transformer",
|
268 |
+
conformer_activation: Optional[nn.Module] = Swish,
|
269 |
+
branchformer_activation: Optional[nn.Module] = nn.GELU,
|
270 |
+
attention_type: Optional[str] = "regularMHA",
|
271 |
+
max_length: Optional[int] = 2500,
|
272 |
+
causal: Optional[bool] = True,
|
273 |
+
csgu_linear_units: Optional[int] = 3072,
|
274 |
+
gate_activation: Optional[nn.Module] = nn.Identity,
|
275 |
+
use_linear_after_conv: Optional[bool] = False,
|
276 |
+
mamba_config=None
|
277 |
+
):
|
278 |
+
super().__init__(
|
279 |
+
d_model=d_model,
|
280 |
+
nhead=nhead,
|
281 |
+
num_encoder_layers=num_encoder_layers,
|
282 |
+
num_decoder_layers=num_decoder_layers,
|
283 |
+
d_ffn=d_ffn,
|
284 |
+
dropout=dropout,
|
285 |
+
activation=activation,
|
286 |
+
positional_encoding=positional_encoding,
|
287 |
+
normalize_before=normalize_before,
|
288 |
+
kernel_size=kernel_size,
|
289 |
+
bias=bias,
|
290 |
+
encoder_module=encoder_module,
|
291 |
+
decoder_module=decoder_module,
|
292 |
+
conformer_activation=conformer_activation,
|
293 |
+
branchformer_activation=branchformer_activation,
|
294 |
+
attention_type=attention_type,
|
295 |
+
max_length=max_length,
|
296 |
+
causal=causal,
|
297 |
+
csgu_linear_units=csgu_linear_units,
|
298 |
+
gate_activation=gate_activation,
|
299 |
+
use_linear_after_conv=use_linear_after_conv,
|
300 |
+
mamba_config=mamba_config
|
301 |
+
)
|
302 |
+
|
303 |
+
self.custom_src_module = ModuleList(
|
304 |
+
Linear(
|
305 |
+
input_size=input_size,
|
306 |
+
n_neurons=d_model,
|
307 |
+
bias=True,
|
308 |
+
combine_dims=False,
|
309 |
+
),
|
310 |
+
torch.nn.Dropout(dropout),
|
311 |
+
)
|
312 |
+
|
313 |
+
self.num_decoder_layers = num_decoder_layers
|
314 |
+
if num_decoder_layers > 0:
|
315 |
+
self.custom_tgt_module = ModuleList(
|
316 |
+
NormalizedEmbedding(d_model, tgt_vocab)
|
317 |
+
)
|
318 |
+
|
319 |
+
# reset parameters using xavier_normal_
|
320 |
+
self._init_params()
|
321 |
+
|
322 |
+
def forward(self, src, tgt, wav_len=None, pad_idx=0):
|
323 |
+
"""
|
324 |
+
Arguments
|
325 |
+
----------
|
326 |
+
src : torch.Tensor
|
327 |
+
The sequence to the encoder.
|
328 |
+
tgt : torch.Tensor
|
329 |
+
The sequence to the decoder.
|
330 |
+
wav_len: torch.Tensor, optional
|
331 |
+
Torch Tensor of shape (batch, ) containing the relative length to padded length for each example.
|
332 |
+
pad_idx : int, optional
|
333 |
+
The index for <pad> token (default=0).
|
334 |
+
"""
|
335 |
+
|
336 |
+
# reshape the src vector to [Batch, Time, Fea] is a 4d vector is given
|
337 |
+
if src.ndim == 4:
|
338 |
+
bz, t, ch1, ch2 = src.shape
|
339 |
+
src = src.reshape(bz, t, ch1 * ch2)
|
340 |
+
|
341 |
+
(
|
342 |
+
src_key_padding_mask,
|
343 |
+
tgt_key_padding_mask,
|
344 |
+
src_mask,
|
345 |
+
tgt_mask,
|
346 |
+
) = make_transformer_src_tgt_masks(
|
347 |
+
src, tgt, wav_len, causal=self.causal, pad_idx=pad_idx
|
348 |
+
)
|
349 |
+
|
350 |
+
src = self.custom_src_module(src)
|
351 |
+
# add pos encoding to queries if are sinusoidal ones else
|
352 |
+
if self.attention_type == "hypermixing":
|
353 |
+
pos_embs_encoder = None
|
354 |
+
elif self.attention_type == "RelPosMHAXL":
|
355 |
+
pos_embs_encoder = self.positional_encoding(src)
|
356 |
+
elif self.positional_encoding_type == "fixed_abs_sine":
|
357 |
+
src = src + self.positional_encoding(src) # add the encodings here
|
358 |
+
pos_embs_encoder = None
|
359 |
+
|
360 |
+
encoder_out, _ = self.encoder(
|
361 |
+
src=src,
|
362 |
+
src_mask=src_mask,
|
363 |
+
src_key_padding_mask=src_key_padding_mask,
|
364 |
+
pos_embs=pos_embs_encoder,
|
365 |
+
)
|
366 |
+
|
367 |
+
if self.num_decoder_layers > 0:
|
368 |
+
tgt = self.custom_tgt_module(tgt)
|
369 |
+
|
370 |
+
if self.attention_type == "RelPosMHAXL":
|
371 |
+
tgt = tgt + self.positional_encoding_decoder(tgt)
|
372 |
+
pos_embs_encoder = None # self.positional_encoding(src)
|
373 |
+
pos_embs_target = None
|
374 |
+
elif (
|
375 |
+
self.positional_encoding_type == "fixed_abs_sine"
|
376 |
+
or self.attention_type == "hypermixing"
|
377 |
+
):
|
378 |
+
tgt = tgt + self.positional_encoding(tgt)
|
379 |
+
pos_embs_target = None
|
380 |
+
pos_embs_encoder = None
|
381 |
+
|
382 |
+
decoder_out, _, _ = self.decoder(
|
383 |
+
tgt=tgt,
|
384 |
+
memory=encoder_out,
|
385 |
+
memory_mask=None,
|
386 |
+
tgt_mask=tgt_mask,
|
387 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
388 |
+
memory_key_padding_mask=src_key_padding_mask,
|
389 |
+
pos_embs_tgt=pos_embs_target,
|
390 |
+
pos_embs_src=pos_embs_encoder,
|
391 |
+
)
|
392 |
+
|
393 |
+
else:
|
394 |
+
decoder_out = None
|
395 |
+
|
396 |
+
return encoder_out, decoder_out
|
397 |
+
|
398 |
+
@torch.no_grad()
|
399 |
+
def decode(self, tgt, encoder_out, enc_len=None):
|
400 |
+
"""This method implements a decoding step for the transformer model.
|
401 |
+
|
402 |
+
Arguments
|
403 |
+
---------
|
404 |
+
tgt : torch.Tensor
|
405 |
+
The sequence to the decoder.
|
406 |
+
encoder_out : torch.Tensor
|
407 |
+
Hidden output of the encoder.
|
408 |
+
enc_len : torch.LongTensor
|
409 |
+
The actual length of encoder states.
|
410 |
+
|
411 |
+
Returns
|
412 |
+
-------
|
413 |
+
prediction
|
414 |
+
"""
|
415 |
+
tgt_mask = get_lookahead_mask(tgt)
|
416 |
+
src_key_padding_mask = None
|
417 |
+
if enc_len is not None:
|
418 |
+
src_key_padding_mask = (1 - length_to_mask(enc_len)).bool()
|
419 |
+
|
420 |
+
if self.num_decoder_layers > 0:
|
421 |
+
tgt = self.custom_tgt_module(tgt)
|
422 |
+
if self.attention_type == "RelPosMHAXL":
|
423 |
+
tgt = tgt + self.positional_encoding_decoder(tgt)
|
424 |
+
pos_embs_encoder = None # self.positional_encoding(src)
|
425 |
+
pos_embs_target = None
|
426 |
+
elif (
|
427 |
+
self.positional_encoding_type == "fixed_abs_sine"
|
428 |
+
or self.attention_type == "hypermixing"
|
429 |
+
):
|
430 |
+
tgt = tgt + self.positional_encoding(tgt) # add the encodings here
|
431 |
+
pos_embs_target = None
|
432 |
+
pos_embs_encoder = None
|
433 |
+
|
434 |
+
|
435 |
+
prediction, self_attns, multihead_attns = self.decoder(
|
436 |
+
tgt,
|
437 |
+
encoder_out,
|
438 |
+
tgt_mask=tgt_mask,
|
439 |
+
memory_key_padding_mask=src_key_padding_mask,
|
440 |
+
pos_embs_tgt=pos_embs_target,
|
441 |
+
pos_embs_src=pos_embs_encoder,
|
442 |
+
)
|
443 |
+
return prediction, multihead_attns[-1]
|
444 |
+
|
445 |
+
def encode(
|
446 |
+
self,
|
447 |
+
src,
|
448 |
+
wav_len=None,
|
449 |
+
pad_idx=0,
|
450 |
+
dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
|
451 |
+
):
|
452 |
+
"""
|
453 |
+
Encoder forward pass
|
454 |
+
|
455 |
+
Arguments
|
456 |
+
---------
|
457 |
+
src : torch.Tensor
|
458 |
+
The sequence to the encoder.
|
459 |
+
wav_len : torch.Tensor, optional
|
460 |
+
Torch Tensor of shape (batch, ) containing the relative length to padded length for each example.
|
461 |
+
pad_idx : int
|
462 |
+
The index used for padding.
|
463 |
+
dynchunktrain_config : DynChunkTrainConfig
|
464 |
+
Dynamic chunking config.
|
465 |
+
|
466 |
+
Returns
|
467 |
+
-------
|
468 |
+
encoder_out : torch.Tensor
|
469 |
+
"""
|
470 |
+
# reshape the src vector to [Batch, Time, Fea] if a 4d vector is given
|
471 |
+
if src.dim() == 4:
|
472 |
+
bz, t, ch1, ch2 = src.shape
|
473 |
+
src = src.reshape(bz, t, ch1 * ch2)
|
474 |
+
|
475 |
+
(
|
476 |
+
src_key_padding_mask,
|
477 |
+
_,
|
478 |
+
src_mask,
|
479 |
+
_,
|
480 |
+
) = make_transformer_src_tgt_masks(
|
481 |
+
src,
|
482 |
+
None,
|
483 |
+
wav_len,
|
484 |
+
pad_idx=pad_idx,
|
485 |
+
causal=self.causal,
|
486 |
+
dynchunktrain_config=dynchunktrain_config,
|
487 |
+
)
|
488 |
+
|
489 |
+
src = self.custom_src_module(src)
|
490 |
+
if self.attention_type == "hypermixing":
|
491 |
+
pos_embs_source = None
|
492 |
+
elif self.attention_type == "RelPosMHAXL":
|
493 |
+
pos_embs_source = self.positional_encoding(src)
|
494 |
+
elif self.positional_encoding_type == "fixed_abs_sine":
|
495 |
+
src = src + self.positional_encoding(src)
|
496 |
+
pos_embs_source = None
|
497 |
+
|
498 |
+
encoder_out, _ = self.encoder(
|
499 |
+
src=src,
|
500 |
+
src_mask=src_mask,
|
501 |
+
src_key_padding_mask=src_key_padding_mask,
|
502 |
+
pos_embs=pos_embs_source,
|
503 |
+
dynchunktrain_config=dynchunktrain_config,
|
504 |
+
)
|
505 |
+
|
506 |
+
return encoder_out
|
507 |
+
|
508 |
+
def encode_streaming(self, src, context: TransformerASRStreamingContext):
|
509 |
+
"""
|
510 |
+
Streaming encoder forward pass
|
511 |
+
|
512 |
+
Arguments
|
513 |
+
---------
|
514 |
+
src : torch.Tensor
|
515 |
+
The sequence (chunk) to the encoder.
|
516 |
+
context : TransformerASRStreamingContext
|
517 |
+
Mutable reference to the streaming context. This holds the state
|
518 |
+
needed to persist across chunk inferences and can be built using
|
519 |
+
`make_streaming_context`. This will get mutated by this function.
|
520 |
+
|
521 |
+
Returns
|
522 |
+
-------
|
523 |
+
Encoder output for this chunk.
|
524 |
+
|
525 |
+
Example
|
526 |
+
-------
|
527 |
+
>>> import torch
|
528 |
+
>>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
|
529 |
+
>>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
|
530 |
+
>>> net = TransformerASR(
|
531 |
+
... tgt_vocab=100,
|
532 |
+
... input_size=64,
|
533 |
+
... d_model=64,
|
534 |
+
... nhead=8,
|
535 |
+
... num_encoder_layers=1,
|
536 |
+
... num_decoder_layers=0,
|
537 |
+
... d_ffn=128,
|
538 |
+
... attention_type="RelPosMHAXL",
|
539 |
+
... positional_encoding=None,
|
540 |
+
... encoder_module="conformer",
|
541 |
+
... normalize_before=True,
|
542 |
+
... causal=False,
|
543 |
+
... )
|
544 |
+
>>> ctx = net.make_streaming_context(DynChunkTrainConfig(16, 1))
|
545 |
+
>>> src1 = torch.rand([8, 16, 64])
|
546 |
+
>>> src2 = torch.rand([8, 16, 64])
|
547 |
+
>>> out1 = net.encode_streaming(src1, ctx)
|
548 |
+
>>> out1.shape
|
549 |
+
torch.Size([8, 16, 64])
|
550 |
+
>>> ctx.encoder_context.layers[0].mha_left_context.shape
|
551 |
+
torch.Size([8, 16, 64])
|
552 |
+
>>> out2 = net.encode_streaming(src2, ctx)
|
553 |
+
>>> out2.shape
|
554 |
+
torch.Size([8, 16, 64])
|
555 |
+
>>> ctx.encoder_context.layers[0].mha_left_context.shape
|
556 |
+
torch.Size([8, 16, 64])
|
557 |
+
>>> combined_out = torch.concat((out1, out2), dim=1)
|
558 |
+
>>> combined_out.shape
|
559 |
+
torch.Size([8, 32, 64])
|
560 |
+
"""
|
561 |
+
|
562 |
+
if src.dim() == 4:
|
563 |
+
bz, t, ch1, ch2 = src.shape
|
564 |
+
src = src.reshape(bz, t, ch1 * ch2)
|
565 |
+
|
566 |
+
# HACK: our problem here is that the positional_encoding is computed
|
567 |
+
# against the size of our source tensor, but we only know how many left
|
568 |
+
# context frames we're injecting to the encoder within the encoder
|
569 |
+
# context.
|
570 |
+
# so this workaround does just that.
|
571 |
+
#
|
572 |
+
# i'm not sure how this would be best refactored, but an option would be
|
573 |
+
# to let the encoder get the pos embedding itself and have a way to
|
574 |
+
# cache it.
|
575 |
+
#
|
576 |
+
# additionally, positional encoding functions take in a whole source
|
577 |
+
# tensor just to get its attributes (size, device, type) but this is
|
578 |
+
# sort of silly for the embeddings that don't need one.
|
579 |
+
# so we craft a dummy empty (uninitialized) tensor to help...
|
580 |
+
known_left_context = context.encoder_context.layers[0].mha_left_context
|
581 |
+
if known_left_context is None:
|
582 |
+
pos_encoding_dummy = src
|
583 |
+
else:
|
584 |
+
target_shape = list(src.shape)
|
585 |
+
target_shape[-2] += known_left_context.shape[-2]
|
586 |
+
pos_encoding_dummy = torch.empty(size=target_shape).to(src)
|
587 |
+
|
588 |
+
src = self.custom_src_module(src)
|
589 |
+
if self.attention_type == "RelPosMHAXL":
|
590 |
+
pos_embs_source = self.positional_encoding(pos_encoding_dummy)
|
591 |
+
|
592 |
+
elif self.positional_encoding_type == "fixed_abs_sine":
|
593 |
+
src = src + self.positional_encoding(pos_encoding_dummy)
|
594 |
+
pos_embs_source = None
|
595 |
+
|
596 |
+
encoder_out, _ = self.encoder.forward_streaming(
|
597 |
+
src=src, pos_embs=pos_embs_source, context=context.encoder_context
|
598 |
+
)
|
599 |
+
return encoder_out
|
600 |
+
|
601 |
+
def make_streaming_context(
|
602 |
+
self, dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={}
|
603 |
+
):
|
604 |
+
"""Creates a blank streaming context for this transformer and its
|
605 |
+
encoder.
|
606 |
+
|
607 |
+
Arguments
|
608 |
+
---------
|
609 |
+
dynchunktrain_config : DynChunkTrainConfig
|
610 |
+
Runtime chunkwise attention configuration.
|
611 |
+
encoder_kwargs : dict
|
612 |
+
Parameters to be forward to the encoder's `make_streaming_context`.
|
613 |
+
Metadata required for the encoder could differ depending on the
|
614 |
+
encoder.
|
615 |
+
|
616 |
+
Returns
|
617 |
+
-------
|
618 |
+
TransformerASRStreamingContext
|
619 |
+
"""
|
620 |
+
return TransformerASRStreamingContext(
|
621 |
+
dynchunktrain_config=dynchunktrain_config,
|
622 |
+
encoder_context=self.encoder.make_streaming_context(
|
623 |
+
dynchunktrain_config,
|
624 |
+
**encoder_kwargs,
|
625 |
+
),
|
626 |
+
)
|
627 |
+
|
628 |
+
def _init_params(self):
|
629 |
+
for p in self.parameters():
|
630 |
+
if p.dim() > 1:
|
631 |
+
torch.nn.init.xavier_normal_(p)
|
632 |
+
|
633 |
+
|
634 |
+
class EncoderWrapper(nn.Module):
|
635 |
+
"""This is a wrapper of any ASR transformer encoder. By default, the
|
636 |
+
TransformerASR .forward() function encodes and decodes. With this wrapper
|
637 |
+
the .forward() function becomes .encode() only.
|
638 |
+
|
639 |
+
Important: The TransformerASR class must contain a .encode() function.
|
640 |
+
|
641 |
+
Arguments
|
642 |
+
---------
|
643 |
+
transformer : sb.lobes.models.TransformerInterface
|
644 |
+
A Transformer instance that contains a .encode() function.
|
645 |
+
*args : tuple
|
646 |
+
**kwargs : dict
|
647 |
+
Arguments to forward to parent class.
|
648 |
+
|
649 |
+
Example
|
650 |
+
-------
|
651 |
+
>>> src = torch.rand([8, 120, 512])
|
652 |
+
>>> tgt = torch.randint(0, 720, [8, 120])
|
653 |
+
>>> net = TransformerASR(
|
654 |
+
... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
|
655 |
+
... )
|
656 |
+
>>> encoder = EncoderWrapper(net)
|
657 |
+
>>> enc_out = encoder(src)
|
658 |
+
>>> enc_out.shape
|
659 |
+
torch.Size([8, 120, 512])
|
660 |
+
"""
|
661 |
+
|
662 |
+
def __init__(self, transformer, *args, **kwargs):
|
663 |
+
super().__init__(*args, **kwargs)
|
664 |
+
self.transformer = transformer
|
665 |
+
self.make_streaming_context = self.transformer.make_streaming_context
|
666 |
+
|
667 |
+
def forward(self, x, wav_lens=None, pad_idx=0, **kwargs):
|
668 |
+
"""Processes the input tensor x and returns an output tensor."""
|
669 |
+
x = self.transformer.encode(x, wav_lens, pad_idx, **kwargs)
|
670 |
+
return x
|
671 |
+
|
672 |
+
def forward_streaming(self, x, context):
|
673 |
+
"""Processes the input audio chunk tensor `x`, using and updating the
|
674 |
+
mutable encoder `context`"""
|
675 |
+
x = self.transformer.encode_streaming(x, context)
|
676 |
+
return x
|
677 |
+
|
678 |
+
def make_streaming_context(self, *args, **kwargs):
|
679 |
+
"""Initializes a streaming context. Forwards all arguments to the
|
680 |
+
underlying transformer. See :meth:`speechbrain.lobes.models.transformer.TransformerASR.make_streaming_context`.
|
681 |
+
"""
|
682 |
+
return self.transformer.make_streaming_context(*args, **kwargs)
|
model/modules/__init__.py
ADDED
File without changes
|
model/modules/mamba/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model/modules/mamba/__init__.py
ADDED
File without changes
|
model/modules/mamba/bimamba.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copied and modified from
|
3 |
+
https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
7 |
+
|
8 |
+
import math
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
|
18 |
+
try:
|
19 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
20 |
+
except ImportError:
|
21 |
+
causal_conv1d_fn, causal_conv1d_update = None
|
22 |
+
|
23 |
+
try:
|
24 |
+
from .selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
|
25 |
+
except ImportError:
|
26 |
+
selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None
|
27 |
+
|
28 |
+
try:
|
29 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
30 |
+
except ImportError:
|
31 |
+
selective_state_update = None
|
32 |
+
|
33 |
+
try:
|
34 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
35 |
+
except ImportError:
|
36 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
37 |
+
|
38 |
+
|
39 |
+
class Mamba(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
d_model,
|
43 |
+
d_state=16,
|
44 |
+
d_conv=4,
|
45 |
+
expand=2,
|
46 |
+
dt_rank="auto",
|
47 |
+
dt_min=0.001,
|
48 |
+
dt_max=0.1,
|
49 |
+
dt_init="random",
|
50 |
+
dt_scale=1.0,
|
51 |
+
dt_init_floor=1e-4,
|
52 |
+
conv_bias=True,
|
53 |
+
bias=False,
|
54 |
+
use_fast_path=True, # Fused kernel options
|
55 |
+
layer_idx=None,
|
56 |
+
device=None,
|
57 |
+
dtype=None,
|
58 |
+
bimamba_type="none",
|
59 |
+
if_devide_out=True, # False
|
60 |
+
init_layer_scale=None,
|
61 |
+
):
|
62 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
63 |
+
super().__init__()
|
64 |
+
self.d_model = d_model
|
65 |
+
self.d_state = d_state
|
66 |
+
self.d_conv = d_conv
|
67 |
+
self.expand = expand
|
68 |
+
self.d_inner = int(self.expand * self.d_model)
|
69 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
70 |
+
self.use_fast_path = use_fast_path
|
71 |
+
self.layer_idx = layer_idx
|
72 |
+
self.bimamba_type = bimamba_type
|
73 |
+
self.if_devide_out = if_devide_out
|
74 |
+
|
75 |
+
assert bimamba_type == 'v2'
|
76 |
+
|
77 |
+
self.init_layer_scale = init_layer_scale
|
78 |
+
if init_layer_scale is not None:
|
79 |
+
self.gamma = nn.Parameter(init_layer_scale * torch.ones((d_model)), requires_grad=True)
|
80 |
+
|
81 |
+
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
82 |
+
|
83 |
+
self.conv1d = nn.Conv1d(
|
84 |
+
in_channels=self.d_inner,
|
85 |
+
out_channels=self.d_inner,
|
86 |
+
bias=conv_bias,
|
87 |
+
kernel_size=d_conv,
|
88 |
+
groups=self.d_inner,
|
89 |
+
padding=d_conv - 1,
|
90 |
+
**factory_kwargs,
|
91 |
+
)
|
92 |
+
|
93 |
+
self.activation = "silu"
|
94 |
+
self.act = nn.SiLU()
|
95 |
+
|
96 |
+
self.x_proj = nn.Linear(
|
97 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
98 |
+
)
|
99 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
100 |
+
|
101 |
+
# Initialize special dt projection to preserve variance at initialization
|
102 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
103 |
+
if dt_init == "constant":
|
104 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
105 |
+
elif dt_init == "random":
|
106 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
107 |
+
else:
|
108 |
+
raise NotImplementedError
|
109 |
+
|
110 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
111 |
+
dt = torch.exp(
|
112 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
113 |
+
+ math.log(dt_min)
|
114 |
+
).clamp(min=dt_init_floor)
|
115 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
116 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
117 |
+
with torch.no_grad():
|
118 |
+
self.dt_proj.bias.copy_(inv_dt)
|
119 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
120 |
+
self.dt_proj.bias._no_reinit = True
|
121 |
+
|
122 |
+
# S4D real initialization
|
123 |
+
A = repeat(
|
124 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
125 |
+
"n -> d n",
|
126 |
+
d=self.d_inner,
|
127 |
+
).contiguous()
|
128 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
129 |
+
self.A_log = nn.Parameter(A_log)
|
130 |
+
self.A_log._no_weight_decay = True
|
131 |
+
|
132 |
+
# D "skip" parameter
|
133 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
134 |
+
self.D._no_weight_decay = True
|
135 |
+
|
136 |
+
# bidirectional
|
137 |
+
if bimamba_type == "v1":
|
138 |
+
A_b = repeat(
|
139 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
140 |
+
"n -> d n",
|
141 |
+
d=self.d_inner,
|
142 |
+
).contiguous()
|
143 |
+
A_b_log = torch.log(A_b) # Keep A_b_log in fp32
|
144 |
+
self.A_b_log = nn.Parameter(A_b_log)
|
145 |
+
self.A_b_log._no_weight_decay = True
|
146 |
+
elif bimamba_type == "v2":
|
147 |
+
A_b = repeat(
|
148 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
149 |
+
"n -> d n",
|
150 |
+
d=self.d_inner,
|
151 |
+
).contiguous()
|
152 |
+
A_b_log = torch.log(A_b) # Keep A_b_log in fp32
|
153 |
+
self.A_b_log = nn.Parameter(A_b_log)
|
154 |
+
self.A_b_log._no_weight_decay = True
|
155 |
+
|
156 |
+
self.conv1d_b = nn.Conv1d(
|
157 |
+
in_channels=self.d_inner,
|
158 |
+
out_channels=self.d_inner,
|
159 |
+
bias=conv_bias,
|
160 |
+
kernel_size=d_conv,
|
161 |
+
groups=self.d_inner,
|
162 |
+
padding=d_conv - 1,
|
163 |
+
**factory_kwargs,
|
164 |
+
)
|
165 |
+
|
166 |
+
self.x_proj_b = nn.Linear(
|
167 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
168 |
+
)
|
169 |
+
self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
|
170 |
+
|
171 |
+
self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
172 |
+
self.D_b._no_weight_decay = True
|
173 |
+
|
174 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
175 |
+
|
176 |
+
def forward(self, hidden_states, inference_params=None):
|
177 |
+
"""
|
178 |
+
hidden_states: (B, L, D)
|
179 |
+
Returns: same shape as hidden_states
|
180 |
+
"""
|
181 |
+
batch, seqlen, dim = hidden_states.shape
|
182 |
+
conv_state, ssm_state = None, None
|
183 |
+
|
184 |
+
if inference_params is not None:
|
185 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
186 |
+
if inference_params.seqlen_offset > 0:
|
187 |
+
# The states are updated inplace
|
188 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
189 |
+
return out
|
190 |
+
|
191 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
192 |
+
xz = rearrange(
|
193 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
194 |
+
"d (b l) -> b d l",
|
195 |
+
l=seqlen,
|
196 |
+
)
|
197 |
+
if self.in_proj.bias is not None:
|
198 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
199 |
+
|
200 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
201 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
202 |
+
if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
|
203 |
+
if self.bimamba_type == "v1":
|
204 |
+
A_b = -torch.exp(self.A_b_log.float())
|
205 |
+
out = bimamba_inner_fn(
|
206 |
+
xz,
|
207 |
+
self.conv1d.weight,
|
208 |
+
self.conv1d.bias,
|
209 |
+
self.x_proj.weight,
|
210 |
+
self.dt_proj.weight,
|
211 |
+
self.out_proj.weight,
|
212 |
+
self.out_proj.bias,
|
213 |
+
A,
|
214 |
+
A_b,
|
215 |
+
None, # input-dependent B
|
216 |
+
None, # input-dependent C
|
217 |
+
self.D.float(),
|
218 |
+
delta_bias=self.dt_proj.bias.float(),
|
219 |
+
delta_softplus=True,
|
220 |
+
)
|
221 |
+
elif self.bimamba_type == "v2":
|
222 |
+
A_b = -torch.exp(self.A_b_log.float())
|
223 |
+
out = mamba_inner_fn_no_out_proj(
|
224 |
+
xz,
|
225 |
+
self.conv1d.weight,
|
226 |
+
self.conv1d.bias,
|
227 |
+
self.x_proj.weight,
|
228 |
+
self.dt_proj.weight,
|
229 |
+
A,
|
230 |
+
None, # input-dependent B
|
231 |
+
None, # input-dependent C
|
232 |
+
self.D.float(),
|
233 |
+
delta_bias=self.dt_proj.bias.float(),
|
234 |
+
delta_softplus=True,
|
235 |
+
)
|
236 |
+
out_b = mamba_inner_fn_no_out_proj(
|
237 |
+
xz.flip([-1]),
|
238 |
+
self.conv1d_b.weight,
|
239 |
+
self.conv1d_b.bias,
|
240 |
+
self.x_proj_b.weight,
|
241 |
+
self.dt_proj_b.weight,
|
242 |
+
A_b,
|
243 |
+
None,
|
244 |
+
None,
|
245 |
+
self.D_b.float(),
|
246 |
+
delta_bias=self.dt_proj_b.bias.float(),
|
247 |
+
delta_softplus=True,
|
248 |
+
)
|
249 |
+
|
250 |
+
if not self.if_devide_out:
|
251 |
+
out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
|
252 |
+
else:
|
253 |
+
out = F.linear(rearrange(0.5*out + 0.5*out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
|
254 |
+
|
255 |
+
else:
|
256 |
+
out = mamba_inner_fn(
|
257 |
+
xz,
|
258 |
+
self.conv1d.weight,
|
259 |
+
self.conv1d.bias,
|
260 |
+
self.x_proj.weight,
|
261 |
+
self.dt_proj.weight,
|
262 |
+
self.out_proj.weight,
|
263 |
+
self.out_proj.bias,
|
264 |
+
A,
|
265 |
+
None, # input-dependent B
|
266 |
+
None, # input-dependent C
|
267 |
+
self.D.float(),
|
268 |
+
delta_bias=self.dt_proj.bias.float(),
|
269 |
+
delta_softplus=True,
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
x, z = xz.chunk(2, dim=1)
|
273 |
+
# Compute short convolution
|
274 |
+
if conv_state is not None:
|
275 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
276 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
277 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
|
278 |
+
if causal_conv1d_fn is None:
|
279 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
280 |
+
else:
|
281 |
+
assert self.activation in ["silu", "swish"]
|
282 |
+
x = causal_conv1d_fn(
|
283 |
+
x=x,
|
284 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
285 |
+
bias=self.conv1d.bias,
|
286 |
+
activation=self.activation,
|
287 |
+
)
|
288 |
+
|
289 |
+
# We're careful here about the layout, to avoid extra transposes.
|
290 |
+
# We want dt to have d as the slowest moving dimension
|
291 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
292 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
293 |
+
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
294 |
+
dt = self.dt_proj.weight @ dt.t()
|
295 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
296 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
297 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
298 |
+
assert self.activation in ["silu", "swish"]
|
299 |
+
y = selective_scan_fn(
|
300 |
+
x,
|
301 |
+
dt,
|
302 |
+
A,
|
303 |
+
B,
|
304 |
+
C,
|
305 |
+
self.D.float(),
|
306 |
+
z=z,
|
307 |
+
delta_bias=self.dt_proj.bias.float(),
|
308 |
+
delta_softplus=True,
|
309 |
+
return_last_state=ssm_state is not None,
|
310 |
+
)
|
311 |
+
if ssm_state is not None:
|
312 |
+
y, last_state = y
|
313 |
+
ssm_state.copy_(last_state)
|
314 |
+
y = rearrange(y, "b d l -> b l d")
|
315 |
+
out = self.out_proj(y)
|
316 |
+
if self.init_layer_scale is not None:
|
317 |
+
out = out * self.gamma
|
318 |
+
return out
|
319 |
+
|
320 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
321 |
+
dtype = hidden_states.dtype
|
322 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
323 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
324 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
325 |
+
|
326 |
+
# Conv step
|
327 |
+
if causal_conv1d_update is None:
|
328 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
329 |
+
conv_state[:, :, -1] = x
|
330 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
331 |
+
if self.conv1d.bias is not None:
|
332 |
+
x = x + self.conv1d.bias
|
333 |
+
x = self.act(x).to(dtype=dtype)
|
334 |
+
else:
|
335 |
+
x = causal_conv1d_update(
|
336 |
+
x,
|
337 |
+
conv_state,
|
338 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
339 |
+
self.conv1d.bias,
|
340 |
+
self.activation,
|
341 |
+
)
|
342 |
+
|
343 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
344 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
345 |
+
# Don't add dt_bias here
|
346 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
347 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
348 |
+
|
349 |
+
# SSM step
|
350 |
+
if selective_state_update is None:
|
351 |
+
# Discretize A and B
|
352 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
353 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
354 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
355 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
356 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
357 |
+
y = y + self.D.to(dtype) * x
|
358 |
+
y = y * self.act(z) # (B D)
|
359 |
+
else:
|
360 |
+
y = selective_state_update(
|
361 |
+
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
362 |
+
)
|
363 |
+
|
364 |
+
out = self.out_proj(y)
|
365 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
366 |
+
|
367 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
368 |
+
device = self.out_proj.weight.device
|
369 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
370 |
+
conv_state = torch.zeros(
|
371 |
+
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
|
372 |
+
)
|
373 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
374 |
+
# ssm_dtype = torch.float32
|
375 |
+
ssm_state = torch.zeros(
|
376 |
+
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
|
377 |
+
)
|
378 |
+
return conv_state, ssm_state
|
379 |
+
|
380 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
381 |
+
assert self.layer_idx is not None
|
382 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
383 |
+
batch_shape = (batch_size,)
|
384 |
+
conv_state = torch.zeros(
|
385 |
+
batch_size,
|
386 |
+
self.d_model * self.expand,
|
387 |
+
self.d_conv,
|
388 |
+
device=self.conv1d.weight.device,
|
389 |
+
dtype=self.conv1d.weight.dtype,
|
390 |
+
)
|
391 |
+
ssm_state = torch.zeros(
|
392 |
+
batch_size,
|
393 |
+
self.d_model * self.expand,
|
394 |
+
self.d_state,
|
395 |
+
device=self.dt_proj.weight.device,
|
396 |
+
dtype=self.dt_proj.weight.dtype,
|
397 |
+
# dtype=torch.float32,
|
398 |
+
)
|
399 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
|
400 |
+
else:
|
401 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
402 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
403 |
+
if initialize_states:
|
404 |
+
conv_state.zero_()
|
405 |
+
ssm_state.zero_()
|
406 |
+
return conv_state, ssm_state
|
407 |
+
|
408 |
+
|
409 |
+
class Block(nn.Module):
|
410 |
+
def __init__(
|
411 |
+
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
412 |
+
):
|
413 |
+
"""
|
414 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
415 |
+
|
416 |
+
This Block has a slightly different structure compared to a regular
|
417 |
+
prenorm Transformer block.
|
418 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
419 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
420 |
+
Here we have: Add -> LN -> Mixer, returning both
|
421 |
+
the hidden_states (output of the mixer) and the residual.
|
422 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
423 |
+
The residual needs to be provided (except for the very first block).
|
424 |
+
"""
|
425 |
+
super().__init__()
|
426 |
+
self.residual_in_fp32 = residual_in_fp32
|
427 |
+
self.fused_add_norm = fused_add_norm
|
428 |
+
self.mixer = mixer_cls(dim)
|
429 |
+
self.norm = norm_cls(dim)
|
430 |
+
if self.fused_add_norm:
|
431 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
432 |
+
assert isinstance(
|
433 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
434 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
435 |
+
|
436 |
+
def forward(
|
437 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
438 |
+
):
|
439 |
+
r"""Pass the input through the encoder layer.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
hidden_states: the sequence to the encoder layer (required).
|
443 |
+
residual: hidden_states = Mixer(LN(residual))
|
444 |
+
"""
|
445 |
+
if not self.fused_add_norm:
|
446 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
447 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
448 |
+
if self.residual_in_fp32:
|
449 |
+
residual = residual.to(torch.float32)
|
450 |
+
else:
|
451 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
452 |
+
hidden_states, residual = fused_add_norm_fn(
|
453 |
+
hidden_states,
|
454 |
+
self.norm.weight,
|
455 |
+
self.norm.bias,
|
456 |
+
residual=residual,
|
457 |
+
prenorm=True,
|
458 |
+
residual_in_fp32=self.residual_in_fp32,
|
459 |
+
eps=self.norm.eps,
|
460 |
+
)
|
461 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
462 |
+
return hidden_states, residual
|
463 |
+
|
464 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
465 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
model/modules/mamba/mamba_blocks.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copied and modified from
|
3 |
+
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
|
12 |
+
from mamba_ssm import Mamba
|
13 |
+
from modules.mamba.bimamba import Mamba as BiMamba
|
14 |
+
from modules.mamba.bimamba import Block as PreNormBlock
|
15 |
+
|
16 |
+
try:
|
17 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
18 |
+
except ImportError:
|
19 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
20 |
+
|
21 |
+
|
22 |
+
def create_block(
|
23 |
+
d_model,
|
24 |
+
ssm_cls=None,
|
25 |
+
ssm_cfg=None,
|
26 |
+
norm_epsilon=1e-5,
|
27 |
+
rms_norm=False,
|
28 |
+
residual_in_fp32=False,
|
29 |
+
fused_add_norm=True,
|
30 |
+
layer_idx=None,
|
31 |
+
device=None,
|
32 |
+
dtype=None,
|
33 |
+
):
|
34 |
+
if ssm_cfg is None:
|
35 |
+
ssm_cfg = {}
|
36 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
37 |
+
mixer_cls = partial(ssm_cls, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
38 |
+
norm_cls = partial(
|
39 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
40 |
+
)
|
41 |
+
block = PreNormBlock(
|
42 |
+
d_model,
|
43 |
+
mixer_cls,
|
44 |
+
norm_cls=norm_cls,
|
45 |
+
fused_add_norm=fused_add_norm,
|
46 |
+
residual_in_fp32=residual_in_fp32,
|
47 |
+
)
|
48 |
+
block.layer_idx = layer_idx
|
49 |
+
return block
|
50 |
+
|
51 |
+
|
52 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
53 |
+
def _init_weights(
|
54 |
+
module,
|
55 |
+
n_layer,
|
56 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
57 |
+
rescale_prenorm_residual=True,
|
58 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
59 |
+
):
|
60 |
+
if isinstance(module, nn.Linear):
|
61 |
+
if module.bias is not None:
|
62 |
+
if not getattr(module.bias, "_no_reinit", False):
|
63 |
+
nn.init.zeros_(module.bias)
|
64 |
+
elif isinstance(module, nn.Embedding):
|
65 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
66 |
+
|
67 |
+
if rescale_prenorm_residual:
|
68 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
69 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
70 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
71 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
72 |
+
#
|
73 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
74 |
+
for name, p in module.named_parameters():
|
75 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
76 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
77 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
78 |
+
# We need to reinit p since this code could be called multiple times
|
79 |
+
# Having just p *= scale would repeatedly scale it down
|
80 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
81 |
+
with torch.no_grad():
|
82 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
83 |
+
|
84 |
+
|
85 |
+
class LnMambaAdd(nn.Module):
|
86 |
+
|
87 |
+
def __init__(self,
|
88 |
+
d_model,
|
89 |
+
ssm_cls,
|
90 |
+
ssm_cfg,
|
91 |
+
rms_norm=False,
|
92 |
+
layer_idx=None
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
if rms_norm:
|
96 |
+
self.norm = RMSNorm(d_model)
|
97 |
+
else:
|
98 |
+
self.norm = nn.LayerNorm(d_model)
|
99 |
+
self.mamba = ssm_cls(d_model=d_model, **ssm_cfg)
|
100 |
+
|
101 |
+
print(type(self.mamba))
|
102 |
+
|
103 |
+
print('Created LnMambaAdd.')
|
104 |
+
|
105 |
+
def forward(self, x, residual=None, inference_params=None):
|
106 |
+
if residual != None:
|
107 |
+
x = x + residual
|
108 |
+
return self.mamba(self.norm(x)), x
|
109 |
+
|
110 |
+
|
111 |
+
class MambaBlocksSequential(nn.Module):
|
112 |
+
"""
|
113 |
+
A wrapper for the Mamba block to replicate it
|
114 |
+
|
115 |
+
Arguments
|
116 |
+
---------
|
117 |
+
n_mamba : int
|
118 |
+
Number of Mamba blocks
|
119 |
+
d_model : int
|
120 |
+
Input dimension to Mamba (bottleneck dimension).
|
121 |
+
d_state : int
|
122 |
+
Mamba state dimension
|
123 |
+
expand: int
|
124 |
+
First linear projection d_model -> d_model * expand
|
125 |
+
d_conv: int
|
126 |
+
Kernel size of Mamba conv
|
127 |
+
norm type : str
|
128 |
+
The type of normalization, in ['gLN', 'cLN'].
|
129 |
+
---------
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self,
|
133 |
+
n_mamba: int,
|
134 |
+
bidirectional: bool,
|
135 |
+
d_model: int, # bottleneck dimension (B)
|
136 |
+
d_state: int = 16,
|
137 |
+
expand: int = 2,
|
138 |
+
d_conv: int = 4, # kernel_size of 'Conv' in Mamba
|
139 |
+
dt_rank: str="auto",
|
140 |
+
conv_bias: bool = True,
|
141 |
+
bias: bool = False,
|
142 |
+
fused_add_norm: bool = True,
|
143 |
+
rms_norm: bool = False,
|
144 |
+
norm_epsilon: float = 1e-5,
|
145 |
+
initializer_cfg=None,
|
146 |
+
residual_in_fp32=False,
|
147 |
+
use_simple_block=False
|
148 |
+
):
|
149 |
+
super().__init__()
|
150 |
+
self.residual_in_fp32 = residual_in_fp32
|
151 |
+
self.bidirectional = bidirectional
|
152 |
+
|
153 |
+
# We change the order of residual and layer norm:
|
154 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
155 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
156 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
157 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
158 |
+
self.fused_add_norm = fused_add_norm
|
159 |
+
if self.fused_add_norm:
|
160 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
161 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
162 |
+
|
163 |
+
self.use_simple_block = use_simple_block
|
164 |
+
|
165 |
+
ssm_cfg = {
|
166 |
+
"d_state": d_state,
|
167 |
+
"expand": expand,
|
168 |
+
"d_conv": d_conv,
|
169 |
+
"dt_rank": dt_rank,
|
170 |
+
"conv_bias": conv_bias,
|
171 |
+
"bias": bias
|
172 |
+
}
|
173 |
+
if bidirectional:
|
174 |
+
ssm_cfg["bimamba_type"] = "v2"
|
175 |
+
|
176 |
+
if use_simple_block:
|
177 |
+
self.layers = nn.Sequential(
|
178 |
+
*[
|
179 |
+
LnMambaAdd(
|
180 |
+
d_model=d_model,
|
181 |
+
ssm_cls=BiMamba if bidirectional else Mamba,
|
182 |
+
ssm_cfg=ssm_cfg,
|
183 |
+
rms_norm=rms_norm,
|
184 |
+
layer_idx=i
|
185 |
+
)
|
186 |
+
for i in range(n_mamba)
|
187 |
+
]
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
self.layers = nn.Sequential(
|
191 |
+
*[
|
192 |
+
create_block(
|
193 |
+
d_model=d_model,
|
194 |
+
ssm_cls=BiMamba if bidirectional else Mamba,
|
195 |
+
ssm_cfg=ssm_cfg,
|
196 |
+
norm_epsilon=norm_epsilon,
|
197 |
+
rms_norm=rms_norm,
|
198 |
+
residual_in_fp32=residual_in_fp32,
|
199 |
+
fused_add_norm=fused_add_norm,
|
200 |
+
layer_idx=i,
|
201 |
+
)
|
202 |
+
for i in range(n_mamba)
|
203 |
+
]
|
204 |
+
)
|
205 |
+
|
206 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
207 |
+
d_model, eps=norm_epsilon
|
208 |
+
)
|
209 |
+
|
210 |
+
self.apply(
|
211 |
+
partial(
|
212 |
+
_init_weights,
|
213 |
+
n_layer=n_mamba,
|
214 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
215 |
+
)
|
216 |
+
)
|
217 |
+
|
218 |
+
|
219 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
220 |
+
return {
|
221 |
+
i: block.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
222 |
+
for i, layer in enumerate(self.layers)
|
223 |
+
}
|
224 |
+
|
225 |
+
def forward(self, x, inference_params=None):
|
226 |
+
|
227 |
+
hidden_states = x
|
228 |
+
residual = None
|
229 |
+
for i, layer in enumerate(self.layers):
|
230 |
+
hidden_states, residual = layer(
|
231 |
+
hidden_states, residual, inference_params=inference_params
|
232 |
+
)
|
233 |
+
|
234 |
+
if not self.fused_add_norm:
|
235 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
236 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
237 |
+
else:
|
238 |
+
# Set prenorm=False here since we don't need the residual
|
239 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
|
240 |
+
|
241 |
+
hidden_states = fused_add_norm_fn(
|
242 |
+
hidden_states,
|
243 |
+
self.norm_f.weight,
|
244 |
+
self.norm_f.bias,
|
245 |
+
eps=self.norm_f.eps,
|
246 |
+
residual=residual,
|
247 |
+
prenorm=False,
|
248 |
+
residual_in_fp32=self.residual_in_fp32,
|
249 |
+
)
|
250 |
+
|
251 |
+
return hidden_states
|
252 |
+
|
model/modules/mamba/selective_scan_interface.py
ADDED
@@ -0,0 +1,714 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copied from
|
3 |
+
https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
11 |
+
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
|
14 |
+
from causal_conv1d import causal_conv1d_fn
|
15 |
+
import causal_conv1d_cuda
|
16 |
+
import selective_scan_cuda
|
17 |
+
|
18 |
+
|
19 |
+
class SelectiveScanFn(torch.autograd.Function):
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
23 |
+
return_last_state=False):
|
24 |
+
if u.stride(-1) != 1:
|
25 |
+
u = u.contiguous()
|
26 |
+
if delta.stride(-1) != 1:
|
27 |
+
delta = delta.contiguous()
|
28 |
+
if D is not None:
|
29 |
+
D = D.contiguous()
|
30 |
+
if B.stride(-1) != 1:
|
31 |
+
B = B.contiguous()
|
32 |
+
if C.stride(-1) != 1:
|
33 |
+
C = C.contiguous()
|
34 |
+
if z is not None and z.stride(-1) != 1:
|
35 |
+
z = z.contiguous()
|
36 |
+
if B.dim() == 3:
|
37 |
+
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
38 |
+
ctx.squeeze_B = True
|
39 |
+
if C.dim() == 3:
|
40 |
+
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
41 |
+
ctx.squeeze_C = True
|
42 |
+
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
43 |
+
ctx.delta_softplus = delta_softplus
|
44 |
+
ctx.has_z = z is not None
|
45 |
+
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
46 |
+
if not ctx.has_z:
|
47 |
+
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
48 |
+
return out if not return_last_state else (out, last_state)
|
49 |
+
else:
|
50 |
+
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
51 |
+
out_z = rest[0]
|
52 |
+
return out_z if not return_last_state else (out_z, last_state)
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def backward(ctx, dout, *args):
|
56 |
+
if not ctx.has_z:
|
57 |
+
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
58 |
+
z = None
|
59 |
+
out = None
|
60 |
+
else:
|
61 |
+
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
62 |
+
if dout.stride(-1) != 1:
|
63 |
+
dout = dout.contiguous()
|
64 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
65 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
66 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
67 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
68 |
+
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
69 |
+
False # option to recompute out_z, not used here
|
70 |
+
)
|
71 |
+
dz = rest[0] if ctx.has_z else None
|
72 |
+
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
73 |
+
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
74 |
+
return (du, ddelta, dA, dB, dC,
|
75 |
+
dD if D is not None else None,
|
76 |
+
dz,
|
77 |
+
ddelta_bias if delta_bias is not None else None,
|
78 |
+
None,
|
79 |
+
None)
|
80 |
+
|
81 |
+
|
82 |
+
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
83 |
+
return_last_state=False):
|
84 |
+
"""if return_last_state is True, returns (out, last_state)
|
85 |
+
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
86 |
+
not considered in the backward pass.
|
87 |
+
"""
|
88 |
+
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
89 |
+
|
90 |
+
|
91 |
+
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
92 |
+
return_last_state=False):
|
93 |
+
"""
|
94 |
+
u: r(B D L)
|
95 |
+
delta: r(B D L)
|
96 |
+
A: c(D N) or r(D N)
|
97 |
+
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
98 |
+
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
99 |
+
D: r(D)
|
100 |
+
z: r(B D L)
|
101 |
+
delta_bias: r(D), fp32
|
102 |
+
|
103 |
+
out: r(B D L)
|
104 |
+
last_state (optional): r(B D dstate) or c(B D dstate)
|
105 |
+
"""
|
106 |
+
dtype_in = u.dtype
|
107 |
+
u = u.float()
|
108 |
+
delta = delta.float()
|
109 |
+
if delta_bias is not None:
|
110 |
+
delta = delta + delta_bias[..., None].float()
|
111 |
+
if delta_softplus:
|
112 |
+
delta = F.softplus(delta)
|
113 |
+
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
114 |
+
is_variable_B = B.dim() >= 3
|
115 |
+
is_variable_C = C.dim() >= 3
|
116 |
+
if A.is_complex():
|
117 |
+
if is_variable_B:
|
118 |
+
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
119 |
+
if is_variable_C:
|
120 |
+
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
121 |
+
else:
|
122 |
+
B = B.float()
|
123 |
+
C = C.float()
|
124 |
+
x = A.new_zeros((batch, dim, dstate))
|
125 |
+
ys = []
|
126 |
+
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
127 |
+
if not is_variable_B:
|
128 |
+
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
129 |
+
else:
|
130 |
+
if B.dim() == 3:
|
131 |
+
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
132 |
+
else:
|
133 |
+
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
134 |
+
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
135 |
+
if is_variable_C and C.dim() == 4:
|
136 |
+
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
137 |
+
last_state = None
|
138 |
+
for i in range(u.shape[2]):
|
139 |
+
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
140 |
+
if not is_variable_C:
|
141 |
+
y = torch.einsum('bdn,dn->bd', x, C)
|
142 |
+
else:
|
143 |
+
if C.dim() == 3:
|
144 |
+
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
145 |
+
else:
|
146 |
+
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
147 |
+
if i == u.shape[2] - 1:
|
148 |
+
last_state = x
|
149 |
+
if y.is_complex():
|
150 |
+
y = y.real * 2
|
151 |
+
ys.append(y)
|
152 |
+
y = torch.stack(ys, dim=2) # (batch dim L)
|
153 |
+
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
154 |
+
if z is not None:
|
155 |
+
out = out * F.silu(z)
|
156 |
+
out = out.to(dtype=dtype_in)
|
157 |
+
return out if not return_last_state else (out, last_state)
|
158 |
+
|
159 |
+
|
160 |
+
class MambaInnerFnNoOutProj(torch.autograd.Function):
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
@custom_fwd
|
164 |
+
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
165 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
166 |
+
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
167 |
+
"""
|
168 |
+
xz: (batch, dim, seqlen)
|
169 |
+
"""
|
170 |
+
assert checkpoint_lvl in [0, 1]
|
171 |
+
L = xz.shape[-1]
|
172 |
+
delta_rank = delta_proj_weight.shape[1]
|
173 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
174 |
+
if torch.is_autocast_enabled():
|
175 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
176 |
+
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
177 |
+
if xz.stride(-1) != 1:
|
178 |
+
xz = xz.contiguous()
|
179 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
180 |
+
x, z = xz.chunk(2, dim=1)
|
181 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
182 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
|
183 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
184 |
+
# We want delta to have d as the slowest moving dimension
|
185 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
186 |
+
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
187 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
188 |
+
ctx.is_variable_B = B is None
|
189 |
+
ctx.is_variable_C = C is None
|
190 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
191 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
192 |
+
if B is None: # variable B
|
193 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
194 |
+
if B_proj_bias is not None:
|
195 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
196 |
+
if not A.is_complex():
|
197 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
198 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
199 |
+
else:
|
200 |
+
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
201 |
+
else:
|
202 |
+
if B.stride(-1) != 1:
|
203 |
+
B = B.contiguous()
|
204 |
+
if C is None: # variable C
|
205 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
206 |
+
if C_proj_bias is not None:
|
207 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
208 |
+
if not A.is_complex():
|
209 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
210 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
211 |
+
else:
|
212 |
+
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
213 |
+
else:
|
214 |
+
if C.stride(-1) != 1:
|
215 |
+
C = C.contiguous()
|
216 |
+
if D is not None:
|
217 |
+
D = D.contiguous()
|
218 |
+
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
219 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
220 |
+
)
|
221 |
+
ctx.delta_softplus = delta_softplus
|
222 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
223 |
+
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
224 |
+
conv1d_out, delta = None, None
|
225 |
+
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
226 |
+
delta_proj_weight, conv1d_out, delta,
|
227 |
+
A, B, C, D, delta_bias, scan_intermediates, out)
|
228 |
+
# return rearrange(out_z, "b d l -> b l d")
|
229 |
+
return out_z
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
@custom_bwd
|
233 |
+
def backward(ctx, dout):
|
234 |
+
# dout: (batch, seqlen, dim)
|
235 |
+
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight,
|
236 |
+
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
237 |
+
L = xz.shape[-1]
|
238 |
+
delta_rank = delta_proj_weight.shape[1]
|
239 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
240 |
+
x, z = xz.chunk(2, dim=1)
|
241 |
+
if dout.stride(-1) != 1:
|
242 |
+
dout = dout.contiguous()
|
243 |
+
if ctx.checkpoint_lvl == 1:
|
244 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
|
245 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
246 |
+
"d (b l) -> b d l", l = L)
|
247 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
248 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
249 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
250 |
+
dx, dz = dxz.chunk(2, dim=1)
|
251 |
+
# dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l
|
252 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
253 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz,
|
254 |
+
ctx.delta_softplus,
|
255 |
+
True # option to recompute out_z
|
256 |
+
)
|
257 |
+
dD = dD if D is not None else None
|
258 |
+
dx_dbl = torch.empty_like(x_dbl)
|
259 |
+
dB_proj_bias = None
|
260 |
+
if ctx.is_variable_B:
|
261 |
+
if not A.is_complex():
|
262 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
263 |
+
else:
|
264 |
+
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
265 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
266 |
+
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
267 |
+
dB = None
|
268 |
+
dC_proj_bias = None
|
269 |
+
if ctx.is_variable_C:
|
270 |
+
if not A.is_complex():
|
271 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
272 |
+
else:
|
273 |
+
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
274 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
275 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
276 |
+
dC = None
|
277 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
278 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
279 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
280 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
281 |
+
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
282 |
+
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
283 |
+
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
284 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
285 |
+
# backward of conv1d with the backward of chunk).
|
286 |
+
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
|
287 |
+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
|
288 |
+
)
|
289 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
290 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
291 |
+
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
292 |
+
dA, dB, dC, dD,
|
293 |
+
ddelta_bias if delta_bias is not None else None,
|
294 |
+
dB_proj_bias, dC_proj_bias, None)
|
295 |
+
|
296 |
+
|
297 |
+
class MambaInnerFn(torch.autograd.Function):
|
298 |
+
|
299 |
+
@staticmethod
|
300 |
+
@custom_fwd
|
301 |
+
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
302 |
+
out_proj_weight, out_proj_bias,
|
303 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
304 |
+
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
305 |
+
"""
|
306 |
+
xz: (batch, dim, seqlen)
|
307 |
+
"""
|
308 |
+
assert checkpoint_lvl in [0, 1]
|
309 |
+
L = xz.shape[-1]
|
310 |
+
delta_rank = delta_proj_weight.shape[1]
|
311 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
312 |
+
if torch.is_autocast_enabled():
|
313 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
314 |
+
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
315 |
+
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
316 |
+
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
317 |
+
if out_proj_bias is not None else None)
|
318 |
+
if xz.stride(-1) != 1:
|
319 |
+
xz = xz.contiguous()
|
320 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
321 |
+
x, z = xz.chunk(2, dim=1)
|
322 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
323 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
|
324 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
325 |
+
# We want delta to have d as the slowest moving dimension
|
326 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
327 |
+
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
328 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
329 |
+
ctx.is_variable_B = B is None
|
330 |
+
ctx.is_variable_C = C is None
|
331 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
332 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
333 |
+
if B is None: # variable B
|
334 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
335 |
+
if B_proj_bias is not None:
|
336 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
337 |
+
if not A.is_complex():
|
338 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
339 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
340 |
+
else:
|
341 |
+
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
342 |
+
else:
|
343 |
+
if B.stride(-1) != 1:
|
344 |
+
B = B.contiguous()
|
345 |
+
if C is None: # variable C
|
346 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
347 |
+
if C_proj_bias is not None:
|
348 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
349 |
+
if not A.is_complex():
|
350 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
351 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
352 |
+
else:
|
353 |
+
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
354 |
+
else:
|
355 |
+
if C.stride(-1) != 1:
|
356 |
+
C = C.contiguous()
|
357 |
+
if D is not None:
|
358 |
+
D = D.contiguous()
|
359 |
+
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
360 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
361 |
+
)
|
362 |
+
ctx.delta_softplus = delta_softplus
|
363 |
+
ctx.out_proj_bias_is_None = out_proj_bias is None
|
364 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
365 |
+
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
366 |
+
conv1d_out, delta = None, None
|
367 |
+
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
368 |
+
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
369 |
+
A, B, C, D, delta_bias, scan_intermediates, out)
|
370 |
+
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
371 |
+
|
372 |
+
@staticmethod
|
373 |
+
@custom_bwd
|
374 |
+
def backward(ctx, dout):
|
375 |
+
# dout: (batch, seqlen, dim)
|
376 |
+
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
377 |
+
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
|
378 |
+
L = xz.shape[-1]
|
379 |
+
delta_rank = delta_proj_weight.shape[1]
|
380 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
381 |
+
x, z = xz.chunk(2, dim=1)
|
382 |
+
if dout.stride(-1) != 1:
|
383 |
+
dout = dout.contiguous()
|
384 |
+
if ctx.checkpoint_lvl == 1:
|
385 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
|
386 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
387 |
+
"d (b l) -> b d l", l = L)
|
388 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
389 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
390 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
391 |
+
dx, dz = dxz.chunk(2, dim=1)
|
392 |
+
dout = rearrange(dout, "b l e -> e (b l)")
|
393 |
+
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
394 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
395 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
|
396 |
+
ctx.delta_softplus,
|
397 |
+
True # option to recompute out_z
|
398 |
+
)
|
399 |
+
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
400 |
+
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
401 |
+
dD = dD if D is not None else None
|
402 |
+
dx_dbl = torch.empty_like(x_dbl)
|
403 |
+
dB_proj_bias = None
|
404 |
+
if ctx.is_variable_B:
|
405 |
+
if not A.is_complex():
|
406 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
407 |
+
else:
|
408 |
+
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
409 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
410 |
+
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
411 |
+
dB = None
|
412 |
+
dC_proj_bias = None
|
413 |
+
if ctx.is_variable_C:
|
414 |
+
if not A.is_complex():
|
415 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
416 |
+
else:
|
417 |
+
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
418 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
419 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
420 |
+
dC = None
|
421 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
422 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
423 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
424 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
425 |
+
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
426 |
+
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
427 |
+
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
428 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
429 |
+
# backward of conv1d with the backward of chunk).
|
430 |
+
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
|
431 |
+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
|
432 |
+
)
|
433 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
434 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
435 |
+
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
436 |
+
dout_proj_weight, dout_proj_bias,
|
437 |
+
dA, dB, dC, dD,
|
438 |
+
ddelta_bias if delta_bias is not None else None,
|
439 |
+
dB_proj_bias, dC_proj_bias, None)
|
440 |
+
|
441 |
+
|
442 |
+
class BiMambaInnerFn(torch.autograd.Function):
|
443 |
+
|
444 |
+
@staticmethod
|
445 |
+
@custom_fwd
|
446 |
+
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
447 |
+
out_proj_weight, out_proj_bias,
|
448 |
+
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
449 |
+
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
|
450 |
+
"""
|
451 |
+
xz: (batch, dim, seqlen)
|
452 |
+
"""
|
453 |
+
assert checkpoint_lvl in [0, 1]
|
454 |
+
L = xz.shape[-1]
|
455 |
+
delta_rank = delta_proj_weight.shape[1]
|
456 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
457 |
+
if torch.is_autocast_enabled():
|
458 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
459 |
+
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
460 |
+
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
461 |
+
out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
462 |
+
if out_proj_bias is not None else None)
|
463 |
+
if xz.stride(-1) != 1:
|
464 |
+
xz = xz.contiguous()
|
465 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
466 |
+
x, z = xz.chunk(2, dim=1)
|
467 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
468 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
|
469 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
470 |
+
# We want delta to have d as the slowest moving dimension
|
471 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
472 |
+
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
473 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
|
474 |
+
ctx.is_variable_B = B is None
|
475 |
+
ctx.is_variable_C = C is None
|
476 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
477 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
478 |
+
if B is None: # variable B
|
479 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
|
480 |
+
if B_proj_bias is not None:
|
481 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
482 |
+
if not A.is_complex():
|
483 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
484 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
485 |
+
else:
|
486 |
+
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
487 |
+
else:
|
488 |
+
if B.stride(-1) != 1:
|
489 |
+
B = B.contiguous()
|
490 |
+
if C is None: # variable C
|
491 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
492 |
+
if C_proj_bias is not None:
|
493 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
494 |
+
if not A.is_complex():
|
495 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
496 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
497 |
+
else:
|
498 |
+
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
499 |
+
else:
|
500 |
+
if C.stride(-1) != 1:
|
501 |
+
C = C.contiguous()
|
502 |
+
if D is not None:
|
503 |
+
D = D.contiguous()
|
504 |
+
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd(
|
505 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
506 |
+
)
|
507 |
+
assert not A_b.is_complex(), "A should not be complex!!"
|
508 |
+
out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd(
|
509 |
+
conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus,
|
510 |
+
)
|
511 |
+
|
512 |
+
out_z = out_z_f + out_z_b.flip([-1])
|
513 |
+
|
514 |
+
ctx.delta_softplus = delta_softplus
|
515 |
+
ctx.out_proj_bias_is_None = out_proj_bias is None
|
516 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
517 |
+
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
518 |
+
conv1d_out, delta = None, None
|
519 |
+
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
|
520 |
+
delta_proj_weight, out_proj_weight, conv1d_out, delta,
|
521 |
+
A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b)
|
522 |
+
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
523 |
+
|
524 |
+
@staticmethod
|
525 |
+
@custom_bwd
|
526 |
+
def backward(ctx, dout):
|
527 |
+
# dout: (batch, seqlen, dim)
|
528 |
+
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
|
529 |
+
conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors
|
530 |
+
L = xz.shape[-1]
|
531 |
+
delta_rank = delta_proj_weight.shape[1]
|
532 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
533 |
+
x, z = xz.chunk(2, dim=1)
|
534 |
+
if dout.stride(-1) != 1:
|
535 |
+
dout = dout.contiguous()
|
536 |
+
if ctx.checkpoint_lvl == 1:
|
537 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
|
538 |
+
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
|
539 |
+
"d (b l) -> b d l", l = L)
|
540 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
541 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
542 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
543 |
+
dx, dz = dxz.chunk(2, dim=1)
|
544 |
+
dout = rearrange(dout, "b l e -> e (b l)")
|
545 |
+
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
546 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd(
|
547 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz,
|
548 |
+
ctx.delta_softplus,
|
549 |
+
True # option to recompute out_z
|
550 |
+
)
|
551 |
+
# flip one
|
552 |
+
dz_b = torch.empty_like(dz)
|
553 |
+
dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd(
|
554 |
+
conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b,
|
555 |
+
ctx.delta_softplus,
|
556 |
+
True # option to recompute out_z
|
557 |
+
)
|
558 |
+
|
559 |
+
dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1])
|
560 |
+
ddelta = ddelta + ddelta_f_b.flip([-1])
|
561 |
+
dB = dB + dB_f_b.flip([-1])
|
562 |
+
dC = dC + dC_f_b.flip([-1])
|
563 |
+
dD = dD + dD_b
|
564 |
+
ddelta_bias = ddelta_bias + ddelta_bias_b
|
565 |
+
dz = dz + dz_b.flip([-1])
|
566 |
+
out_z = out_z_f + out_z_b.flip([-1])
|
567 |
+
|
568 |
+
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
569 |
+
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
570 |
+
dD = dD if D is not None else None
|
571 |
+
dx_dbl = torch.empty_like(x_dbl)
|
572 |
+
dB_proj_bias = None
|
573 |
+
if ctx.is_variable_B:
|
574 |
+
if not A.is_complex():
|
575 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
576 |
+
else:
|
577 |
+
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
578 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
579 |
+
dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
|
580 |
+
dB = None
|
581 |
+
dC_proj_bias = None
|
582 |
+
if ctx.is_variable_C:
|
583 |
+
if not A.is_complex():
|
584 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
585 |
+
else:
|
586 |
+
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
587 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
588 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
589 |
+
dC = None
|
590 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
591 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
592 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
593 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
594 |
+
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
595 |
+
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
596 |
+
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
597 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
598 |
+
# backward of conv1d with the backward of chunk).
|
599 |
+
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
|
600 |
+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
|
601 |
+
)
|
602 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
603 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
604 |
+
return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
|
605 |
+
dout_proj_weight, dout_proj_bias,
|
606 |
+
dA, dA_b, dB, dC, dD,
|
607 |
+
ddelta_bias if delta_bias is not None else None,
|
608 |
+
dB_proj_bias, dC_proj_bias, None)
|
609 |
+
|
610 |
+
|
611 |
+
def mamba_inner_fn(
|
612 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
613 |
+
out_proj_weight, out_proj_bias,
|
614 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
615 |
+
C_proj_bias=None, delta_softplus=True
|
616 |
+
):
|
617 |
+
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
618 |
+
out_proj_weight, out_proj_bias,
|
619 |
+
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
620 |
+
|
621 |
+
def bimamba_inner_fn(
|
622 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
623 |
+
out_proj_weight, out_proj_bias,
|
624 |
+
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
625 |
+
C_proj_bias=None, delta_softplus=True
|
626 |
+
):
|
627 |
+
return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
628 |
+
out_proj_weight, out_proj_bias,
|
629 |
+
A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
630 |
+
|
631 |
+
|
632 |
+
def mamba_inner_fn_no_out_proj(
|
633 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
634 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
635 |
+
C_proj_bias=None, delta_softplus=True
|
636 |
+
):
|
637 |
+
return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
638 |
+
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
|
639 |
+
|
640 |
+
|
641 |
+
def mamba_inner_ref(
|
642 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
643 |
+
out_proj_weight, out_proj_bias,
|
644 |
+
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
645 |
+
C_proj_bias=None, delta_softplus=True
|
646 |
+
):
|
647 |
+
L = xz.shape[-1]
|
648 |
+
delta_rank = delta_proj_weight.shape[1]
|
649 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
650 |
+
x, z = xz.chunk(2, dim=1)
|
651 |
+
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
|
652 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
653 |
+
# We want delta to have d as the slowest moving dimension
|
654 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
655 |
+
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
656 |
+
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
657 |
+
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
658 |
+
if B is None: # variable B
|
659 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
|
660 |
+
if B_proj_bias is not None:
|
661 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
662 |
+
if not A.is_complex():
|
663 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
664 |
+
else:
|
665 |
+
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
666 |
+
if C is None: # variable B
|
667 |
+
C = x_dbl[:, -d_state:] # (bl d)
|
668 |
+
if C_proj_bias is not None:
|
669 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
670 |
+
if not A.is_complex():
|
671 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
672 |
+
else:
|
673 |
+
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
674 |
+
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
675 |
+
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
676 |
+
|
677 |
+
|
678 |
+
def bimamba_inner_ref(
|
679 |
+
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
680 |
+
out_proj_weight, out_proj_bias,
|
681 |
+
A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
|
682 |
+
C_proj_bias=None, delta_softplus=True
|
683 |
+
):
|
684 |
+
L = xz.shape[-1]
|
685 |
+
delta_rank = delta_proj_weight.shape[1]
|
686 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
687 |
+
x, z = xz.chunk(2, dim=1)
|
688 |
+
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
|
689 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
690 |
+
# We want delta to have d as the slowest moving dimension
|
691 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
692 |
+
x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
|
693 |
+
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
694 |
+
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
695 |
+
if B is None: # variable B
|
696 |
+
B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
|
697 |
+
if B_proj_bias is not None:
|
698 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
699 |
+
if not A.is_complex():
|
700 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
701 |
+
else:
|
702 |
+
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
703 |
+
if C is None: # variable B
|
704 |
+
C = x_dbl[:, -d_state:] # (bl d)
|
705 |
+
if C_proj_bias is not None:
|
706 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
707 |
+
if not A.is_complex():
|
708 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
709 |
+
else:
|
710 |
+
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
|
711 |
+
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
|
712 |
+
y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
|
713 |
+
y = y + y_b.flip([-1])
|
714 |
+
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
model/patchify.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Patchify(nn.Module):
|
5 |
+
def __init__(self, in_channels, out_channels, patch_size):
|
6 |
+
super(Patchify, self).__init__()
|
7 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(8, patch_size), stride=(8, patch_size), padding=0, bias=False)
|
8 |
+
|
9 |
+
def forward(self, x):
|
10 |
+
# x.shape = (batch_size, channels, height, width)
|
11 |
+
x = self.conv(x)
|
12 |
+
|
13 |
+
return x
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
model = Patchify(1, 32, 2)
|
17 |
+
print(model)
|
18 |
+
dummy_input = torch.randn(1, 1, 64, 16)
|
19 |
+
output = model(dummy_input)
|
20 |
+
print(output.shape)
|
model/sinc_conv.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.fft
|
6 |
+
import sys
|
7 |
+
from torch.autograd import Variable
|
8 |
+
import math
|
9 |
+
|
10 |
+
class GlobalLayerNorm(nn.Module):
|
11 |
+
'''
|
12 |
+
Calculate Global Layer Normalization
|
13 |
+
dim: (int or list or torch.Size) –
|
14 |
+
input shape from an expected input of size
|
15 |
+
eps: a value added to the denominator for numerical stability.
|
16 |
+
elementwise_affine: a boolean value that when set to True,
|
17 |
+
this module has learnable per-element affine parameters
|
18 |
+
initialized to ones (for weights) and zeros (for biases).
|
19 |
+
'''
|
20 |
+
|
21 |
+
def __init__(self, dim, eps=1e-05, elementwise_affine=True):
|
22 |
+
super(GlobalLayerNorm, self).__init__()
|
23 |
+
self.dim = dim
|
24 |
+
self.eps = eps
|
25 |
+
self.elementwise_affine = elementwise_affine
|
26 |
+
|
27 |
+
if self.elementwise_affine:
|
28 |
+
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
29 |
+
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
30 |
+
else:
|
31 |
+
self.register_parameter('weight', None)
|
32 |
+
self.register_parameter('bias', None)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
# x = N x C x L
|
36 |
+
# N x 1 x 1
|
37 |
+
# cln: mean,var N x 1 x L
|
38 |
+
# gln: mean,var N x 1 x 1
|
39 |
+
if x.dim() != 3:
|
40 |
+
raise RuntimeError("{} accept 3D tensor as input".format(
|
41 |
+
self.__name__))
|
42 |
+
|
43 |
+
mean = torch.mean(x, (1, 2), keepdim=True)
|
44 |
+
var = torch.mean((x-mean)**2, (1, 2), keepdim=True)
|
45 |
+
# N x C x L
|
46 |
+
if self.elementwise_affine:
|
47 |
+
x = self.weight*(x-mean)/torch.sqrt(var+self.eps)+self.bias
|
48 |
+
else:
|
49 |
+
x = (x-mean)/torch.sqrt(var+self.eps)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class TimeSincExtractor(nn.Module):
|
54 |
+
"""Sinc-based convolution
|
55 |
+
Parameters
|
56 |
+
----------
|
57 |
+
in_channels : `int`
|
58 |
+
Number of input channels. Must be 1.
|
59 |
+
out_channels : `int`
|
60 |
+
Number of filters.
|
61 |
+
kernel_size : `int`
|
62 |
+
Filter length.
|
63 |
+
sample_rate : `int`, optional
|
64 |
+
Sample rate. Defaults to 16000.
|
65 |
+
triangular : `bool`
|
66 |
+
Squared sinc -> Triangular filter.
|
67 |
+
freq_nml : `bool`
|
68 |
+
Normalized to gain of 1 in frequency.
|
69 |
+
range_constraint : `bool`
|
70 |
+
Project the learned band within nyquist freq manually.
|
71 |
+
Usage
|
72 |
+
-----
|
73 |
+
See `torch.nn.Conv1d`
|
74 |
+
"""
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def to_mel(hz):
|
78 |
+
return 2595 * np.log10(1 + hz / 700)
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def to_hz(mel):
|
82 |
+
return 700 * (10 ** (mel / 2595) - 1)
|
83 |
+
|
84 |
+
def swap_(self, x, y, sort=False):
|
85 |
+
mini = torch.minimum(x, y)
|
86 |
+
maxi = torch.maximum(x, y)
|
87 |
+
|
88 |
+
if sort:
|
89 |
+
mini, idx = torch.sort(mini)
|
90 |
+
maxi = maxi[idx].view(mini.shape)
|
91 |
+
|
92 |
+
return mini, maxi
|
93 |
+
|
94 |
+
def __init__(self, out_channels, kernel_size, triangular=False,
|
95 |
+
freq_nml=False, range_constraint=False, freq_init='uniform', norm_after=True, sample_rate=16000, in_channels=1,
|
96 |
+
stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50, bi_factor=False, frame_length=400, hop_length=160):
|
97 |
+
|
98 |
+
super(TimeSincExtractor,self).__init__()
|
99 |
+
|
100 |
+
if in_channels != 1:
|
101 |
+
# msg = (f'SincConv only support one input channel '
|
102 |
+
# f'(here, in_channels = {in_channels:d}).')
|
103 |
+
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
|
104 |
+
raise ValueError(msg)
|
105 |
+
|
106 |
+
self.out_channels = out_channels
|
107 |
+
self.kernel_size = kernel_size
|
108 |
+
self.triangular = False
|
109 |
+
self.freq_nml = False
|
110 |
+
|
111 |
+
# Forcing the filters to be odd (i.e, perfectly symmetrics)
|
112 |
+
if kernel_size%2 == 0:
|
113 |
+
self.kernel_size = self.kernel_size+1
|
114 |
+
|
115 |
+
self.stride = stride
|
116 |
+
self.padding = padding
|
117 |
+
self.dilation = dilation
|
118 |
+
|
119 |
+
self.frame_length = frame_length
|
120 |
+
self.hop_length = hop_length
|
121 |
+
|
122 |
+
if bias:
|
123 |
+
raise ValueError('SincConv does not support bias.')
|
124 |
+
if groups > 1:
|
125 |
+
raise ValueError('SincConv does not support groups.')
|
126 |
+
|
127 |
+
self.sample_rate = sample_rate
|
128 |
+
self.nyquist_rate = sample_rate/2
|
129 |
+
self.min_low_hz = min_low_hz
|
130 |
+
self.min_band_hz = min_band_hz
|
131 |
+
self.range_constraint = range_constraint
|
132 |
+
self.bi_factor = bi_factor
|
133 |
+
|
134 |
+
if self.range_constraint:
|
135 |
+
# msg = "Range constraint in learned frequency is not supported yet."
|
136 |
+
# raise ValueError(msg)
|
137 |
+
if freq_init == "uniform":
|
138 |
+
low_freq, high_freq = torch.rand(out_channels*2).chunk(2)
|
139 |
+
elif freq_init == "formant":
|
140 |
+
# raise NotImplementedError('Formant distribution hasn\'t been implemented yet.')
|
141 |
+
p = np.load('/share/nas165/Jasonho610/SincNet/exp/formant_distribution.npy')
|
142 |
+
low_freq, high_freq = torch.from_numpy(np.random.choice(8000, out_channels*2, p=p)).chunk(2)
|
143 |
+
low_freq = low_freq / self.nyquist_rate
|
144 |
+
high_freq = high_freq / self.nyquist_rate
|
145 |
+
elif freq_init == "mel":
|
146 |
+
# raise NotImplementedError('Mel distribution hasn\'t been implemented yet.')
|
147 |
+
low_hz = 30
|
148 |
+
high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz)
|
149 |
+
mel = np.linspace(self.to_mel(low_hz),
|
150 |
+
self.to_mel(high_hz),
|
151 |
+
self.out_channels + 1)
|
152 |
+
hz = self.to_hz(mel)
|
153 |
+
low_freq = torch.Tensor(hz[:-1]) / self.nyquist_rate
|
154 |
+
high_freq = torch.Tensor(hz[1:]) / self.nyquist_rate
|
155 |
+
else:
|
156 |
+
raise ValueError('SincConv must specify the freq initialization methods.')
|
157 |
+
|
158 |
+
low_freq, high_freq = self.swap_(low_freq, high_freq)
|
159 |
+
|
160 |
+
if self.bi_factor:
|
161 |
+
self.band_imp = nn.Parameter(torch.ones(out_channels))
|
162 |
+
self.low_f_ = nn.Parameter(low_freq.view(-1, 1))
|
163 |
+
self.high_f_ = nn.Parameter(high_freq.view(-1, 1))
|
164 |
+
else:
|
165 |
+
# initialize filterbanks such that they are equally spaced in Mel scale
|
166 |
+
low_hz = 30
|
167 |
+
high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz)
|
168 |
+
mel = np.linspace(self.to_mel(low_hz),
|
169 |
+
self.to_mel(high_hz),
|
170 |
+
self.out_channels + 1)
|
171 |
+
hz = self.to_hz(mel)
|
172 |
+
# filter lower frequency (out_channels, 1)
|
173 |
+
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
|
174 |
+
|
175 |
+
# filter frequency band (out_channels, 1)
|
176 |
+
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
|
177 |
+
|
178 |
+
# Hamming window
|
179 |
+
# self.window_ = torch.hamming_window(self.kernel_size)
|
180 |
+
n_lin = torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window
|
181 |
+
self.window_ = 0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size);
|
182 |
+
|
183 |
+
# (1, kernel_size/2)
|
184 |
+
n = (self.kernel_size - 1) / 2.0
|
185 |
+
self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes
|
186 |
+
|
187 |
+
self.norm_after = norm_after
|
188 |
+
if self.norm_after:
|
189 |
+
self.ln = GlobalLayerNorm(out_channels)
|
190 |
+
|
191 |
+
|
192 |
+
def forward(self, waveforms, embedding):
|
193 |
+
"""
|
194 |
+
Parameters
|
195 |
+
----------
|
196 |
+
waveforms : `torch.Tensor` (batch_size, 1, n_samples)
|
197 |
+
Batch of waveforms.
|
198 |
+
Returns
|
199 |
+
-------
|
200 |
+
features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
|
201 |
+
Batch of sinc filters activations.
|
202 |
+
"""
|
203 |
+
|
204 |
+
self.n_ = self.n_.to(waveforms.device)
|
205 |
+
self.window_ = self.window_.to(waveforms.device)
|
206 |
+
# waveforms = waveforms.unsqueeze(1)
|
207 |
+
# print("Waveforms:", waveforms.shape)
|
208 |
+
|
209 |
+
framing_padding = self.frame_length - (waveforms.shape[-1] % self.hop_length)
|
210 |
+
waveforms = F.pad(waveforms, (0, framing_padding))
|
211 |
+
frames = waveforms.unfold(-1, self.frame_length, self.hop_length)
|
212 |
+
|
213 |
+
batch_size = frames.shape[0]
|
214 |
+
n_frames = frames.shape[2]
|
215 |
+
|
216 |
+
if self.range_constraint:
|
217 |
+
low_f_, high_f_ = self.swap_(torch.abs(self.low_f_), torch.abs(self.high_f_))
|
218 |
+
|
219 |
+
low = self.min_low_hz + low_f_*self.nyquist_rate
|
220 |
+
high = torch.clamp(self.min_band_hz + high_f_*self.nyquist_rate, self.min_low_hz, self.nyquist_rate)
|
221 |
+
band = (high-low)[:,0]
|
222 |
+
else:
|
223 |
+
low = self.min_low_hz + torch.abs(self.low_hz_)
|
224 |
+
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.nyquist_rate)
|
225 |
+
band = (high-low)[:,0]
|
226 |
+
|
227 |
+
self.low = low
|
228 |
+
self.high = high
|
229 |
+
self.band = band
|
230 |
+
|
231 |
+
f_times_t_low = torch.matmul(low, self.n_)
|
232 |
+
f_times_t_high = torch.matmul(high, self.n_)
|
233 |
+
|
234 |
+
band_pass_left = ((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations.
|
235 |
+
band_pass_center = 2*band.view(-1,1)
|
236 |
+
band_pass_right = torch.flip(band_pass_left,dims=[1])
|
237 |
+
|
238 |
+
band_pass = torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1)
|
239 |
+
|
240 |
+
band_pass = band_pass / (2*band[:,None])
|
241 |
+
|
242 |
+
if self.triangular:
|
243 |
+
band_pass = band_pass**2
|
244 |
+
|
245 |
+
if self.freq_nml:
|
246 |
+
mag_resp = torch.fft.rfft(band_pass).abs()
|
247 |
+
mag_max = torch.max(mag_resp, dim=-1)[0]
|
248 |
+
band_pass = band_pass / mag_max.unsqueeze(-1)
|
249 |
+
|
250 |
+
if self.bi_factor:
|
251 |
+
band_imp = F.relu(self.band_imp)
|
252 |
+
band_pass = band_pass * band_imp.unsqueeze(-1)
|
253 |
+
|
254 |
+
|
255 |
+
self.filters = (band_pass).view(
|
256 |
+
self.out_channels, 1, self.kernel_size)
|
257 |
+
|
258 |
+
# print("Filters:", self.filters.shape)
|
259 |
+
# print("Frames:", frames.shape)
|
260 |
+
|
261 |
+
rs_frames = frames.reshape(batch_size*n_frames, 1, self.frame_length)
|
262 |
+
# print("Reshaped frames:", rs_frames.shape)
|
263 |
+
|
264 |
+
filtered = F.conv1d(rs_frames, self.filters, stride=self.stride,
|
265 |
+
padding=self.padding, dilation=self.dilation,
|
266 |
+
bias=None, groups=1)
|
267 |
+
# print('Pass conv1d')
|
268 |
+
# print("Filtered:", filtered.shape)
|
269 |
+
if self.norm_after:
|
270 |
+
filtered = self.ln(filtered)
|
271 |
+
|
272 |
+
# print("Normed filtered:", filtered.shape)
|
273 |
+
|
274 |
+
filtered = filtered.reshape(batch_size, n_frames, self.out_channels , -1)
|
275 |
+
|
276 |
+
# print("Final filtered:", filtered.shape)
|
277 |
+
|
278 |
+
energy = torch.mean(filtered**2, dim=-1)
|
279 |
+
log_filtered_energy = torch.log10(energy + 1e-6)
|
280 |
+
# print("Log filtered energy:", log_filtered_energy.shape) # (batch_size, n_samples_out(time), out_channels(frequency))
|
281 |
+
|
282 |
+
log_filtered_energy = log_filtered_energy.unsqueeze(1)
|
283 |
+
# print("Unsqueezed log filtered energy:", log_filtered_energy.shape) # (batch_size, channels, n_samples_out(time), out_channels(frequency))
|
284 |
+
|
285 |
+
log_filtered_energy = log_filtered_energy.permute(0, 1, 3, 2)
|
286 |
+
# print("Permuted log filtered energy:", log_filtered_energy.shape) # (batch_size, channels, out_channels(frequency), n_samples_out(time))
|
287 |
+
|
288 |
+
return log_filtered_energy, self.filters, self.stride, self.padding
|
289 |
+
|
290 |
+
|
291 |
+
class FreqSincExtractor(nn.Module):
|
292 |
+
@staticmethod
|
293 |
+
def to_mel(hz):
|
294 |
+
return 2595 * np.log10(1 + hz / 700)
|
295 |
+
|
296 |
+
@staticmethod
|
297 |
+
def to_hz(mel):
|
298 |
+
return 700 * (10 ** (mel / 2595) - 1)
|
299 |
+
|
300 |
+
def swap_(self, x, y, sort=False):
|
301 |
+
mini = torch.minimum(x, y)
|
302 |
+
maxi = torch.maximum(x, y)
|
303 |
+
if sort:
|
304 |
+
mini, idx = torch.sort(mini)
|
305 |
+
maxi = maxi[idx].view(mini.shape)
|
306 |
+
return mini, maxi
|
307 |
+
|
308 |
+
def __init__(self, out_channels, kernel_size, triangular=False,
|
309 |
+
freq_nml=False, range_constraint=False, freq_init='uniform',
|
310 |
+
norm_after=True, sample_rate=16000, in_channels=1,
|
311 |
+
stride=1, padding=0, dilation=1, bias=False, groups=1,
|
312 |
+
min_low_hz=50, min_band_hz=50, bi_factor=False,
|
313 |
+
frame_length=400, hop_length=160, n_fft=400):
|
314 |
+
super(FreqSincExtractor, self).__init__()
|
315 |
+
|
316 |
+
if in_channels != 1:
|
317 |
+
msg = "FreqSincExtractor only supports one input channel (here, in_channels = {%i})" % (in_channels)
|
318 |
+
raise ValueError(msg)
|
319 |
+
|
320 |
+
self.out_channels = out_channels
|
321 |
+
self.kernel_size = kernel_size
|
322 |
+
self.triangular = triangular
|
323 |
+
self.freq_nml = freq_nml
|
324 |
+
self.sample_rate = sample_rate
|
325 |
+
self.nyquist_rate = sample_rate/2
|
326 |
+
self.min_low_hz = min_low_hz
|
327 |
+
self.min_band_hz = min_band_hz
|
328 |
+
self.range_constraint = range_constraint
|
329 |
+
self.bi_factor = bi_factor
|
330 |
+
self.frame_length = frame_length
|
331 |
+
self.hop_length = hop_length
|
332 |
+
self.n_fft = n_fft
|
333 |
+
self.stride = stride
|
334 |
+
self.padding = padding
|
335 |
+
self.output_size = 64
|
336 |
+
|
337 |
+
# Initialize frequency bands
|
338 |
+
if self.range_constraint:
|
339 |
+
if freq_init == "uniform":
|
340 |
+
low_freq, high_freq = torch.rand(out_channels*2).chunk(2)
|
341 |
+
elif freq_init == "mel":
|
342 |
+
low_hz = 30
|
343 |
+
high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz)
|
344 |
+
mel = np.linspace(self.to_mel(low_hz),
|
345 |
+
self.to_mel(high_hz),
|
346 |
+
self.out_channels + 1)
|
347 |
+
hz = self.to_hz(mel)
|
348 |
+
low_freq = torch.Tensor(hz[:-1]) / self.nyquist_rate
|
349 |
+
high_freq = torch.Tensor(hz[1:]) / self.nyquist_rate
|
350 |
+
else:
|
351 |
+
raise ValueError('FreqSincExtractor must specify the freq initialization methods.')
|
352 |
+
|
353 |
+
low_freq, high_freq = self.swap_(low_freq, high_freq)
|
354 |
+
|
355 |
+
if self.bi_factor:
|
356 |
+
self.band_imp = nn.Parameter(torch.ones(out_channels))
|
357 |
+
self.low_f_ = nn.Parameter(low_freq.view(-1, 1))
|
358 |
+
self.high_f_ = nn.Parameter(high_freq.view(-1, 1))
|
359 |
+
else:
|
360 |
+
low_hz = 30
|
361 |
+
high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz)
|
362 |
+
mel = np.linspace(self.to_mel(low_hz),
|
363 |
+
self.to_mel(high_hz),
|
364 |
+
self.out_channels + 1)
|
365 |
+
hz = self.to_hz(mel)
|
366 |
+
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
|
367 |
+
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
|
368 |
+
|
369 |
+
# Frequency axis for STFT
|
370 |
+
self.freq_axis = torch.linspace(0, self.nyquist_rate, self.n_fft//2 + 1)
|
371 |
+
|
372 |
+
self.norm_after = norm_after
|
373 |
+
if self.norm_after:
|
374 |
+
self.ln = GlobalLayerNorm(out_channels)
|
375 |
+
|
376 |
+
def get_filters(self):
|
377 |
+
if self.range_constraint:
|
378 |
+
low_f_, high_f_ = self.swap_(torch.abs(self.low_f_), torch.abs(self.high_f_))
|
379 |
+
low = self.min_low_hz + low_f_ * self.nyquist_rate
|
380 |
+
high = torch.clamp(self.min_low_hz + high_f_ * self.nyquist_rate,
|
381 |
+
self.min_low_hz, self.nyquist_rate)
|
382 |
+
else:
|
383 |
+
low = self.min_low_hz + torch.abs(self.low_hz_)
|
384 |
+
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),
|
385 |
+
self.min_low_hz, self.nyquist_rate)
|
386 |
+
|
387 |
+
# Create frequency domain filters
|
388 |
+
freq_axis = self.freq_axis.to(low.device)
|
389 |
+
filters = torch.zeros((self.out_channels, len(freq_axis))).to(low.device)
|
390 |
+
|
391 |
+
for i in range(self.out_channels):
|
392 |
+
mask = (freq_axis >= low[i]) & (freq_axis <= high[i])
|
393 |
+
filters[i, mask] = 1.0
|
394 |
+
|
395 |
+
if self.triangular:
|
396 |
+
center_freq = (low[i] + high[i]) / 2
|
397 |
+
bandwidth = high[i] - low[i]
|
398 |
+
mask = (freq_axis >= low[i]) & (freq_axis <= high[i])
|
399 |
+
freq_response = 1.0 - torch.abs(freq_axis[mask] - center_freq) / (bandwidth/2)
|
400 |
+
filters[i, mask] = freq_response
|
401 |
+
|
402 |
+
if self.freq_nml:
|
403 |
+
filters = F.normalize(filters, p=2, dim=1)
|
404 |
+
|
405 |
+
if self.bi_factor:
|
406 |
+
band_imp = F.relu(self.band_imp)
|
407 |
+
filters = filters * band_imp.unsqueeze(-1)
|
408 |
+
|
409 |
+
return filters
|
410 |
+
|
411 |
+
def forward(self, waveforms, embedding=None):
|
412 |
+
batch_size = waveforms.shape[0]
|
413 |
+
|
414 |
+
# Calculate necessary padding to achieve the correct output size
|
415 |
+
target_length = self.hop_length * (self.output_size - 1) + self.frame_length
|
416 |
+
current_length = waveforms.shape[-1]
|
417 |
+
padding_needed = target_length - current_length
|
418 |
+
|
419 |
+
# Pad the input if necessary
|
420 |
+
if padding_needed > 0:
|
421 |
+
waveforms = F.pad(waveforms, (0, padding_needed))
|
422 |
+
|
423 |
+
# Compute STFT
|
424 |
+
stft = torch.stft(waveforms.squeeze(1),
|
425 |
+
n_fft=self.n_fft,
|
426 |
+
hop_length=self.hop_length,
|
427 |
+
win_length=self.frame_length,
|
428 |
+
window=torch.hann_window(self.frame_length).to(waveforms.device),
|
429 |
+
return_complex=True)
|
430 |
+
|
431 |
+
# Get magnitude spectrogram
|
432 |
+
mag_spec = torch.abs(stft) # (batch_size, freq_bins, time_frames)
|
433 |
+
|
434 |
+
# Get and apply filters
|
435 |
+
filters = self.get_filters() # (out_channels, freq_bins)
|
436 |
+
filtered = torch.matmul(filters, mag_spec) # (batch_size, out_channels, time_frames)
|
437 |
+
|
438 |
+
if self.norm_after:
|
439 |
+
filtered = self.ln(filtered)
|
440 |
+
|
441 |
+
# Compute log energy
|
442 |
+
energy = filtered ** 2
|
443 |
+
log_energy = torch.log10(energy + 1e-6)
|
444 |
+
|
445 |
+
# Ensure correct time dimension
|
446 |
+
if log_energy.shape[-1] != self.output_size:
|
447 |
+
log_energy = F.interpolate(
|
448 |
+
log_energy,
|
449 |
+
size=self.output_size,
|
450 |
+
mode='linear',
|
451 |
+
align_corners=False
|
452 |
+
)
|
453 |
+
|
454 |
+
# Reshape to the desired output format
|
455 |
+
log_energy = log_energy.unsqueeze(1) # Add channel dimension
|
456 |
+
log_energy = log_energy.permute(0, 1, 3, 2) # Rearrange to (batch, channel, freq, time)
|
457 |
+
|
458 |
+
return log_energy, filters, self.stride, self.padding
|
459 |
+
|
460 |
+
|
461 |
+
if __name__ == "__main__":
|
462 |
+
batch_size = 256
|
463 |
+
n_samples = 10080
|
464 |
+
waveforms = torch.rand(batch_size, 1, n_samples)
|
465 |
+
|
466 |
+
# model = TimeSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
|
467 |
+
model = FreqSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
|
468 |
+
print(model)
|
469 |
+
|
470 |
+
outputs, _, _, _ = model(waveforms, embedding=None)
|
471 |
+
print("Outputs:", outputs.shape)
|
model/tiny_block.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class TinyBlock(nn.Module):
|
5 |
+
def __init__(self, in_channels, out_channels, dilation=2):
|
6 |
+
super(TinyBlock, self).__init__()
|
7 |
+
|
8 |
+
# f1: 3x3 depthwise convolution + BatchNorm
|
9 |
+
self.f1 = nn.Sequential(
|
10 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False),
|
11 |
+
nn.BatchNorm2d(in_channels)
|
12 |
+
)
|
13 |
+
|
14 |
+
# f2: 1x1 grouped pointwise convolutions with 8 groups + ReLU
|
15 |
+
self.f2 = nn.Sequential(
|
16 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=8, bias=False),
|
17 |
+
nn.ReLU(inplace=True)
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
f1_out = self.f1(x)
|
22 |
+
f2_out = self.f2(x + f1_out)
|
23 |
+
out = x + f1_out + f2_out
|
24 |
+
return out
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
model = TinyBlock(16, 16)
|
28 |
+
print(model)
|
29 |
+
dummy_input = torch.randn(256, 16, 8, 8)
|
30 |
+
output = model(dummy_input)
|
31 |
+
print(output.shape)
|
model/tinyvad.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .sinc_conv import TimeSincExtractor, FreqSincExtractor
|
4 |
+
from .patchify import Patchify
|
5 |
+
from .csp_tiny_layer import CSPTinyLayer
|
6 |
+
|
7 |
+
class TinyVAD(nn.Module):
|
8 |
+
def __init__(self, in_channels, hidden_channels, out_channels, patch_size, num_blocks, sinc_conv, ssm):
|
9 |
+
super(TinyVAD, self).__init__()
|
10 |
+
|
11 |
+
self.sinc_conv = sinc_conv
|
12 |
+
|
13 |
+
if self.sinc_conv:
|
14 |
+
# self.extractor = TimeSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
|
15 |
+
self.extractor = FreqSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
|
16 |
+
|
17 |
+
self.patchify = Patchify(in_channels, hidden_channels, patch_size)
|
18 |
+
|
19 |
+
self.csp_tiny_layer1 = CSPTinyLayer(hidden_channels, hidden_channels, num_blocks, ssm)
|
20 |
+
self.csp_tiny_layer2 = CSPTinyLayer(hidden_channels, hidden_channels, num_blocks, ssm)
|
21 |
+
self.csp_tiny_layer3 = CSPTinyLayer(hidden_channels, out_channels, num_blocks, ssm)
|
22 |
+
|
23 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
24 |
+
|
25 |
+
self.classifier = nn.Sequential(
|
26 |
+
nn.Linear(out_channels, 1),
|
27 |
+
# nn.Sigmoid()
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
if self.sinc_conv:
|
32 |
+
x = self.extractor(x, None)
|
33 |
+
x = x[0] # Untuple
|
34 |
+
|
35 |
+
x = self.patchify(x)
|
36 |
+
|
37 |
+
x = self.csp_tiny_layer1(x)
|
38 |
+
x = self.csp_tiny_layer2(x)
|
39 |
+
x = self.csp_tiny_layer3(x)
|
40 |
+
|
41 |
+
x = self.avg_pool(x).view(x.size(0), -1)
|
42 |
+
|
43 |
+
x = self.classifier(x)
|
44 |
+
|
45 |
+
return x
|
46 |
+
|
47 |
+
def predict(self, inputs):
|
48 |
+
logits = self.forward(inputs)
|
49 |
+
probs = torch.sigmoid(logits)
|
50 |
+
|
51 |
+
return probs
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
+
print(f"Using device: {device}")
|
56 |
+
|
57 |
+
model = TinyVAD(1, 32, 64, 2, 2, False, False).to(device)
|
58 |
+
print(model)
|
59 |
+
dummy_input = torch.randn(1, 1, 64, 16).to(device)
|
60 |
+
output = model(dummy_input)
|
61 |
+
print(output)
|
62 |
+
|