File size: 19,218 Bytes
1423dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
"""ConMamba encoder and Mamba decoder implementation.

Authors
-------
* Xilin Jiang 2024
"""

import warnings
from dataclasses import dataclass
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

import speechbrain as sb
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.attention import (
    MultiheadAttention,
    PositionalwiseFeedForward,
    RelPosMHAXL,
)
from speechbrain.nnet.hypermixing import HyperMixing
from speechbrain.nnet.normalization import LayerNorm
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig

# Mamba
from mamba_ssm import Mamba
from .mamba.bimamba import Mamba as BiMamba 


class ConvolutionModule(nn.Module):
    """This is an implementation of convolution module in Conmamba.
    """

    def __init__(
        self,
        input_size,
        kernel_size=31,
        bias=True,
        activation=Swish,
        dropout=0.0,
        causal=False,
        dilation=1,
    ):
        super().__init__()

        self.kernel_size = kernel_size
        self.causal = causal
        self.dilation = dilation

        if self.causal:
            self.padding = (kernel_size - 1) * 2 ** (dilation - 1)
        else:
            self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2

        self.layer_norm = nn.LayerNorm(input_size)
        self.bottleneck = nn.Sequential(
            # pointwise
            nn.Conv1d(
                input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias
            ),
            nn.GLU(dim=1),
        )
        # depthwise
        self.conv = nn.Conv1d(
            input_size,
            input_size,
            kernel_size=kernel_size,
            stride=1,
            padding=self.padding,
            dilation=dilation,
            groups=input_size,
            bias=bias,
        )

        # BatchNorm in the original Conformer replaced with a LayerNorm due to
        # https://github.com/speechbrain/speechbrain/pull/1329
        # see discussion
        # https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884

        self.after_conv = nn.Sequential(
            nn.LayerNorm(input_size),
            activation(),
            # pointwise
            nn.Linear(input_size, input_size, bias=bias),
            nn.Dropout(dropout),
        )

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
    ):
        """Applies the convolution to an input tensor `x`.
        """

        if dynchunktrain_config is not None:
            # chances are chunking+causal is unintended; i don't know where it
            # may make sense, but if it does to you, feel free to implement it.
            assert (
                not self.causal
            ), "Chunked convolution not supported with causal padding"

            assert (
                self.dilation == 1
            ), "Current DynChunkTrain logic does not support dilation != 1"

            # in a causal convolution, which is not the case here, an output
            # frame would never be able to depend on a input frame from any
            # point in the future.

            # but with the dynamic chunk convolution, we instead use a "normal"
            # convolution but where, for any output frame, the future beyond the
            # "current" chunk gets masked.
            # see the paper linked in the documentation for details.

            chunk_size = dynchunktrain_config.chunk_size
            batch_size = x.shape[0]

            # determine the amount of padding we need to insert at the right of
            # the last chunk so that all chunks end up with the same size.
            if x.shape[1] % chunk_size != 0:
                final_right_padding = chunk_size - (x.shape[1] % chunk_size)
            else:
                final_right_padding = 0

            # -> [batch_size, t, in_channels]
            out = self.layer_norm(x)

            # -> [batch_size, in_channels, t] for the CNN
            out = out.transpose(1, 2)

            # -> [batch_size, in_channels, t] (pointwise)
            out = self.bottleneck(out)

            # -> [batch_size, in_channels, lc+t+final_right_padding]
            out = F.pad(out, (self.padding, final_right_padding), value=0)

            # now, make chunks with left context.
            # as a recap to what the above padding and this unfold do, consider
            # each a/b/c letter represents a frame as part of chunks a, b, c.
            # consider a chunk size of 4 and a kernel size of 5 (padding=2):
            #
            # input seq: 00aaaabbbbcc00
            # chunk #1:  00aaaa
            # chunk #2:      aabbbb
            # chunk #3:          bbcc00
            #
            # a few remarks here:
            # - the left padding gets inserted early so that the unfold logic
            #   works trivially
            # - the right 0-padding got inserted as the number of time steps
            #   could not be evenly split in `chunk_size` chunks

            # -> [batch_size, in_channels, num_chunks, lc+chunk_size]
            out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size)

            # as we manually disable padding in the convolution below, we insert
            # right 0-padding to the chunks, e.g. reusing the above example:
            #
            # chunk #1:  00aaaa00
            # chunk #2:      aabbbb00
            # chunk #3:          bbcc0000

            # -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad]
            out = F.pad(out, (0, self.padding), value=0)

            # the transpose+flatten effectively flattens chunks into the batch
            # dimension to be processed into the time-wise convolution. the
            # chunks will later on be unflattened.

            # -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad]
            out = out.transpose(1, 2)

            # -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad]
            out = out.flatten(start_dim=0, end_dim=1)

            # TODO: experiment around reflect padding, which is difficult
            # because small chunks have too little time steps to reflect from

            # let's keep backwards compat by pointing at the weights from the
            # already declared Conv1d.
            #
            # still reusing the above example, the convolution will be applied,
            # with the padding truncated on both ends. the following example
            # shows the letter corresponding to the input frame on which the
            # convolution was centered.
            #
            # as you can see, the sum of lengths of all chunks is equal to our
            # input sequence length + `final_right_padding`.
            #
            # chunk #1:  aaaa
            # chunk #2:      bbbb
            # chunk #3:          cc00

            # -> [batch_size * num_chunks, out_channels, chunk_size]
            out = F.conv1d(
                out,
                weight=self.conv.weight,
                bias=self.conv.bias,
                stride=self.conv.stride,
                padding=0,
                dilation=self.conv.dilation,
                groups=self.conv.groups,
            )

            # -> [batch_size * num_chunks, chunk_size, out_channels]
            out = out.transpose(1, 2)

            out = self.after_conv(out)

            # -> [batch_size, num_chunks, chunk_size, out_channels]
            out = torch.unflatten(out, dim=0, sizes=(batch_size, -1))

            # -> [batch_size, t + final_right_padding, out_channels]
            out = torch.flatten(out, start_dim=1, end_dim=2)

            # -> [batch_size, t, out_channels]
            if final_right_padding > 0:
                out = out[:, :-final_right_padding, :]
        else:
            out = self.layer_norm(x)
            out = out.transpose(1, 2)
            out = self.bottleneck(out)
            out = self.conv(out)

            if self.causal:
                # chomp
                out = out[..., : -self.padding]

            out = out.transpose(1, 2)
            out = self.after_conv(out)

        if mask is not None:
            out.masked_fill_(mask, 0.0)

        return out


