File size: 36,300 Bytes
1423dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
"""Conformer implementation.

Authors
-------
* Jianyuan Zhong 2020
* Samuele Cornell 2021
* Sylvain de Langen 2023
"""

import warnings
from dataclasses import dataclass
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

import speechbrain as sb
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.attention import (
    MultiheadAttention,
    PositionalwiseFeedForward,
    RelPosMHAXL,
)
from speechbrain.nnet.hypermixing import HyperMixing
from speechbrain.nnet.normalization import LayerNorm
from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig


@dataclass
class ConformerEncoderLayerStreamingContext:
    """Streaming metadata and state for a `ConformerEncoderLayer`.

    The multi-head attention and Dynamic Chunk Convolution require to save some
    left context that gets inserted as left padding.

    See :class:`.ConvolutionModule` documentation for further details.
    """

    mha_left_context_size: int
    """For this layer, specifies how many frames of inputs should be saved.
    Usually, the same value is used across all layers, but this can be modified.
    """

    mha_left_context: Optional[torch.Tensor] = None
    """Left context to insert at the left of the current chunk as inputs to the
    multi-head attention. It can be `None` (if we're dealing with the first
    chunk) or `<= mha_left_context_size` because for the first few chunks, not
    enough left context may be available to pad.
    """

    dcconv_left_context: Optional[torch.Tensor] = None
    """Left context to insert at the left of the convolution according to the
    Dynamic Chunk Convolution method.

    Unlike `mha_left_context`, here the amount of frames to keep is fixed and
    inferred from the kernel size of the convolution module.
    """


@dataclass
class ConformerEncoderStreamingContext:
    """Streaming metadata and state for a `ConformerEncoder`."""

    dynchunktrain_config: DynChunkTrainConfig
    """Dynamic Chunk Training configuration holding chunk size and context size
    information."""

    layers: List[ConformerEncoderLayerStreamingContext]
    """Streaming metadata and state for each layer of the encoder."""


