jethrowang commited on
Commit
1423dc8
·
verified ·
1 Parent(s): db481bf

Upload 18 files

Browse files
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/csp_tiny_layer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .tiny_block import TinyBlock
4
+ from transformers import MambaConfig, MambaModel
5
+ # from .conmamba import ConMamba
6
+
7
+ class CSPTinyLayer(nn.Module):
8
+ def __init__(self, in_channels, out_channels, num_blocks, ssm=False):
9
+ super(CSPTinyLayer, self).__init__()
10
+
11
+ self.ssm = ssm
12
+
13
+ # Split channels
14
+ self.split_channels = in_channels // 2
15
+
16
+ if self.ssm:
17
+ # Mamba Blocks
18
+ configuration = MambaConfig(vocab_size=0, hidden_size=self.split_channels, num_hidden_layers=num_blocks)
19
+ self.mamba_blocks = MambaModel(configuration)
20
+
21
+ # mamba_config = {
22
+ # 'd_state': self.split_channels,
23
+ # 'expand': 2,
24
+ # 'd_conv': 4,
25
+ # 'bidirectional': True
26
+ # }
27
+ # self.mamba_blocks = ConMamba(
28
+ # num_blocks=num_blocks,
29
+ # channels=self.split_channels,
30
+ # height=8,
31
+ # width=8,
32
+ # mamba_config=mamba_config
33
+ # )
34
+
35
+ else:
36
+ # TinyBlocks
37
+ self.tiny_blocks = nn.Sequential(
38
+ *[TinyBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)]
39
+ )
40
+
41
+ # Transition layer to adjust channel dimensions
42
+ self.transition = nn.Sequential(
43
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
44
+ nn.BatchNorm2d(out_channels),
45
+ nn.ReLU(inplace=True)
46
+ )
47
+
48
+ def forward(self, x):
49
+ # Split input into two parts
50
+ p1 = x[:, :self.split_channels, :, :]
51
+ p2 = x[:, self.split_channels:, :, :]
52
+
53
+ if self.ssm:
54
+ # Reshape to fit Mamba
55
+ B, C, H, W = p2.shape
56
+ p2 = p2.permute(0, 2, 3, 1) # [B, H, W, C]
57
+ p2 = p2.reshape(B, H * W, C) # [B, L, C], L = H * W
58
+
59
+ # Process p2 through MambaBlocks
60
+ p2_out = self.mamba_blocks(inputs_embeds=p2).last_hidden_state
61
+
62
+ # p2_out = self.mamba_blocks(p2)
63
+
64
+ # Reshape back to original dimension
65
+ p2_out = p2_out.reshape(B, H, W, -1)
66
+ p2_out = p2_out.permute(0, 3, 1, 2) # [B, C, H, W]
67
+ else:
68
+ # Process p2 through TinyBlocks
69
+ p2_out = self.tiny_blocks(p2)
70
+
71
+ # Concatenate p1 and processed p2
72
+ concatenated = torch.cat((p1, p2_out), dim=1)
73
+
74
+ # Apply transition layer
75
+ out = self.transition(concatenated)
76
+ return out
77
+
78
+ if __name__ == "__main__":
79
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ print(f"Using device: {device}")
81
+
82
+ model = CSPTinyLayer(32, 32, 2, True).to(device)
83
+ print(model)
84
+ dummy_input = torch.randn(256, 32, 8, 8).to(device)
85
+ output = model(dummy_input)
86
+ print(output.shape)
model/mamba_hf.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import MambaConfig, MambaModel, Mamba2Config, Mamba2Model
3
+
4
+ print(f"CUDA available: {torch.cuda.is_available()}")
5
+ if torch.cuda.is_available():
6
+ print(f"CUDA device: {torch.cuda.get_device_name()}")
7
+ print(f"CUDA version: {torch.version.cuda}")
8
+
9
+ batch, channel, height, width = 256, 16, 8, 8
10
+ x = torch.randn(batch, channel, height, width).to("cuda")
11
+ print(f'x: {x.shape}')
12
+
13
+ B, C, H, W = x.shape
14
+ x = x.permute(0, 2, 3, 1) # [B, H, W, C]
15
+ print(f'Permuted x: {x.shape}')
16
+
17
+ x = x.reshape(B, H * W, C) # [B, L, C], L = H * W
18
+ print(f'Reshaped x: {x.shape}')
19
+
20
+ # Initializing a Mamba configuration
21
+ configuration = MambaConfig(vocab_size=0, hidden_size=channel, num_hidden_layers=2)
22
+ # configuration = Mamba2Config(hidden_size=channel)
23
+
24
+ # Initializing a model (with random weights) from the configuration
25
+ model = MambaModel(configuration).to("cuda")
26
+ # model = Mamba2Model(configuration).to("cuda")
27
+ print(f'Model: {model}')
28
+
29
+ # Accessing the model configuration
30
+ configuration = model.config
31
+ print(f'Configuration: {configuration}')
32
+
33
+ # y = model(inputs_embeds=x).last_hidden_state
34
+ y = model(inputs_embeds=x, return_dict=True)[0]
35
+ print(f'y: {y.shape}')
36
+
37
+ y = y.reshape(B, H, W, -1)
38
+ print(f'Reshaped y: {y.shape}')
39
+
40
+ y = y.permute(0, 3, 1, 2) # [B, C, H, W]
41
+ print(f'Permuted y: {y.shape}')
model/modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/modules/Conformer.py ADDED
@@ -0,0 +1,1094 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conformer implementation.
2
+
3
+ Authors
4
+ -------
5
+ * Jianyuan Zhong 2020
6
+ * Samuele Cornell 2021
7
+ * Sylvain de Langen 2023
8
+ """
9
+
10
+ import warnings
11
+ from dataclasses import dataclass
12
+ from typing import List, Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ import speechbrain as sb
19
+ from speechbrain.nnet.activations import Swish
20
+ from speechbrain.nnet.attention import (
21
+ MultiheadAttention,
22
+ PositionalwiseFeedForward,
23
+ RelPosMHAXL,
24
+ )
25
+ from speechbrain.nnet.hypermixing import HyperMixing
26
+ from speechbrain.nnet.normalization import LayerNorm
27
+ from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
28
+
29
+
30
+ @dataclass
31
+ class ConformerEncoderLayerStreamingContext:
32
+ """Streaming metadata and state for a `ConformerEncoderLayer`.
33
+
34
+ The multi-head attention and Dynamic Chunk Convolution require to save some
35
+ left context that gets inserted as left padding.
36
+
37
+ See :class:`.ConvolutionModule` documentation for further details.
38
+ """
39
+
40
+ mha_left_context_size: int
41
+ """For this layer, specifies how many frames of inputs should be saved.
42
+ Usually, the same value is used across all layers, but this can be modified.
43
+ """
44
+
45
+ mha_left_context: Optional[torch.Tensor] = None
46
+ """Left context to insert at the left of the current chunk as inputs to the
47
+ multi-head attention. It can be `None` (if we're dealing with the first
48
+ chunk) or `<= mha_left_context_size` because for the first few chunks, not
49
+ enough left context may be available to pad.
50
+ """
51
+
52
+ dcconv_left_context: Optional[torch.Tensor] = None
53
+ """Left context to insert at the left of the convolution according to the
54
+ Dynamic Chunk Convolution method.
55
+
56
+ Unlike `mha_left_context`, here the amount of frames to keep is fixed and
57
+ inferred from the kernel size of the convolution module.
58
+ """
59
+
60
+
61
+ @dataclass
62
+ class ConformerEncoderStreamingContext:
63
+ """Streaming metadata and state for a `ConformerEncoder`."""
64
+
65
+ dynchunktrain_config: DynChunkTrainConfig
66
+ """Dynamic Chunk Training configuration holding chunk size and context size
67
+ information."""
68
+
69
+ layers: List[ConformerEncoderLayerStreamingContext]
70
+ """Streaming metadata and state for each layer of the encoder."""
71
+
72
+
73
+ class ConvolutionModule(nn.Module):
74
+ """This is an implementation of convolution module in Conformer.
75
+
76
+ Arguments
77
+ ---------
78
+ input_size : int
79
+ The expected size of the input embedding dimension.
80
+ kernel_size: int, optional
81
+ Kernel size of non-bottleneck convolutional layer.
82
+ bias: bool, optional
83
+ Whether to use bias in the non-bottleneck conv layer.
84
+ activation: torch.nn.Module
85
+ Activation function used after non-bottleneck conv layer.
86
+ dropout: float, optional
87
+ Dropout rate.
88
+ causal: bool, optional
89
+ Whether the convolution should be causal or not.
90
+ dilation: int, optional
91
+ Dilation factor for the non bottleneck conv layer.
92
+
93
+ Example
94
+ -------
95
+ >>> import torch
96
+ >>> x = torch.rand((8, 60, 512))
97
+ >>> net = ConvolutionModule(512, 3)
98
+ >>> output = net(x)
99
+ >>> output.shape
100
+ torch.Size([8, 60, 512])
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ input_size,
106
+ kernel_size=31,
107
+ bias=True,
108
+ activation=Swish,
109
+ dropout=0.0,
110
+ causal=False,
111
+ dilation=1,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.kernel_size = kernel_size
116
+ self.causal = causal
117
+ self.dilation = dilation
118
+
119
+ if self.causal:
120
+ self.padding = (kernel_size - 1) * 2 ** (dilation - 1)
121
+ else:
122
+ self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2
123
+
124
+ self.layer_norm = nn.LayerNorm(input_size)
125
+ self.bottleneck = nn.Sequential(
126
+ # pointwise
127
+ nn.Conv1d(
128
+ input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias
129
+ ),
130
+ nn.GLU(dim=1),
131
+ )
132
+ # depthwise
133
+ self.conv = nn.Conv1d(
134
+ input_size,
135
+ input_size,
136
+ kernel_size=kernel_size,
137
+ stride=1,
138
+ padding=self.padding,
139
+ dilation=dilation,
140
+ groups=input_size,
141
+ bias=bias,
142
+ )
143
+
144
+ # BatchNorm in the original Conformer replaced with a LayerNorm due to
145
+ # https://github.com/speechbrain/speechbrain/pull/1329
146
+ # see discussion
147
+ # https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884
148
+
149
+ self.after_conv = nn.Sequential(
150
+ nn.LayerNorm(input_size),
151
+ activation(),
152
+ # pointwise
153
+ nn.Linear(input_size, input_size, bias=bias),
154
+ nn.Dropout(dropout),
155
+ )
156
+
157
+ def forward(
158
+ self,
159
+ x: torch.Tensor,
160
+ mask: Optional[torch.Tensor] = None,
161
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
162
+ ):
163
+ """Applies the convolution to an input tensor `x`.
164
+
165
+ Arguments
166
+ ---------
167
+ x: torch.Tensor
168
+ Input tensor to the convolution module.
169
+ mask: torch.Tensor, optional
170
+ Mask to be applied over the output of the convolution using
171
+ `masked_fill_`, if specified.
172
+ dynchunktrain_config: DynChunkTrainConfig, optional
173
+ If specified, makes the module support Dynamic Chunk Convolution
174
+ (DCConv) as implemented by
175
+ `Dynamic Chunk Convolution for Unified Streaming and Non-Streaming Conformer ASR <https://www.amazon.science/publications/dynamic-chunk-convolution-for-unified-streaming-and-non-streaming-conformer-asr>`_.
176
+ This allows masking future frames while preserving better accuracy
177
+ than a fully causal convolution, at a small speed cost.
178
+ This should only be used for training (or, if you know what you're
179
+ doing, for masked evaluation at inference time), as the forward
180
+ streaming function should be used at inference time.
181
+
182
+ Returns
183
+ -------
184
+ out: torch.Tensor
185
+ The output tensor.
186
+ """
187
+
188
+ if dynchunktrain_config is not None:
189
+ # chances are chunking+causal is unintended; i don't know where it
190
+ # may make sense, but if it does to you, feel free to implement it.
191
+ assert (
192
+ not self.causal
193
+ ), "Chunked convolution not supported with causal padding"
194
+
195
+ assert (
196
+ self.dilation == 1
197
+ ), "Current DynChunkTrain logic does not support dilation != 1"
198
+
199
+ # in a causal convolution, which is not the case here, an output
200
+ # frame would never be able to depend on a input frame from any
201
+ # point in the future.
202
+
203
+ # but with the dynamic chunk convolution, we instead use a "normal"
204
+ # convolution but where, for any output frame, the future beyond the
205
+ # "current" chunk gets masked.
206
+ # see the paper linked in the documentation for details.
207
+
208
+ chunk_size = dynchunktrain_config.chunk_size
209
+ batch_size = x.shape[0]
210
+
211
+ # determine the amount of padding we need to insert at the right of
212
+ # the last chunk so that all chunks end up with the same size.
213
+ if x.shape[1] % chunk_size != 0:
214
+ final_right_padding = chunk_size - (x.shape[1] % chunk_size)
215
+ else:
216
+ final_right_padding = 0
217
+
218
+ # -> [batch_size, t, in_channels]
219
+ out = self.layer_norm(x)
220
+
221
+ # -> [batch_size, in_channels, t] for the CNN
222
+ out = out.transpose(1, 2)
223
+
224
+ # -> [batch_size, in_channels, t] (pointwise)
225
+ out = self.bottleneck(out)
226
+
227
+ # -> [batch_size, in_channels, lc+t+final_right_padding]
228
+ out = F.pad(out, (self.padding, final_right_padding), value=0)
229
+
230
+ # now, make chunks with left context.
231
+ # as a recap to what the above padding and this unfold do, consider
232
+ # each a/b/c letter represents a frame as part of chunks a, b, c.
233
+ # consider a chunk size of 4 and a kernel size of 5 (padding=2):
234
+ #
235
+ # input seq: 00aaaabbbbcc00
236
+ # chunk #1: 00aaaa
237
+ # chunk #2: aabbbb
238
+ # chunk #3: bbcc00
239
+ #
240
+ # a few remarks here:
241
+ # - the left padding gets inserted early so that the unfold logic
242
+ # works trivially
243
+ # - the right 0-padding got inserted as the number of time steps
244
+ # could not be evenly split in `chunk_size` chunks
245
+
246
+ # -> [batch_size, in_channels, num_chunks, lc+chunk_size]
247
+ out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size)
248
+
249
+ # as we manually disable padding in the convolution below, we insert
250
+ # right 0-padding to the chunks, e.g. reusing the above example:
251
+ #
252
+ # chunk #1: 00aaaa00
253
+ # chunk #2: aabbbb00
254
+ # chunk #3: bbcc0000
255
+
256
+ # -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad]
257
+ out = F.pad(out, (0, self.padding), value=0)
258
+
259
+ # the transpose+flatten effectively flattens chunks into the batch
260
+ # dimension to be processed into the time-wise convolution. the
261
+ # chunks will later on be unflattened.
262
+
263
+ # -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad]
264
+ out = out.transpose(1, 2)
265
+
266
+ # -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad]
267
+ out = out.flatten(start_dim=0, end_dim=1)
268
+
269
+ # TODO: experiment around reflect padding, which is difficult
270
+ # because small chunks have too little time steps to reflect from
271
+
272
+ # let's keep backwards compat by pointing at the weights from the
273
+ # already declared Conv1d.
274
+ #
275
+ # still reusing the above example, the convolution will be applied,
276
+ # with the padding truncated on both ends. the following example
277
+ # shows the letter corresponding to the input frame on which the
278
+ # convolution was centered.
279
+ #
280
+ # as you can see, the sum of lengths of all chunks is equal to our
281
+ # input sequence length + `final_right_padding`.
282
+ #
283
+ # chunk #1: aaaa
284
+ # chunk #2: bbbb
285
+ # chunk #3: cc00
286
+
287
+ # -> [batch_size * num_chunks, out_channels, chunk_size]
288
+ out = F.conv1d(
289
+ out,
290
+ weight=self.conv.weight,
291
+ bias=self.conv.bias,
292
+ stride=self.conv.stride,
293
+ padding=0,
294
+ dilation=self.conv.dilation,
295
+ groups=self.conv.groups,
296
+ )
297
+
298
+ # -> [batch_size * num_chunks, chunk_size, out_channels]
299
+ out = out.transpose(1, 2)
300
+
301
+ out = self.after_conv(out)
302
+
303
+ # -> [batch_size, num_chunks, chunk_size, out_channels]
304
+ out = torch.unflatten(out, dim=0, sizes=(batch_size, -1))
305
+
306
+ # -> [batch_size, t + final_right_padding, out_channels]
307
+ out = torch.flatten(out, start_dim=1, end_dim=2)
308
+
309
+ # -> [batch_size, t, out_channels]
310
+ if final_right_padding > 0:
311
+ out = out[:, :-final_right_padding, :]
312
+ else:
313
+ out = self.layer_norm(x)
314
+ out = out.transpose(1, 2)
315
+ out = self.bottleneck(out)
316
+ out = self.conv(out)
317
+
318
+ if self.causal:
319
+ # chomp
320
+ out = out[..., : -self.padding]
321
+
322
+ out = out.transpose(1, 2)
323
+ out = self.after_conv(out)
324
+
325
+ if mask is not None:
326
+ out.masked_fill_(mask, 0.0)
327
+
328
+ return out
329
+
330
+
331
+ class ConformerEncoderLayer(nn.Module):
332
+ """This is an implementation of Conformer encoder layer.
333
+
334
+ Arguments
335
+ ---------
336
+ d_model : int
337
+ The expected size of the input embedding.
338
+ d_ffn : int
339
+ Hidden size of self-attention Feed Forward layer.
340
+ nhead : int
341
+ Number of attention heads.
342
+ kernel_size : int, optional
343
+ Kernel size of convolution model.
344
+ kdim : int, optional
345
+ Dimension of the key.
346
+ vdim : int, optional
347
+ Dimension of the value.
348
+ activation: torch.nn.Module
349
+ Activation function used in each Conformer layer.
350
+ bias : bool, optional
351
+ Whether convolution module.
352
+ dropout : int, optional
353
+ Dropout for the encoder.
354
+ causal : bool, optional
355
+ Whether the convolutions should be causal or not.
356
+ attention_type : str, optional
357
+ type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
358
+
359
+ Example
360
+ -------
361
+ >>> import torch
362
+ >>> x = torch.rand((8, 60, 512))
363
+ >>> pos_embs = torch.rand((1, 2*60-1, 512))
364
+ >>> net = ConformerEncoderLayer(d_ffn=512, nhead=8, d_model=512, kernel_size=3)
365
+ >>> output = net(x, pos_embs=pos_embs)
366
+ >>> output[0].shape
367
+ torch.Size([8, 60, 512])
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ d_model,
373
+ d_ffn,
374
+ nhead,
375
+ kernel_size=31,
376
+ kdim=None,
377
+ vdim=None,
378
+ activation=Swish,
379
+ bias=True,
380
+ dropout=0.0,
381
+ causal=False,
382
+ attention_type="RelPosMHAXL",
383
+ ):
384
+ super().__init__()
385
+
386
+ if attention_type == "regularMHA":
387
+ self.mha_layer = MultiheadAttention(
388
+ nhead=nhead,
389
+ d_model=d_model,
390
+ dropout=dropout,
391
+ kdim=kdim,
392
+ vdim=vdim,
393
+ )
394
+ elif attention_type == "RelPosMHAXL":
395
+ # transformerXL style positional encoding
396
+ self.mha_layer = RelPosMHAXL(
397
+ num_heads=nhead,
398
+ embed_dim=d_model,
399
+ dropout=dropout,
400
+ mask_pos_future=causal,
401
+ )
402
+ elif attention_type == "hypermixing":
403
+ self.mha_layer = HyperMixing(
404
+ input_output_dim=d_model,
405
+ hypernet_size=d_ffn,
406
+ tied=False,
407
+ num_heads=nhead,
408
+ fix_tm_hidden_size=False,
409
+ )
410
+
411
+ self.convolution_module = ConvolutionModule(
412
+ d_model, kernel_size, bias, activation, dropout, causal=causal
413
+ )
414
+
415
+ self.ffn_module1 = nn.Sequential(
416
+ nn.LayerNorm(d_model),
417
+ PositionalwiseFeedForward(
418
+ d_ffn=d_ffn,
419
+ input_size=d_model,
420
+ dropout=dropout,
421
+ activation=activation,
422
+ ),
423
+ nn.Dropout(dropout),
424
+ )
425
+
426
+ self.ffn_module2 = nn.Sequential(
427
+ nn.LayerNorm(d_model),
428
+ PositionalwiseFeedForward(
429
+ d_ffn=d_ffn,
430
+ input_size=d_model,
431
+ dropout=dropout,
432
+ activation=activation,
433
+ ),
434
+ nn.Dropout(dropout),
435
+ )
436
+
437
+ self.norm1 = LayerNorm(d_model)
438
+ self.norm2 = LayerNorm(d_model)
439
+ self.drop = nn.Dropout(dropout)
440
+
441
+ def forward(
442
+ self,
443
+ x,
444
+ src_mask: Optional[torch.Tensor] = None,
445
+ src_key_padding_mask: Optional[torch.Tensor] = None,
446
+ pos_embs: torch.Tensor = None,
447
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
448
+ ):
449
+ """
450
+ Arguments
451
+ ----------
452
+ src : torch.Tensor
453
+ The sequence to the encoder layer.
454
+ src_mask : torch.Tensor, optional
455
+ The mask for the src sequence.
456
+ src_key_padding_mask : torch.Tensor, optional
457
+ The mask for the src keys per batch.
458
+ pos_embs: torch.Tensor, torch.nn.Module, optional
459
+ Module or tensor containing the input sequence positional embeddings
460
+ dynchunktrain_config: Optional[DynChunkTrainConfig]
461
+ Dynamic Chunk Training configuration object for streaming,
462
+ specifically involved here to apply Dynamic Chunk Convolution to
463
+ the convolution module.
464
+ """
465
+ conv_mask: Optional[torch.Tensor] = None
466
+ if src_key_padding_mask is not None:
467
+ conv_mask = src_key_padding_mask.unsqueeze(-1)
468
+ # ffn module
469
+ x = x + 0.5 * self.ffn_module1(x)
470
+ # multi-head attention module
471
+ skip = x
472
+ x = self.norm1(x)
473
+
474
+ x, self_attn = self.mha_layer(
475
+ x,
476
+ x,
477
+ x,
478
+ attn_mask=src_mask,
479
+ key_padding_mask=src_key_padding_mask,
480
+ pos_embs=pos_embs,
481
+ )
482
+ x = x + skip
483
+ # convolution module
484
+ x = x + self.convolution_module(
485
+ x, conv_mask, dynchunktrain_config=dynchunktrain_config
486
+ )
487
+ # ffn module
488
+ x = self.norm2(x + 0.5 * self.ffn_module2(x))
489
+ return x, self_attn
490
+
491
+ def forward_streaming(
492
+ self,
493
+ x,
494
+ context: ConformerEncoderLayerStreamingContext,
495
+ pos_embs: torch.Tensor = None,
496
+ ):
497
+ """Conformer layer streaming forward (typically for
498
+ DynamicChunkTraining-trained models), which is to be used at inference
499
+ time. Relies on a mutable context object as initialized by
500
+ `make_streaming_context` that should be used across chunks.
501
+ Invoked by `ConformerEncoder.forward_streaming`.
502
+
503
+ Arguments
504
+ ---------
505
+ x : torch.Tensor
506
+ Input tensor for this layer. Batching is supported as long as you
507
+ keep the context consistent.
508
+ context : ConformerEncoderStreamingContext
509
+ Mutable streaming context; the same object should be passed across
510
+ calls.
511
+ pos_embs : torch.Tensor, optional
512
+ Positional embeddings, if used.
513
+
514
+ Returns
515
+ -------
516
+ x : torch.Tensor
517
+ Output tensor.
518
+ self_attn : list
519
+ List of self attention values.
520
+ """
521
+
522
+ orig_len = x.shape[-2]
523
+ # ffn module
524
+ x = x + 0.5 * self.ffn_module1(x)
525
+
526
+ # TODO: make the approach for MHA left context more efficient.
527
+ # currently, this saves the inputs to the MHA.
528
+ # the naive approach is suboptimal in a few ways, namely that the
529
+ # outputs for this left padding is being re-computed even though we
530
+ # discard them immediately after.
531
+
532
+ # left pad `x` with our MHA left context
533
+ if context.mha_left_context is not None:
534
+ x = torch.cat((context.mha_left_context, x), dim=1)
535
+
536
+ # compute new MHA left context for the next call to our function
537
+ if context.mha_left_context_size > 0:
538
+ context.mha_left_context = x[
539
+ ..., -context.mha_left_context_size :, :
540
+ ]
541
+
542
+ # multi-head attention module
543
+ skip = x
544
+ x = self.norm1(x)
545
+
546
+ x, self_attn = self.mha_layer(
547
+ x,
548
+ x,
549
+ x,
550
+ attn_mask=None,
551
+ key_padding_mask=None,
552
+ pos_embs=pos_embs,
553
+ )
554
+ x = x + skip
555
+
556
+ # truncate outputs corresponding to the MHA left context (we only care
557
+ # about our chunk's outputs); see above to-do
558
+ x = x[..., -orig_len:, :]
559
+
560
+ if context.dcconv_left_context is not None:
561
+ x = torch.cat((context.dcconv_left_context, x), dim=1)
562
+
563
+ # compute new DCConv left context for the next call to our function
564
+ context.dcconv_left_context = x[
565
+ ..., -self.convolution_module.padding :, :
566
+ ]
567
+
568
+ # convolution module
569
+ x = x + self.convolution_module(x)
570
+
571
+ # truncate outputs corresponding to the DCConv left context
572
+ x = x[..., -orig_len:, :]
573
+
574
+ # ffn module
575
+ x = self.norm2(x + 0.5 * self.ffn_module2(x))
576
+ return x, self_attn
577
+
578
+ def make_streaming_context(self, mha_left_context_size: int):
579
+ """Creates a blank streaming context for this encoding layer.
580
+
581
+ Arguments
582
+ ---------
583
+ mha_left_context_size : int
584
+ How many left frames should be saved and used as left context to the
585
+ current chunk when streaming
586
+
587
+ Returns
588
+ -------
589
+ ConformerEncoderLayerStreamingContext
590
+ """
591
+ return ConformerEncoderLayerStreamingContext(
592
+ mha_left_context_size=mha_left_context_size
593
+ )
594
+
595
+
596
+ class ConformerEncoder(nn.Module):
597
+ """This class implements the Conformer encoder.
598
+
599
+ Arguments
600
+ ---------
601
+ num_layers : int
602
+ Number of layers.
603
+ d_model : int
604
+ Embedding dimension size.
605
+ d_ffn : int
606
+ Hidden size of self-attention Feed Forward layer.
607
+ nhead : int
608
+ Number of attention heads.
609
+ kernel_size : int, optional
610
+ Kernel size of convolution model.
611
+ kdim : int, optional
612
+ Dimension of the key.
613
+ vdim : int, optional
614
+ Dimension of the value.
615
+ activation: torch.nn.Module
616
+ Activation function used in each Confomer layer.
617
+ bias : bool, optional
618
+ Whether convolution module.
619
+ dropout : int, optional
620
+ Dropout for the encoder.
621
+ causal: bool, optional
622
+ Whether the convolutions should be causal or not.
623
+ attention_type: str, optional
624
+ type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
625
+
626
+
627
+ Example
628
+ -------
629
+ >>> import torch
630
+ >>> x = torch.rand((8, 60, 512))
631
+ >>> pos_emb = torch.rand((1, 2*60-1, 512))
632
+ >>> net = ConformerEncoder(1, 512, 512, 8)
633
+ >>> output, _ = net(x, pos_embs=pos_emb)
634
+ >>> output.shape
635
+ torch.Size([8, 60, 512])
636
+ """
637
+
638
+ def __init__(
639
+ self,
640
+ num_layers,
641
+ d_model,
642
+ d_ffn,
643
+ nhead,
644
+ kernel_size=31,
645
+ kdim=None,
646
+ vdim=None,
647
+ activation=Swish,
648
+ bias=True,
649
+ dropout=0.0,
650
+ causal=False,
651
+ attention_type="RelPosMHAXL",
652
+ ):
653
+ super().__init__()
654
+
655
+ self.layers = torch.nn.ModuleList(
656
+ [
657
+ ConformerEncoderLayer(
658
+ d_ffn=d_ffn,
659
+ nhead=nhead,
660
+ d_model=d_model,
661
+ kdim=kdim,
662
+ vdim=vdim,
663
+ dropout=dropout,
664
+ activation=activation,
665
+ kernel_size=kernel_size,
666
+ bias=bias,
667
+ causal=causal,
668
+ attention_type=attention_type,
669
+ )
670
+ for i in range(num_layers)
671
+ ]
672
+ )
673
+ self.norm = LayerNorm(d_model, eps=1e-6)
674
+ self.attention_type = attention_type
675
+
676
+ def forward(
677
+ self,
678
+ src,
679
+ src_mask: Optional[torch.Tensor] = None,
680
+ src_key_padding_mask: Optional[torch.Tensor] = None,
681
+ pos_embs: Optional[torch.Tensor] = None,
682
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
683
+ ):
684
+ """
685
+ Arguments
686
+ ----------
687
+ src : torch.Tensor
688
+ The sequence to the encoder layer.
689
+ src_mask : torch.Tensor, optional
690
+ The mask for the src sequence.
691
+ src_key_padding_mask : torch.Tensor, optional
692
+ The mask for the src keys per batch.
693
+ pos_embs: torch.Tensor, torch.nn.Module,
694
+ Module or tensor containing the input sequence positional embeddings
695
+ If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
696
+ where S is the sequence length, and E is the embedding dimension.
697
+ dynchunktrain_config: Optional[DynChunkTrainConfig]
698
+ Dynamic Chunk Training configuration object for streaming,
699
+ specifically involved here to apply Dynamic Chunk Convolution to the
700
+ convolution module.
701
+ """
702
+ if self.attention_type == "RelPosMHAXL":
703
+ if pos_embs is None:
704
+ raise ValueError(
705
+ "The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
706
+ )
707
+
708
+ output = src
709
+ attention_lst = []
710
+ for enc_layer in self.layers:
711
+ output, attention = enc_layer(
712
+ output,
713
+ src_mask=src_mask,
714
+ src_key_padding_mask=src_key_padding_mask,
715
+ pos_embs=pos_embs,
716
+ dynchunktrain_config=dynchunktrain_config,
717
+ )
718
+ attention_lst.append(attention)
719
+ output = self.norm(output)
720
+
721
+ return output, attention_lst
722
+
723
+ def forward_streaming(
724
+ self,
725
+ src: torch.Tensor,
726
+ context: ConformerEncoderStreamingContext,
727
+ pos_embs: Optional[torch.Tensor] = None,
728
+ ):
729
+ """Conformer streaming forward (typically for
730
+ DynamicChunkTraining-trained models), which is to be used at inference
731
+ time. Relies on a mutable context object as initialized by
732
+ `make_streaming_context` that should be used across chunks.
733
+
734
+ Arguments
735
+ ---------
736
+ src : torch.Tensor
737
+ Input tensor. Batching is supported as long as you keep the context
738
+ consistent.
739
+ context : ConformerEncoderStreamingContext
740
+ Mutable streaming context; the same object should be passed across
741
+ calls.
742
+ pos_embs : torch.Tensor, optional
743
+ Positional embeddings, if used.
744
+
745
+ Returns
746
+ -------
747
+ output : torch.Tensor
748
+ The output of the streaming conformer.
749
+ attention_lst : list
750
+ The attention values.
751
+ """
752
+
753
+ if self.attention_type == "RelPosMHAXL":
754
+ if pos_embs is None:
755
+ raise ValueError(
756
+ "The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
757
+ )
758
+
759
+ output = src
760
+ attention_lst = []
761
+ for i, enc_layer in enumerate(self.layers):
762
+ output, attention = enc_layer.forward_streaming(
763
+ output, pos_embs=pos_embs, context=context.layers[i]
764
+ )
765
+ attention_lst.append(attention)
766
+ output = self.norm(output)
767
+
768
+ return output, attention_lst
769
+
770
+ def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig):
771
+ """Creates a blank streaming context for the encoder.
772
+
773
+ Arguments
774
+ ---------
775
+ dynchunktrain_config: Optional[DynChunkTrainConfig]
776
+ Dynamic Chunk Training configuration object for streaming
777
+
778
+ Returns
779
+ -------
780
+ ConformerEncoderStreamingContext
781
+ """
782
+ return ConformerEncoderStreamingContext(
783
+ dynchunktrain_config=dynchunktrain_config,
784
+ layers=[
785
+ layer.make_streaming_context(
786
+ mha_left_context_size=dynchunktrain_config.left_context_size_frames()
787
+ )
788
+ for layer in self.layers
789
+ ],
790
+ )
791
+
792
+
793
+ class ConformerDecoderLayer(nn.Module):
794
+ """This is an implementation of Conformer encoder layer.
795
+
796
+ Arguments
797
+ ---------
798
+ d_model : int
799
+ The expected size of the input embedding.
800
+ d_ffn : int
801
+ Hidden size of self-attention Feed Forward layer.
802
+ nhead : int
803
+ Number of attention heads.
804
+ kernel_size : int, optional
805
+ Kernel size of convolution model.
806
+ kdim : int, optional
807
+ Dimension of the key.
808
+ vdim : int, optional
809
+ Dimension of the value.
810
+ activation : torch.nn.Module, optional
811
+ Activation function used in each Conformer layer.
812
+ bias : bool, optional
813
+ Whether convolution module.
814
+ dropout : int, optional
815
+ Dropout for the encoder.
816
+ causal : bool, optional
817
+ Whether the convolutions should be causal or not.
818
+ attention_type : str, optional
819
+ type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
820
+
821
+ Example
822
+ -------
823
+ >>> import torch
824
+ >>> x = torch.rand((8, 60, 512))
825
+ >>> pos_embs = torch.rand((1, 2*60-1, 512))
826
+ >>> net = ConformerEncoderLayer(d_ffn=512, nhead=8, d_model=512, kernel_size=3)
827
+ >>> output = net(x, pos_embs=pos_embs)
828
+ >>> output[0].shape
829
+ torch.Size([8, 60, 512])
830
+ """
831
+
832
+ def __init__(
833
+ self,
834
+ d_model,
835
+ d_ffn,
836
+ nhead,
837
+ kernel_size,
838
+ kdim=None,
839
+ vdim=None,
840
+ activation=Swish,
841
+ bias=True,
842
+ dropout=0.0,
843
+ causal=True,
844
+ attention_type="RelPosMHAXL",
845
+ ):
846
+ super().__init__()
847
+
848
+ if not causal:
849
+ warnings.warn(
850
+ "Decoder is not causal, in most applications it should be causal, you have been warned !"
851
+ )
852
+
853
+ if attention_type == "regularMHA":
854
+ self.mha_layer = MultiheadAttention(
855
+ nhead=nhead,
856
+ d_model=d_model,
857
+ dropout=dropout,
858
+ kdim=kdim,
859
+ vdim=vdim,
860
+ )
861
+ elif attention_type == "RelPosMHAXL":
862
+ # transformerXL style positional encoding
863
+ self.mha_layer = RelPosMHAXL(
864
+ num_heads=nhead,
865
+ embed_dim=d_model,
866
+ dropout=dropout,
867
+ mask_pos_future=causal,
868
+ )
869
+
870
+ self.convolution_module = ConvolutionModule(
871
+ d_model, kernel_size, bias, activation, dropout, causal=causal
872
+ )
873
+
874
+ self.ffn_module1 = nn.Sequential(
875
+ nn.LayerNorm(d_model),
876
+ PositionalwiseFeedForward(
877
+ d_ffn=d_ffn,
878
+ input_size=d_model,
879
+ dropout=dropout,
880
+ activation=activation,
881
+ ),
882
+ nn.Dropout(dropout),
883
+ )
884
+
885
+ self.ffn_module2 = nn.Sequential(
886
+ nn.LayerNorm(d_model),
887
+ PositionalwiseFeedForward(
888
+ d_ffn=d_ffn,
889
+ input_size=d_model,
890
+ dropout=dropout,
891
+ activation=activation,
892
+ ),
893
+ nn.Dropout(dropout),
894
+ )
895
+
896
+ self.norm1 = LayerNorm(d_model)
897
+ self.norm2 = LayerNorm(d_model)
898
+ self.drop = nn.Dropout(dropout)
899
+
900
+ def forward(
901
+ self,
902
+ tgt,
903
+ memory,
904
+ tgt_mask=None,
905
+ memory_mask=None,
906
+ tgt_key_padding_mask=None,
907
+ memory_key_padding_mask=None,
908
+ pos_embs_tgt=None,
909
+ pos_embs_src=None,
910
+ ):
911
+ """
912
+ Arguments
913
+ ---------
914
+ tgt: torch.Tensor
915
+ The sequence to the decoder layer.
916
+ memory: torch.Tensor
917
+ The sequence from the last layer of the encoder.
918
+ tgt_mask: torch.Tensor, optional, optional
919
+ The mask for the tgt sequence.
920
+ memory_mask: torch.Tensor, optional
921
+ The mask for the memory sequence.
922
+ tgt_key_padding_mask: torch.Tensor, optional
923
+ The mask for the tgt keys per batch.
924
+ memory_key_padding_mask: torch.Tensor, optional
925
+ The mask for the memory keys per batch.
926
+ pos_embs_tgt: torch.Tensor, torch.nn.Module, optional
927
+ Module or tensor containing the target sequence positional embeddings for each attention layer.
928
+ pos_embs_src: torch.Tensor, torch.nn.Module, optional
929
+ Module or tensor containing the source sequence positional embeddings for each attention layer.
930
+
931
+ Returns
932
+ -------
933
+ x: torch.Tensor
934
+ The output tensor
935
+ self_attn : torch.Tensor
936
+ self_attn : torch.Tensor
937
+ The self attention tensor
938
+ """
939
+ # ffn module
940
+ tgt = tgt + 0.5 * self.ffn_module1(tgt)
941
+ # multi-head attention module
942
+ skip = tgt
943
+ x = self.norm1(tgt)
944
+ x, self_attn = self.mha_layer(
945
+ x,
946
+ memory,
947
+ memory,
948
+ attn_mask=memory_mask,
949
+ key_padding_mask=memory_key_padding_mask,
950
+ pos_embs=pos_embs_src,
951
+ )
952
+ x = x + skip
953
+ # convolution module
954
+ x = x + self.convolution_module(x)
955
+ # ffn module
956
+ x = self.norm2(x + 0.5 * self.ffn_module2(x))
957
+ return x, self_attn, self_attn
958
+
959
+
960
+ class ConformerDecoder(nn.Module):
961
+ """This class implements the Transformer decoder.
962
+
963
+ Arguments
964
+ ---------
965
+ num_layers: int
966
+ Number of layers.
967
+ nhead: int
968
+ Number of attention heads.
969
+ d_ffn: int
970
+ Hidden size of self-attention Feed Forward layer.
971
+ d_model: int
972
+ Embedding dimension size.
973
+ kdim: int, optional
974
+ Dimension for key.
975
+ vdim: int, optional
976
+ Dimension for value.
977
+ dropout: float, optional
978
+ Dropout rate.
979
+ activation: torch.nn.Module, optional
980
+ Activation function used after non-bottleneck conv layer.
981
+ kernel_size : int, optional
982
+ Kernel size of convolutional layer.
983
+ bias : bool, optional
984
+ Whether convolution module.
985
+ causal: bool, optional
986
+ Whether the convolutions should be causal or not.
987
+ attention_type: str, optional
988
+ type of attention layer, e.g. regularMHA for regular MultiHeadAttention.
989
+
990
+
991
+ Example
992
+ -------
993
+ >>> src = torch.rand((8, 60, 512))
994
+ >>> tgt = torch.rand((8, 60, 512))
995
+ >>> net = ConformerDecoder(1, 8, 1024, 512, attention_type="regularMHA")
996
+ >>> output, _, _ = net(tgt, src)
997
+ >>> output.shape
998
+ torch.Size([8, 60, 512])
999
+ """
1000
+
1001
+ def __init__(
1002
+ self,
1003
+ num_layers,
1004
+ nhead,
1005
+ d_ffn,
1006
+ d_model,
1007
+ kdim=None,
1008
+ vdim=None,
1009
+ dropout=0.0,
1010
+ activation=Swish,
1011
+ kernel_size=3,
1012
+ bias=True,
1013
+ causal=True,
1014
+ attention_type="RelPosMHAXL",
1015
+ ):
1016
+ super().__init__()
1017
+ self.layers = torch.nn.ModuleList(
1018
+ [
1019
+ ConformerDecoderLayer(
1020
+ d_ffn=d_ffn,
1021
+ nhead=nhead,
1022
+ d_model=d_model,
1023
+ kdim=kdim,
1024
+ vdim=vdim,
1025
+ dropout=dropout,
1026
+ activation=activation,
1027
+ kernel_size=kernel_size,
1028
+ bias=bias,
1029
+ causal=causal,
1030
+ attention_type=attention_type,
1031
+ )
1032
+ for _ in range(num_layers)
1033
+ ]
1034
+ )
1035
+ self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
1036
+
1037
+ def forward(
1038
+ self,
1039
+ tgt,
1040
+ memory,
1041
+ tgt_mask=None,
1042
+ memory_mask=None,
1043
+ tgt_key_padding_mask=None,
1044
+ memory_key_padding_mask=None,
1045
+ pos_embs_tgt=None,
1046
+ pos_embs_src=None,
1047
+ ):
1048
+ """
1049
+ Arguments
1050
+ ---------
1051
+ tgt: torch.Tensor
1052
+ The sequence to the decoder layer.
1053
+ memory: torch.Tensor
1054
+ The sequence from the last layer of the encoder.
1055
+ tgt_mask: torch.Tensor, optional, optional
1056
+ The mask for the tgt sequence.
1057
+ memory_mask: torch.Tensor, optional
1058
+ The mask for the memory sequence.
1059
+ tgt_key_padding_mask : torch.Tensor, optional
1060
+ The mask for the tgt keys per batch.
1061
+ memory_key_padding_mask : torch.Tensor, optional
1062
+ The mask for the memory keys per batch.
1063
+ pos_embs_tgt: torch.Tensor, torch.nn.Module, optional
1064
+ Module or tensor containing the target sequence positional embeddings for each attention layer.
1065
+ pos_embs_src: torch.Tensor, torch.nn.Module, optional
1066
+ Module or tensor containing the source sequence positional embeddings for each attention layer.
1067
+
1068
+ Returns
1069
+ -------
1070
+ output: torch.Tensor
1071
+ Conformer decoder output.
1072
+ self_attns : list
1073
+ Location of self attentions.
1074
+ multihead_attns : list
1075
+ Location of multihead attentions.
1076
+ """
1077
+ output = tgt
1078
+ self_attns, multihead_attns = [], []
1079
+ for dec_layer in self.layers:
1080
+ output, self_attn, multihead_attn = dec_layer(
1081
+ output,
1082
+ memory,
1083
+ tgt_mask=tgt_mask,
1084
+ memory_mask=memory_mask,
1085
+ tgt_key_padding_mask=tgt_key_padding_mask,
1086
+ memory_key_padding_mask=memory_key_padding_mask,
1087
+ pos_embs_tgt=pos_embs_tgt,
1088
+ pos_embs_src=pos_embs_src,
1089
+ )
1090
+ self_attns.append(self_attn)
1091
+ multihead_attns.append(multihead_attn)
1092
+ output = self.norm(output)
1093
+
1094
+ return output, self_attns, multihead_attns
model/modules/Conmamba.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ConMamba encoder and Mamba decoder implementation.
2
+
3
+ Authors
4
+ -------
5
+ * Xilin Jiang 2024
6
+ """
7
+
8
+ import warnings
9
+ from dataclasses import dataclass
10
+ from typing import List, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ import speechbrain as sb
17
+ from speechbrain.nnet.activations import Swish
18
+ from speechbrain.nnet.attention import (
19
+ MultiheadAttention,
20
+ PositionalwiseFeedForward,
21
+ RelPosMHAXL,
22
+ )
23
+ from speechbrain.nnet.hypermixing import HyperMixing
24
+ from speechbrain.nnet.normalization import LayerNorm
25
+ from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
26
+
27
+ # Mamba
28
+ from mamba_ssm import Mamba
29
+ from .mamba.bimamba import Mamba as BiMamba
30
+
31
+
32
+ class ConvolutionModule(nn.Module):
33
+ """This is an implementation of convolution module in Conmamba.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ input_size,
39
+ kernel_size=31,
40
+ bias=True,
41
+ activation=Swish,
42
+ dropout=0.0,
43
+ causal=False,
44
+ dilation=1,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.kernel_size = kernel_size
49
+ self.causal = causal
50
+ self.dilation = dilation
51
+
52
+ if self.causal:
53
+ self.padding = (kernel_size - 1) * 2 ** (dilation - 1)
54
+ else:
55
+ self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2
56
+
57
+ self.layer_norm = nn.LayerNorm(input_size)
58
+ self.bottleneck = nn.Sequential(
59
+ # pointwise
60
+ nn.Conv1d(
61
+ input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias
62
+ ),
63
+ nn.GLU(dim=1),
64
+ )
65
+ # depthwise
66
+ self.conv = nn.Conv1d(
67
+ input_size,
68
+ input_size,
69
+ kernel_size=kernel_size,
70
+ stride=1,
71
+ padding=self.padding,
72
+ dilation=dilation,
73
+ groups=input_size,
74
+ bias=bias,
75
+ )
76
+
77
+ # BatchNorm in the original Conformer replaced with a LayerNorm due to
78
+ # https://github.com/speechbrain/speechbrain/pull/1329
79
+ # see discussion
80
+ # https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884
81
+
82
+ self.after_conv = nn.Sequential(
83
+ nn.LayerNorm(input_size),
84
+ activation(),
85
+ # pointwise
86
+ nn.Linear(input_size, input_size, bias=bias),
87
+ nn.Dropout(dropout),
88
+ )
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask: Optional[torch.Tensor] = None,
94
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
95
+ ):
96
+ """Applies the convolution to an input tensor `x`.
97
+ """
98
+
99
+ if dynchunktrain_config is not None:
100
+ # chances are chunking+causal is unintended; i don't know where it
101
+ # may make sense, but if it does to you, feel free to implement it.
102
+ assert (
103
+ not self.causal
104
+ ), "Chunked convolution not supported with causal padding"
105
+
106
+ assert (
107
+ self.dilation == 1
108
+ ), "Current DynChunkTrain logic does not support dilation != 1"
109
+
110
+ # in a causal convolution, which is not the case here, an output
111
+ # frame would never be able to depend on a input frame from any
112
+ # point in the future.
113
+
114
+ # but with the dynamic chunk convolution, we instead use a "normal"
115
+ # convolution but where, for any output frame, the future beyond the
116
+ # "current" chunk gets masked.
117
+ # see the paper linked in the documentation for details.
118
+
119
+ chunk_size = dynchunktrain_config.chunk_size
120
+ batch_size = x.shape[0]
121
+
122
+ # determine the amount of padding we need to insert at the right of
123
+ # the last chunk so that all chunks end up with the same size.
124
+ if x.shape[1] % chunk_size != 0:
125
+ final_right_padding = chunk_size - (x.shape[1] % chunk_size)
126
+ else:
127
+ final_right_padding = 0
128
+
129
+ # -> [batch_size, t, in_channels]
130
+ out = self.layer_norm(x)
131
+
132
+ # -> [batch_size, in_channels, t] for the CNN
133
+ out = out.transpose(1, 2)
134
+
135
+ # -> [batch_size, in_channels, t] (pointwise)
136
+ out = self.bottleneck(out)
137
+
138
+ # -> [batch_size, in_channels, lc+t+final_right_padding]
139
+ out = F.pad(out, (self.padding, final_right_padding), value=0)
140
+
141
+ # now, make chunks with left context.
142
+ # as a recap to what the above padding and this unfold do, consider
143
+ # each a/b/c letter represents a frame as part of chunks a, b, c.
144
+ # consider a chunk size of 4 and a kernel size of 5 (padding=2):
145
+ #
146
+ # input seq: 00aaaabbbbcc00
147
+ # chunk #1: 00aaaa
148
+ # chunk #2: aabbbb
149
+ # chunk #3: bbcc00
150
+ #
151
+ # a few remarks here:
152
+ # - the left padding gets inserted early so that the unfold logic
153
+ # works trivially
154
+ # - the right 0-padding got inserted as the number of time steps
155
+ # could not be evenly split in `chunk_size` chunks
156
+
157
+ # -> [batch_size, in_channels, num_chunks, lc+chunk_size]
158
+ out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size)
159
+
160
+ # as we manually disable padding in the convolution below, we insert
161
+ # right 0-padding to the chunks, e.g. reusing the above example:
162
+ #
163
+ # chunk #1: 00aaaa00
164
+ # chunk #2: aabbbb00
165
+ # chunk #3: bbcc0000
166
+
167
+ # -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad]
168
+ out = F.pad(out, (0, self.padding), value=0)
169
+
170
+ # the transpose+flatten effectively flattens chunks into the batch
171
+ # dimension to be processed into the time-wise convolution. the
172
+ # chunks will later on be unflattened.
173
+
174
+ # -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad]
175
+ out = out.transpose(1, 2)
176
+
177
+ # -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad]
178
+ out = out.flatten(start_dim=0, end_dim=1)
179
+
180
+ # TODO: experiment around reflect padding, which is difficult
181
+ # because small chunks have too little time steps to reflect from
182
+
183
+ # let's keep backwards compat by pointing at the weights from the
184
+ # already declared Conv1d.
185
+ #
186
+ # still reusing the above example, the convolution will be applied,
187
+ # with the padding truncated on both ends. the following example
188
+ # shows the letter corresponding to the input frame on which the
189
+ # convolution was centered.
190
+ #
191
+ # as you can see, the sum of lengths of all chunks is equal to our
192
+ # input sequence length + `final_right_padding`.
193
+ #
194
+ # chunk #1: aaaa
195
+ # chunk #2: bbbb
196
+ # chunk #3: cc00
197
+
198
+ # -> [batch_size * num_chunks, out_channels, chunk_size]
199
+ out = F.conv1d(
200
+ out,
201
+ weight=self.conv.weight,
202
+ bias=self.conv.bias,
203
+ stride=self.conv.stride,
204
+ padding=0,
205
+ dilation=self.conv.dilation,
206
+ groups=self.conv.groups,
207
+ )
208
+
209
+ # -> [batch_size * num_chunks, chunk_size, out_channels]
210
+ out = out.transpose(1, 2)
211
+
212
+ out = self.after_conv(out)
213
+
214
+ # -> [batch_size, num_chunks, chunk_size, out_channels]
215
+ out = torch.unflatten(out, dim=0, sizes=(batch_size, -1))
216
+
217
+ # -> [batch_size, t + final_right_padding, out_channels]
218
+ out = torch.flatten(out, start_dim=1, end_dim=2)
219
+
220
+ # -> [batch_size, t, out_channels]
221
+ if final_right_padding > 0:
222
+ out = out[:, :-final_right_padding, :]
223
+ else:
224
+ out = self.layer_norm(x)
225
+ out = out.transpose(1, 2)
226
+ out = self.bottleneck(out)
227
+ out = self.conv(out)
228
+
229
+ if self.causal:
230
+ # chomp
231
+ out = out[..., : -self.padding]
232
+
233
+ out = out.transpose(1, 2)
234
+ out = self.after_conv(out)
235
+
236
+ if mask is not None:
237
+ out.masked_fill_(mask, 0.0)
238
+
239
+ return out
240
+
241
+
242
+ class ConmambaEncoderLayer(nn.Module):
243
+ """This is an implementation of Conmamba encoder layer.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ d_model,
249
+ d_ffn,
250
+ kernel_size=31,
251
+ activation=Swish,
252
+ bias=True,
253
+ dropout=0.0,
254
+ causal=False,
255
+ mamba_config=None
256
+ ):
257
+ super().__init__()
258
+ assert mamba_config != None
259
+
260
+ bidirectional = mamba_config.pop('bidirectional')
261
+ if causal or (not bidirectional):
262
+ self.mamba = Mamba(
263
+ d_model=d_model,
264
+ **mamba_config
265
+ )
266
+ else:
267
+ self.mamba = BiMamba(
268
+ d_model=d_model,
269
+ bimamba_type='v2',
270
+ **mamba_config
271
+ )
272
+ mamba_config['bidirectional'] = bidirectional
273
+
274
+ self.convolution_module = ConvolutionModule(
275
+ d_model, kernel_size, bias, activation, dropout, causal=causal
276
+ )
277
+
278
+ self.ffn_module1 = nn.Sequential(
279
+ nn.LayerNorm(d_model),
280
+ PositionalwiseFeedForward(
281
+ d_ffn=d_ffn,
282
+ input_size=d_model,
283
+ dropout=dropout,
284
+ activation=activation,
285
+ ),
286
+ nn.Dropout(dropout),
287
+ )
288
+
289
+ self.ffn_module2 = nn.Sequential(
290
+ nn.LayerNorm(d_model),
291
+ PositionalwiseFeedForward(
292
+ d_ffn=d_ffn,
293
+ input_size=d_model,
294
+ dropout=dropout,
295
+ activation=activation,
296
+ ),
297
+ nn.Dropout(dropout),
298
+ )
299
+
300
+ self.norm1 = LayerNorm(d_model)
301
+ self.norm2 = LayerNorm(d_model)
302
+ self.drop = nn.Dropout(dropout)
303
+
304
+ def forward(
305
+ self,
306
+ x,
307
+ src_mask: Optional[torch.Tensor] = None,
308
+ src_key_padding_mask: Optional[torch.Tensor] = None,
309
+ pos_embs: torch.Tensor = None,
310
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
311
+ ):
312
+ conv_mask: Optional[torch.Tensor] = None
313
+ if src_key_padding_mask is not None:
314
+ conv_mask = src_key_padding_mask.unsqueeze(-1)
315
+
316
+ conv_mask = None
317
+
318
+ # ffn module
319
+ x = x + 0.5 * self.ffn_module1(x)
320
+ # mamba module
321
+ skip = x
322
+ x = self.norm1(x)
323
+ x = self.mamba(x)
324
+ x = x + skip
325
+ # convolution module
326
+ x = x + self.convolution_module(
327
+ x, conv_mask, dynchunktrain_config=dynchunktrain_config
328
+ )
329
+ # ffn module
330
+ x = self.norm2(x + 0.5 * self.ffn_module2(x))
331
+ return x
332
+
333
+
334
+ class ConmambaEncoder(nn.Module):
335
+ """This class implements the Conmamba encoder.
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ num_layers,
341
+ d_model,
342
+ d_ffn,
343
+ kernel_size=31,
344
+ activation=Swish,
345
+ bias=True,
346
+ dropout=0.0,
347
+ causal=False,
348
+ mamba_config=None
349
+ ):
350
+ super().__init__()
351
+ print(f'dropout={str(dropout)} is not used in Mamba.')
352
+
353
+ self.layers = torch.nn.ModuleList(
354
+ [
355
+ ConmambaEncoderLayer(
356
+ d_model=d_model,
357
+ d_ffn=d_ffn,
358
+ dropout=dropout,
359
+ activation=activation,
360
+ kernel_size=kernel_size,
361
+ bias=bias,
362
+ causal=causal,
363
+ mamba_config=mamba_config,
364
+ )
365
+ for i in range(num_layers)
366
+ ]
367
+ )
368
+ self.norm = LayerNorm(d_model, eps=1e-6)
369
+
370
+ def forward(
371
+ self,
372
+ src,
373
+ src_mask: Optional[torch.Tensor] = None,
374
+ src_key_padding_mask: Optional[torch.Tensor] = None,
375
+ pos_embs: Optional[torch.Tensor] = None,
376
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
377
+ ):
378
+ """
379
+ Arguments
380
+ ----------
381
+ src : torch.Tensor
382
+ The sequence to the encoder layer.
383
+ src_mask : torch.Tensor, optional
384
+ The mask for the src sequence.
385
+ src_key_padding_mask : torch.Tensor, optional
386
+ The mask for the src keys per batch.
387
+ pos_embs: torch.Tensor, torch.nn.Module,
388
+ Module or tensor containing the input sequence positional embeddings
389
+ If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
390
+ where S is the sequence length, and E is the embedding dimension.
391
+ dynchunktrain_config: Optional[DynChunkTrainConfig]
392
+ Dynamic Chunk Training configuration object for streaming,
393
+ specifically involved here to apply Dynamic Chunk Convolution to the
394
+ convolution module.
395
+ """
396
+
397
+ output = src
398
+ for enc_layer in self.layers:
399
+ output = enc_layer(
400
+ output,
401
+ src_mask=src_mask,
402
+ src_key_padding_mask=src_key_padding_mask,
403
+ pos_embs=pos_embs,
404
+ dynchunktrain_config=dynchunktrain_config,
405
+ )
406
+ output = self.norm(output)
407
+
408
+ return output, None
409
+
410
+
411
+ class MambaDecoderLayer(nn.Module):
412
+ """This class implements the Mamba decoder layer.
413
+ """
414
+
415
+ def __init__(
416
+ self,
417
+ d_model,
418
+ d_ffn,
419
+ activation=nn.ReLU,
420
+ dropout=0.0,
421
+ normalize_before=False,
422
+ mamba_config=None
423
+ ):
424
+ super().__init__()
425
+
426
+ assert mamba_config != None
427
+
428
+ bidirectional = mamba_config.pop('bidirectional')
429
+
430
+ self.self_mamba = Mamba(
431
+ d_model=d_model,
432
+ **mamba_config
433
+ )
434
+
435
+ self.cross_mamba = Mamba(
436
+ d_model=d_model,
437
+ **mamba_config
438
+ )
439
+
440
+ mamba_config['bidirectional'] = bidirectional
441
+
442
+ self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
443
+ d_ffn=d_ffn,
444
+ input_size=d_model,
445
+ dropout=dropout,
446
+ activation=activation,
447
+ )
448
+
449
+ # normalization layers
450
+ self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
451
+ self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
452
+ self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
453
+ self.dropout1 = torch.nn.Dropout(dropout)
454
+ self.dropout2 = torch.nn.Dropout(dropout)
455
+ self.dropout3 = torch.nn.Dropout(dropout)
456
+
457
+ self.normalize_before = normalize_before
458
+
459
+ def forward(
460
+ self,
461
+ tgt,
462
+ memory,
463
+ tgt_mask=None,
464
+ memory_mask=None,
465
+ tgt_key_padding_mask=None,
466
+ memory_key_padding_mask=None,
467
+ pos_embs_tgt=None,
468
+ pos_embs_src=None,
469
+ ):
470
+ """
471
+ Arguments
472
+ ----------
473
+ tgt: torch.Tensor
474
+ The sequence to the decoder layer (required).
475
+ memory: torch.Tensor
476
+ The sequence from the last layer of the encoder (required).
477
+ tgt_mask: torch.Tensor
478
+ The mask for the tgt sequence (optional).
479
+ memory_mask: torch.Tensor
480
+ The mask for the memory sequence (optional).
481
+ tgt_key_padding_mask: torch.Tensor
482
+ The mask for the tgt keys per batch (optional).
483
+ memory_key_padding_mask: torch.Tensor
484
+ The mask for the memory keys per batch (optional).
485
+ pos_embs_tgt: torch.Tensor
486
+ The positional embeddings for the target (optional).
487
+ pos_embs_src: torch.Tensor
488
+ The positional embeddings for the source (optional).
489
+ """
490
+ if self.normalize_before:
491
+ tgt1 = self.norm1(tgt)
492
+ else:
493
+ tgt1 = tgt
494
+
495
+ # Mamba over the target sequence
496
+ tgt2 = self.self_mamba(tgt1)
497
+
498
+ # add & norm
499
+ tgt = tgt + self.dropout1(tgt2)
500
+ if not self.normalize_before:
501
+ tgt = self.norm1(tgt)
502
+
503
+ if self.normalize_before:
504
+ tgt1 = self.norm2(tgt)
505
+ else:
506
+ tgt1 = tgt
507
+
508
+ # Mamba over key=value + query
509
+ # and only take the last len(query) tokens
510
+ tgt2 = self.cross_mamba(torch.cat([memory, tgt1], dim=1))[:, -tgt1.shape[1]:]
511
+
512
+ # add & norm
513
+ tgt = tgt + self.dropout2(tgt2)
514
+ if not self.normalize_before:
515
+ tgt = self.norm2(tgt)
516
+
517
+ if self.normalize_before:
518
+ tgt1 = self.norm3(tgt)
519
+ else:
520
+ tgt1 = tgt
521
+
522
+ tgt2 = self.pos_ffn(tgt1)
523
+
524
+ # add & norm
525
+ tgt = tgt + self.dropout3(tgt2)
526
+ if not self.normalize_before:
527
+ tgt = self.norm3(tgt)
528
+
529
+ return tgt, None, None
530
+
531
+
532
+ class MambaDecoder(nn.Module):
533
+ """This class implements the Mamba decoder.
534
+ """
535
+
536
+ def __init__(
537
+ self,
538
+ num_layers,
539
+ d_model,
540
+ d_ffn,
541
+ activation=nn.ReLU,
542
+ dropout=0.0,
543
+ normalize_before=False,
544
+ mamba_config=None
545
+ ):
546
+ super().__init__()
547
+ self.layers = torch.nn.ModuleList(
548
+ [
549
+ MambaDecoderLayer(
550
+ d_model=d_model,
551
+ d_ffn=d_ffn,
552
+ activation=activation,
553
+ dropout=dropout,
554
+ normalize_before=normalize_before,
555
+ mamba_config=mamba_config
556
+ )
557
+ for _ in range(num_layers)
558
+ ]
559
+ )
560
+ self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
561
+
562
+ def forward(
563
+ self,
564
+ tgt,
565
+ memory,
566
+ tgt_mask=None,
567
+ memory_mask=None,
568
+ tgt_key_padding_mask=None,
569
+ memory_key_padding_mask=None,
570
+ pos_embs_tgt=None,
571
+ pos_embs_src=None,
572
+ ):
573
+ """
574
+ Arguments
575
+ ----------
576
+ tgt : torch.Tensor
577
+ The sequence to the decoder layer (required).
578
+ memory : torch.Tensor
579
+ The sequence from the last layer of the encoder (required).
580
+ tgt_mask : torch.Tensor
581
+ The mask for the tgt sequence (optional).
582
+ memory_mask : torch.Tensor
583
+ The mask for the memory sequence (optional).
584
+ tgt_key_padding_mask : torch.Tensor
585
+ The mask for the tgt keys per batch (optional).
586
+ memory_key_padding_mask : torch.Tensor
587
+ The mask for the memory keys per batch (optional).
588
+ pos_embs_tgt : torch.Tensor
589
+ The positional embeddings for the target (optional).
590
+ pos_embs_src : torch.Tensor
591
+ The positional embeddings for the source (optional).
592
+ """
593
+ output = tgt
594
+ for dec_layer in self.layers:
595
+ output, _, _ = dec_layer(
596
+ output,
597
+ memory,
598
+ tgt_mask=tgt_mask,
599
+ memory_mask=memory_mask,
600
+ tgt_key_padding_mask=tgt_key_padding_mask,
601
+ memory_key_padding_mask=memory_key_padding_mask,
602
+ pos_embs_tgt=pos_embs_tgt,
603
+ pos_embs_src=pos_embs_src,
604
+ )
605
+ output = self.norm(output)
606
+
607
+ return output, [None], [None]
model/modules/Transformer.py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Added ConMamba and Mamba
2
+
3
+ Authors
4
+ * Xilin Jiang 2024
5
+ """
6
+
7
+ """Transformer implementation in the SpeechBrain style.
8
+
9
+ Authors
10
+ * Jianyuan Zhong 2020
11
+ * Samuele Cornell 2021
12
+ """
13
+
14
+ import math
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ import speechbrain as sb
22
+ from speechbrain.nnet.activations import Swish
23
+ from speechbrain.nnet.attention import RelPosEncXL
24
+ from speechbrain.nnet.CNN import Conv1d
25
+
26
+ from modules.Conformer import ConformerEncoder
27
+ from modules.Conmamba import ConmambaEncoder, MambaDecoder
28
+
29
+
30
+ class TransformerInterface(nn.Module):
31
+ """This is an interface for transformer model.
32
+ Users can modify the attributes and define the forward function as
33
+ needed according to their own tasks.
34
+ The architecture is based on the paper "Attention Is All You Need":
35
+ https://arxiv.org/pdf/1706.03762.pdf
36
+
37
+ Arguments
38
+ ---------
39
+ d_model: int
40
+ The number of expected features in the encoder/decoder inputs (default=512).
41
+ nhead: int
42
+ The number of heads in the multi-head attention models (default=8).
43
+ num_encoder_layers: int, optional
44
+ The number of encoder layers in1ì the encoder.
45
+ num_decoder_layers: int, optional
46
+ The number of decoder layers in the decoder.
47
+ d_ffn: int, optional
48
+ The dimension of the feedforward network model hidden layer.
49
+ dropout: int, optional
50
+ The dropout value.
51
+ activation: torch.nn.Module, optional
52
+ The activation function for Feed-Forward Network layer,
53
+ e.g., relu or gelu or swish.
54
+ custom_src_module: torch.nn.Module, optional
55
+ Module that processes the src features to expected feature dim.
56
+ custom_tgt_module: torch.nn.Module, optional
57
+ Module that processes the src features to expected feature dim.
58
+ positional_encoding: str, optional
59
+ Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings.
60
+ normalize_before: bool, optional
61
+ Whether normalization should be applied before or after MHA or FFN in Transformer layers.
62
+ Defaults to True as this was shown to lead to better performance and training stability.
63
+ kernel_size: int, optional
64
+ Kernel size in convolutional layers when Conformer is used.
65
+ bias: bool, optional
66
+ Whether to use bias in Conformer convolutional layers.
67
+ encoder_module: str, optional
68
+ Choose between Branchformer, Conformer, ConMamba, and Transformer for the encoder.
69
+ decoder_module: str, optional
70
+ Choose between Mamba and Transformer for the decoder.
71
+ conformer_activation: torch.nn.Module, optional
72
+ Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.
73
+ branchformer_activation: torch.nn.Module, optional
74
+ Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.
75
+ attention_type: str, optional
76
+ Type of attention layer used in all Transformer or Conformer layers.
77
+ e.g. regularMHA or RelPosMHA.
78
+ max_length: int, optional
79
+ Max length for the target and source sequence in input.
80
+ Used for positional encodings.
81
+ causal: bool, optional
82
+ Whether the encoder should be causal or not (the decoder is always causal).
83
+ If causal the Conformer convolutional layer is causal.
84
+ encoder_kdim: int, optional
85
+ Dimension of the key for the encoder.
86
+ encoder_vdim: int, optional
87
+ Dimension of the value for the encoder.
88
+ decoder_kdim: int, optional
89
+ Dimension of the key for the decoder.
90
+ decoder_vdim: int, optional
91
+ Dimension of the value for the decoder.
92
+ csgu_linear_units: int, optional
93
+ Number of neurons in the hidden linear units of the CSGU Module.
94
+ -> Branchformer
95
+ gate_activation: torch.nn.Module, optional
96
+ Activation function used at the gate of the CSGU module.
97
+ -> Branchformer
98
+ use_linear_after_conv: bool, optional
99
+ If True, will apply a linear transformation of size input_size//2.
100
+ -> Branchformer
101
+ mamba_config: dict, optional
102
+ Mamba parameters if encoder_module or decoder_module is Mamba or ConMamba
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ d_model=512,
108
+ nhead=8,
109
+ num_encoder_layers=6,
110
+ num_decoder_layers=6,
111
+ d_ffn=2048,
112
+ dropout=0.1,
113
+ activation=nn.ReLU,
114
+ custom_src_module=None,
115
+ custom_tgt_module=None,
116
+ positional_encoding="fixed_abs_sine",
117
+ normalize_before=True,
118
+ kernel_size: Optional[int] = 31,
119
+ bias: Optional[bool] = True,
120
+ encoder_module: Optional[str] = "transformer",
121
+ decoder_module: Optional[str] = "transformer",
122
+ conformer_activation: Optional[nn.Module] = Swish,
123
+ branchformer_activation: Optional[nn.Module] = nn.GELU,
124
+ attention_type: Optional[str] = "regularMHA",
125
+ max_length: Optional[int] = 2500,
126
+ causal: Optional[bool] = False,
127
+ encoder_kdim: Optional[int] = None,
128
+ encoder_vdim: Optional[int] = None,
129
+ decoder_kdim: Optional[int] = None,
130
+ decoder_vdim: Optional[int] = None,
131
+ csgu_linear_units: Optional[int] = 3072,
132
+ gate_activation: Optional[nn.Module] = nn.Identity,
133
+ use_linear_after_conv: Optional[bool] = False,
134
+ mamba_config=None
135
+ ):
136
+ super().__init__()
137
+ self.causal = causal
138
+ self.attention_type = attention_type
139
+ self.positional_encoding_type = positional_encoding
140
+ self.encoder_kdim = encoder_kdim
141
+ self.encoder_vdim = encoder_vdim
142
+ self.decoder_kdim = decoder_kdim
143
+ self.decoder_vdim = decoder_vdim
144
+
145
+ assert attention_type in ["regularMHA", "RelPosMHAXL", "hypermixing"]
146
+ assert positional_encoding in ["fixed_abs_sine", None]
147
+
148
+ assert (
149
+ num_encoder_layers + num_decoder_layers > 0
150
+ ), "number of encoder layers and number of decoder layers cannot both be 0!"
151
+
152
+ if positional_encoding == "fixed_abs_sine":
153
+ self.positional_encoding = PositionalEncoding(d_model, max_length)
154
+ elif positional_encoding is None:
155
+ pass
156
+ # no positional encodings
157
+
158
+ # overrides any other pos_embedding
159
+ if attention_type == "RelPosMHAXL":
160
+ self.positional_encoding = RelPosEncXL(d_model)
161
+ self.positional_encoding_decoder = PositionalEncoding(
162
+ d_model, max_length
163
+ )
164
+
165
+ # initialize the encoder
166
+ if num_encoder_layers > 0:
167
+ if custom_src_module is not None:
168
+ self.custom_src_module = custom_src_module(d_model)
169
+ if encoder_module == "transformer":
170
+ self.encoder = TransformerEncoder(
171
+ nhead=nhead,
172
+ num_layers=num_encoder_layers,
173
+ d_ffn=d_ffn,
174
+ d_model=d_model,
175
+ dropout=dropout,
176
+ activation=activation,
177
+ normalize_before=normalize_before,
178
+ causal=self.causal,
179
+ attention_type=self.attention_type,
180
+ kdim=self.encoder_kdim,
181
+ vdim=self.encoder_vdim,
182
+ )
183
+ elif encoder_module == "conformer":
184
+ self.encoder = ConformerEncoder(
185
+ nhead=nhead,
186
+ num_layers=num_encoder_layers,
187
+ d_ffn=d_ffn,
188
+ d_model=d_model,
189
+ dropout=dropout,
190
+ activation=conformer_activation,
191
+ kernel_size=kernel_size,
192
+ bias=bias,
193
+ causal=self.causal,
194
+ attention_type=self.attention_type,
195
+ )
196
+ assert (
197
+ normalize_before
198
+ ), "normalize_before must be True for Conformer"
199
+
200
+ assert (
201
+ conformer_activation is not None
202
+ ), "conformer_activation must not be None"
203
+ elif encoder_module == "branchformer":
204
+ self.encoder = BranchformerEncoder(
205
+ nhead=nhead,
206
+ num_layers=num_encoder_layers,
207
+ d_model=d_model,
208
+ dropout=dropout,
209
+ activation=branchformer_activation,
210
+ kernel_size=kernel_size,
211
+ attention_type=self.attention_type,
212
+ csgu_linear_units=csgu_linear_units,
213
+ gate_activation=gate_activation,
214
+ use_linear_after_conv=use_linear_after_conv,
215
+ )
216
+ elif encoder_module == "conmamba":
217
+ self.encoder = ConmambaEncoder(
218
+ num_layers=num_encoder_layers,
219
+ d_model=d_model,
220
+ d_ffn=d_ffn,
221
+ dropout=dropout,
222
+ activation=branchformer_activation,
223
+ kernel_size=kernel_size,
224
+ bias=bias,
225
+ causal=self.causal,
226
+ mamba_config=mamba_config
227
+ )
228
+ assert (
229
+ normalize_before
230
+ ), "normalize_before must be True for Conmamba"
231
+
232
+ assert (
233
+ conformer_activation is not None
234
+ ), "conformer_activation must not be None"
235
+
236
+ # initialize the decoder
237
+ if num_decoder_layers > 0:
238
+ if custom_tgt_module is not None:
239
+ self.custom_tgt_module = custom_tgt_module(d_model)
240
+ if decoder_module == 'transformer':
241
+ self.decoder = TransformerDecoder(
242
+ num_layers=num_decoder_layers,
243
+ nhead=nhead,
244
+ d_ffn=d_ffn,
245
+ d_model=d_model,
246
+ dropout=dropout,
247
+ activation=activation,
248
+ normalize_before=normalize_before,
249
+ causal=True,
250
+ attention_type="regularMHA", # always use regular attention in decoder
251
+ kdim=self.decoder_kdim,
252
+ vdim=self.decoder_vdim,
253
+ )
254
+ elif decoder_module in ['mamba']:
255
+ self.decoder = MambaDecoder(
256
+ num_layers=num_decoder_layers,
257
+ d_ffn=d_ffn,
258
+ d_model=d_model,
259
+ activation=activation,
260
+ dropout=dropout,
261
+ normalize_before=normalize_before,
262
+ mamba_config=mamba_config
263
+ )
264
+ else:
265
+ raise NotImplementedError(decoder_module)
266
+
267
+ def forward(self, **kwags):
268
+ """Users should modify this function according to their own tasks."""
269
+ raise NotImplementedError
270
+
271
+
272
+ class PositionalEncoding(nn.Module):
273
+ """This class implements the absolute sinusoidal positional encoding function.
274
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
275
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
276
+
277
+ Arguments
278
+ ---------
279
+ input_size: int
280
+ Embedding dimension.
281
+ max_len : int, optional
282
+ Max length of the input sequences (default 2500).
283
+
284
+ Example
285
+ -------
286
+ >>> a = torch.rand((8, 120, 512))
287
+ >>> enc = PositionalEncoding(input_size=a.shape[-1])
288
+ >>> b = enc(a)
289
+ >>> b.shape
290
+ torch.Size([1, 120, 512])
291
+ """
292
+
293
+ def __init__(self, input_size, max_len=2500):
294
+ super().__init__()
295
+ if input_size % 2 != 0:
296
+ raise ValueError(
297
+ f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
298
+ )
299
+ self.max_len = max_len
300
+ pe = torch.zeros(self.max_len, input_size, requires_grad=False)
301
+ positions = torch.arange(0, self.max_len).unsqueeze(1).float()
302
+ denominator = torch.exp(
303
+ torch.arange(0, input_size, 2).float()
304
+ * -(math.log(10000.0) / input_size)
305
+ )
306
+
307
+ pe[:, 0::2] = torch.sin(positions * denominator)
308
+ pe[:, 1::2] = torch.cos(positions * denominator)
309
+ pe = pe.unsqueeze(0)
310
+ self.register_buffer("pe", pe)
311
+
312
+ def forward(self, x):
313
+ """
314
+ Arguments
315
+ ---------
316
+ x : torch.Tensor
317
+ Input feature shape (batch, time, fea)
318
+
319
+ Returns
320
+ -------
321
+ The positional encoding.
322
+ """
323
+ return self.pe[:, : x.size(1)].clone().detach()
324
+
325
+
326
+ class TransformerEncoderLayer(nn.Module):
327
+ """This is an implementation of self-attention encoder layer.
328
+
329
+ Arguments
330
+ ---------
331
+ d_ffn: int, optional
332
+ The dimension of the feedforward network model hidden layer.
333
+ nhead: int
334
+ The number of heads in the multi-head attention models (default=8).
335
+ d_model: int
336
+ The number of expected features in the encoder/decoder inputs (default=512).
337
+ kdim: int, optional
338
+ Dimension of the key.
339
+ vdim: int, optional
340
+ Dimension of the value.
341
+ dropout: int, optional
342
+ The dropout value.
343
+ activation: torch.nn.Module, optional
344
+ The activation function for Feed-Forward Network layer,
345
+ e.g., relu or gelu or swish.
346
+ normalize_before: bool, optional
347
+ Whether normalization should be applied before or after MHA or FFN in Transformer layers.
348
+ Defaults to True as this was shown to lead to better performance and training stability.
349
+ attention_type: str, optional
350
+ Type of attention layer used in all Transformer or Conformer layers.
351
+ e.g. regularMHA or RelPosMHA.
352
+ ffn_type: str
353
+ type of ffn: regularFFN/1dcnn
354
+ ffn_cnn_kernel_size_list: list of int
355
+ kernel size of 2 1d-convs if ffn_type is 1dcnn
356
+ causal: bool, optional
357
+ Whether the encoder should be causal or not (the decoder is always causal).
358
+ If causal the Conformer convolutional layer is causal.
359
+
360
+ Example
361
+ -------
362
+ >>> import torch
363
+ >>> x = torch.rand((8, 60, 512))
364
+ >>> net = TransformerEncoderLayer(512, 8, d_model=512)
365
+ >>> output = net(x)
366
+ >>> output[0].shape
367
+ torch.Size([8, 60, 512])
368
+ """
369
+
370
+ def __init__(
371
+ self,
372
+ d_ffn,
373
+ nhead,
374
+ d_model,
375
+ kdim=None,
376
+ vdim=None,
377
+ dropout=0.0,
378
+ activation=nn.ReLU,
379
+ normalize_before=False,
380
+ attention_type="regularMHA",
381
+ ffn_type="regularFFN",
382
+ ffn_cnn_kernel_size_list=[3, 3],
383
+ causal=False,
384
+ ):
385
+ super().__init__()
386
+
387
+ if attention_type == "regularMHA":
388
+ self.self_att = sb.nnet.attention.MultiheadAttention(
389
+ nhead=nhead,
390
+ d_model=d_model,
391
+ dropout=dropout,
392
+ kdim=kdim,
393
+ vdim=vdim,
394
+ )
395
+
396
+ elif attention_type == "RelPosMHAXL":
397
+ self.self_att = sb.nnet.attention.RelPosMHAXL(
398
+ d_model, nhead, dropout, mask_pos_future=causal
399
+ )
400
+ elif attention_type == "hypermixing":
401
+ self.self_att = sb.nnet.hypermixing.HyperMixing(
402
+ input_output_dim=d_model,
403
+ hypernet_size=d_ffn,
404
+ tied=False,
405
+ num_heads=nhead,
406
+ fix_tm_hidden_size=False,
407
+ )
408
+
409
+ if ffn_type == "regularFFN":
410
+ self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
411
+ d_ffn=d_ffn,
412
+ input_size=d_model,
413
+ dropout=dropout,
414
+ activation=activation,
415
+ )
416
+ elif ffn_type == "1dcnn":
417
+ self.pos_ffn = nn.Sequential(
418
+ Conv1d(
419
+ in_channels=d_model,
420
+ out_channels=d_ffn,
421
+ kernel_size=ffn_cnn_kernel_size_list[0],
422
+ padding="causal" if causal else "same",
423
+ ),
424
+ nn.ReLU(),
425
+ Conv1d(
426
+ in_channels=d_ffn,
427
+ out_channels=d_model,
428
+ kernel_size=ffn_cnn_kernel_size_list[1],
429
+ padding="causal" if causal else "same",
430
+ ),
431
+ )
432
+
433
+ self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
434
+ self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
435
+ self.dropout1 = torch.nn.Dropout(dropout)
436
+ self.dropout2 = torch.nn.Dropout(dropout)
437
+
438
+ self.normalize_before = normalize_before
439
+ self.pos_ffn_type = ffn_type
440
+
441
+ def forward(
442
+ self,
443
+ src,
444
+ src_mask: Optional[torch.Tensor] = None,
445
+ src_key_padding_mask: Optional[torch.Tensor] = None,
446
+ pos_embs: Optional[torch.Tensor] = None,
447
+ ):
448
+ """
449
+ Arguments
450
+ ---------
451
+ src : torch.Tensor
452
+ The sequence to the encoder layer.
453
+ src_mask : torch.Tensor
454
+ The mask for the src query for each example in the batch.
455
+ src_key_padding_mask : torch.Tensor, optional
456
+ The mask for the src keys for each example in the batch.
457
+ pos_embs: torch.Tensor, optional
458
+ The positional embeddings tensor.
459
+
460
+ Returns
461
+ -------
462
+ output : torch.Tensor
463
+ The output of the transformer encoder layer.
464
+ """
465
+
466
+ if self.normalize_before:
467
+ src1 = self.norm1(src)
468
+ else:
469
+ src1 = src
470
+
471
+ output, self_attn = self.self_att(
472
+ src1,
473
+ src1,
474
+ src1,
475
+ attn_mask=src_mask,
476
+ key_padding_mask=src_key_padding_mask,
477
+ pos_embs=pos_embs,
478
+ )
479
+
480
+ # add & norm
481
+ src = src + self.dropout1(output)
482
+ if not self.normalize_before:
483
+ src = self.norm1(src)
484
+
485
+ if self.normalize_before:
486
+ src1 = self.norm2(src)
487
+ else:
488
+ src1 = src
489
+ output = self.pos_ffn(src1)
490
+
491
+ # add & norm
492
+ output = src + self.dropout2(output)
493
+ if not self.normalize_before:
494
+ output = self.norm2(output)
495
+ return output, self_attn
496
+
497
+
498
+ class TransformerEncoder(nn.Module):
499
+ """This class implements the transformer encoder.
500
+
501
+ Arguments
502
+ ---------
503
+ num_layers : int
504
+ Number of transformer layers to include.
505
+ nhead : int
506
+ Number of attention heads.
507
+ d_ffn : int
508
+ Hidden size of self-attention Feed Forward layer.
509
+ input_shape : tuple
510
+ Expected shape of the input.
511
+ d_model : int
512
+ The dimension of the input embedding.
513
+ kdim : int
514
+ Dimension for key (Optional).
515
+ vdim : int
516
+ Dimension for value (Optional).
517
+ dropout : float
518
+ Dropout for the encoder (Optional).
519
+ activation: torch.nn.Module, optional
520
+ The activation function for Feed-Forward Network layer,
521
+ e.g., relu or gelu or swish.
522
+ normalize_before: bool, optional
523
+ Whether normalization should be applied before or after MHA or FFN in Transformer layers.
524
+ Defaults to True as this was shown to lead to better performance and training stability.
525
+ causal: bool, optional
526
+ Whether the encoder should be causal or not (the decoder is always causal).
527
+ If causal the Conformer convolutional layer is causal.
528
+ layerdrop_prob: float
529
+ The probability to drop an entire layer
530
+ attention_type: str, optional
531
+ Type of attention layer used in all Transformer or Conformer layers.
532
+ e.g. regularMHA or RelPosMHA.
533
+ ffn_type: str
534
+ type of ffn: regularFFN/1dcnn
535
+ ffn_cnn_kernel_size_list: list of int
536
+ conv kernel size of 2 1d-convs if ffn_type is 1dcnn
537
+
538
+ Example
539
+ -------
540
+ >>> import torch
541
+ >>> x = torch.rand((8, 60, 512))
542
+ >>> net = TransformerEncoder(1, 8, 512, d_model=512)
543
+ >>> output, _ = net(x)
544
+ >>> output.shape
545
+ torch.Size([8, 60, 512])
546
+ """
547
+
548
+ def __init__(
549
+ self,
550
+ num_layers,
551
+ nhead,
552
+ d_ffn,
553
+ input_shape=None,
554
+ d_model=None,
555
+ kdim=None,
556
+ vdim=None,
557
+ dropout=0.0,
558
+ activation=nn.ReLU,
559
+ normalize_before=False,
560
+ causal=False,
561
+ layerdrop_prob=0.0,
562
+ attention_type="regularMHA",
563
+ ffn_type="regularFFN",
564
+ ffn_cnn_kernel_size_list=[3, 3],
565
+ ):
566
+ super().__init__()
567
+
568
+ self.layers = torch.nn.ModuleList(
569
+ [
570
+ TransformerEncoderLayer(
571
+ d_ffn=d_ffn,
572
+ nhead=nhead,
573
+ d_model=d_model,
574
+ kdim=kdim,
575
+ vdim=vdim,
576
+ dropout=dropout,
577
+ activation=activation,
578
+ normalize_before=normalize_before,
579
+ causal=causal,
580
+ attention_type=attention_type,
581
+ ffn_type=ffn_type,
582
+ ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
583
+ )
584
+ for i in range(num_layers)
585
+ ]
586
+ )
587
+ self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
588
+ self.layerdrop_prob = layerdrop_prob
589
+ self.rng = np.random.default_rng()
590
+
591
+ def forward(
592
+ self,
593
+ src,
594
+ src_mask: Optional[torch.Tensor] = None,
595
+ src_key_padding_mask: Optional[torch.Tensor] = None,
596
+ pos_embs: Optional[torch.Tensor] = None,
597
+ dynchunktrain_config=None,
598
+ ):
599
+ """
600
+ Arguments
601
+ ---------
602
+ src : torch.Tensor
603
+ The sequence to the encoder layer (required).
604
+ src_mask : torch.Tensor
605
+ The mask for the src sequence (optional).
606
+ src_key_padding_mask : torch.Tensor
607
+ The mask for the src keys per batch (optional).
608
+ pos_embs : torch.Tensor
609
+ The positional embedding tensor
610
+ dynchunktrain_config : config
611
+ Not supported for this encoder.
612
+
613
+ Returns
614
+ -------
615
+ output : torch.Tensor
616
+ The output of the transformer.
617
+ attention_lst : list
618
+ The attention values.
619
+ """
620
+ assert (
621
+ dynchunktrain_config is None
622
+ ), "Dynamic Chunk Training unsupported for this encoder"
623
+
624
+ output = src
625
+ if self.layerdrop_prob > 0.0:
626
+ keep_probs = self.rng.random(len(self.layers))
627
+ else:
628
+ keep_probs = None
629
+ attention_lst = []
630
+ for i, enc_layer in enumerate(self.layers):
631
+ if (
632
+ not self.training
633
+ or self.layerdrop_prob == 0.0
634
+ or keep_probs[i] > self.layerdrop_prob
635
+ ):
636
+ output, attention = enc_layer(
637
+ output,
638
+ src_mask=src_mask,
639
+ src_key_padding_mask=src_key_padding_mask,
640
+ pos_embs=pos_embs,
641
+ )
642
+
643
+ attention_lst.append(attention)
644
+ output = self.norm(output)
645
+ return output, attention_lst
646
+
647
+
648
+ class TransformerDecoderLayer(nn.Module):
649
+ """This class implements the self-attention decoder layer.
650
+
651
+ Arguments
652
+ ---------
653
+ d_ffn : int
654
+ Hidden size of self-attention Feed Forward layer.
655
+ nhead : int
656
+ Number of attention heads.
657
+ d_model : int
658
+ Dimension of the model.
659
+ kdim : int
660
+ Dimension for key (optional).
661
+ vdim : int
662
+ Dimension for value (optional).
663
+ dropout : float
664
+ Dropout for the decoder (optional).
665
+ activation : Callable
666
+ Function to use between layers, default nn.ReLU
667
+ normalize_before : bool
668
+ Whether to normalize before layers.
669
+ attention_type : str
670
+ Type of attention to use, "regularMHA" or "RelPosMHAXL"
671
+ causal : bool
672
+ Whether to mask future positions.
673
+
674
+ Example
675
+ -------
676
+ >>> src = torch.rand((8, 60, 512))
677
+ >>> tgt = torch.rand((8, 60, 512))
678
+ >>> net = TransformerDecoderLayer(1024, 8, d_model=512)
679
+ >>> output, self_attn, multihead_attn = net(src, tgt)
680
+ >>> output.shape
681
+ torch.Size([8, 60, 512])
682
+ """
683
+
684
+ def __init__(
685
+ self,
686
+ d_ffn,
687
+ nhead,
688
+ d_model,
689
+ kdim=None,
690
+ vdim=None,
691
+ dropout=0.0,
692
+ activation=nn.ReLU,
693
+ normalize_before=False,
694
+ attention_type="regularMHA",
695
+ causal=None,
696
+ ):
697
+ super().__init__()
698
+ self.nhead = nhead
699
+
700
+ if attention_type == "regularMHA":
701
+ self.self_attn = sb.nnet.attention.MultiheadAttention(
702
+ nhead=nhead,
703
+ d_model=d_model,
704
+ kdim=kdim,
705
+ vdim=vdim,
706
+ dropout=dropout,
707
+ )
708
+ self.multihead_attn = sb.nnet.attention.MultiheadAttention(
709
+ nhead=nhead,
710
+ d_model=d_model,
711
+ kdim=kdim,
712
+ vdim=vdim,
713
+ dropout=dropout,
714
+ )
715
+
716
+ elif attention_type == "RelPosMHAXL":
717
+ self.self_attn = sb.nnet.attention.RelPosMHAXL(
718
+ d_model, nhead, dropout, mask_pos_future=causal
719
+ )
720
+ self.multihead_attn = sb.nnet.attention.RelPosMHAXL(
721
+ d_model, nhead, dropout, mask_pos_future=causal
722
+ )
723
+
724
+ self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
725
+ d_ffn=d_ffn,
726
+ input_size=d_model,
727
+ dropout=dropout,
728
+ activation=activation,
729
+ )
730
+
731
+ # normalization layers
732
+ self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
733
+ self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
734
+ self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
735
+ self.dropout1 = torch.nn.Dropout(dropout)
736
+ self.dropout2 = torch.nn.Dropout(dropout)
737
+ self.dropout3 = torch.nn.Dropout(dropout)
738
+
739
+ self.normalize_before = normalize_before
740
+
741
+ def forward(
742
+ self,
743
+ tgt,
744
+ memory,
745
+ tgt_mask=None,
746
+ memory_mask=None,
747
+ tgt_key_padding_mask=None,
748
+ memory_key_padding_mask=None,
749
+ pos_embs_tgt=None,
750
+ pos_embs_src=None,
751
+ ):
752
+ """
753
+ Arguments
754
+ ----------
755
+ tgt: torch.Tensor
756
+ The sequence to the decoder layer (required).
757
+ memory: torch.Tensor
758
+ The sequence from the last layer of the encoder (required).
759
+ tgt_mask: torch.Tensor
760
+ The mask for the tgt sequence (optional).
761
+ memory_mask: torch.Tensor
762
+ The mask for the memory sequence (optional).
763
+ tgt_key_padding_mask: torch.Tensor
764
+ The mask for the tgt keys per batch (optional).
765
+ memory_key_padding_mask: torch.Tensor
766
+ The mask for the memory keys per batch (optional).
767
+ pos_embs_tgt: torch.Tensor
768
+ The positional embeddings for the target (optional).
769
+ pos_embs_src: torch.Tensor
770
+ The positional embeddings for the source (optional).
771
+ """
772
+ if self.normalize_before:
773
+ tgt1 = self.norm1(tgt)
774
+ else:
775
+ tgt1 = tgt
776
+
777
+ # self-attention over the target sequence
778
+ tgt2, self_attn = self.self_attn(
779
+ query=tgt1,
780
+ key=tgt1,
781
+ value=tgt1,
782
+ attn_mask=tgt_mask,
783
+ key_padding_mask=tgt_key_padding_mask,
784
+ pos_embs=pos_embs_tgt,
785
+ )
786
+
787
+ # add & norm
788
+ tgt = tgt + self.dropout1(tgt2)
789
+ if not self.normalize_before:
790
+ tgt = self.norm1(tgt)
791
+
792
+ if self.normalize_before:
793
+ tgt1 = self.norm2(tgt)
794
+ else:
795
+ tgt1 = tgt
796
+
797
+ # multi-head attention over the target sequence and encoder states
798
+
799
+ tgt2, multihead_attention = self.multihead_attn(
800
+ query=tgt1,
801
+ key=memory,
802
+ value=memory,
803
+ attn_mask=memory_mask,
804
+ key_padding_mask=memory_key_padding_mask,
805
+ pos_embs=pos_embs_src,
806
+ )
807
+
808
+ # add & norm
809
+ tgt = tgt + self.dropout2(tgt2)
810
+ if not self.normalize_before:
811
+ tgt = self.norm2(tgt)
812
+
813
+ if self.normalize_before:
814
+ tgt1 = self.norm3(tgt)
815
+ else:
816
+ tgt1 = tgt
817
+
818
+ tgt2 = self.pos_ffn(tgt1)
819
+
820
+ # add & norm
821
+ tgt = tgt + self.dropout3(tgt2)
822
+ if not self.normalize_before:
823
+ tgt = self.norm3(tgt)
824
+
825
+ return tgt, self_attn, multihead_attention
826
+
827
+
828
+ class TransformerDecoder(nn.Module):
829
+ """This class implements the Transformer decoder.
830
+
831
+ Arguments
832
+ ---------
833
+ num_layers : int
834
+ Number of transformer layers for the decoder.
835
+ nhead : int
836
+ Number of attention heads.
837
+ d_ffn : int
838
+ Hidden size of self-attention Feed Forward layer.
839
+ d_model : int
840
+ Dimension of the model.
841
+ kdim : int, optional
842
+ Dimension for key (Optional).
843
+ vdim : int, optional
844
+ Dimension for value (Optional).
845
+ dropout : float, optional
846
+ Dropout for the decoder (Optional).
847
+ activation : Callable
848
+ The function to apply between layers, default nn.ReLU
849
+ normalize_before : bool
850
+ Whether to normalize before layers.
851
+ causal : bool
852
+ Whether to allow future information in decoding.
853
+ attention_type : str
854
+ Type of attention to use, "regularMHA" or "RelPosMHAXL"
855
+
856
+ Example
857
+ -------
858
+ >>> src = torch.rand((8, 60, 512))
859
+ >>> tgt = torch.rand((8, 60, 512))
860
+ >>> net = TransformerDecoder(1, 8, 1024, d_model=512)
861
+ >>> output, _, _ = net(src, tgt)
862
+ >>> output.shape
863
+ torch.Size([8, 60, 512])
864
+ """
865
+
866
+ def __init__(
867
+ self,
868
+ num_layers,
869
+ nhead,
870
+ d_ffn,
871
+ d_model,
872
+ kdim=None,
873
+ vdim=None,
874
+ dropout=0.0,
875
+ activation=nn.ReLU,
876
+ normalize_before=False,
877
+ causal=False,
878
+ attention_type="regularMHA",
879
+ ):
880
+ super().__init__()
881
+ self.layers = torch.nn.ModuleList(
882
+ [
883
+ TransformerDecoderLayer(
884
+ d_ffn=d_ffn,
885
+ nhead=nhead,
886
+ d_model=d_model,
887
+ kdim=kdim,
888
+ vdim=vdim,
889
+ dropout=dropout,
890
+ activation=activation,
891
+ normalize_before=normalize_before,
892
+ causal=causal,
893
+ attention_type=attention_type,
894
+ )
895
+ for _ in range(num_layers)
896
+ ]
897
+ )
898
+ self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
899
+
900
+ def forward(
901
+ self,
902
+ tgt,
903
+ memory,
904
+ tgt_mask=None,
905
+ memory_mask=None,
906
+ tgt_key_padding_mask=None,
907
+ memory_key_padding_mask=None,
908
+ pos_embs_tgt=None,
909
+ pos_embs_src=None,
910
+ ):
911
+ """
912
+ Arguments
913
+ ----------
914
+ tgt : torch.Tensor
915
+ The sequence to the decoder layer (required).
916
+ memory : torch.Tensor
917
+ The sequence from the last layer of the encoder (required).
918
+ tgt_mask : torch.Tensor
919
+ The mask for the tgt sequence (optional).
920
+ memory_mask : torch.Tensor
921
+ The mask for the memory sequence (optional).
922
+ tgt_key_padding_mask : torch.Tensor
923
+ The mask for the tgt keys per batch (optional).
924
+ memory_key_padding_mask : torch.Tensor
925
+ The mask for the memory keys per batch (optional).
926
+ pos_embs_tgt : torch.Tensor
927
+ The positional embeddings for the target (optional).
928
+ pos_embs_src : torch.Tensor
929
+ The positional embeddings for the source (optional).
930
+ """
931
+ output = tgt
932
+ self_attns, multihead_attns = [], []
933
+ for dec_layer in self.layers:
934
+ output, self_attn, multihead_attn = dec_layer(
935
+ output,
936
+ memory,
937
+ tgt_mask=tgt_mask,
938
+ memory_mask=memory_mask,
939
+ tgt_key_padding_mask=tgt_key_padding_mask,
940
+ memory_key_padding_mask=memory_key_padding_mask,
941
+ pos_embs_tgt=pos_embs_tgt,
942
+ pos_embs_src=pos_embs_src,
943
+ )
944
+ self_attns.append(self_attn)
945
+ multihead_attns.append(multihead_attn)
946
+ output = self.norm(output)
947
+
948
+ return output, self_attns, multihead_attns
949
+
950
+
951
+ class NormalizedEmbedding(nn.Module):
952
+ """This class implements the normalized embedding layer for the transformer.
953
+ Since the dot product of the self-attention is always normalized by sqrt(d_model)
954
+ and the final linear projection for prediction shares weight with the embedding layer,
955
+ we multiply the output of the embedding by sqrt(d_model).
956
+
957
+ Arguments
958
+ ---------
959
+ d_model: int
960
+ The number of expected features in the encoder/decoder inputs (default=512).
961
+ vocab: int
962
+ The vocab size.
963
+
964
+ Example
965
+ -------
966
+ >>> emb = NormalizedEmbedding(512, 1000)
967
+ >>> trg = torch.randint(0, 999, (8, 50))
968
+ >>> emb_fea = emb(trg)
969
+ """
970
+
971
+ def __init__(self, d_model, vocab):
972
+ super().__init__()
973
+ self.emb = sb.nnet.embedding.Embedding(
974
+ num_embeddings=vocab, embedding_dim=d_model, blank_id=0
975
+ )
976
+ self.d_model = d_model
977
+
978
+ def forward(self, x):
979
+ """Processes the input tensor x and returns an output tensor."""
980
+ return self.emb(x) * math.sqrt(self.d_model)
981
+
982
+
983
+ def get_key_padding_mask(padded_input, pad_idx):
984
+ """Creates a binary mask to prevent attention to padded locations.
985
+ We suggest using ``get_mask_from_lengths`` instead of this function.
986
+
987
+ Arguments
988
+ ---------
989
+ padded_input: torch.Tensor
990
+ Padded input.
991
+ pad_idx: int
992
+ idx for padding element.
993
+
994
+ Returns
995
+ -------
996
+ key_padded_mask: torch.Tensor
997
+ Binary mask to prevent attention to padding.
998
+
999
+ Example
1000
+ -------
1001
+ >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
1002
+ >>> get_key_padding_mask(a, pad_idx=0)
1003
+ tensor([[False, False, True],
1004
+ [False, False, True],
1005
+ [False, False, True]])
1006
+ """
1007
+ if len(padded_input.shape) == 4:
1008
+ bz, time, ch1, ch2 = padded_input.shape
1009
+ padded_input = padded_input.reshape(bz, time, ch1 * ch2)
1010
+
1011
+ key_padded_mask = padded_input.eq(pad_idx).to(padded_input.device)
1012
+
1013
+ # if the input is more than 2d, mask the locations where they are silence
1014
+ # across all channels
1015
+ if len(padded_input.shape) > 2:
1016
+ key_padded_mask = key_padded_mask.float().prod(dim=-1).bool()
1017
+ return key_padded_mask.detach()
1018
+
1019
+ return key_padded_mask.detach()
1020
+
1021
+
1022
+ def get_lookahead_mask(padded_input):
1023
+ """Creates a binary mask for each sequence which masks future frames.
1024
+
1025
+ Arguments
1026
+ ---------
1027
+ padded_input: torch.Tensor
1028
+ Padded input tensor.
1029
+
1030
+ Returns
1031
+ -------
1032
+ mask : torch.Tensor
1033
+ Binary mask for masking future frames.
1034
+
1035
+ Example
1036
+ -------
1037
+ >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
1038
+ >>> get_lookahead_mask(a)
1039
+ tensor([[0., -inf, -inf],
1040
+ [0., 0., -inf],
1041
+ [0., 0., 0.]])
1042
+ """
1043
+ seq_len = padded_input.shape[1]
1044
+ mask = (
1045
+ torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device))
1046
+ == 1
1047
+ ).transpose(0, 1)
1048
+ mask = (
1049
+ mask.float()
1050
+ .masked_fill(mask == 0, float("-inf"))
1051
+ .masked_fill(mask == 1, float(0.0))
1052
+ )
1053
+ return mask.detach().to(padded_input.device)
1054
+
1055
+
1056
+ def get_mask_from_lengths(lengths, max_len=None):
1057
+ """Creates a binary mask from sequence lengths
1058
+
1059
+ Arguments
1060
+ ---------
1061
+ lengths: torch.Tensor
1062
+ A tensor of sequence lengths
1063
+ max_len: int (Optional)
1064
+ Maximum sequence length, defaults to None.
1065
+
1066
+ Returns
1067
+ -------
1068
+ mask: torch.Tensor
1069
+ the mask where padded elements are set to True.
1070
+ Then one can use tensor.masked_fill_(mask, 0) for the masking.
1071
+
1072
+ Example
1073
+ -------
1074
+ >>> lengths = torch.tensor([3, 2, 4])
1075
+ >>> get_mask_from_lengths(lengths)
1076
+ tensor([[False, False, False, True],
1077
+ [False, False, True, True],
1078
+ [False, False, False, False]])
1079
+ """
1080
+ if max_len is None:
1081
+ max_len = torch.max(lengths).item()
1082
+ seq_range = torch.arange(
1083
+ max_len, device=lengths.device, dtype=lengths.dtype
1084
+ )
1085
+ return ~(seq_range.unsqueeze(0) < lengths.unsqueeze(1))
model/modules/TransformerASR.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Added ConMamba and Mamba
2
+
3
+ Authors
4
+ * Xilin Jiang 2024
5
+ """
6
+
7
+ """Transformer for ASR in the SpeechBrain style.
8
+
9
+ Authors
10
+ * Jianyuan Zhong 2020
11
+ * Titouan Parcollet 2024
12
+ * Luca Della Libera 2024
13
+ """
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Optional
17
+
18
+ import torch # noqa 42
19
+ from torch import nn
20
+
21
+ from speechbrain.dataio.dataio import length_to_mask
22
+ from modules.Transformer import (
23
+ NormalizedEmbedding,
24
+ TransformerInterface,
25
+ get_key_padding_mask,
26
+ get_lookahead_mask,
27
+ )
28
+ from speechbrain.nnet.activations import Swish
29
+ from speechbrain.nnet.containers import ModuleList
30
+ from speechbrain.nnet.linear import Linear
31
+ from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
32
+
33
+
34
+ @dataclass
35
+ class TransformerASRStreamingContext:
36
+ """Streaming metadata and state for a `TransformerASR` instance."""
37
+
38
+ dynchunktrain_config: DynChunkTrainConfig
39
+ """Dynamic Chunk Training configuration holding chunk size and context size
40
+ information."""
41
+
42
+ encoder_context: Any
43
+ """Opaque encoder context information. It is constructed by the encoder's
44
+ `make_streaming_context` method and is passed to the encoder when using
45
+ `encode_streaming`.
46
+ """
47
+
48
+
49
+ def make_transformer_src_mask(
50
+ src: torch.Tensor,
51
+ causal: bool = False,
52
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
53
+ ) -> Optional[torch.Tensor]:
54
+ """Prepare the source transformer mask that restricts which frames can
55
+ attend to which frames depending on causal or other simple restricted
56
+ attention methods.
57
+
58
+ Arguments
59
+ ---------
60
+ src: torch.Tensor
61
+ The source tensor to build a mask from. The contents of the tensor are
62
+ not actually used currently; only its shape and other metadata (e.g.
63
+ device).
64
+ causal: bool
65
+ Whether strict causality shall be used. Frames will not be able to
66
+ attend to any future frame.
67
+ dynchunktrain_config: DynChunkTrainConfig, optional
68
+ Dynamic Chunk Training configuration. This implements a simple form of
69
+ chunkwise attention. Incompatible with `causal`.
70
+
71
+ Returns
72
+ -------
73
+ torch.Tensor
74
+ A boolean mask Tensor of shape (timesteps, timesteps).
75
+ """
76
+ if causal:
77
+ assert dynchunktrain_config is None
78
+ return get_lookahead_mask(src)
79
+
80
+ if dynchunktrain_config is None:
81
+ return
82
+
83
+ # The following is not really the sole source used to implement this,
84
+ # but it helps introduce the concept.
85
+ # ref: Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
86
+ # https://arxiv.org/pdf/2012.05481.pdf
87
+ timesteps = src.size(1)
88
+
89
+ # Mask the future at the right of each chunk
90
+ chunk_size = dynchunktrain_config.chunk_size
91
+ num_chunks = timesteps // chunk_size
92
+ timestep_idx = torch.arange(timesteps, device=src.device)
93
+ mask_idx = torch.arange(
94
+ chunk_size, chunk_size * (num_chunks + 2), chunk_size, device=src.device
95
+ ).repeat_interleave(chunk_size)[:timesteps]
96
+ src_mask = timestep_idx[None] >= mask_idx[:, None]
97
+
98
+ # Mask the past at the left of each chunk (accounting for left context)
99
+ # only relevant if using left context
100
+ if not dynchunktrain_config.is_infinite_left_context():
101
+ num_left_chunks = dynchunktrain_config.left_context_size
102
+ mask_idx -= chunk_size * (num_left_chunks + 1)
103
+ src_mask += timestep_idx[None] < mask_idx[:, None]
104
+
105
+ return src_mask
106
+
107
+
108
+ def make_transformer_src_tgt_masks(
109
+ src,
110
+ tgt=None,
111
+ wav_len=None,
112
+ pad_idx=0,
113
+ causal: bool = False,
114
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
115
+ ):
116
+ """This function generates masks for training the transformer model,
117
+ opinionated for an ASR context with encoding masks and, optionally, decoding
118
+ masks (if specifying `tgt`).
119
+
120
+ Arguments
121
+ ---------
122
+ src : torch.Tensor
123
+ The sequence to the encoder (required).
124
+ tgt : torch.Tensor
125
+ The sequence to the decoder.
126
+ wav_len : torch.Tensor
127
+ The lengths of the inputs.
128
+ pad_idx : int
129
+ The index for <pad> token (default=0).
130
+ causal: bool
131
+ Whether strict causality shall be used. See `make_asr_src_mask`
132
+ dynchunktrain_config: DynChunkTrainConfig, optional
133
+ Dynamic Chunk Training configuration. See `make_asr_src_mask`
134
+
135
+ Returns
136
+ -------
137
+ src_key_padding_mask : torch.Tensor
138
+ Key padding mask for ignoring padding
139
+ tgt_key_padding_mask : torch.Tensor
140
+ Key padding mask for ignoring padding
141
+ src_mask : torch.Tensor
142
+ Mask for ignoring invalid (e.g. future) timesteps
143
+ tgt_mask : torch.Tensor
144
+ Mask for ignoring invalid (e.g. future) timesteps
145
+ """
146
+ src_key_padding_mask = None
147
+
148
+ # mask out audio beyond the length of audio for each batch
149
+ if wav_len is not None:
150
+ abs_len = torch.round(wav_len * src.shape[1])
151
+ src_key_padding_mask = ~length_to_mask(abs_len).bool()
152
+
153
+ # mask out the source
154
+ src_mask = make_transformer_src_mask(
155
+ src, causal=causal, dynchunktrain_config=dynchunktrain_config
156
+ )
157
+
158
+ # If no decoder in the transformer...
159
+ if tgt is not None:
160
+ tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx)
161
+ tgt_mask = get_lookahead_mask(tgt)
162
+ else:
163
+ tgt_key_padding_mask = None
164
+ tgt_mask = None
165
+
166
+ return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask
167
+
168
+
169
+ class TransformerASR(TransformerInterface):
170
+ """This is an implementation of transformer model for ASR.
171
+
172
+ The architecture is based on the paper "Attention Is All You Need":
173
+ https://arxiv.org/pdf/1706.03762.pdf
174
+
175
+ Arguments
176
+ ---------
177
+ tgt_vocab: int
178
+ Size of vocabulary.
179
+ input_size: int
180
+ Input feature size.
181
+ d_model : int, optional
182
+ Embedding dimension size.
183
+ (default=512).
184
+ nhead : int, optional
185
+ The number of heads in the multi-head attention models (default=8).
186
+ num_encoder_layers : int, optional
187
+ The number of sub-encoder-layers in the encoder (default=6).
188
+ num_decoder_layers : int, optional
189
+ The number of sub-decoder-layers in the decoder (default=6).
190
+ d_ffn : int, optional
191
+ The dimension of the feedforward network model (default=2048).
192
+ dropout : int, optional
193
+ The dropout value (default=0.1).
194
+ activation : torch.nn.Module, optional
195
+ The activation function of FFN layers.
196
+ Recommended: relu or gelu (default=relu).
197
+ positional_encoding: str, optional
198
+ Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings.
199
+ normalize_before: bool, optional
200
+ Whether normalization should be applied before or after MHA or FFN in Transformer layers.
201
+ Defaults to True as this was shown to lead to better performance and training stability.
202
+ kernel_size: int, optional
203
+ Kernel size in convolutional layers when Conformer is used.
204
+ bias: bool, optional
205
+ Whether to use bias in Conformer convolutional layers.
206
+ encoder_module: str, optional
207
+ Choose between Branchformer, Conformer, ConMamba, and Transformer for the encoder.
208
+ decoder_module: str, optional
209
+ Choose between Mamba and Transformer for the decoder.
210
+ decoder_module: str, optional
211
+ Choose between Transformer and Mamba for the decoder.
212
+ conformer_activation: torch.nn.Module, optional
213
+ Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module.
214
+ branchformer_activation: torch.nn.Module, optional
215
+ Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module.
216
+ attention_type: str, optional
217
+ Type of attention layer used in all Transformer or Conformer layers.
218
+ e.g. regularMHA or RelPosMHA.
219
+ max_length: int, optional
220
+ Max length for the target and source sequence in input.
221
+ Used for positional encodings.
222
+ causal: bool, optional
223
+ Whether the encoder should be causal or not (the decoder is always causal).
224
+ If causal the Conformer convolutional layer is causal.
225
+ csgu_linear_units: int, optional
226
+ Number of neurons in the hidden linear units of the CSGU Module.
227
+ -> Branchformer
228
+ gate_activation: torch.nn.Module, optional
229
+ Activation function used at the gate of the CSGU module.
230
+ -> Branchformer
231
+ use_linear_after_conv: bool, optional
232
+ If True, will apply a linear transformation of size input_size//2.
233
+ -> Branchformer
234
+ mamba_config: dict, optional
235
+ Mamba parameters if encoder_module or decoder_module is Mamba or ConMamba
236
+
237
+ Example
238
+ -------
239
+ >>> src = torch.rand([8, 120, 512])
240
+ >>> tgt = torch.randint(0, 720, [8, 120])
241
+ >>> net = TransformerASR(
242
+ ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
243
+ ... )
244
+ >>> enc_out, dec_out = net.forward(src, tgt)
245
+ >>> enc_out.shape
246
+ torch.Size([8, 120, 512])
247
+ >>> dec_out.shape
248
+ torch.Size([8, 120, 512])
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ tgt_vocab,
254
+ input_size,
255
+ d_model=512,
256
+ nhead=8,
257
+ num_encoder_layers=6,
258
+ num_decoder_layers=6,
259
+ d_ffn=2048,
260
+ dropout=0.1,
261
+ activation=nn.ReLU,
262
+ positional_encoding="fixed_abs_sine",
263
+ normalize_before=False,
264
+ kernel_size: Optional[int] = 31,
265
+ bias: Optional[bool] = True,
266
+ encoder_module: Optional[str] = "transformer",
267
+ decoder_module: Optional[str] = "transformer",
268
+ conformer_activation: Optional[nn.Module] = Swish,
269
+ branchformer_activation: Optional[nn.Module] = nn.GELU,
270
+ attention_type: Optional[str] = "regularMHA",
271
+ max_length: Optional[int] = 2500,
272
+ causal: Optional[bool] = True,
273
+ csgu_linear_units: Optional[int] = 3072,
274
+ gate_activation: Optional[nn.Module] = nn.Identity,
275
+ use_linear_after_conv: Optional[bool] = False,
276
+ mamba_config=None
277
+ ):
278
+ super().__init__(
279
+ d_model=d_model,
280
+ nhead=nhead,
281
+ num_encoder_layers=num_encoder_layers,
282
+ num_decoder_layers=num_decoder_layers,
283
+ d_ffn=d_ffn,
284
+ dropout=dropout,
285
+ activation=activation,
286
+ positional_encoding=positional_encoding,
287
+ normalize_before=normalize_before,
288
+ kernel_size=kernel_size,
289
+ bias=bias,
290
+ encoder_module=encoder_module,
291
+ decoder_module=decoder_module,
292
+ conformer_activation=conformer_activation,
293
+ branchformer_activation=branchformer_activation,
294
+ attention_type=attention_type,
295
+ max_length=max_length,
296
+ causal=causal,
297
+ csgu_linear_units=csgu_linear_units,
298
+ gate_activation=gate_activation,
299
+ use_linear_after_conv=use_linear_after_conv,
300
+ mamba_config=mamba_config
301
+ )
302
+
303
+ self.custom_src_module = ModuleList(
304
+ Linear(
305
+ input_size=input_size,
306
+ n_neurons=d_model,
307
+ bias=True,
308
+ combine_dims=False,
309
+ ),
310
+ torch.nn.Dropout(dropout),
311
+ )
312
+
313
+ self.num_decoder_layers = num_decoder_layers
314
+ if num_decoder_layers > 0:
315
+ self.custom_tgt_module = ModuleList(
316
+ NormalizedEmbedding(d_model, tgt_vocab)
317
+ )
318
+
319
+ # reset parameters using xavier_normal_
320
+ self._init_params()
321
+
322
+ def forward(self, src, tgt, wav_len=None, pad_idx=0):
323
+ """
324
+ Arguments
325
+ ----------
326
+ src : torch.Tensor
327
+ The sequence to the encoder.
328
+ tgt : torch.Tensor
329
+ The sequence to the decoder.
330
+ wav_len: torch.Tensor, optional
331
+ Torch Tensor of shape (batch, ) containing the relative length to padded length for each example.
332
+ pad_idx : int, optional
333
+ The index for <pad> token (default=0).
334
+ """
335
+
336
+ # reshape the src vector to [Batch, Time, Fea] is a 4d vector is given
337
+ if src.ndim == 4:
338
+ bz, t, ch1, ch2 = src.shape
339
+ src = src.reshape(bz, t, ch1 * ch2)
340
+
341
+ (
342
+ src_key_padding_mask,
343
+ tgt_key_padding_mask,
344
+ src_mask,
345
+ tgt_mask,
346
+ ) = make_transformer_src_tgt_masks(
347
+ src, tgt, wav_len, causal=self.causal, pad_idx=pad_idx
348
+ )
349
+
350
+ src = self.custom_src_module(src)
351
+ # add pos encoding to queries if are sinusoidal ones else
352
+ if self.attention_type == "hypermixing":
353
+ pos_embs_encoder = None
354
+ elif self.attention_type == "RelPosMHAXL":
355
+ pos_embs_encoder = self.positional_encoding(src)
356
+ elif self.positional_encoding_type == "fixed_abs_sine":
357
+ src = src + self.positional_encoding(src) # add the encodings here
358
+ pos_embs_encoder = None
359
+
360
+ encoder_out, _ = self.encoder(
361
+ src=src,
362
+ src_mask=src_mask,
363
+ src_key_padding_mask=src_key_padding_mask,
364
+ pos_embs=pos_embs_encoder,
365
+ )
366
+
367
+ if self.num_decoder_layers > 0:
368
+ tgt = self.custom_tgt_module(tgt)
369
+
370
+ if self.attention_type == "RelPosMHAXL":
371
+ tgt = tgt + self.positional_encoding_decoder(tgt)
372
+ pos_embs_encoder = None # self.positional_encoding(src)
373
+ pos_embs_target = None
374
+ elif (
375
+ self.positional_encoding_type == "fixed_abs_sine"
376
+ or self.attention_type == "hypermixing"
377
+ ):
378
+ tgt = tgt + self.positional_encoding(tgt)
379
+ pos_embs_target = None
380
+ pos_embs_encoder = None
381
+
382
+ decoder_out, _, _ = self.decoder(
383
+ tgt=tgt,
384
+ memory=encoder_out,
385
+ memory_mask=None,
386
+ tgt_mask=tgt_mask,
387
+ tgt_key_padding_mask=tgt_key_padding_mask,
388
+ memory_key_padding_mask=src_key_padding_mask,
389
+ pos_embs_tgt=pos_embs_target,
390
+ pos_embs_src=pos_embs_encoder,
391
+ )
392
+
393
+ else:
394
+ decoder_out = None
395
+
396
+ return encoder_out, decoder_out
397
+
398
+ @torch.no_grad()
399
+ def decode(self, tgt, encoder_out, enc_len=None):
400
+ """This method implements a decoding step for the transformer model.
401
+
402
+ Arguments
403
+ ---------
404
+ tgt : torch.Tensor
405
+ The sequence to the decoder.
406
+ encoder_out : torch.Tensor
407
+ Hidden output of the encoder.
408
+ enc_len : torch.LongTensor
409
+ The actual length of encoder states.
410
+
411
+ Returns
412
+ -------
413
+ prediction
414
+ """
415
+ tgt_mask = get_lookahead_mask(tgt)
416
+ src_key_padding_mask = None
417
+ if enc_len is not None:
418
+ src_key_padding_mask = (1 - length_to_mask(enc_len)).bool()
419
+
420
+ if self.num_decoder_layers > 0:
421
+ tgt = self.custom_tgt_module(tgt)
422
+ if self.attention_type == "RelPosMHAXL":
423
+ tgt = tgt + self.positional_encoding_decoder(tgt)
424
+ pos_embs_encoder = None # self.positional_encoding(src)
425
+ pos_embs_target = None
426
+ elif (
427
+ self.positional_encoding_type == "fixed_abs_sine"
428
+ or self.attention_type == "hypermixing"
429
+ ):
430
+ tgt = tgt + self.positional_encoding(tgt) # add the encodings here
431
+ pos_embs_target = None
432
+ pos_embs_encoder = None
433
+
434
+
435
+ prediction, self_attns, multihead_attns = self.decoder(
436
+ tgt,
437
+ encoder_out,
438
+ tgt_mask=tgt_mask,
439
+ memory_key_padding_mask=src_key_padding_mask,
440
+ pos_embs_tgt=pos_embs_target,
441
+ pos_embs_src=pos_embs_encoder,
442
+ )
443
+ return prediction, multihead_attns[-1]
444
+
445
+ def encode(
446
+ self,
447
+ src,
448
+ wav_len=None,
449
+ pad_idx=0,
450
+ dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
451
+ ):
452
+ """
453
+ Encoder forward pass
454
+
455
+ Arguments
456
+ ---------
457
+ src : torch.Tensor
458
+ The sequence to the encoder.
459
+ wav_len : torch.Tensor, optional
460
+ Torch Tensor of shape (batch, ) containing the relative length to padded length for each example.
461
+ pad_idx : int
462
+ The index used for padding.
463
+ dynchunktrain_config : DynChunkTrainConfig
464
+ Dynamic chunking config.
465
+
466
+ Returns
467
+ -------
468
+ encoder_out : torch.Tensor
469
+ """
470
+ # reshape the src vector to [Batch, Time, Fea] if a 4d vector is given
471
+ if src.dim() == 4:
472
+ bz, t, ch1, ch2 = src.shape
473
+ src = src.reshape(bz, t, ch1 * ch2)
474
+
475
+ (
476
+ src_key_padding_mask,
477
+ _,
478
+ src_mask,
479
+ _,
480
+ ) = make_transformer_src_tgt_masks(
481
+ src,
482
+ None,
483
+ wav_len,
484
+ pad_idx=pad_idx,
485
+ causal=self.causal,
486
+ dynchunktrain_config=dynchunktrain_config,
487
+ )
488
+
489
+ src = self.custom_src_module(src)
490
+ if self.attention_type == "hypermixing":
491
+ pos_embs_source = None
492
+ elif self.attention_type == "RelPosMHAXL":
493
+ pos_embs_source = self.positional_encoding(src)
494
+ elif self.positional_encoding_type == "fixed_abs_sine":
495
+ src = src + self.positional_encoding(src)
496
+ pos_embs_source = None
497
+
498
+ encoder_out, _ = self.encoder(
499
+ src=src,
500
+ src_mask=src_mask,
501
+ src_key_padding_mask=src_key_padding_mask,
502
+ pos_embs=pos_embs_source,
503
+ dynchunktrain_config=dynchunktrain_config,
504
+ )
505
+
506
+ return encoder_out
507
+
508
+ def encode_streaming(self, src, context: TransformerASRStreamingContext):
509
+ """
510
+ Streaming encoder forward pass
511
+
512
+ Arguments
513
+ ---------
514
+ src : torch.Tensor
515
+ The sequence (chunk) to the encoder.
516
+ context : TransformerASRStreamingContext
517
+ Mutable reference to the streaming context. This holds the state
518
+ needed to persist across chunk inferences and can be built using
519
+ `make_streaming_context`. This will get mutated by this function.
520
+
521
+ Returns
522
+ -------
523
+ Encoder output for this chunk.
524
+
525
+ Example
526
+ -------
527
+ >>> import torch
528
+ >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR
529
+ >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
530
+ >>> net = TransformerASR(
531
+ ... tgt_vocab=100,
532
+ ... input_size=64,
533
+ ... d_model=64,
534
+ ... nhead=8,
535
+ ... num_encoder_layers=1,
536
+ ... num_decoder_layers=0,
537
+ ... d_ffn=128,
538
+ ... attention_type="RelPosMHAXL",
539
+ ... positional_encoding=None,
540
+ ... encoder_module="conformer",
541
+ ... normalize_before=True,
542
+ ... causal=False,
543
+ ... )
544
+ >>> ctx = net.make_streaming_context(DynChunkTrainConfig(16, 1))
545
+ >>> src1 = torch.rand([8, 16, 64])
546
+ >>> src2 = torch.rand([8, 16, 64])
547
+ >>> out1 = net.encode_streaming(src1, ctx)
548
+ >>> out1.shape
549
+ torch.Size([8, 16, 64])
550
+ >>> ctx.encoder_context.layers[0].mha_left_context.shape
551
+ torch.Size([8, 16, 64])
552
+ >>> out2 = net.encode_streaming(src2, ctx)
553
+ >>> out2.shape
554
+ torch.Size([8, 16, 64])
555
+ >>> ctx.encoder_context.layers[0].mha_left_context.shape
556
+ torch.Size([8, 16, 64])
557
+ >>> combined_out = torch.concat((out1, out2), dim=1)
558
+ >>> combined_out.shape
559
+ torch.Size([8, 32, 64])
560
+ """
561
+
562
+ if src.dim() == 4:
563
+ bz, t, ch1, ch2 = src.shape
564
+ src = src.reshape(bz, t, ch1 * ch2)
565
+
566
+ # HACK: our problem here is that the positional_encoding is computed
567
+ # against the size of our source tensor, but we only know how many left
568
+ # context frames we're injecting to the encoder within the encoder
569
+ # context.
570
+ # so this workaround does just that.
571
+ #
572
+ # i'm not sure how this would be best refactored, but an option would be
573
+ # to let the encoder get the pos embedding itself and have a way to
574
+ # cache it.
575
+ #
576
+ # additionally, positional encoding functions take in a whole source
577
+ # tensor just to get its attributes (size, device, type) but this is
578
+ # sort of silly for the embeddings that don't need one.
579
+ # so we craft a dummy empty (uninitialized) tensor to help...
580
+ known_left_context = context.encoder_context.layers[0].mha_left_context
581
+ if known_left_context is None:
582
+ pos_encoding_dummy = src
583
+ else:
584
+ target_shape = list(src.shape)
585
+ target_shape[-2] += known_left_context.shape[-2]
586
+ pos_encoding_dummy = torch.empty(size=target_shape).to(src)
587
+
588
+ src = self.custom_src_module(src)
589
+ if self.attention_type == "RelPosMHAXL":
590
+ pos_embs_source = self.positional_encoding(pos_encoding_dummy)
591
+
592
+ elif self.positional_encoding_type == "fixed_abs_sine":
593
+ src = src + self.positional_encoding(pos_encoding_dummy)
594
+ pos_embs_source = None
595
+
596
+ encoder_out, _ = self.encoder.forward_streaming(
597
+ src=src, pos_embs=pos_embs_source, context=context.encoder_context
598
+ )
599
+ return encoder_out
600
+
601
+ def make_streaming_context(
602
+ self, dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={}
603
+ ):
604
+ """Creates a blank streaming context for this transformer and its
605
+ encoder.
606
+
607
+ Arguments
608
+ ---------
609
+ dynchunktrain_config : DynChunkTrainConfig
610
+ Runtime chunkwise attention configuration.
611
+ encoder_kwargs : dict
612
+ Parameters to be forward to the encoder's `make_streaming_context`.
613
+ Metadata required for the encoder could differ depending on the
614
+ encoder.
615
+
616
+ Returns
617
+ -------
618
+ TransformerASRStreamingContext
619
+ """
620
+ return TransformerASRStreamingContext(
621
+ dynchunktrain_config=dynchunktrain_config,
622
+ encoder_context=self.encoder.make_streaming_context(
623
+ dynchunktrain_config,
624
+ **encoder_kwargs,
625
+ ),
626
+ )
627
+
628
+ def _init_params(self):
629
+ for p in self.parameters():
630
+ if p.dim() > 1:
631
+ torch.nn.init.xavier_normal_(p)
632
+
633
+
634
+ class EncoderWrapper(nn.Module):
635
+ """This is a wrapper of any ASR transformer encoder. By default, the
636
+ TransformerASR .forward() function encodes and decodes. With this wrapper
637
+ the .forward() function becomes .encode() only.
638
+
639
+ Important: The TransformerASR class must contain a .encode() function.
640
+
641
+ Arguments
642
+ ---------
643
+ transformer : sb.lobes.models.TransformerInterface
644
+ A Transformer instance that contains a .encode() function.
645
+ *args : tuple
646
+ **kwargs : dict
647
+ Arguments to forward to parent class.
648
+
649
+ Example
650
+ -------
651
+ >>> src = torch.rand([8, 120, 512])
652
+ >>> tgt = torch.randint(0, 720, [8, 120])
653
+ >>> net = TransformerASR(
654
+ ... 720, 512, 512, 8, 1, 1, 1024, activation=torch.nn.GELU
655
+ ... )
656
+ >>> encoder = EncoderWrapper(net)
657
+ >>> enc_out = encoder(src)
658
+ >>> enc_out.shape
659
+ torch.Size([8, 120, 512])
660
+ """
661
+
662
+ def __init__(self, transformer, *args, **kwargs):
663
+ super().__init__(*args, **kwargs)
664
+ self.transformer = transformer
665
+ self.make_streaming_context = self.transformer.make_streaming_context
666
+
667
+ def forward(self, x, wav_lens=None, pad_idx=0, **kwargs):
668
+ """Processes the input tensor x and returns an output tensor."""
669
+ x = self.transformer.encode(x, wav_lens, pad_idx, **kwargs)
670
+ return x
671
+
672
+ def forward_streaming(self, x, context):
673
+ """Processes the input audio chunk tensor `x`, using and updating the
674
+ mutable encoder `context`"""
675
+ x = self.transformer.encode_streaming(x, context)
676
+ return x
677
+
678
+ def make_streaming_context(self, *args, **kwargs):
679
+ """Initializes a streaming context. Forwards all arguments to the
680
+ underlying transformer. See :meth:`speechbrain.lobes.models.transformer.TransformerASR.make_streaming_context`.
681
+ """
682
+ return self.transformer.make_streaming_context(*args, **kwargs)
model/modules/__init__.py ADDED
File without changes
model/modules/mamba/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/modules/mamba/__init__.py ADDED
File without changes
model/modules/mamba/bimamba.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copied and modified from
3
+ https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/modules/mamba_simple.py
4
+ '''
5
+
6
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
7
+
8
+ import math
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch import Tensor
15
+
16
+ from einops import rearrange, repeat
17
+
18
+ try:
19
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
20
+ except ImportError:
21
+ causal_conv1d_fn, causal_conv1d_update = None
22
+
23
+ try:
24
+ from .selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj
25
+ except ImportError:
26
+ selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None
27
+
28
+ try:
29
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
30
+ except ImportError:
31
+ selective_state_update = None
32
+
33
+ try:
34
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
35
+ except ImportError:
36
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
37
+
38
+
39
+ class Mamba(nn.Module):
40
+ def __init__(
41
+ self,
42
+ d_model,
43
+ d_state=16,
44
+ d_conv=4,
45
+ expand=2,
46
+ dt_rank="auto",
47
+ dt_min=0.001,
48
+ dt_max=0.1,
49
+ dt_init="random",
50
+ dt_scale=1.0,
51
+ dt_init_floor=1e-4,
52
+ conv_bias=True,
53
+ bias=False,
54
+ use_fast_path=True, # Fused kernel options
55
+ layer_idx=None,
56
+ device=None,
57
+ dtype=None,
58
+ bimamba_type="none",
59
+ if_devide_out=True, # False
60
+ init_layer_scale=None,
61
+ ):
62
+ factory_kwargs = {"device": device, "dtype": dtype}
63
+ super().__init__()
64
+ self.d_model = d_model
65
+ self.d_state = d_state
66
+ self.d_conv = d_conv
67
+ self.expand = expand
68
+ self.d_inner = int(self.expand * self.d_model)
69
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
70
+ self.use_fast_path = use_fast_path
71
+ self.layer_idx = layer_idx
72
+ self.bimamba_type = bimamba_type
73
+ self.if_devide_out = if_devide_out
74
+
75
+ assert bimamba_type == 'v2'
76
+
77
+ self.init_layer_scale = init_layer_scale
78
+ if init_layer_scale is not None:
79
+ self.gamma = nn.Parameter(init_layer_scale * torch.ones((d_model)), requires_grad=True)
80
+
81
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
82
+
83
+ self.conv1d = nn.Conv1d(
84
+ in_channels=self.d_inner,
85
+ out_channels=self.d_inner,
86
+ bias=conv_bias,
87
+ kernel_size=d_conv,
88
+ groups=self.d_inner,
89
+ padding=d_conv - 1,
90
+ **factory_kwargs,
91
+ )
92
+
93
+ self.activation = "silu"
94
+ self.act = nn.SiLU()
95
+
96
+ self.x_proj = nn.Linear(
97
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
98
+ )
99
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
100
+
101
+ # Initialize special dt projection to preserve variance at initialization
102
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
103
+ if dt_init == "constant":
104
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
105
+ elif dt_init == "random":
106
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
107
+ else:
108
+ raise NotImplementedError
109
+
110
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
111
+ dt = torch.exp(
112
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
113
+ + math.log(dt_min)
114
+ ).clamp(min=dt_init_floor)
115
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
116
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
117
+ with torch.no_grad():
118
+ self.dt_proj.bias.copy_(inv_dt)
119
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
120
+ self.dt_proj.bias._no_reinit = True
121
+
122
+ # S4D real initialization
123
+ A = repeat(
124
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
125
+ "n -> d n",
126
+ d=self.d_inner,
127
+ ).contiguous()
128
+ A_log = torch.log(A) # Keep A_log in fp32
129
+ self.A_log = nn.Parameter(A_log)
130
+ self.A_log._no_weight_decay = True
131
+
132
+ # D "skip" parameter
133
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
134
+ self.D._no_weight_decay = True
135
+
136
+ # bidirectional
137
+ if bimamba_type == "v1":
138
+ A_b = repeat(
139
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
140
+ "n -> d n",
141
+ d=self.d_inner,
142
+ ).contiguous()
143
+ A_b_log = torch.log(A_b) # Keep A_b_log in fp32
144
+ self.A_b_log = nn.Parameter(A_b_log)
145
+ self.A_b_log._no_weight_decay = True
146
+ elif bimamba_type == "v2":
147
+ A_b = repeat(
148
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
149
+ "n -> d n",
150
+ d=self.d_inner,
151
+ ).contiguous()
152
+ A_b_log = torch.log(A_b) # Keep A_b_log in fp32
153
+ self.A_b_log = nn.Parameter(A_b_log)
154
+ self.A_b_log._no_weight_decay = True
155
+
156
+ self.conv1d_b = nn.Conv1d(
157
+ in_channels=self.d_inner,
158
+ out_channels=self.d_inner,
159
+ bias=conv_bias,
160
+ kernel_size=d_conv,
161
+ groups=self.d_inner,
162
+ padding=d_conv - 1,
163
+ **factory_kwargs,
164
+ )
165
+
166
+ self.x_proj_b = nn.Linear(
167
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
168
+ )
169
+ self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
170
+
171
+ self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
172
+ self.D_b._no_weight_decay = True
173
+
174
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
175
+
176
+ def forward(self, hidden_states, inference_params=None):
177
+ """
178
+ hidden_states: (B, L, D)
179
+ Returns: same shape as hidden_states
180
+ """
181
+ batch, seqlen, dim = hidden_states.shape
182
+ conv_state, ssm_state = None, None
183
+
184
+ if inference_params is not None:
185
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
186
+ if inference_params.seqlen_offset > 0:
187
+ # The states are updated inplace
188
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
189
+ return out
190
+
191
+ # We do matmul and transpose BLH -> HBL at the same time
192
+ xz = rearrange(
193
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
194
+ "d (b l) -> b d l",
195
+ l=seqlen,
196
+ )
197
+ if self.in_proj.bias is not None:
198
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
199
+
200
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
201
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
202
+ if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
203
+ if self.bimamba_type == "v1":
204
+ A_b = -torch.exp(self.A_b_log.float())
205
+ out = bimamba_inner_fn(
206
+ xz,
207
+ self.conv1d.weight,
208
+ self.conv1d.bias,
209
+ self.x_proj.weight,
210
+ self.dt_proj.weight,
211
+ self.out_proj.weight,
212
+ self.out_proj.bias,
213
+ A,
214
+ A_b,
215
+ None, # input-dependent B
216
+ None, # input-dependent C
217
+ self.D.float(),
218
+ delta_bias=self.dt_proj.bias.float(),
219
+ delta_softplus=True,
220
+ )
221
+ elif self.bimamba_type == "v2":
222
+ A_b = -torch.exp(self.A_b_log.float())
223
+ out = mamba_inner_fn_no_out_proj(
224
+ xz,
225
+ self.conv1d.weight,
226
+ self.conv1d.bias,
227
+ self.x_proj.weight,
228
+ self.dt_proj.weight,
229
+ A,
230
+ None, # input-dependent B
231
+ None, # input-dependent C
232
+ self.D.float(),
233
+ delta_bias=self.dt_proj.bias.float(),
234
+ delta_softplus=True,
235
+ )
236
+ out_b = mamba_inner_fn_no_out_proj(
237
+ xz.flip([-1]),
238
+ self.conv1d_b.weight,
239
+ self.conv1d_b.bias,
240
+ self.x_proj_b.weight,
241
+ self.dt_proj_b.weight,
242
+ A_b,
243
+ None,
244
+ None,
245
+ self.D_b.float(),
246
+ delta_bias=self.dt_proj_b.bias.float(),
247
+ delta_softplus=True,
248
+ )
249
+
250
+ if not self.if_devide_out:
251
+ out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
252
+ else:
253
+ out = F.linear(rearrange(0.5*out + 0.5*out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
254
+
255
+ else:
256
+ out = mamba_inner_fn(
257
+ xz,
258
+ self.conv1d.weight,
259
+ self.conv1d.bias,
260
+ self.x_proj.weight,
261
+ self.dt_proj.weight,
262
+ self.out_proj.weight,
263
+ self.out_proj.bias,
264
+ A,
265
+ None, # input-dependent B
266
+ None, # input-dependent C
267
+ self.D.float(),
268
+ delta_bias=self.dt_proj.bias.float(),
269
+ delta_softplus=True,
270
+ )
271
+ else:
272
+ x, z = xz.chunk(2, dim=1)
273
+ # Compute short convolution
274
+ if conv_state is not None:
275
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
276
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
277
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
278
+ if causal_conv1d_fn is None:
279
+ x = self.act(self.conv1d(x)[..., :seqlen])
280
+ else:
281
+ assert self.activation in ["silu", "swish"]
282
+ x = causal_conv1d_fn(
283
+ x=x,
284
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
285
+ bias=self.conv1d.bias,
286
+ activation=self.activation,
287
+ )
288
+
289
+ # We're careful here about the layout, to avoid extra transposes.
290
+ # We want dt to have d as the slowest moving dimension
291
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
292
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
293
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
294
+ dt = self.dt_proj.weight @ dt.t()
295
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
296
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
297
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
298
+ assert self.activation in ["silu", "swish"]
299
+ y = selective_scan_fn(
300
+ x,
301
+ dt,
302
+ A,
303
+ B,
304
+ C,
305
+ self.D.float(),
306
+ z=z,
307
+ delta_bias=self.dt_proj.bias.float(),
308
+ delta_softplus=True,
309
+ return_last_state=ssm_state is not None,
310
+ )
311
+ if ssm_state is not None:
312
+ y, last_state = y
313
+ ssm_state.copy_(last_state)
314
+ y = rearrange(y, "b d l -> b l d")
315
+ out = self.out_proj(y)
316
+ if self.init_layer_scale is not None:
317
+ out = out * self.gamma
318
+ return out
319
+
320
+ def step(self, hidden_states, conv_state, ssm_state):
321
+ dtype = hidden_states.dtype
322
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
323
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
324
+ x, z = xz.chunk(2, dim=-1) # (B D)
325
+
326
+ # Conv step
327
+ if causal_conv1d_update is None:
328
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
329
+ conv_state[:, :, -1] = x
330
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
331
+ if self.conv1d.bias is not None:
332
+ x = x + self.conv1d.bias
333
+ x = self.act(x).to(dtype=dtype)
334
+ else:
335
+ x = causal_conv1d_update(
336
+ x,
337
+ conv_state,
338
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
339
+ self.conv1d.bias,
340
+ self.activation,
341
+ )
342
+
343
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
344
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
345
+ # Don't add dt_bias here
346
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
347
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
348
+
349
+ # SSM step
350
+ if selective_state_update is None:
351
+ # Discretize A and B
352
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
353
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
354
+ dB = torch.einsum("bd,bn->bdn", dt, B)
355
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
356
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
357
+ y = y + self.D.to(dtype) * x
358
+ y = y * self.act(z) # (B D)
359
+ else:
360
+ y = selective_state_update(
361
+ ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
362
+ )
363
+
364
+ out = self.out_proj(y)
365
+ return out.unsqueeze(1), conv_state, ssm_state
366
+
367
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
368
+ device = self.out_proj.weight.device
369
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
370
+ conv_state = torch.zeros(
371
+ batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
372
+ )
373
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
374
+ # ssm_dtype = torch.float32
375
+ ssm_state = torch.zeros(
376
+ batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
377
+ )
378
+ return conv_state, ssm_state
379
+
380
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
381
+ assert self.layer_idx is not None
382
+ if self.layer_idx not in inference_params.key_value_memory_dict:
383
+ batch_shape = (batch_size,)
384
+ conv_state = torch.zeros(
385
+ batch_size,
386
+ self.d_model * self.expand,
387
+ self.d_conv,
388
+ device=self.conv1d.weight.device,
389
+ dtype=self.conv1d.weight.dtype,
390
+ )
391
+ ssm_state = torch.zeros(
392
+ batch_size,
393
+ self.d_model * self.expand,
394
+ self.d_state,
395
+ device=self.dt_proj.weight.device,
396
+ dtype=self.dt_proj.weight.dtype,
397
+ # dtype=torch.float32,
398
+ )
399
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
400
+ else:
401
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
402
+ # TODO: What if batch size changes between generation, and we reuse the same states?
403
+ if initialize_states:
404
+ conv_state.zero_()
405
+ ssm_state.zero_()
406
+ return conv_state, ssm_state
407
+
408
+
409
+ class Block(nn.Module):
410
+ def __init__(
411
+ self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
412
+ ):
413
+ """
414
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
415
+
416
+ This Block has a slightly different structure compared to a regular
417
+ prenorm Transformer block.
418
+ The standard block is: LN -> MHA/MLP -> Add.
419
+ [Ref: https://arxiv.org/abs/2002.04745]
420
+ Here we have: Add -> LN -> Mixer, returning both
421
+ the hidden_states (output of the mixer) and the residual.
422
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
423
+ The residual needs to be provided (except for the very first block).
424
+ """
425
+ super().__init__()
426
+ self.residual_in_fp32 = residual_in_fp32
427
+ self.fused_add_norm = fused_add_norm
428
+ self.mixer = mixer_cls(dim)
429
+ self.norm = norm_cls(dim)
430
+ if self.fused_add_norm:
431
+ assert RMSNorm is not None, "RMSNorm import fails"
432
+ assert isinstance(
433
+ self.norm, (nn.LayerNorm, RMSNorm)
434
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
435
+
436
+ def forward(
437
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
438
+ ):
439
+ r"""Pass the input through the encoder layer.
440
+
441
+ Args:
442
+ hidden_states: the sequence to the encoder layer (required).
443
+ residual: hidden_states = Mixer(LN(residual))
444
+ """
445
+ if not self.fused_add_norm:
446
+ residual = (hidden_states + residual) if residual is not None else hidden_states
447
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
448
+ if self.residual_in_fp32:
449
+ residual = residual.to(torch.float32)
450
+ else:
451
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
452
+ hidden_states, residual = fused_add_norm_fn(
453
+ hidden_states,
454
+ self.norm.weight,
455
+ self.norm.bias,
456
+ residual=residual,
457
+ prenorm=True,
458
+ residual_in_fp32=self.residual_in_fp32,
459
+ eps=self.norm.eps,
460
+ )
461
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
462
+ return hidden_states, residual
463
+
464
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
465
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
model/modules/mamba/mamba_blocks.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copied and modified from
3
+ https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
4
+ '''
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from functools import partial
11
+
12
+ from mamba_ssm import Mamba
13
+ from modules.mamba.bimamba import Mamba as BiMamba
14
+ from modules.mamba.bimamba import Block as PreNormBlock
15
+
16
+ try:
17
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
18
+ except ImportError:
19
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
20
+
21
+
22
+ def create_block(
23
+ d_model,
24
+ ssm_cls=None,
25
+ ssm_cfg=None,
26
+ norm_epsilon=1e-5,
27
+ rms_norm=False,
28
+ residual_in_fp32=False,
29
+ fused_add_norm=True,
30
+ layer_idx=None,
31
+ device=None,
32
+ dtype=None,
33
+ ):
34
+ if ssm_cfg is None:
35
+ ssm_cfg = {}
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ mixer_cls = partial(ssm_cls, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
38
+ norm_cls = partial(
39
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
40
+ )
41
+ block = PreNormBlock(
42
+ d_model,
43
+ mixer_cls,
44
+ norm_cls=norm_cls,
45
+ fused_add_norm=fused_add_norm,
46
+ residual_in_fp32=residual_in_fp32,
47
+ )
48
+ block.layer_idx = layer_idx
49
+ return block
50
+
51
+
52
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
53
+ def _init_weights(
54
+ module,
55
+ n_layer,
56
+ initializer_range=0.02, # Now only used for embedding layer.
57
+ rescale_prenorm_residual=True,
58
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
59
+ ):
60
+ if isinstance(module, nn.Linear):
61
+ if module.bias is not None:
62
+ if not getattr(module.bias, "_no_reinit", False):
63
+ nn.init.zeros_(module.bias)
64
+ elif isinstance(module, nn.Embedding):
65
+ nn.init.normal_(module.weight, std=initializer_range)
66
+
67
+ if rescale_prenorm_residual:
68
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
69
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
70
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
71
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
72
+ #
73
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
74
+ for name, p in module.named_parameters():
75
+ if name in ["out_proj.weight", "fc2.weight"]:
76
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
77
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
78
+ # We need to reinit p since this code could be called multiple times
79
+ # Having just p *= scale would repeatedly scale it down
80
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
81
+ with torch.no_grad():
82
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
83
+
84
+
85
+ class LnMambaAdd(nn.Module):
86
+
87
+ def __init__(self,
88
+ d_model,
89
+ ssm_cls,
90
+ ssm_cfg,
91
+ rms_norm=False,
92
+ layer_idx=None
93
+ ):
94
+ super().__init__()
95
+ if rms_norm:
96
+ self.norm = RMSNorm(d_model)
97
+ else:
98
+ self.norm = nn.LayerNorm(d_model)
99
+ self.mamba = ssm_cls(d_model=d_model, **ssm_cfg)
100
+
101
+ print(type(self.mamba))
102
+
103
+ print('Created LnMambaAdd.')
104
+
105
+ def forward(self, x, residual=None, inference_params=None):
106
+ if residual != None:
107
+ x = x + residual
108
+ return self.mamba(self.norm(x)), x
109
+
110
+
111
+ class MambaBlocksSequential(nn.Module):
112
+ """
113
+ A wrapper for the Mamba block to replicate it
114
+
115
+ Arguments
116
+ ---------
117
+ n_mamba : int
118
+ Number of Mamba blocks
119
+ d_model : int
120
+ Input dimension to Mamba (bottleneck dimension).
121
+ d_state : int
122
+ Mamba state dimension
123
+ expand: int
124
+ First linear projection d_model -> d_model * expand
125
+ d_conv: int
126
+ Kernel size of Mamba conv
127
+ norm type : str
128
+ The type of normalization, in ['gLN', 'cLN'].
129
+ ---------
130
+ """
131
+
132
+ def __init__(self,
133
+ n_mamba: int,
134
+ bidirectional: bool,
135
+ d_model: int, # bottleneck dimension (B)
136
+ d_state: int = 16,
137
+ expand: int = 2,
138
+ d_conv: int = 4, # kernel_size of 'Conv' in Mamba
139
+ dt_rank: str="auto",
140
+ conv_bias: bool = True,
141
+ bias: bool = False,
142
+ fused_add_norm: bool = True,
143
+ rms_norm: bool = False,
144
+ norm_epsilon: float = 1e-5,
145
+ initializer_cfg=None,
146
+ residual_in_fp32=False,
147
+ use_simple_block=False
148
+ ):
149
+ super().__init__()
150
+ self.residual_in_fp32 = residual_in_fp32
151
+ self.bidirectional = bidirectional
152
+
153
+ # We change the order of residual and layer norm:
154
+ # Instead of LN -> Attn / MLP -> Add, we do:
155
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
156
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
157
+ # This is for performance reason: we can fuse add + layer_norm.
158
+ self.fused_add_norm = fused_add_norm
159
+ if self.fused_add_norm:
160
+ if layer_norm_fn is None or rms_norm_fn is None:
161
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
162
+
163
+ self.use_simple_block = use_simple_block
164
+
165
+ ssm_cfg = {
166
+ "d_state": d_state,
167
+ "expand": expand,
168
+ "d_conv": d_conv,
169
+ "dt_rank": dt_rank,
170
+ "conv_bias": conv_bias,
171
+ "bias": bias
172
+ }
173
+ if bidirectional:
174
+ ssm_cfg["bimamba_type"] = "v2"
175
+
176
+ if use_simple_block:
177
+ self.layers = nn.Sequential(
178
+ *[
179
+ LnMambaAdd(
180
+ d_model=d_model,
181
+ ssm_cls=BiMamba if bidirectional else Mamba,
182
+ ssm_cfg=ssm_cfg,
183
+ rms_norm=rms_norm,
184
+ layer_idx=i
185
+ )
186
+ for i in range(n_mamba)
187
+ ]
188
+ )
189
+ else:
190
+ self.layers = nn.Sequential(
191
+ *[
192
+ create_block(
193
+ d_model=d_model,
194
+ ssm_cls=BiMamba if bidirectional else Mamba,
195
+ ssm_cfg=ssm_cfg,
196
+ norm_epsilon=norm_epsilon,
197
+ rms_norm=rms_norm,
198
+ residual_in_fp32=residual_in_fp32,
199
+ fused_add_norm=fused_add_norm,
200
+ layer_idx=i,
201
+ )
202
+ for i in range(n_mamba)
203
+ ]
204
+ )
205
+
206
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
207
+ d_model, eps=norm_epsilon
208
+ )
209
+
210
+ self.apply(
211
+ partial(
212
+ _init_weights,
213
+ n_layer=n_mamba,
214
+ **(initializer_cfg if initializer_cfg is not None else {}),
215
+ )
216
+ )
217
+
218
+
219
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
220
+ return {
221
+ i: block.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
222
+ for i, layer in enumerate(self.layers)
223
+ }
224
+
225
+ def forward(self, x, inference_params=None):
226
+
227
+ hidden_states = x
228
+ residual = None
229
+ for i, layer in enumerate(self.layers):
230
+ hidden_states, residual = layer(
231
+ hidden_states, residual, inference_params=inference_params
232
+ )
233
+
234
+ if not self.fused_add_norm:
235
+ residual = (hidden_states + residual) if residual is not None else hidden_states
236
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
237
+ else:
238
+ # Set prenorm=False here since we don't need the residual
239
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
240
+
241
+ hidden_states = fused_add_norm_fn(
242
+ hidden_states,
243
+ self.norm_f.weight,
244
+ self.norm_f.bias,
245
+ eps=self.norm_f.eps,
246
+ residual=residual,
247
+ prenorm=False,
248
+ residual_in_fp32=self.residual_in_fp32,
249
+ )
250
+
251
+ return hidden_states
252
+
model/modules/mamba/selective_scan_interface.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copied from
3
+ https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py
4
+ '''
5
+
6
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.cuda.amp import custom_bwd, custom_fwd
11
+
12
+ from einops import rearrange, repeat
13
+
14
+ from causal_conv1d import causal_conv1d_fn
15
+ import causal_conv1d_cuda
16
+ import selective_scan_cuda
17
+
18
+
19
+ class SelectiveScanFn(torch.autograd.Function):
20
+
21
+ @staticmethod
22
+ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
23
+ return_last_state=False):
24
+ if u.stride(-1) != 1:
25
+ u = u.contiguous()
26
+ if delta.stride(-1) != 1:
27
+ delta = delta.contiguous()
28
+ if D is not None:
29
+ D = D.contiguous()
30
+ if B.stride(-1) != 1:
31
+ B = B.contiguous()
32
+ if C.stride(-1) != 1:
33
+ C = C.contiguous()
34
+ if z is not None and z.stride(-1) != 1:
35
+ z = z.contiguous()
36
+ if B.dim() == 3:
37
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
38
+ ctx.squeeze_B = True
39
+ if C.dim() == 3:
40
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
41
+ ctx.squeeze_C = True
42
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
43
+ ctx.delta_softplus = delta_softplus
44
+ ctx.has_z = z is not None
45
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
46
+ if not ctx.has_z:
47
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
48
+ return out if not return_last_state else (out, last_state)
49
+ else:
50
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
51
+ out_z = rest[0]
52
+ return out_z if not return_last_state else (out_z, last_state)
53
+
54
+ @staticmethod
55
+ def backward(ctx, dout, *args):
56
+ if not ctx.has_z:
57
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
58
+ z = None
59
+ out = None
60
+ else:
61
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
62
+ if dout.stride(-1) != 1:
63
+ dout = dout.contiguous()
64
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
65
+ # backward of selective_scan_cuda with the backward of chunk).
66
+ # Here we just pass in None and dz will be allocated in the C++ code.
67
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
68
+ u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
69
+ False # option to recompute out_z, not used here
70
+ )
71
+ dz = rest[0] if ctx.has_z else None
72
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
73
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
74
+ return (du, ddelta, dA, dB, dC,
75
+ dD if D is not None else None,
76
+ dz,
77
+ ddelta_bias if delta_bias is not None else None,
78
+ None,
79
+ None)
80
+
81
+
82
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
83
+ return_last_state=False):
84
+ """if return_last_state is True, returns (out, last_state)
85
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
86
+ not considered in the backward pass.
87
+ """
88
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
89
+
90
+
91
+ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
92
+ return_last_state=False):
93
+ """
94
+ u: r(B D L)
95
+ delta: r(B D L)
96
+ A: c(D N) or r(D N)
97
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
98
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
99
+ D: r(D)
100
+ z: r(B D L)
101
+ delta_bias: r(D), fp32
102
+
103
+ out: r(B D L)
104
+ last_state (optional): r(B D dstate) or c(B D dstate)
105
+ """
106
+ dtype_in = u.dtype
107
+ u = u.float()
108
+ delta = delta.float()
109
+ if delta_bias is not None:
110
+ delta = delta + delta_bias[..., None].float()
111
+ if delta_softplus:
112
+ delta = F.softplus(delta)
113
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
114
+ is_variable_B = B.dim() >= 3
115
+ is_variable_C = C.dim() >= 3
116
+ if A.is_complex():
117
+ if is_variable_B:
118
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
119
+ if is_variable_C:
120
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
121
+ else:
122
+ B = B.float()
123
+ C = C.float()
124
+ x = A.new_zeros((batch, dim, dstate))
125
+ ys = []
126
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
127
+ if not is_variable_B:
128
+ deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
129
+ else:
130
+ if B.dim() == 3:
131
+ deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
132
+ else:
133
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
134
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
135
+ if is_variable_C and C.dim() == 4:
136
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
137
+ last_state = None
138
+ for i in range(u.shape[2]):
139
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
140
+ if not is_variable_C:
141
+ y = torch.einsum('bdn,dn->bd', x, C)
142
+ else:
143
+ if C.dim() == 3:
144
+ y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
145
+ else:
146
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
147
+ if i == u.shape[2] - 1:
148
+ last_state = x
149
+ if y.is_complex():
150
+ y = y.real * 2
151
+ ys.append(y)
152
+ y = torch.stack(ys, dim=2) # (batch dim L)
153
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
154
+ if z is not None:
155
+ out = out * F.silu(z)
156
+ out = out.to(dtype=dtype_in)
157
+ return out if not return_last_state else (out, last_state)
158
+
159
+
160
+ class MambaInnerFnNoOutProj(torch.autograd.Function):
161
+
162
+ @staticmethod
163
+ @custom_fwd
164
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
165
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
166
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
167
+ """
168
+ xz: (batch, dim, seqlen)
169
+ """
170
+ assert checkpoint_lvl in [0, 1]
171
+ L = xz.shape[-1]
172
+ delta_rank = delta_proj_weight.shape[1]
173
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
174
+ if torch.is_autocast_enabled():
175
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
176
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
177
+ if xz.stride(-1) != 1:
178
+ xz = xz.contiguous()
179
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
180
+ x, z = xz.chunk(2, dim=1)
181
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
182
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
183
+ # We're being very careful here about the layout, to avoid extra transposes.
184
+ # We want delta to have d as the slowest moving dimension
185
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
186
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
187
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
188
+ ctx.is_variable_B = B is None
189
+ ctx.is_variable_C = C is None
190
+ ctx.B_proj_bias_is_None = B_proj_bias is None
191
+ ctx.C_proj_bias_is_None = C_proj_bias is None
192
+ if B is None: # variable B
193
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
194
+ if B_proj_bias is not None:
195
+ B = B + B_proj_bias.to(dtype=B.dtype)
196
+ if not A.is_complex():
197
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
198
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
199
+ else:
200
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
201
+ else:
202
+ if B.stride(-1) != 1:
203
+ B = B.contiguous()
204
+ if C is None: # variable C
205
+ C = x_dbl[:, -d_state:] # (bl dstate)
206
+ if C_proj_bias is not None:
207
+ C = C + C_proj_bias.to(dtype=C.dtype)
208
+ if not A.is_complex():
209
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
210
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
211
+ else:
212
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
213
+ else:
214
+ if C.stride(-1) != 1:
215
+ C = C.contiguous()
216
+ if D is not None:
217
+ D = D.contiguous()
218
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
219
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
220
+ )
221
+ ctx.delta_softplus = delta_softplus
222
+ ctx.checkpoint_lvl = checkpoint_lvl
223
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
224
+ conv1d_out, delta = None, None
225
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
226
+ delta_proj_weight, conv1d_out, delta,
227
+ A, B, C, D, delta_bias, scan_intermediates, out)
228
+ # return rearrange(out_z, "b d l -> b l d")
229
+ return out_z
230
+
231
+ @staticmethod
232
+ @custom_bwd
233
+ def backward(ctx, dout):
234
+ # dout: (batch, seqlen, dim)
235
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight,
236
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
237
+ L = xz.shape[-1]
238
+ delta_rank = delta_proj_weight.shape[1]
239
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
240
+ x, z = xz.chunk(2, dim=1)
241
+ if dout.stride(-1) != 1:
242
+ dout = dout.contiguous()
243
+ if ctx.checkpoint_lvl == 1:
244
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
245
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
246
+ "d (b l) -> b d l", l = L)
247
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
248
+ # backward of selective_scan_cuda with the backward of chunk).
249
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
250
+ dx, dz = dxz.chunk(2, dim=1)
251
+ # dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l
252
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
253
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz,
254
+ ctx.delta_softplus,
255
+ True # option to recompute out_z
256
+ )
257
+ dD = dD if D is not None else None
258
+ dx_dbl = torch.empty_like(x_dbl)
259
+ dB_proj_bias = None
260
+ if ctx.is_variable_B:
261
+ if not A.is_complex():
262
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
263
+ else:
264
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
265
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
266
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
267
+ dB = None
268
+ dC_proj_bias = None
269
+ if ctx.is_variable_C:
270
+ if not A.is_complex():
271
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
272
+ else:
273
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
274
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
275
+ dx_dbl[:, -d_state:] = dC # (bl d)
276
+ dC = None
277
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
278
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
279
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
280
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
281
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
282
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
283
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
284
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
285
+ # backward of conv1d with the backward of chunk).
286
+ dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
287
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
288
+ )
289
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
290
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
291
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
292
+ dA, dB, dC, dD,
293
+ ddelta_bias if delta_bias is not None else None,
294
+ dB_proj_bias, dC_proj_bias, None)
295
+
296
+
297
+ class MambaInnerFn(torch.autograd.Function):
298
+
299
+ @staticmethod
300
+ @custom_fwd
301
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
302
+ out_proj_weight, out_proj_bias,
303
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
304
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
305
+ """
306
+ xz: (batch, dim, seqlen)
307
+ """
308
+ assert checkpoint_lvl in [0, 1]
309
+ L = xz.shape[-1]
310
+ delta_rank = delta_proj_weight.shape[1]
311
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
312
+ if torch.is_autocast_enabled():
313
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
314
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
315
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
316
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
317
+ if out_proj_bias is not None else None)
318
+ if xz.stride(-1) != 1:
319
+ xz = xz.contiguous()
320
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
321
+ x, z = xz.chunk(2, dim=1)
322
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
323
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
324
+ # We're being very careful here about the layout, to avoid extra transposes.
325
+ # We want delta to have d as the slowest moving dimension
326
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
327
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
328
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
329
+ ctx.is_variable_B = B is None
330
+ ctx.is_variable_C = C is None
331
+ ctx.B_proj_bias_is_None = B_proj_bias is None
332
+ ctx.C_proj_bias_is_None = C_proj_bias is None
333
+ if B is None: # variable B
334
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
335
+ if B_proj_bias is not None:
336
+ B = B + B_proj_bias.to(dtype=B.dtype)
337
+ if not A.is_complex():
338
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
339
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
340
+ else:
341
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
342
+ else:
343
+ if B.stride(-1) != 1:
344
+ B = B.contiguous()
345
+ if C is None: # variable C
346
+ C = x_dbl[:, -d_state:] # (bl dstate)
347
+ if C_proj_bias is not None:
348
+ C = C + C_proj_bias.to(dtype=C.dtype)
349
+ if not A.is_complex():
350
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
351
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
352
+ else:
353
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
354
+ else:
355
+ if C.stride(-1) != 1:
356
+ C = C.contiguous()
357
+ if D is not None:
358
+ D = D.contiguous()
359
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
360
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
361
+ )
362
+ ctx.delta_softplus = delta_softplus
363
+ ctx.out_proj_bias_is_None = out_proj_bias is None
364
+ ctx.checkpoint_lvl = checkpoint_lvl
365
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
366
+ conv1d_out, delta = None, None
367
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
368
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
369
+ A, B, C, D, delta_bias, scan_intermediates, out)
370
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
371
+
372
+ @staticmethod
373
+ @custom_bwd
374
+ def backward(ctx, dout):
375
+ # dout: (batch, seqlen, dim)
376
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
377
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
378
+ L = xz.shape[-1]
379
+ delta_rank = delta_proj_weight.shape[1]
380
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
381
+ x, z = xz.chunk(2, dim=1)
382
+ if dout.stride(-1) != 1:
383
+ dout = dout.contiguous()
384
+ if ctx.checkpoint_lvl == 1:
385
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
386
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
387
+ "d (b l) -> b d l", l = L)
388
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
389
+ # backward of selective_scan_cuda with the backward of chunk).
390
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
391
+ dx, dz = dxz.chunk(2, dim=1)
392
+ dout = rearrange(dout, "b l e -> e (b l)")
393
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
394
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
395
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
396
+ ctx.delta_softplus,
397
+ True # option to recompute out_z
398
+ )
399
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
400
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
401
+ dD = dD if D is not None else None
402
+ dx_dbl = torch.empty_like(x_dbl)
403
+ dB_proj_bias = None
404
+ if ctx.is_variable_B:
405
+ if not A.is_complex():
406
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
407
+ else:
408
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
409
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
410
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
411
+ dB = None
412
+ dC_proj_bias = None
413
+ if ctx.is_variable_C:
414
+ if not A.is_complex():
415
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
416
+ else:
417
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
418
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
419
+ dx_dbl[:, -d_state:] = dC # (bl d)
420
+ dC = None
421
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
422
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
423
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
424
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
425
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
426
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
427
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
428
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
429
+ # backward of conv1d with the backward of chunk).
430
+ dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
431
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
432
+ )
433
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
434
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
435
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
436
+ dout_proj_weight, dout_proj_bias,
437
+ dA, dB, dC, dD,
438
+ ddelta_bias if delta_bias is not None else None,
439
+ dB_proj_bias, dC_proj_bias, None)
440
+
441
+
442
+ class BiMambaInnerFn(torch.autograd.Function):
443
+
444
+ @staticmethod
445
+ @custom_fwd
446
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
447
+ out_proj_weight, out_proj_bias,
448
+ A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
449
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
450
+ """
451
+ xz: (batch, dim, seqlen)
452
+ """
453
+ assert checkpoint_lvl in [0, 1]
454
+ L = xz.shape[-1]
455
+ delta_rank = delta_proj_weight.shape[1]
456
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
457
+ if torch.is_autocast_enabled():
458
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
459
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
460
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
461
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
462
+ if out_proj_bias is not None else None)
463
+ if xz.stride(-1) != 1:
464
+ xz = xz.contiguous()
465
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
466
+ x, z = xz.chunk(2, dim=1)
467
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
468
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
469
+ # We're being very careful here about the layout, to avoid extra transposes.
470
+ # We want delta to have d as the slowest moving dimension
471
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
472
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
473
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
474
+ ctx.is_variable_B = B is None
475
+ ctx.is_variable_C = C is None
476
+ ctx.B_proj_bias_is_None = B_proj_bias is None
477
+ ctx.C_proj_bias_is_None = C_proj_bias is None
478
+ if B is None: # variable B
479
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
480
+ if B_proj_bias is not None:
481
+ B = B + B_proj_bias.to(dtype=B.dtype)
482
+ if not A.is_complex():
483
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
484
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
485
+ else:
486
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
487
+ else:
488
+ if B.stride(-1) != 1:
489
+ B = B.contiguous()
490
+ if C is None: # variable C
491
+ C = x_dbl[:, -d_state:] # (bl dstate)
492
+ if C_proj_bias is not None:
493
+ C = C + C_proj_bias.to(dtype=C.dtype)
494
+ if not A.is_complex():
495
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
496
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
497
+ else:
498
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
499
+ else:
500
+ if C.stride(-1) != 1:
501
+ C = C.contiguous()
502
+ if D is not None:
503
+ D = D.contiguous()
504
+ out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd(
505
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
506
+ )
507
+ assert not A_b.is_complex(), "A should not be complex!!"
508
+ out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd(
509
+ conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus,
510
+ )
511
+
512
+ out_z = out_z_f + out_z_b.flip([-1])
513
+
514
+ ctx.delta_softplus = delta_softplus
515
+ ctx.out_proj_bias_is_None = out_proj_bias is None
516
+ ctx.checkpoint_lvl = checkpoint_lvl
517
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
518
+ conv1d_out, delta = None, None
519
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
520
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
521
+ A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b)
522
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
523
+
524
+ @staticmethod
525
+ @custom_bwd
526
+ def backward(ctx, dout):
527
+ # dout: (batch, seqlen, dim)
528
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
529
+ conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b) = ctx.saved_tensors
530
+ L = xz.shape[-1]
531
+ delta_rank = delta_proj_weight.shape[1]
532
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
533
+ x, z = xz.chunk(2, dim=1)
534
+ if dout.stride(-1) != 1:
535
+ dout = dout.contiguous()
536
+ if ctx.checkpoint_lvl == 1:
537
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
538
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
539
+ "d (b l) -> b d l", l = L)
540
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
541
+ # backward of selective_scan_cuda with the backward of chunk).
542
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
543
+ dx, dz = dxz.chunk(2, dim=1)
544
+ dout = rearrange(dout, "b l e -> e (b l)")
545
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
546
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd(
547
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz,
548
+ ctx.delta_softplus,
549
+ True # option to recompute out_z
550
+ )
551
+ # flip one
552
+ dz_b = torch.empty_like(dz)
553
+ dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd(
554
+ conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b,
555
+ ctx.delta_softplus,
556
+ True # option to recompute out_z
557
+ )
558
+
559
+ dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1])
560
+ ddelta = ddelta + ddelta_f_b.flip([-1])
561
+ dB = dB + dB_f_b.flip([-1])
562
+ dC = dC + dC_f_b.flip([-1])
563
+ dD = dD + dD_b
564
+ ddelta_bias = ddelta_bias + ddelta_bias_b
565
+ dz = dz + dz_b.flip([-1])
566
+ out_z = out_z_f + out_z_b.flip([-1])
567
+
568
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
569
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
570
+ dD = dD if D is not None else None
571
+ dx_dbl = torch.empty_like(x_dbl)
572
+ dB_proj_bias = None
573
+ if ctx.is_variable_B:
574
+ if not A.is_complex():
575
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
576
+ else:
577
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
578
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
579
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
580
+ dB = None
581
+ dC_proj_bias = None
582
+ if ctx.is_variable_C:
583
+ if not A.is_complex():
584
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
585
+ else:
586
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
587
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
588
+ dx_dbl[:, -d_state:] = dC # (bl d)
589
+ dC = None
590
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
591
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
592
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
593
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
594
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
595
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
596
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
597
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
598
+ # backward of conv1d with the backward of chunk).
599
+ dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
600
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
601
+ )
602
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
603
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
604
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
605
+ dout_proj_weight, dout_proj_bias,
606
+ dA, dA_b, dB, dC, dD,
607
+ ddelta_bias if delta_bias is not None else None,
608
+ dB_proj_bias, dC_proj_bias, None)
609
+
610
+
611
+ def mamba_inner_fn(
612
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
613
+ out_proj_weight, out_proj_bias,
614
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
615
+ C_proj_bias=None, delta_softplus=True
616
+ ):
617
+ return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
618
+ out_proj_weight, out_proj_bias,
619
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
620
+
621
+ def bimamba_inner_fn(
622
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
623
+ out_proj_weight, out_proj_bias,
624
+ A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
625
+ C_proj_bias=None, delta_softplus=True
626
+ ):
627
+ return BiMambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
628
+ out_proj_weight, out_proj_bias,
629
+ A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
630
+
631
+
632
+ def mamba_inner_fn_no_out_proj(
633
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
634
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
635
+ C_proj_bias=None, delta_softplus=True
636
+ ):
637
+ return MambaInnerFnNoOutProj.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
638
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
639
+
640
+
641
+ def mamba_inner_ref(
642
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
643
+ out_proj_weight, out_proj_bias,
644
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
645
+ C_proj_bias=None, delta_softplus=True
646
+ ):
647
+ L = xz.shape[-1]
648
+ delta_rank = delta_proj_weight.shape[1]
649
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
650
+ x, z = xz.chunk(2, dim=1)
651
+ x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
652
+ # We're being very careful here about the layout, to avoid extra transposes.
653
+ # We want delta to have d as the slowest moving dimension
654
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
655
+ x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
656
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
657
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
658
+ if B is None: # variable B
659
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
660
+ if B_proj_bias is not None:
661
+ B = B + B_proj_bias.to(dtype=B.dtype)
662
+ if not A.is_complex():
663
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
664
+ else:
665
+ B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
666
+ if C is None: # variable B
667
+ C = x_dbl[:, -d_state:] # (bl d)
668
+ if C_proj_bias is not None:
669
+ C = C + C_proj_bias.to(dtype=C.dtype)
670
+ if not A.is_complex():
671
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
672
+ else:
673
+ C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
674
+ y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
675
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
676
+
677
+
678
+ def bimamba_inner_ref(
679
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
680
+ out_proj_weight, out_proj_bias,
681
+ A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
682
+ C_proj_bias=None, delta_softplus=True
683
+ ):
684
+ L = xz.shape[-1]
685
+ delta_rank = delta_proj_weight.shape[1]
686
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
687
+ x, z = xz.chunk(2, dim=1)
688
+ x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
689
+ # We're being very careful here about the layout, to avoid extra transposes.
690
+ # We want delta to have d as the slowest moving dimension
691
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
692
+ x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
693
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
694
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
695
+ if B is None: # variable B
696
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
697
+ if B_proj_bias is not None:
698
+ B = B + B_proj_bias.to(dtype=B.dtype)
699
+ if not A.is_complex():
700
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
701
+ else:
702
+ B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
703
+ if C is None: # variable B
704
+ C = x_dbl[:, -d_state:] # (bl d)
705
+ if C_proj_bias is not None:
706
+ C = C + C_proj_bias.to(dtype=C.dtype)
707
+ if not A.is_complex():
708
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
709
+ else:
710
+ C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
711
+ y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
712
+ y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
713
+ y = y + y_b.flip([-1])
714
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
model/patchify.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Patchify(nn.Module):
5
+ def __init__(self, in_channels, out_channels, patch_size):
6
+ super(Patchify, self).__init__()
7
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(8, patch_size), stride=(8, patch_size), padding=0, bias=False)
8
+
9
+ def forward(self, x):
10
+ # x.shape = (batch_size, channels, height, width)
11
+ x = self.conv(x)
12
+
13
+ return x
14
+
15
+ if __name__ == "__main__":
16
+ model = Patchify(1, 32, 2)
17
+ print(model)
18
+ dummy_input = torch.randn(1, 1, 64, 16)
19
+ output = model(dummy_input)
20
+ print(output.shape)
model/sinc_conv.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ import torch.fft
6
+ import sys
7
+ from torch.autograd import Variable
8
+ import math
9
+
10
+ class GlobalLayerNorm(nn.Module):
11
+ '''
12
+ Calculate Global Layer Normalization
13
+ dim: (int or list or torch.Size) –
14
+ input shape from an expected input of size
15
+ eps: a value added to the denominator for numerical stability.
16
+ elementwise_affine: a boolean value that when set to True,
17
+ this module has learnable per-element affine parameters
18
+ initialized to ones (for weights) and zeros (for biases).
19
+ '''
20
+
21
+ def __init__(self, dim, eps=1e-05, elementwise_affine=True):
22
+ super(GlobalLayerNorm, self).__init__()
23
+ self.dim = dim
24
+ self.eps = eps
25
+ self.elementwise_affine = elementwise_affine
26
+
27
+ if self.elementwise_affine:
28
+ self.weight = nn.Parameter(torch.ones(self.dim, 1))
29
+ self.bias = nn.Parameter(torch.zeros(self.dim, 1))
30
+ else:
31
+ self.register_parameter('weight', None)
32
+ self.register_parameter('bias', None)
33
+
34
+ def forward(self, x):
35
+ # x = N x C x L
36
+ # N x 1 x 1
37
+ # cln: mean,var N x 1 x L
38
+ # gln: mean,var N x 1 x 1
39
+ if x.dim() != 3:
40
+ raise RuntimeError("{} accept 3D tensor as input".format(
41
+ self.__name__))
42
+
43
+ mean = torch.mean(x, (1, 2), keepdim=True)
44
+ var = torch.mean((x-mean)**2, (1, 2), keepdim=True)
45
+ # N x C x L
46
+ if self.elementwise_affine:
47
+ x = self.weight*(x-mean)/torch.sqrt(var+self.eps)+self.bias
48
+ else:
49
+ x = (x-mean)/torch.sqrt(var+self.eps)
50
+ return x
51
+
52
+
53
+ class TimeSincExtractor(nn.Module):
54
+ """Sinc-based convolution
55
+ Parameters
56
+ ----------
57
+ in_channels : `int`
58
+ Number of input channels. Must be 1.
59
+ out_channels : `int`
60
+ Number of filters.
61
+ kernel_size : `int`
62
+ Filter length.
63
+ sample_rate : `int`, optional
64
+ Sample rate. Defaults to 16000.
65
+ triangular : `bool`
66
+ Squared sinc -> Triangular filter.
67
+ freq_nml : `bool`
68
+ Normalized to gain of 1 in frequency.
69
+ range_constraint : `bool`
70
+ Project the learned band within nyquist freq manually.
71
+ Usage
72
+ -----
73
+ See `torch.nn.Conv1d`
74
+ """
75
+
76
+ @staticmethod
77
+ def to_mel(hz):
78
+ return 2595 * np.log10(1 + hz / 700)
79
+
80
+ @staticmethod
81
+ def to_hz(mel):
82
+ return 700 * (10 ** (mel / 2595) - 1)
83
+
84
+ def swap_(self, x, y, sort=False):
85
+ mini = torch.minimum(x, y)
86
+ maxi = torch.maximum(x, y)
87
+
88
+ if sort:
89
+ mini, idx = torch.sort(mini)
90
+ maxi = maxi[idx].view(mini.shape)
91
+
92
+ return mini, maxi
93
+
94
+ def __init__(self, out_channels, kernel_size, triangular=False,
95
+ freq_nml=False, range_constraint=False, freq_init='uniform', norm_after=True, sample_rate=16000, in_channels=1,
96
+ stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50, bi_factor=False, frame_length=400, hop_length=160):
97
+
98
+ super(TimeSincExtractor,self).__init__()
99
+
100
+ if in_channels != 1:
101
+ # msg = (f'SincConv only support one input channel '
102
+ # f'(here, in_channels = {in_channels:d}).')
103
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
104
+ raise ValueError(msg)
105
+
106
+ self.out_channels = out_channels
107
+ self.kernel_size = kernel_size
108
+ self.triangular = False
109
+ self.freq_nml = False
110
+
111
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
112
+ if kernel_size%2 == 0:
113
+ self.kernel_size = self.kernel_size+1
114
+
115
+ self.stride = stride
116
+ self.padding = padding
117
+ self.dilation = dilation
118
+
119
+ self.frame_length = frame_length
120
+ self.hop_length = hop_length
121
+
122
+ if bias:
123
+ raise ValueError('SincConv does not support bias.')
124
+ if groups > 1:
125
+ raise ValueError('SincConv does not support groups.')
126
+
127
+ self.sample_rate = sample_rate
128
+ self.nyquist_rate = sample_rate/2
129
+ self.min_low_hz = min_low_hz
130
+ self.min_band_hz = min_band_hz
131
+ self.range_constraint = range_constraint
132
+ self.bi_factor = bi_factor
133
+
134
+ if self.range_constraint:
135
+ # msg = "Range constraint in learned frequency is not supported yet."
136
+ # raise ValueError(msg)
137
+ if freq_init == "uniform":
138
+ low_freq, high_freq = torch.rand(out_channels*2).chunk(2)
139
+ elif freq_init == "formant":
140
+ # raise NotImplementedError('Formant distribution hasn\'t been implemented yet.')
141
+ p = np.load('/share/nas165/Jasonho610/SincNet/exp/formant_distribution.npy')
142
+ low_freq, high_freq = torch.from_numpy(np.random.choice(8000, out_channels*2, p=p)).chunk(2)
143
+ low_freq = low_freq / self.nyquist_rate
144
+ high_freq = high_freq / self.nyquist_rate
145
+ elif freq_init == "mel":
146
+ # raise NotImplementedError('Mel distribution hasn\'t been implemented yet.')
147
+ low_hz = 30
148
+ high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz)
149
+ mel = np.linspace(self.to_mel(low_hz),
150
+ self.to_mel(high_hz),
151
+ self.out_channels + 1)
152
+ hz = self.to_hz(mel)
153
+ low_freq = torch.Tensor(hz[:-1]) / self.nyquist_rate
154
+ high_freq = torch.Tensor(hz[1:]) / self.nyquist_rate
155
+ else:
156
+ raise ValueError('SincConv must specify the freq initialization methods.')
157
+
158
+ low_freq, high_freq = self.swap_(low_freq, high_freq)
159
+
160
+ if self.bi_factor:
161
+ self.band_imp = nn.Parameter(torch.ones(out_channels))
162
+ self.low_f_ = nn.Parameter(low_freq.view(-1, 1))
163
+ self.high_f_ = nn.Parameter(high_freq.view(-1, 1))
164
+ else:
165
+ # initialize filterbanks such that they are equally spaced in Mel scale
166
+ low_hz = 30
167
+ high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz)
168
+ mel = np.linspace(self.to_mel(low_hz),
169
+ self.to_mel(high_hz),
170
+ self.out_channels + 1)
171
+ hz = self.to_hz(mel)
172
+ # filter lower frequency (out_channels, 1)
173
+ self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
174
+
175
+ # filter frequency band (out_channels, 1)
176
+ self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
177
+
178
+ # Hamming window
179
+ # self.window_ = torch.hamming_window(self.kernel_size)
180
+ n_lin = torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window
181
+ self.window_ = 0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size);
182
+
183
+ # (1, kernel_size/2)
184
+ n = (self.kernel_size - 1) / 2.0
185
+ self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes
186
+
187
+ self.norm_after = norm_after
188
+ if self.norm_after:
189
+ self.ln = GlobalLayerNorm(out_channels)
190
+
191
+
192
+ def forward(self, waveforms, embedding):
193
+ """
194
+ Parameters
195
+ ----------
196
+ waveforms : `torch.Tensor` (batch_size, 1, n_samples)
197
+ Batch of waveforms.
198
+ Returns
199
+ -------
200
+ features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
201
+ Batch of sinc filters activations.
202
+ """
203
+
204
+ self.n_ = self.n_.to(waveforms.device)
205
+ self.window_ = self.window_.to(waveforms.device)
206
+ # waveforms = waveforms.unsqueeze(1)
207
+ # print("Waveforms:", waveforms.shape)
208
+
209
+ framing_padding = self.frame_length - (waveforms.shape[-1] % self.hop_length)
210
+ waveforms = F.pad(waveforms, (0, framing_padding))
211
+ frames = waveforms.unfold(-1, self.frame_length, self.hop_length)
212
+
213
+ batch_size = frames.shape[0]
214
+ n_frames = frames.shape[2]
215
+
216
+ if self.range_constraint:
217
+ low_f_, high_f_ = self.swap_(torch.abs(self.low_f_), torch.abs(self.high_f_))
218
+
219
+ low = self.min_low_hz + low_f_*self.nyquist_rate
220
+ high = torch.clamp(self.min_band_hz + high_f_*self.nyquist_rate, self.min_low_hz, self.nyquist_rate)
221
+ band = (high-low)[:,0]
222
+ else:
223
+ low = self.min_low_hz + torch.abs(self.low_hz_)
224
+ high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.nyquist_rate)
225
+ band = (high-low)[:,0]
226
+
227
+ self.low = low
228
+ self.high = high
229
+ self.band = band
230
+
231
+ f_times_t_low = torch.matmul(low, self.n_)
232
+ f_times_t_high = torch.matmul(high, self.n_)
233
+
234
+ band_pass_left = ((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations.
235
+ band_pass_center = 2*band.view(-1,1)
236
+ band_pass_right = torch.flip(band_pass_left,dims=[1])
237
+
238
+ band_pass = torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1)
239
+
240
+ band_pass = band_pass / (2*band[:,None])
241
+
242
+ if self.triangular:
243
+ band_pass = band_pass**2
244
+
245
+ if self.freq_nml:
246
+ mag_resp = torch.fft.rfft(band_pass).abs()
247
+ mag_max = torch.max(mag_resp, dim=-1)[0]
248
+ band_pass = band_pass / mag_max.unsqueeze(-1)
249
+
250
+ if self.bi_factor:
251
+ band_imp = F.relu(self.band_imp)
252
+ band_pass = band_pass * band_imp.unsqueeze(-1)
253
+
254
+
255
+ self.filters = (band_pass).view(
256
+ self.out_channels, 1, self.kernel_size)
257
+
258
+ # print("Filters:", self.filters.shape)
259
+ # print("Frames:", frames.shape)
260
+
261
+ rs_frames = frames.reshape(batch_size*n_frames, 1, self.frame_length)
262
+ # print("Reshaped frames:", rs_frames.shape)
263
+
264
+ filtered = F.conv1d(rs_frames, self.filters, stride=self.stride,
265
+ padding=self.padding, dilation=self.dilation,
266
+ bias=None, groups=1)
267
+ # print('Pass conv1d')
268
+ # print("Filtered:", filtered.shape)
269
+ if self.norm_after:
270
+ filtered = self.ln(filtered)
271
+
272
+ # print("Normed filtered:", filtered.shape)
273
+
274
+ filtered = filtered.reshape(batch_size, n_frames, self.out_channels , -1)
275
+
276
+ # print("Final filtered:", filtered.shape)
277
+
278
+ energy = torch.mean(filtered**2, dim=-1)
279
+ log_filtered_energy = torch.log10(energy + 1e-6)
280
+ # print("Log filtered energy:", log_filtered_energy.shape) # (batch_size, n_samples_out(time), out_channels(frequency))
281
+
282
+ log_filtered_energy = log_filtered_energy.unsqueeze(1)
283
+ # print("Unsqueezed log filtered energy:", log_filtered_energy.shape) # (batch_size, channels, n_samples_out(time), out_channels(frequency))
284
+
285
+ log_filtered_energy = log_filtered_energy.permute(0, 1, 3, 2)
286
+ # print("Permuted log filtered energy:", log_filtered_energy.shape) # (batch_size, channels, out_channels(frequency), n_samples_out(time))
287
+
288
+ return log_filtered_energy, self.filters, self.stride, self.padding
289
+
290
+
291
+ class FreqSincExtractor(nn.Module):
292
+ @staticmethod
293
+ def to_mel(hz):
294
+ return 2595 * np.log10(1 + hz / 700)
295
+
296
+ @staticmethod
297
+ def to_hz(mel):
298
+ return 700 * (10 ** (mel / 2595) - 1)
299
+
300
+ def swap_(self, x, y, sort=False):
301
+ mini = torch.minimum(x, y)
302
+ maxi = torch.maximum(x, y)
303
+ if sort:
304
+ mini, idx = torch.sort(mini)
305
+ maxi = maxi[idx].view(mini.shape)
306
+ return mini, maxi
307
+
308
+ def __init__(self, out_channels, kernel_size, triangular=False,
309
+ freq_nml=False, range_constraint=False, freq_init='uniform',
310
+ norm_after=True, sample_rate=16000, in_channels=1,
311
+ stride=1, padding=0, dilation=1, bias=False, groups=1,
312
+ min_low_hz=50, min_band_hz=50, bi_factor=False,
313
+ frame_length=400, hop_length=160, n_fft=400):
314
+ super(FreqSincExtractor, self).__init__()
315
+
316
+ if in_channels != 1:
317
+ msg = "FreqSincExtractor only supports one input channel (here, in_channels = {%i})" % (in_channels)
318
+ raise ValueError(msg)
319
+
320
+ self.out_channels = out_channels
321
+ self.kernel_size = kernel_size
322
+ self.triangular = triangular
323
+ self.freq_nml = freq_nml
324
+ self.sample_rate = sample_rate
325
+ self.nyquist_rate = sample_rate/2
326
+ self.min_low_hz = min_low_hz
327
+ self.min_band_hz = min_band_hz
328
+ self.range_constraint = range_constraint
329
+ self.bi_factor = bi_factor
330
+ self.frame_length = frame_length
331
+ self.hop_length = hop_length
332
+ self.n_fft = n_fft
333
+ self.stride = stride
334
+ self.padding = padding
335
+ self.output_size = 64
336
+
337
+ # Initialize frequency bands
338
+ if self.range_constraint:
339
+ if freq_init == "uniform":
340
+ low_freq, high_freq = torch.rand(out_channels*2).chunk(2)
341
+ elif freq_init == "mel":
342
+ low_hz = 30
343
+ high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz)
344
+ mel = np.linspace(self.to_mel(low_hz),
345
+ self.to_mel(high_hz),
346
+ self.out_channels + 1)
347
+ hz = self.to_hz(mel)
348
+ low_freq = torch.Tensor(hz[:-1]) / self.nyquist_rate
349
+ high_freq = torch.Tensor(hz[1:]) / self.nyquist_rate
350
+ else:
351
+ raise ValueError('FreqSincExtractor must specify the freq initialization methods.')
352
+
353
+ low_freq, high_freq = self.swap_(low_freq, high_freq)
354
+
355
+ if self.bi_factor:
356
+ self.band_imp = nn.Parameter(torch.ones(out_channels))
357
+ self.low_f_ = nn.Parameter(low_freq.view(-1, 1))
358
+ self.high_f_ = nn.Parameter(high_freq.view(-1, 1))
359
+ else:
360
+ low_hz = 30
361
+ high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz)
362
+ mel = np.linspace(self.to_mel(low_hz),
363
+ self.to_mel(high_hz),
364
+ self.out_channels + 1)
365
+ hz = self.to_hz(mel)
366
+ self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
367
+ self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
368
+
369
+ # Frequency axis for STFT
370
+ self.freq_axis = torch.linspace(0, self.nyquist_rate, self.n_fft//2 + 1)
371
+
372
+ self.norm_after = norm_after
373
+ if self.norm_after:
374
+ self.ln = GlobalLayerNorm(out_channels)
375
+
376
+ def get_filters(self):
377
+ if self.range_constraint:
378
+ low_f_, high_f_ = self.swap_(torch.abs(self.low_f_), torch.abs(self.high_f_))
379
+ low = self.min_low_hz + low_f_ * self.nyquist_rate
380
+ high = torch.clamp(self.min_low_hz + high_f_ * self.nyquist_rate,
381
+ self.min_low_hz, self.nyquist_rate)
382
+ else:
383
+ low = self.min_low_hz + torch.abs(self.low_hz_)
384
+ high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),
385
+ self.min_low_hz, self.nyquist_rate)
386
+
387
+ # Create frequency domain filters
388
+ freq_axis = self.freq_axis.to(low.device)
389
+ filters = torch.zeros((self.out_channels, len(freq_axis))).to(low.device)
390
+
391
+ for i in range(self.out_channels):
392
+ mask = (freq_axis >= low[i]) & (freq_axis <= high[i])
393
+ filters[i, mask] = 1.0
394
+
395
+ if self.triangular:
396
+ center_freq = (low[i] + high[i]) / 2
397
+ bandwidth = high[i] - low[i]
398
+ mask = (freq_axis >= low[i]) & (freq_axis <= high[i])
399
+ freq_response = 1.0 - torch.abs(freq_axis[mask] - center_freq) / (bandwidth/2)
400
+ filters[i, mask] = freq_response
401
+
402
+ if self.freq_nml:
403
+ filters = F.normalize(filters, p=2, dim=1)
404
+
405
+ if self.bi_factor:
406
+ band_imp = F.relu(self.band_imp)
407
+ filters = filters * band_imp.unsqueeze(-1)
408
+
409
+ return filters
410
+
411
+ def forward(self, waveforms, embedding=None):
412
+ batch_size = waveforms.shape[0]
413
+
414
+ # Calculate necessary padding to achieve the correct output size
415
+ target_length = self.hop_length * (self.output_size - 1) + self.frame_length
416
+ current_length = waveforms.shape[-1]
417
+ padding_needed = target_length - current_length
418
+
419
+ # Pad the input if necessary
420
+ if padding_needed > 0:
421
+ waveforms = F.pad(waveforms, (0, padding_needed))
422
+
423
+ # Compute STFT
424
+ stft = torch.stft(waveforms.squeeze(1),
425
+ n_fft=self.n_fft,
426
+ hop_length=self.hop_length,
427
+ win_length=self.frame_length,
428
+ window=torch.hann_window(self.frame_length).to(waveforms.device),
429
+ return_complex=True)
430
+
431
+ # Get magnitude spectrogram
432
+ mag_spec = torch.abs(stft) # (batch_size, freq_bins, time_frames)
433
+
434
+ # Get and apply filters
435
+ filters = self.get_filters() # (out_channels, freq_bins)
436
+ filtered = torch.matmul(filters, mag_spec) # (batch_size, out_channels, time_frames)
437
+
438
+ if self.norm_after:
439
+ filtered = self.ln(filtered)
440
+
441
+ # Compute log energy
442
+ energy = filtered ** 2
443
+ log_energy = torch.log10(energy + 1e-6)
444
+
445
+ # Ensure correct time dimension
446
+ if log_energy.shape[-1] != self.output_size:
447
+ log_energy = F.interpolate(
448
+ log_energy,
449
+ size=self.output_size,
450
+ mode='linear',
451
+ align_corners=False
452
+ )
453
+
454
+ # Reshape to the desired output format
455
+ log_energy = log_energy.unsqueeze(1) # Add channel dimension
456
+ log_energy = log_energy.permute(0, 1, 3, 2) # Rearrange to (batch, channel, freq, time)
457
+
458
+ return log_energy, filters, self.stride, self.padding
459
+
460
+
461
+ if __name__ == "__main__":
462
+ batch_size = 256
463
+ n_samples = 10080
464
+ waveforms = torch.rand(batch_size, 1, n_samples)
465
+
466
+ # model = TimeSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
467
+ model = FreqSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
468
+ print(model)
469
+
470
+ outputs, _, _, _ = model(waveforms, embedding=None)
471
+ print("Outputs:", outputs.shape)
model/tiny_block.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class TinyBlock(nn.Module):
5
+ def __init__(self, in_channels, out_channels, dilation=2):
6
+ super(TinyBlock, self).__init__()
7
+
8
+ # f1: 3x3 depthwise convolution + BatchNorm
9
+ self.f1 = nn.Sequential(
10
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False),
11
+ nn.BatchNorm2d(in_channels)
12
+ )
13
+
14
+ # f2: 1x1 grouped pointwise convolutions with 8 groups + ReLU
15
+ self.f2 = nn.Sequential(
16
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=8, bias=False),
17
+ nn.ReLU(inplace=True)
18
+ )
19
+
20
+ def forward(self, x):
21
+ f1_out = self.f1(x)
22
+ f2_out = self.f2(x + f1_out)
23
+ out = x + f1_out + f2_out
24
+ return out
25
+
26
+ if __name__ == "__main__":
27
+ model = TinyBlock(16, 16)
28
+ print(model)
29
+ dummy_input = torch.randn(256, 16, 8, 8)
30
+ output = model(dummy_input)
31
+ print(output.shape)
model/tinyvad.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .sinc_conv import TimeSincExtractor, FreqSincExtractor
4
+ from .patchify import Patchify
5
+ from .csp_tiny_layer import CSPTinyLayer
6
+
7
+ class TinyVAD(nn.Module):
8
+ def __init__(self, in_channels, hidden_channels, out_channels, patch_size, num_blocks, sinc_conv, ssm):
9
+ super(TinyVAD, self).__init__()
10
+
11
+ self.sinc_conv = sinc_conv
12
+
13
+ if self.sinc_conv:
14
+ # self.extractor = TimeSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
15
+ self.extractor = FreqSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
16
+
17
+ self.patchify = Patchify(in_channels, hidden_channels, patch_size)
18
+
19
+ self.csp_tiny_layer1 = CSPTinyLayer(hidden_channels, hidden_channels, num_blocks, ssm)
20
+ self.csp_tiny_layer2 = CSPTinyLayer(hidden_channels, hidden_channels, num_blocks, ssm)
21
+ self.csp_tiny_layer3 = CSPTinyLayer(hidden_channels, out_channels, num_blocks, ssm)
22
+
23
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
24
+
25
+ self.classifier = nn.Sequential(
26
+ nn.Linear(out_channels, 1),
27
+ # nn.Sigmoid()
28
+ )
29
+
30
+ def forward(self, x):
31
+ if self.sinc_conv:
32
+ x = self.extractor(x, None)
33
+ x = x[0] # Untuple
34
+
35
+ x = self.patchify(x)
36
+
37
+ x = self.csp_tiny_layer1(x)
38
+ x = self.csp_tiny_layer2(x)
39
+ x = self.csp_tiny_layer3(x)
40
+
41
+ x = self.avg_pool(x).view(x.size(0), -1)
42
+
43
+ x = self.classifier(x)
44
+
45
+ return x
46
+
47
+ def predict(self, inputs):
48
+ logits = self.forward(inputs)
49
+ probs = torch.sigmoid(logits)
50
+
51
+ return probs
52
+
53
+ if __name__ == "__main__":
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ print(f"Using device: {device}")
56
+
57
+ model = TinyVAD(1, 32, 64, 2, 2, False, False).to(device)
58
+ print(model)
59
+ dummy_input = torch.randn(1, 1, 64, 16).to(device)
60
+ output = model(dummy_input)
61
+ print(output)
62
+