class ConmambaEncoderLayer(nn.Module):
    """This is an implementation of Conmamba encoder layer.
    """

    def __init__(
        self,
        d_model,
        d_ffn,
        kernel_size=31,
        activation=Swish,
        bias=True,
        dropout=0.0,
        causal=False,
        mamba_config=None
    ):
        super().__init__()
        assert mamba_config != None

        bidirectional = mamba_config.pop('bidirectional')
        if causal or (not bidirectional):
            self.mamba = Mamba(
                d_model=d_model,
                **mamba_config
            )
        else:
            self.mamba = BiMamba(
                d_model=d_model,
                bimamba_type='v2',
                **mamba_config
            )
        mamba_config['bidirectional'] = bidirectional

        self.convolution_module = ConvolutionModule(
            d_model, kernel_size, bias, activation, dropout, causal=causal
        )

        self.ffn_module1 = nn.Sequential(
            nn.LayerNorm(d_model),
            PositionalwiseFeedForward(
                d_ffn=d_ffn,
                input_size=d_model,
                dropout=dropout,
                activation=activation,
            ),
            nn.Dropout(dropout),
        )

        self.ffn_module2 = nn.Sequential(
            nn.LayerNorm(d_model),
            PositionalwiseFeedForward(
                d_ffn=d_ffn,
                input_size=d_model,
                dropout=dropout,
                activation=activation,
            ),
            nn.Dropout(dropout),
        )

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(
        self,
        x,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        pos_embs: torch.Tensor = None,
        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
    ):
        conv_mask: Optional[torch.Tensor] = None
        if src_key_padding_mask is not None:
            conv_mask = src_key_padding_mask.unsqueeze(-1)

        conv_mask = None

        # ffn module
        x = x + 0.5 * self.ffn_module1(x)
        # mamba module
        skip = x
        x = self.norm1(x)
        x = self.mamba(x)
        x = x + skip
        # convolution module
        x = x + self.convolution_module(
            x, conv_mask, dynchunktrain_config=dynchunktrain_config
        )
        # ffn module
        x = self.norm2(x + 0.5 * self.ffn_module2(x))
        return x