class ConvolutionModule(nn.Module):
    """This is an implementation of convolution module in Conformer.

    Arguments
    ---------
    input_size : int
        The expected size of the input embedding dimension.
    kernel_size: int, optional
        Kernel size of non-bottleneck convolutional layer.
    bias: bool, optional
        Whether to use bias in the non-bottleneck conv layer.
    activation: torch.nn.Module
         Activation function used after non-bottleneck conv layer.
    dropout: float, optional
         Dropout rate.
    causal: bool, optional
         Whether the convolution should be causal or not.
    dilation: int, optional
         Dilation factor for the non bottleneck conv layer.

    Example
    -------
    >>> import torch
    >>> x = torch.rand((8, 60, 512))
    >>> net = ConvolutionModule(512, 3)
    >>> output = net(x)
    >>> output.shape
    torch.Size([8, 60, 512])
    """

    def __init__(
        self,
        input_size,
        kernel_size=31,
        bias=True,
        activation=Swish,
        dropout=0.0,
        causal=False,
        dilation=1,
    ):
        super().__init__()

        self.kernel_size = kernel_size
        self.causal = causal
        self.dilation = dilation

        if self.causal:
            self.padding = (kernel_size - 1) * 2 ** (dilation - 1)
        else:
            self.padding = (kernel_size - 1) * 2 ** (dilation - 1) // 2

        self.layer_norm = nn.LayerNorm(input_size)
        self.bottleneck = nn.Sequential(
            # pointwise
            nn.Conv1d(
                input_size, 2 * input_size, kernel_size=1, stride=1, bias=bias
            ),
            nn.GLU(dim=1),
        )
        # depthwise
        self.conv = nn.Conv1d(
            input_size,
            input_size,
            kernel_size=kernel_size,
            stride=1,
            padding=self.padding,
            dilation=dilation,
            groups=input_size,
            bias=bias,
        )

        # BatchNorm in the original Conformer replaced with a LayerNorm due to
        # https://github.com/speechbrain/speechbrain/pull/1329
        # see discussion
        # https://github.com/speechbrain/speechbrain/pull/933#issuecomment-1033367884

        self.after_conv = nn.Sequential(
            nn.LayerNorm(input_size),
            activation(),
            # pointwise
            nn.Linear(input_size, input_size, bias=bias),
            nn.Dropout(dropout),
        )

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
    ):
        """Applies the convolution to an input tensor `x`.

        Arguments
        ---------
        x: torch.Tensor
            Input tensor to the convolution module.
        mask: torch.Tensor, optional
            Mask to be applied over the output of the convolution using
            `masked_fill_`, if specified.
        dynchunktrain_config: DynChunkTrainConfig, optional
            If specified, makes the module support Dynamic Chunk Convolution
            (DCConv) as implemented by
            `Dynamic Chunk Convolution for Unified Streaming and Non-Streaming Conformer ASR <https://www.amazon.science/publications/dynamic-chunk-convolution-for-unified-streaming-and-non-streaming-conformer-asr>`_.
            This allows masking future frames while preserving better accuracy
            than a fully causal convolution, at a small speed cost.
            This should only be used for training (or, if you know what you're
            doing, for masked evaluation at inference time), as the forward
            streaming function should be used at inference time.

        Returns
        -------
        out: torch.Tensor
            The output tensor.
        """

        if dynchunktrain_config is not None:
            # chances are chunking+causal is unintended; i don't know where it
            # may make sense, but if it does to you, feel free to implement it.
            assert (
                not self.causal
            ), "Chunked convolution not supported with causal padding"

            assert (
                self.dilation == 1
            ), "Current DynChunkTrain logic does not support dilation != 1"

            # in a causal convolution, which is not the case here, an output
            # frame would never be able to depend on a input frame from any
            # point in the future.

            # but with the dynamic chunk convolution, we instead use a "normal"
            # convolution but where, for any output frame, the future beyond the
            # "current" chunk gets masked.
            # see the paper linked in the documentation for details.

            chunk_size = dynchunktrain_config.chunk_size
            batch_size = x.shape[0]

            # determine the amount of padding we need to insert at the right of
            # the last chunk so that all chunks end up with the same size.
            if x.shape[1] % chunk_size != 0:
                final_right_padding = chunk_size - (x.shape[1] % chunk_size)
            else:
                final_right_padding = 0

            # -> [batch_size, t, in_channels]
            out = self.layer_norm(x)

            # -> [batch_size, in_channels, t] for the CNN
            out = out.transpose(1, 2)

            # -> [batch_size, in_channels, t] (pointwise)
            out = self.bottleneck(out)

            # -> [batch_size, in_channels, lc+t+final_right_padding]
            out = F.pad(out, (self.padding, final_right_padding), value=0)

            # now, make chunks with left context.
            # as a recap to what the above padding and this unfold do, consider
            # each a/b/c letter represents a frame as part of chunks a, b, c.
            # consider a chunk size of 4 and a kernel size of 5 (padding=2):
            #
            # input seq: 00aaaabbbbcc00
            # chunk #1:  00aaaa
            # chunk #2:      aabbbb
            # chunk #3:          bbcc00
            #
            # a few remarks here:
            # - the left padding gets inserted early so that the unfold logic
            #   works trivially
            # - the right 0-padding got inserted as the number of time steps
            #   could not be evenly split in `chunk_size` chunks

            # -> [batch_size, in_channels, num_chunks, lc+chunk_size]
            out = out.unfold(2, size=chunk_size + self.padding, step=chunk_size)

            # as we manually disable padding in the convolution below, we insert
            # right 0-padding to the chunks, e.g. reusing the above example:
            #
            # chunk #1:  00aaaa00
            # chunk #2:      aabbbb00
            # chunk #3:          bbcc0000

            # -> [batch_size, in_channels, num_chunks, lc+chunk_size+rpad]
            out = F.pad(out, (0, self.padding), value=0)

            # the transpose+flatten effectively flattens chunks into the batch
            # dimension to be processed into the time-wise convolution. the
            # chunks will later on be unflattened.

            # -> [batch_size, num_chunks, in_channels, lc+chunk_size+rpad]
            out = out.transpose(1, 2)

            # -> [batch_size * num_chunks, in_channels, lc+chunk_size+rpad]
            out = out.flatten(start_dim=0, end_dim=1)

            # TODO: experiment around reflect padding, which is difficult
            # because small chunks have too little time steps to reflect from

            # let's keep backwards compat by pointing at the weights from the
            # already declared Conv1d.
            #
            # still reusing the above example, the convolution will be applied,
            # with the padding truncated on both ends. the following example
            # shows the letter corresponding to the input frame on which the
            # convolution was centered.
            #
            # as you can see, the sum of lengths of all chunks is equal to our
            # input sequence length + `final_right_padding`.
            #
            # chunk #1:  aaaa
            # chunk #2:      bbbb
            # chunk #3:          cc00

            # -> [batch_size * num_chunks, out_channels, chunk_size]
            out = F.conv1d(
                out,
                weight=self.conv.weight,
                bias=self.conv.bias,
                stride=self.conv.stride,
                padding=0,
                dilation=self.conv.dilation,
                groups=self.conv.groups,
            )

            # -> [batch_size * num_chunks, chunk_size, out_channels]
            out = out.transpose(1, 2)

            out = self.after_conv(out)

            # -> [batch_size, num_chunks, chunk_size, out_channels]
            out = torch.unflatten(out, dim=0, sizes=(batch_size, -1))

            # -> [batch_size, t + final_right_padding, out_channels]
            out = torch.flatten(out, start_dim=1, end_dim=2)

            # -> [batch_size, t, out_channels]
            if final_right_padding > 0:
                out = out[:, :-final_right_padding, :]
        else:
            out = self.layer_norm(x)
            out = out.transpose(1, 2)
            out = self.bottleneck(out)
            out = self.conv(out)

            if self.causal:
                # chomp
                out = out[..., : -self.padding]

            out = out.transpose(1, 2)
            out = self.after_conv(out)

        if mask is not None:
            out.masked_fill_(mask, 0.0)

        return out