class ConmambaEncoder(nn.Module):
    """This class implements the Conmamba encoder.
    """

    def __init__(
        self,
        num_layers,
        d_model,
        d_ffn,
        kernel_size=31,
        activation=Swish,
        bias=True,
        dropout=0.0,
        causal=False,
        mamba_config=None
    ):
        super().__init__()
        print(f'dropout={str(dropout)} is not used in Mamba.')

        self.layers = torch.nn.ModuleList(
            [
                ConmambaEncoderLayer(
                    d_model=d_model,
                    d_ffn=d_ffn,
                    dropout=dropout,
                    activation=activation,
                    kernel_size=kernel_size,
                    bias=bias,
                    causal=causal,
                    mamba_config=mamba_config,
                )
                for i in range(num_layers)
            ]
        )
        self.norm = LayerNorm(d_model, eps=1e-6)

    def forward(
        self,
        src,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        pos_embs: Optional[torch.Tensor] = None,
        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
    ):
        """
        Arguments
        ----------
        src : torch.Tensor
            The sequence to the encoder layer.
        src_mask : torch.Tensor, optional
            The mask for the src sequence.
        src_key_padding_mask : torch.Tensor, optional
            The mask for the src keys per batch.
        pos_embs: torch.Tensor, torch.nn.Module,
            Module or tensor containing the input sequence positional embeddings
            If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
            where S is the sequence length, and E is the embedding dimension.
        dynchunktrain_config: Optional[DynChunkTrainConfig]
            Dynamic Chunk Training configuration object for streaming,
            specifically involved here to apply Dynamic Chunk Convolution to the
            convolution module.
        """

        output = src
        for enc_layer in self.layers:
            output = enc_layer(
                output,
                src_mask=src_mask,
                src_key_padding_mask=src_key_padding_mask,
                pos_embs=pos_embs,
                dynchunktrain_config=dynchunktrain_config,
            )
        output = self.norm(output)

        return output, None