class ConformerEncoderLayer(nn.Module):
    """This is an implementation of Conformer encoder layer.

    Arguments
    ---------
    d_model : int
        The expected size of the input embedding.
    d_ffn : int
        Hidden size of self-attention Feed Forward layer.
    nhead : int
        Number of attention heads.
    kernel_size : int, optional
        Kernel size of convolution model.
    kdim : int, optional
        Dimension of the key.
    vdim : int, optional
        Dimension of the value.
    activation: torch.nn.Module
         Activation function used in each Conformer layer.
    bias : bool, optional
        Whether  convolution module.
    dropout : int, optional
        Dropout for the encoder.
    causal : bool, optional
        Whether the convolutions should be causal or not.
    attention_type : str, optional
        type of attention layer, e.g. regularMHA for regular MultiHeadAttention.

    Example
    -------
    >>> import torch
    >>> x = torch.rand((8, 60, 512))
    >>> pos_embs = torch.rand((1, 2*60-1, 512))
    >>> net = ConformerEncoderLayer(d_ffn=512, nhead=8, d_model=512, kernel_size=3)
    >>> output = net(x, pos_embs=pos_embs)
    >>> output[0].shape
    torch.Size([8, 60, 512])
    """

    def __init__(
        self,
        d_model,
        d_ffn,
        nhead,
        kernel_size=31,
        kdim=None,
        vdim=None,
        activation=Swish,
        bias=True,
        dropout=0.0,
        causal=False,
        attention_type="RelPosMHAXL",
    ):
        super().__init__()

        if attention_type == "regularMHA":
            self.mha_layer = MultiheadAttention(
                nhead=nhead,
                d_model=d_model,
                dropout=dropout,
                kdim=kdim,
                vdim=vdim,
            )
        elif attention_type == "RelPosMHAXL":
            # transformerXL style positional encoding
            self.mha_layer = RelPosMHAXL(
                num_heads=nhead,
                embed_dim=d_model,
                dropout=dropout,
                mask_pos_future=causal,
            )
        elif attention_type == "hypermixing":
            self.mha_layer = HyperMixing(
                input_output_dim=d_model,
                hypernet_size=d_ffn,
                tied=False,
                num_heads=nhead,
                fix_tm_hidden_size=False,
            )

        self.convolution_module = ConvolutionModule(
            d_model, kernel_size, bias, activation, dropout, causal=causal
        )

        self.ffn_module1 = nn.Sequential(
            nn.LayerNorm(d_model),
            PositionalwiseFeedForward(
                d_ffn=d_ffn,
                input_size=d_model,
                dropout=dropout,
                activation=activation,
            ),
            nn.Dropout(dropout),
        )

        self.ffn_module2 = nn.Sequential(
            nn.LayerNorm(d_model),
            PositionalwiseFeedForward(
                d_ffn=d_ffn,
                input_size=d_model,
                dropout=dropout,
                activation=activation,
            ),
            nn.Dropout(dropout),
        )

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(
        self,
        x,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        pos_embs: torch.Tensor = None,
        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
    ):
        """
        Arguments
        ----------
        src : torch.Tensor
            The sequence to the encoder layer.
        src_mask : torch.Tensor, optional
            The mask for the src sequence.
        src_key_padding_mask : torch.Tensor, optional
            The mask for the src keys per batch.
        pos_embs: torch.Tensor, torch.nn.Module, optional
            Module or tensor containing the input sequence positional embeddings
        dynchunktrain_config: Optional[DynChunkTrainConfig]
            Dynamic Chunk Training configuration object for streaming,
            specifically involved here to apply Dynamic Chunk Convolution to
            the convolution module.
        """
        conv_mask: Optional[torch.Tensor] = None
        if src_key_padding_mask is not None:
            conv_mask = src_key_padding_mask.unsqueeze(-1)
        # ffn module
        x = x + 0.5 * self.ffn_module1(x)
        # multi-head attention module
        skip = x
        x = self.norm1(x)

        x, self_attn = self.mha_layer(
            x,
            x,
            x,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask,
            pos_embs=pos_embs,
        )
        x = x + skip
        # convolution module
        x = x + self.convolution_module(
            x, conv_mask, dynchunktrain_config=dynchunktrain_config
        )
        # ffn module
        x = self.norm2(x + 0.5 * self.ffn_module2(x))
        return x, self_attn

    def forward_streaming(
        self,
        x,
        context: ConformerEncoderLayerStreamingContext,
        pos_embs: torch.Tensor = None,
    ):
        """Conformer layer streaming forward (typically for
        DynamicChunkTraining-trained models), which is to be used at inference
        time. Relies on a mutable context object as initialized by
        `make_streaming_context` that should be used across chunks.
        Invoked by `ConformerEncoder.forward_streaming`.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor for this layer. Batching is supported as long as you
            keep the context consistent.
        context : ConformerEncoderStreamingContext
            Mutable streaming context; the same object should be passed across
            calls.
        pos_embs : torch.Tensor, optional
            Positional embeddings, if used.

        Returns
        -------
        x : torch.Tensor
            Output tensor.
        self_attn : list
            List of self attention values.
        """

        orig_len = x.shape[-2]
        # ffn module
        x = x + 0.5 * self.ffn_module1(x)

        # TODO: make the approach for MHA left context more efficient.
        # currently, this saves the inputs to the MHA.
        # the naive approach is suboptimal in a few ways, namely that the
        # outputs for this left padding is being re-computed even though we
        # discard them immediately after.

        # left pad `x` with our MHA left context
        if context.mha_left_context is not None:
            x = torch.cat((context.mha_left_context, x), dim=1)

        # compute new MHA left context for the next call to our function
        if context.mha_left_context_size > 0:
            context.mha_left_context = x[
                ..., -context.mha_left_context_size :, :
            ]

        # multi-head attention module
        skip = x
        x = self.norm1(x)

        x, self_attn = self.mha_layer(
            x,
            x,
            x,
            attn_mask=None,
            key_padding_mask=None,
            pos_embs=pos_embs,
        )
        x = x + skip

        # truncate outputs corresponding to the MHA left context (we only care
        # about our chunk's outputs); see above to-do
        x = x[..., -orig_len:, :]

        if context.dcconv_left_context is not None:
            x = torch.cat((context.dcconv_left_context, x), dim=1)

        # compute new DCConv left context for the next call to our function
        context.dcconv_left_context = x[
            ..., -self.convolution_module.padding :, :
        ]

        # convolution module
        x = x + self.convolution_module(x)

        # truncate outputs corresponding to the DCConv left context
        x = x[..., -orig_len:, :]

        # ffn module
        x = self.norm2(x + 0.5 * self.ffn_module2(x))
        return x, self_attn

    def make_streaming_context(self, mha_left_context_size: int):
        """Creates a blank streaming context for this encoding layer.

        Arguments
        ---------
        mha_left_context_size : int
            How many left frames should be saved and used as left context to the
            current chunk when streaming

        Returns
        -------
        ConformerEncoderLayerStreamingContext
        """
        return ConformerEncoderLayerStreamingContext(
            mha_left_context_size=mha_left_context_size
        )


class ConformerEncoder(nn.Module):
    """This class implements the Conformer encoder.

    Arguments
    ---------
    num_layers : int
        Number of layers.
    d_model : int
        Embedding dimension size.
    d_ffn : int
        Hidden size of self-attention Feed Forward layer.
    nhead : int
        Number of attention heads.
    kernel_size : int, optional
        Kernel size of convolution model.
    kdim : int, optional
        Dimension of the key.
    vdim : int, optional
        Dimension of the value.
    activation: torch.nn.Module
         Activation function used in each Confomer layer.
    bias : bool, optional
        Whether  convolution module.
    dropout : int, optional
        Dropout for the encoder.
    causal: bool, optional
        Whether the convolutions should be causal or not.
    attention_type: str, optional
        type of attention layer, e.g. regularMHA for regular MultiHeadAttention.


    Example
    -------
    >>> import torch
    >>> x = torch.rand((8, 60, 512))
    >>> pos_emb = torch.rand((1, 2*60-1, 512))
    >>> net = ConformerEncoder(1, 512, 512, 8)
    >>> output, _ = net(x, pos_embs=pos_emb)
    >>> output.shape
    torch.Size([8, 60, 512])
    """

    def __init__(
        self,
        num_layers,
        d_model,
        d_ffn,
        nhead,
        kernel_size=31,
        kdim=None,
        vdim=None,
        activation=Swish,
        bias=True,
        dropout=0.0,
        causal=False,
        attention_type="RelPosMHAXL",
    ):
        super().__init__()

        self.layers = torch.nn.ModuleList(
            [
                ConformerEncoderLayer(
                    d_ffn=d_ffn,
                    nhead=nhead,
                    d_model=d_model,
                    kdim=kdim,
                    vdim=vdim,
                    dropout=dropout,
                    activation=activation,
                    kernel_size=kernel_size,
                    bias=bias,
                    causal=causal,
                    attention_type=attention_type,
                )
                for i in range(num_layers)
            ]
        )
        self.norm = LayerNorm(d_model, eps=1e-6)
        self.attention_type = attention_type

    def forward(
        self,
        src,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        pos_embs: Optional[torch.Tensor] = None,
        dynchunktrain_config: Optional[DynChunkTrainConfig] = None,
    ):
        """
        Arguments
        ----------
        src : torch.Tensor
            The sequence to the encoder layer.
        src_mask : torch.Tensor, optional
            The mask for the src sequence.
        src_key_padding_mask : torch.Tensor, optional
            The mask for the src keys per batch.
        pos_embs: torch.Tensor, torch.nn.Module,
            Module or tensor containing the input sequence positional embeddings
            If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
            where S is the sequence length, and E is the embedding dimension.
        dynchunktrain_config: Optional[DynChunkTrainConfig]
            Dynamic Chunk Training configuration object for streaming,
            specifically involved here to apply Dynamic Chunk Convolution to the
            convolution module.
        """
        if self.attention_type == "RelPosMHAXL":
            if pos_embs is None:
                raise ValueError(
                    "The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
                )

        output = src
        attention_lst = []
        for enc_layer in self.layers:
            output, attention = enc_layer(
                output,
                src_mask=src_mask,
                src_key_padding_mask=src_key_padding_mask,
                pos_embs=pos_embs,
                dynchunktrain_config=dynchunktrain_config,
            )
            attention_lst.append(attention)
        output = self.norm(output)

        return output, attention_lst

    def forward_streaming(
        self,
        src: torch.Tensor,
        context: ConformerEncoderStreamingContext,
        pos_embs: Optional[torch.Tensor] = None,
    ):
        """Conformer streaming forward (typically for
        DynamicChunkTraining-trained models), which is to be used at inference
        time. Relies on a mutable context object as initialized by
        `make_streaming_context` that should be used across chunks.

        Arguments
        ---------
        src : torch.Tensor
            Input tensor. Batching is supported as long as you keep the context
            consistent.
        context : ConformerEncoderStreamingContext
            Mutable streaming context; the same object should be passed across
            calls.
        pos_embs : torch.Tensor, optional
            Positional embeddings, if used.

        Returns
        -------
        output : torch.Tensor
            The output of the streaming conformer.
        attention_lst : list
            The attention values.
        """

        if self.attention_type == "RelPosMHAXL":
            if pos_embs is None:
                raise ValueError(
                    "The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
                )

        output = src
        attention_lst = []
        for i, enc_layer in enumerate(self.layers):
            output, attention = enc_layer.forward_streaming(
                output, pos_embs=pos_embs, context=context.layers[i]
            )
            attention_lst.append(attention)
        output = self.norm(output)

        return output, attention_lst

    def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig):
        """Creates a blank streaming context for the encoder.

        Arguments
        ---------
        dynchunktrain_config: Optional[DynChunkTrainConfig]
            Dynamic Chunk Training configuration object for streaming

        Returns
        -------
        ConformerEncoderStreamingContext
        """
        return ConformerEncoderStreamingContext(
            dynchunktrain_config=dynchunktrain_config,
            layers=[
                layer.make_streaming_context(
                    mha_left_context_size=dynchunktrain_config.left_context_size_frames()
                )
                for layer in self.layers
            ],
        )