class MambaDecoderLayer(nn.Module):
    """This class implements the Mamba decoder layer.
    """

    def __init__(
        self,
        d_model,
        d_ffn,
        activation=nn.ReLU,
        dropout=0.0,
        normalize_before=False,
        mamba_config=None
    ):
        super().__init__()

        assert mamba_config != None

        bidirectional = mamba_config.pop('bidirectional')

        self.self_mamba = Mamba(
            d_model=d_model,
            **mamba_config
        )

        self.cross_mamba = Mamba(
            d_model=d_model,
            **mamba_config
        )

        mamba_config['bidirectional'] = bidirectional

        self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
            d_ffn=d_ffn,
            input_size=d_model,
            dropout=dropout,
            activation=activation,
        )

        # normalization layers
        self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
        self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
        self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
        self.dropout1 = torch.nn.Dropout(dropout)
        self.dropout2 = torch.nn.Dropout(dropout)
        self.dropout3 = torch.nn.Dropout(dropout)

        self.normalize_before = normalize_before

    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        tgt_key_padding_mask=None,
        memory_key_padding_mask=None,
        pos_embs_tgt=None,
        pos_embs_src=None,
    ):
        """
        Arguments
        ----------
        tgt: torch.Tensor
            The sequence to the decoder layer (required).
        memory: torch.Tensor
            The sequence from the last layer of the encoder (required).
        tgt_mask: torch.Tensor
            The mask for the tgt sequence (optional).
        memory_mask: torch.Tensor
            The mask for the memory sequence (optional).
        tgt_key_padding_mask: torch.Tensor
            The mask for the tgt keys per batch (optional).
        memory_key_padding_mask: torch.Tensor
            The mask for the memory keys per batch (optional).
        pos_embs_tgt: torch.Tensor
            The positional embeddings for the target (optional).
        pos_embs_src: torch.Tensor
            The positional embeddings for the source (optional).
        """
        if self.normalize_before:
            tgt1 = self.norm1(tgt)
        else:
            tgt1 = tgt

        # Mamba over the target sequence
        tgt2 = self.self_mamba(tgt1)

        # add & norm
        tgt = tgt + self.dropout1(tgt2)
        if not self.normalize_before:
            tgt = self.norm1(tgt)

        if self.normalize_before:
            tgt1 = self.norm2(tgt)
        else:
            tgt1 = tgt

        # Mamba over key=value + query
        # and only take the last len(query) tokens
        tgt2 = self.cross_mamba(torch.cat([memory, tgt1], dim=1))[:, -tgt1.shape[1]:]
        
        # add & norm
        tgt = tgt + self.dropout2(tgt2)
        if not self.normalize_before:
            tgt = self.norm2(tgt)

        if self.normalize_before:
            tgt1 = self.norm3(tgt)
        else:
            tgt1 = tgt

        tgt2 = self.pos_ffn(tgt1)

        # add & norm
        tgt = tgt + self.dropout3(tgt2)
        if not self.normalize_before:
            tgt = self.norm3(tgt)

        return tgt, None, None


class MambaDecoder(nn.Module):
    """This class implements the Mamba decoder.
    """

    def __init__(
        self,
        num_layers,
        d_model,
        d_ffn,
        activation=nn.ReLU,
        dropout=0.0,
        normalize_before=False,
        mamba_config=None
    ):
        super().__init__()
        self.layers = torch.nn.ModuleList(
            [
                MambaDecoderLayer(
                    d_model=d_model,
                    d_ffn=d_ffn,
                    activation=activation,
                    dropout=dropout,
                    normalize_before=normalize_before,
                    mamba_config=mamba_config
                )
                for _ in range(num_layers)
            ]
        )
        self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)

    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        tgt_key_padding_mask=None,
        memory_key_padding_mask=None,
        pos_embs_tgt=None,
        pos_embs_src=None,
    ):
        """
        Arguments
        ----------
        tgt : torch.Tensor
            The sequence to the decoder layer (required).
        memory : torch.Tensor
            The sequence from the last layer of the encoder (required).
        tgt_mask : torch.Tensor
            The mask for the tgt sequence (optional).
        memory_mask : torch.Tensor
            The mask for the memory sequence (optional).
        tgt_key_padding_mask : torch.Tensor
            The mask for the tgt keys per batch (optional).
        memory_key_padding_mask : torch.Tensor
            The mask for the memory keys per batch (optional).
        pos_embs_tgt : torch.Tensor
            The positional embeddings for the target (optional).
        pos_embs_src : torch.Tensor
            The positional embeddings for the source (optional).
        """
        output = tgt
        for dec_layer in self.layers:
            output, _, _ = dec_layer(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                pos_embs_tgt=pos_embs_tgt,
                pos_embs_src=pos_embs_src,
            )
        output = self.norm(output)

        return output, [None], [None]