class ConformerDecoderLayer(nn.Module):
    """This is an implementation of Conformer encoder layer.

    Arguments
    ---------
    d_model : int
        The expected size of the input embedding.
    d_ffn : int
        Hidden size of self-attention Feed Forward layer.
    nhead : int
        Number of attention heads.
    kernel_size : int, optional
        Kernel size of convolution model.
    kdim : int, optional
        Dimension of the key.
    vdim : int, optional
        Dimension of the value.
    activation : torch.nn.Module, optional
         Activation function used in each Conformer layer.
    bias : bool, optional
        Whether  convolution module.
    dropout : int, optional
        Dropout for the encoder.
    causal : bool, optional
        Whether the convolutions should be causal or not.
    attention_type : str, optional
        type of attention layer, e.g. regularMHA for regular MultiHeadAttention.

    Example
    -------
    >>> import torch
    >>> x = torch.rand((8, 60, 512))
    >>> pos_embs = torch.rand((1, 2*60-1, 512))
    >>> net = ConformerEncoderLayer(d_ffn=512, nhead=8, d_model=512, kernel_size=3)
    >>> output = net(x, pos_embs=pos_embs)
    >>> output[0].shape
    torch.Size([8, 60, 512])
    """

    def __init__(
        self,
        d_model,
        d_ffn,
        nhead,
        kernel_size,
        kdim=None,
        vdim=None,
        activation=Swish,
        bias=True,
        dropout=0.0,
        causal=True,
        attention_type="RelPosMHAXL",
    ):
        super().__init__()

        if not causal:
            warnings.warn(
                "Decoder is not causal, in most applications it should be causal, you have been warned !"
            )

        if attention_type == "regularMHA":
            self.mha_layer = MultiheadAttention(
                nhead=nhead,
                d_model=d_model,
                dropout=dropout,
                kdim=kdim,
                vdim=vdim,
            )
        elif attention_type == "RelPosMHAXL":
            # transformerXL style positional encoding
            self.mha_layer = RelPosMHAXL(
                num_heads=nhead,
                embed_dim=d_model,
                dropout=dropout,
                mask_pos_future=causal,
            )

        self.convolution_module = ConvolutionModule(
            d_model, kernel_size, bias, activation, dropout, causal=causal
        )

        self.ffn_module1 = nn.Sequential(
            nn.LayerNorm(d_model),
            PositionalwiseFeedForward(
                d_ffn=d_ffn,
                input_size=d_model,
                dropout=dropout,
                activation=activation,
            ),
            nn.Dropout(dropout),
        )

        self.ffn_module2 = nn.Sequential(
            nn.LayerNorm(d_model),
            PositionalwiseFeedForward(
                d_ffn=d_ffn,
                input_size=d_model,
                dropout=dropout,
                activation=activation,
            ),
            nn.Dropout(dropout),
        )

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        tgt_key_padding_mask=None,
        memory_key_padding_mask=None,
        pos_embs_tgt=None,
        pos_embs_src=None,
    ):
        """
        Arguments
        ---------
        tgt: torch.Tensor
            The sequence to the decoder layer.
        memory: torch.Tensor
            The sequence from the last layer of the encoder.
        tgt_mask: torch.Tensor, optional, optional
            The mask for the tgt sequence.
        memory_mask: torch.Tensor, optional
            The mask for the memory sequence.
        tgt_key_padding_mask: torch.Tensor, optional
            The mask for the tgt keys per batch.
        memory_key_padding_mask: torch.Tensor, optional
            The mask for the memory keys per batch.
        pos_embs_tgt: torch.Tensor, torch.nn.Module, optional
            Module or tensor containing the target sequence positional embeddings for each attention layer.
        pos_embs_src: torch.Tensor, torch.nn.Module, optional
            Module or tensor containing the source sequence positional embeddings for each attention layer.

        Returns
        -------
        x: torch.Tensor
            The output tensor
        self_attn : torch.Tensor
        self_attn : torch.Tensor
            The self attention tensor
        """
        # ffn module
        tgt = tgt + 0.5 * self.ffn_module1(tgt)
        # multi-head attention module
        skip = tgt
        x = self.norm1(tgt)
        x, self_attn = self.mha_layer(
            x,
            memory,
            memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
            pos_embs=pos_embs_src,
        )
        x = x + skip
        # convolution module
        x = x + self.convolution_module(x)
        # ffn module
        x = self.norm2(x + 0.5 * self.ffn_module2(x))
        return x, self_attn, self_attn


class ConformerDecoder(nn.Module):
    """This class implements the Transformer decoder.

    Arguments
    ---------
    num_layers: int
        Number of layers.
    nhead: int
        Number of attention heads.
    d_ffn: int
        Hidden size of self-attention Feed Forward layer.
    d_model: int
        Embedding dimension size.
    kdim: int, optional
        Dimension for key.
    vdim: int, optional
        Dimension for value.
    dropout: float, optional
        Dropout rate.
    activation: torch.nn.Module, optional
        Activation function used after non-bottleneck conv layer.
    kernel_size : int, optional
        Kernel size of convolutional layer.
    bias : bool, optional
        Whether  convolution module.
    causal: bool, optional
        Whether the convolutions should be causal or not.
    attention_type: str, optional
        type of attention layer, e.g. regularMHA for regular MultiHeadAttention.


    Example
    -------
    >>> src = torch.rand((8, 60, 512))
    >>> tgt = torch.rand((8, 60, 512))
    >>> net = ConformerDecoder(1, 8, 1024, 512, attention_type="regularMHA")
    >>> output, _, _ = net(tgt, src)
    >>> output.shape
    torch.Size([8, 60, 512])
    """

    def __init__(
        self,
        num_layers,
        nhead,
        d_ffn,
        d_model,
        kdim=None,
        vdim=None,
        dropout=0.0,
        activation=Swish,
        kernel_size=3,
        bias=True,
        causal=True,
        attention_type="RelPosMHAXL",
    ):
        super().__init__()
        self.layers = torch.nn.ModuleList(
            [
                ConformerDecoderLayer(
                    d_ffn=d_ffn,
                    nhead=nhead,
                    d_model=d_model,
                    kdim=kdim,
                    vdim=vdim,
                    dropout=dropout,
                    activation=activation,
                    kernel_size=kernel_size,
                    bias=bias,
                    causal=causal,
                    attention_type=attention_type,
                )
                for _ in range(num_layers)
            ]
        )
        self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)

    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        tgt_key_padding_mask=None,
        memory_key_padding_mask=None,
        pos_embs_tgt=None,
        pos_embs_src=None,
    ):
        """
        Arguments
        ---------
        tgt: torch.Tensor
            The sequence to the decoder layer.
        memory: torch.Tensor
            The sequence from the last layer of the encoder.
        tgt_mask: torch.Tensor, optional, optional
            The mask for the tgt sequence.
        memory_mask: torch.Tensor, optional
            The mask for the memory sequence.
        tgt_key_padding_mask : torch.Tensor, optional
            The mask for the tgt keys per batch.
        memory_key_padding_mask : torch.Tensor, optional
            The mask for the memory keys per batch.
        pos_embs_tgt: torch.Tensor, torch.nn.Module, optional
            Module or tensor containing the target sequence positional embeddings for each attention layer.
        pos_embs_src: torch.Tensor, torch.nn.Module, optional
            Module or tensor containing the source sequence positional embeddings for each attention layer.

        Returns
        -------
        output: torch.Tensor
            Conformer decoder output.
        self_attns : list
            Location of self attentions.
        multihead_attns : list
            Location of multihead attentions.
        """
        output = tgt
        self_attns, multihead_attns = [], []
        for dec_layer in self.layers:
            output, self_attn, multihead_attn = dec_layer(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                pos_embs_tgt=pos_embs_tgt,
                pos_embs_src=pos_embs_src,
            )
            self_attns.append(self_attn)
            multihead_attns.append(multihead_attn)
        output = self.norm(output)

        return output, self_attns, multihead_attns