File size: 34,715 Bytes
1d4c9c3
 
 
 
 
 
 
 
 
da78a0e
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da78a0e
1d4c9c3
 
 
da78a0e
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da78a0e
 
 
 
 
 
 
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da78a0e
1d4c9c3
da78a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d4c9c3
 
da78a0e
 
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da78a0e
1d4c9c3
 
 
 
 
 
 
 
da78a0e
1d4c9c3
 
da78a0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da78a0e
1d4c9c3
da78a0e
1d4c9c3
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
import math
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchaudio

from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT


MODEL_FILE = "model.pt"


norm_layer_dict = {
    "batch_norm_2d": torch.nn.BatchNorm2d
}


activation_layer_dict = {
    "relu": torch.nn.ReLU,
    "identity": torch.nn.Identity,
    "sigmoid": torch.nn.Sigmoid,
}


class CausalConv2d(nn.Sequential):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Iterable[int]],
                 fstride: int = 1,
                 dilation: int = 1,
                 fpad: bool = True,
                 bias: bool = True,
                 separable: bool = False,
                 norm_layer: str = "batch_norm_2d",
                 activation_layer: str = "relu",
                 lookahead: int = 0
                 ):
        """
        Causal Conv2d by delaying the signal for any lookahead.

        Expected input format: [batch_size, channels, time_steps, spec_dim]

        :param in_channels:
        :param out_channels:
        :param kernel_size:
        :param fstride:
        :param dilation:
        :param fpad:
        """
        super(CausalConv2d, self).__init__()
        kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)

        if fpad:
            fpad_ = kernel_size[1] // 2 + dilation - 1
        else:
            fpad_ = 0

        # for last 2 dim, pad (left, right, top, bottom).
        pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)

        layers = list()
        if any(x > 0 for x in pad):
            layers.append(nn.ConstantPad2d(pad, 0.0))

        groups = math.gcd(in_channels, out_channels) if separable else 1
        if groups == 1:
            separable = False
        if max(kernel_size) == 1:
            separable = False

        layers.append(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=(0, fpad_),
                stride=(1, fstride),  # stride over time is always 1
                dilation=(1, dilation),  # dilation over time is always 1
                groups=groups,
                bias=bias,
            )
        )

        if separable:
            layers.append(
                nn.Conv2d(
                    out_channels,
                    out_channels,
                    kernel_size=1,
                    bias=False,
                )
            )

        if norm_layer is not None:
            norm_layer = norm_layer_dict[norm_layer]
            layers.append(norm_layer(out_channels))

        if activation_layer is not None:
            activation_layer = activation_layer_dict[activation_layer]
            layers.append(activation_layer())

        super().__init__(*layers)

    def forward(self, inputs):
        for module in self:
            inputs = module(inputs)
        return inputs


class CausalConvTranspose2d(nn.Sequential):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Iterable[int]],
                 fstride: int = 1,
                 dilation: int = 1,
                 fpad: bool = True,
                 bias: bool = True,
                 separable: bool = False,
                 norm_layer: str = "batch_norm_2d",
                 activation_layer: str = "relu",
                 lookahead: int = 0
                 ):
        """
        Causal ConvTranspose2d.

        Expected input format: [batch_size, channels, time_steps, spec_dim]
        """
        super(CausalConvTranspose2d, self).__init__()

        kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size

        if fpad:
            fpad_ = kernel_size[1] // 2
        else:
            fpad_ = 0

        # for last 2 dim, pad (left, right, top, bottom).
        pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)

        layers = []
        if any(x > 0 for x in pad):
            layers.append(nn.ConstantPad2d(pad, 0.0))

        groups = math.gcd(in_channels, out_channels) if separable else 1
        if groups == 1:
            separable = False

        layers.append(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
                output_padding=(0, fpad_),
                stride=(1, fstride),  # stride over time is always 1
                dilation=(1, dilation),  # dilation over time is always 1
                groups=groups,
                bias=bias,
            )
        )

        if separable:
            layers.append(
                nn.Conv2d(
                    out_channels,
                    out_channels,
                    kernel_size=1,
                    bias=False,
                )
            )

        if norm_layer is not None:
            norm_layer = norm_layer_dict[norm_layer]
            layers.append(norm_layer(out_channels))

        if activation_layer is not None:
            activation_layer = activation_layer_dict[activation_layer]
            layers.append(activation_layer())

        super().__init__(*layers)


class GroupedLinear(nn.Module):

    def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
        super().__init__()
        # self.weight: Tensor
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.groups = groups
        assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
        assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
        self.ws = input_size // groups
        self.register_parameter(
            "weight",
            torch.nn.Parameter(
                torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
            ),
        )
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))  # type: ignore

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [..., I]
        b, t, _ = x.shape
        # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
        new_shape = (b, t, self.groups, self.ws)
        x = x.view(new_shape)
        # The better way, but not supported by torchscript
        # x = x.unflatten(-1, (self.groups, self.ws))  # [..., G, I/G]
        x = torch.einsum("btgi,gih->btgh", x, self.weight)  # [..., G, H/G]
        x = x.flatten(2, 3)  # [B, T, H]
        return x

    def __repr__(self):
        cls = self.__class__.__name__
        return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"


class SqueezedGRU_S(nn.Module):
    """
    SGE net: Video object detection with squeezed GRU and information entropy map
    https://arxiv.org/abs/2106.07224
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: Optional[int] = None,
        num_layers: int = 1,
        linear_groups: int = 8,
        batch_first: bool = True,
        skip_op: str = "none",
        activation_layer: str = "identity",
    ):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.linear_in = nn.Sequential(
            GroupedLinear(
                input_size=input_size,
                hidden_size=hidden_size,
                groups=linear_groups,
            ),
            activation_layer_dict[activation_layer](),
        )

        # gru skip operator
        self.gru_skip_op = None

        if skip_op == "none":
            self.gru_skip_op = None
        elif skip_op == "identity":
            if not input_size != output_size:
                raise AssertionError("Dimensions do not match")
            self.gru_skip_op = nn.Identity()
        elif skip_op == "grouped_linear":
            self.gru_skip_op = GroupedLinear(
                input_size=hidden_size,
                hidden_size=hidden_size,
                groups=linear_groups,
            )
        else:
            raise NotImplementedError()

        self.gru = nn.GRU(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=batch_first,
            bidirectional=False,
        )

        if output_size is not None:
            self.linear_out = nn.Sequential(
                GroupedLinear(
                    input_size=hidden_size,
                    hidden_size=output_size,
                    groups=linear_groups,
                ),
                activation_layer_dict[activation_layer](),
            )
        else:
            self.linear_out = nn.Identity()

    def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.linear_in(inputs)

        x, h = self.gru.forward(x, h)

        x = self.linear_out(x)

        if self.gru_skip_op is not None:
            x = x + self.gru_skip_op(inputs)

        return x, h


class Add(nn.Module):
    def forward(self, a, b):
        return a + b


class Concat(nn.Module):
    def forward(self, a, b):
        return torch.cat((a, b), dim=-1)


class Encoder(nn.Module):
    def __init__(self, config: DfNetConfig):
        super(Encoder, self).__init__()
        self.embedding_input_size = config.conv_channels * config.spec_bins // 4
        self.embedding_output_size = config.conv_channels * config.spec_bins // 4
        self.embedding_hidden_size = config.embedding_hidden_size

        self.spec_conv0 = CausalConv2d(
            in_channels=1,
            out_channels=config.conv_channels,
            kernel_size=config.conv_kernel_size_input,
            bias=False,
            separable=True,
            fstride=1,
            lookahead=config.conv_lookahead,
        )
        self.spec_conv1 = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=config.conv_kernel_size_inner,
            bias=False,
            separable=True,
            fstride=2,
            lookahead=config.conv_lookahead,
        )
        self.spec_conv2 = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=config.conv_kernel_size_inner,
            bias=False,
            separable=True,
            fstride=2,
            lookahead=config.conv_lookahead,
        )
        self.spec_conv3 = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=config.conv_kernel_size_inner,
            bias=False,
            separable=True,
            fstride=1,
            lookahead=config.conv_lookahead,
        )

        self.df_conv0 = CausalConv2d(
            in_channels=2,
            out_channels=config.conv_channels,
            kernel_size=config.conv_kernel_size_input,
            bias=False,
            separable=True,
            fstride=1,
        )
        self.df_conv1 = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=config.conv_kernel_size_inner,
            bias=False,
            separable=True,
            fstride=2,
        )
        self.df_fc_emb = nn.Sequential(
            GroupedLinear(
                config.conv_channels * config.df_bins // 2,
                self.embedding_input_size,
                groups=config.encoder_linear_groups
            ),
            nn.ReLU(inplace=True)
        )

        if config.encoder_combine_op == "concat":
            self.embedding_input_size *= 2
            self.combine = Concat()
        else:
            self.combine = Add()

        # emb_gru
        if config.spec_bins % 8 != 0:
            raise AssertionError("spec_bins should be divisible by 8")

        self.emb_gru = SqueezedGRU_S(
            self.embedding_input_size,
            self.embedding_hidden_size,
            output_size=self.embedding_output_size,
            num_layers=1,
            batch_first=True,
            skip_op=config.encoder_emb_skip_op,
            linear_groups=config.encoder_emb_linear_groups,
            activation_layer="relu",
        )

        # lsnr
        self.lsnr_fc = nn.Sequential(
            nn.Linear(self.embedding_output_size, 1),
            nn.Sigmoid()
        )
        self.lsnr_scale = config.lsnr_max - config.lsnr_min
        self.lsnr_offset = config.lsnr_min

    def forward(self,
                feat_power: torch.Tensor,
                feat_spec: torch.Tensor,
                hidden_state: torch.Tensor = None,
                ):
        # feat_power shape: (batch_size, 1, time_steps, spec_dim)
        e0 = self.spec_conv0.forward(feat_power)
        e1 = self.spec_conv1.forward(e0)
        e2 = self.spec_conv2.forward(e1)
        e3 = self.spec_conv3.forward(e2)
        # e0 shape: [batch_size, channels, time_steps, spec_dim]
        # e1 shape: [batch_size, channels, time_steps, spec_dim // 2]
        # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
        # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]

        # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
        c0 = self.df_conv0(feat_spec)
        c1 = self.df_conv1(c0)
        # c0 shape: [batch_size, channels, time_steps, df_bins]
        # c1 shape: [batch_size, channels, time_steps, df_bins // 2]

        cemb = c1.permute(0, 2, 3, 1)
        # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
        cemb = cemb.flatten(2)
        # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
        cemb = self.df_fc_emb(cemb)
        # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]

        # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
        emb = e3.permute(0, 2, 3, 1)
        # emb shape: [batch_size, time_steps, spec_dim // 4, channels]
        emb = emb.flatten(2)
        # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]

        emb = self.combine(emb, cemb)
        # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
        # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]

        emb, h = self.emb_gru.forward(emb, hidden_state)
        # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
        # h shape: [batch_size, 1, spec_dim]

        lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
        # lsnr shape: [batch_size, time_steps, 1]

        return e0, e1, e2, e3, emb, c0, lsnr, h


class Decoder(nn.Module):
    def __init__(self, config: DfNetConfig):
        super(Decoder, self).__init__()

        if config.spec_bins % 8 != 0:
            raise AssertionError("spec_bins should be divisible by 8")

        self.emb_in_dim = config.conv_channels * config.spec_bins // 4
        self.emb_out_dim = config.conv_channels * config.spec_bins // 4
        self.emb_hidden_dim = config.decoder_emb_hidden_size

        self.emb_gru = SqueezedGRU_S(
            self.emb_in_dim,
            self.emb_hidden_dim,
            output_size=self.emb_out_dim,
            num_layers=config.decoder_emb_num_layers - 1,
            batch_first=True,
            skip_op=config.decoder_emb_skip_op,
            linear_groups=config.decoder_emb_linear_groups,
            activation_layer="relu",
        )
        self.conv3p = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=1,
            bias=False,
            separable=True,
            fstride=1,
            lookahead=config.conv_lookahead,
        )
        self.convt3 = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=config.conv_kernel_size_inner,
            bias=False,
            separable=True,
            fstride=1,
            lookahead=config.conv_lookahead,
        )
        self.conv2p = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=1,
            bias=False,
            separable=True,
            fstride=1,
            lookahead=config.conv_lookahead,
        )
        self.convt2 = CausalConvTranspose2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=config.convt_kernel_size_inner,
            bias=False,
            separable=True,
            fstride=2,
            lookahead=config.conv_lookahead,
        )
        self.conv1p = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=1,
            bias=False,
            separable=True,
            fstride=1,
            lookahead=config.conv_lookahead,
        )
        self.convt1 = CausalConvTranspose2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=config.convt_kernel_size_inner,
            bias=False,
            separable=True,
            fstride=2,
            lookahead=config.conv_lookahead,
        )
        self.conv0p = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=config.conv_channels,
            kernel_size=1,
            bias=False,
            separable=True,
            fstride=1,
            lookahead=config.conv_lookahead,
        )
        self.conv0_out = CausalConv2d(
            in_channels=config.conv_channels,
            out_channels=1,
            kernel_size=config.conv_kernel_size_inner,
            activation_layer="sigmoid",
            bias=False,
            separable=True,
            fstride=1,
            lookahead=config.conv_lookahead,
        )

    def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
        # Estimates erb mask
        b, _, t, f8 = e3.shape

        # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
        emb, _ = self.emb_gru(emb)
        # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
        emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
        e3 = self.convt3(self.conv3p(e3) + emb)
        # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
        e2 = self.convt2(self.conv2p(e2) + e3)
        # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
        e1 = self.convt1(self.conv1p(e1) + e2)
        # e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
        mask = self.conv0_out(self.conv0p(e0) + e1)
        # mask shape: [batch_size, 1, time_steps, freq_dim]
        return mask


class DfDecoder(nn.Module):
    def __init__(self, config: DfNetConfig):
        super(DfDecoder, self).__init__()

        self.embedding_input_size = config.conv_channels * config.spec_bins // 4
        self.df_decoder_hidden_size = config.df_decoder_hidden_size
        self.df_num_layers = config.df_num_layers

        self.df_order = config.df_order

        self.df_bins = config.df_bins
        self.df_out_ch = config.df_order * 2

        self.df_convp = CausalConv2d(
            config.conv_channels,
            self.df_out_ch,
            fstride=1,
            kernel_size=(config.df_pathway_kernel_size_t, 1),
            separable=True,
            bias=False,
        )
        self.df_gru = SqueezedGRU_S(
            self.embedding_input_size,
            self.df_decoder_hidden_size,
            num_layers=self.df_num_layers,
            batch_first=True,
            skip_op="none",
            activation_layer="relu",
        )

        if config.df_gru_skip == "none":
            self.df_skip = None
        elif config.df_gru_skip == "identity":
            if config.embedding_hidden_size != config.df_decoder_hidden_size:
                raise AssertionError("Dimensions do not match")
            self.df_skip = nn.Identity()
        elif config.df_gru_skip == "grouped_linear":
            self.df_skip = GroupedLinear(
                self.embedding_input_size,
                self.df_decoder_hidden_size,
                groups=config.df_decoder_linear_groups
            )
        else:
            raise NotImplementedError()

        self.df_out: nn.Module
        out_dim = self.df_bins * self.df_out_ch

        self.df_out = nn.Sequential(
            GroupedLinear(
                input_size=self.df_decoder_hidden_size,
                hidden_size=out_dim,
                groups=config.df_decoder_linear_groups
            ),
            nn.Tanh()
        )
        self.df_fc_a = nn.Sequential(
            nn.Linear(self.df_decoder_hidden_size, 1),
            nn.Sigmoid()
        )

    def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
        # emb shape: [batch_size, time_steps, df_bins // 4 * channels]
        b, t, _ = emb.shape
        df_coefs, _ = self.df_gru(emb)
        if self.df_skip is not None:
            df_coefs = df_coefs + self.df_skip(emb)
        # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]

        # c0 shape: [batch_size, channels, time_steps, df_bins]
        c0 = self.df_convp(c0)
        # c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
        c0 = c0.permute(0, 2, 3, 1)
        # c0 shape: [batch_size, time_steps, df_bins, df_order * 2]

        df_coefs = self.df_out(df_coefs)  # [B, T, F*O*2], O: df_order
        # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
        df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
        # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
        df_coefs = df_coefs + c0
        # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
        return df_coefs


class DfOutputReshapeMF(nn.Module):
    """Coefficients output reshape for multiframe/MultiFrameModule

    Requires input of shape B, C, T, F, 2.
    """

    def __init__(self, df_order: int, df_bins: int):
        super().__init__()
        self.df_order = df_order
        self.df_bins = df_bins

    def forward(self, coefs: torch.Tensor) -> torch.Tensor:
        # [B, T, F, O*2] -> [B, O, T, F, 2]
        new_shape = list(coefs.shape)
        new_shape[-1] = -1
        new_shape.append(2)
        coefs = coefs.view(new_shape)
        coefs = coefs.permute(0, 3, 1, 2, 4)
        return coefs


class Mask(nn.Module):
    def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
        super().__init__()
        self.use_post_filter = use_post_filter
        self.eps = eps

    def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
        """
        Post-Filter

        A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
        https://arxiv.org/abs/2008.04259

        :param mask: Real valued mask, typically of shape [B, C, T, F].
        :param beta: Global gain factor.
        :return:
        """
        mask_sin = mask * torch.sin(np.pi * mask / 2)
        mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
        return mask_pf

    def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # spec shape: [batch_size, 1, time_steps, spec_bins, 2]

        if not self.training and self.use_post_filter:
            mask = self.post_filter(mask)

        # mask shape: [batch_size, 1, time_steps, spec_bins]
        mask = mask.unsqueeze(4)
        # mask shape: [batch_size, 1, time_steps, spec_bins, 1]
        return spec * mask


class DeepFiltering(nn.Module):
    def __init__(self,
                 df_bins: int,
                 df_order: int,
                 lookahead: int = 0,
                 ):
        super(DeepFiltering, self).__init__()
        self.df_bins = df_bins
        self.df_order = df_order
        self.need_unfold = df_order > 1
        self.lookahead = lookahead

        self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)

    def spec_unfold(self, spec: torch.Tensor):
        """
        Pads and unfolds the spectrogram according to frame_size.
        :param spec: complex Tensor, Spectrogram of shape [B, C, T, F].
        :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
        """
        if self.need_unfold:
            # spec shape: [batch_size, spec_bins, time_steps]
            spec_pad = self.pad(spec)
            # spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins]
            spec_unfold = spec_pad.unfold(2, self.df_order, 1)
            # spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order]
            return spec_unfold
        else:
            return spec.unsqueeze(-1)

    def forward(self,
                spec: torch.Tensor,
                coefs: torch.Tensor,
                ):
        # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
        spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous()))
        # spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order]

        # coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
        coefs = torch.view_as_complex(coefs.contiguous())
        # coefs shape: [batch_size, df_order, time_steps, df_bins]
        spec_f = spec_u.narrow(-2, 0, self.df_bins)
        # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]

        coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:])
        # coefs shape: [batch_size, 1, df_order, time_steps, df_bins]

        spec_f = self.df(spec_f, coefs)
        # spec_f shape: [batch_size, 1, time_steps, df_bins]

        if self.training:
            spec = spec.clone()
        spec[..., :self.df_bins, :] = torch.view_as_real(spec_f)
        # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
        return spec

    @staticmethod
    def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
        """
        Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
        :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
        :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
        :return: (complex Tensor). Spectrogram of shape [B, C, T, F].
        """
        return torch.einsum("...tfn,...ntf->...tf", spec, coefs)


class DfNet(nn.Module):
    def __init__(self, config: DfNetConfig):
        super(DfNet, self).__init__()
        self.config = config

        self.freq_bins = self.config.nfft // 2 + 1

        self.nfft = config.nfft
        self.win_size = config.win_size
        self.hop_size = config.hop_size
        self.win_type = config.win_type

        self.stft = ConvSTFT(
            nfft=config.nfft,
            win_size=config.win_size,
            hop_size=config.hop_size,
            win_type=config.win_type,
            feature_type="complex",
            requires_grad=False
        )
        self.istft = ConviSTFT(
            nfft=config.nfft,
            win_size=config.win_size,
            hop_size=config.hop_size,
            win_type=config.win_type,
            feature_type="complex",
            requires_grad=False
        )

        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

        self.df_decoder = DfDecoder(config)
        self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
        self.df_op = DeepFiltering(
            df_bins=config.df_bins,
            df_order=config.df_order,
            lookahead=config.df_lookahead,
        )

        self.mask = Mask(use_post_filter=config.use_post_filter)

    def forward(self,
                noisy: torch.Tensor,
                ):
        if noisy.dim() == 2:
            noisy = torch.unsqueeze(noisy, dim=1)
        _, _, n_samples = noisy.shape
        remainder = (n_samples - self.win_size) % self.hop_size
        if remainder > 0:
            n_samples_pad = self.hop_size - remainder
            noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)

        # [batch_size, freq_bins * 2, time_steps]
        cmp_spec = self.stft.forward(noisy)
        # [batch_size, 1, freq_bins * 2, time_steps]
        cmp_spec = torch.unsqueeze(cmp_spec, 1)

        # [batch_size, 2, freq_bins, time_steps]
        cmp_spec = torch.cat([
            cmp_spec[:, :, :self.freq_bins, :],
            cmp_spec[:, :, self.freq_bins:, :],
        ], dim=1)
        # n//2+1 -> n//2; 257 -> 256
        cmp_spec = cmp_spec[:, :, :-1, :]

        spec = torch.unsqueeze(cmp_spec, dim=4)
        # [batch_size, 2, freq_bins, time_steps, 1]
        spec = spec.permute(0, 4, 3, 2, 1)
        # spec shape: [batch_size, 1, time_steps, freq_bins, 2]

        feat_power = torch.sum(torch.square(spec), dim=-1)
        # feat_power shape: [batch_size, 1, time_steps, spec_bins]

        feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
        # feat_spec shape: [batch_size, 2, time_steps, freq_bins]
        feat_spec = feat_spec[..., :self.df_decoder.df_bins]
        # feat_spec shape: [batch_size, 2, time_steps, df_bins]

        e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)

        mask = self.decoder.forward(emb, e3, e2, e1, e0)
        # mask shape: [batch_size, 1, time_steps, spec_bins]
        if torch.any(mask > 1) or torch.any(mask < 0):
            raise AssertionError

        spec_m = self.mask.forward(spec, mask)

        # lsnr shape: [batch_size, time_steps, 1]
        lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
        # lsnr shape: [batch_size, 1, time_steps]

        df_coefs = self.df_decoder.forward(emb, c0)
        df_coefs = self.df_out_transform(df_coefs)
        # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]

        spec_e = self.df_op.forward(spec.clone(), df_coefs)
        # est_spec shape: [batch_size, 1, time_steps, spec_bins, 2]

        spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]

        spec_e = torch.squeeze(spec_e, dim=1)
        spec_e = spec_e.permute(0, 2, 1, 3)
        # spec_e shape: [batch_size, spec_bins, time_steps, 2]

        mask = torch.squeeze(mask, dim=1)
        est_mask = mask.permute(0, 2, 1)
        # mask shape: [batch_size, spec_bins, time_steps]

        b, _, t, _ = spec_e.shape
        est_spec = torch.cat(tensors=[
            torch.concat(tensors=[
                spec_e[..., 0],
                torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
            ], dim=1),
            torch.concat(tensors=[
                spec_e[..., 1],
                torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
            ], dim=1),
        ], dim=1)
        # est_spec shape: [b, n+2, t]
        est_wav = self.istft.forward(est_spec)
        est_wav = torch.squeeze(est_wav, dim=1)
        est_wav = est_wav[:, :n_samples]
        # est_wav shape: [b, n_samples]
        return est_spec, est_wav, est_mask, lsnr

    def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
        """

        :param est_mask: torch.Tensor, shape: [b, n+2, t]
        :param clean:
        :param noisy:
        :return:
        """
        clean_stft = self.stft(clean)
        clean_re = clean_stft[:, :self.freq_bins, :]
        clean_im = clean_stft[:, self.freq_bins:, :]

        noisy_stft = self.stft(noisy)
        noisy_re = noisy_stft[:, :self.freq_bins, :]
        noisy_im = noisy_stft[:, self.freq_bins:, :]

        noisy_power = noisy_re ** 2 + noisy_im ** 2

        sr = clean_re
        yr = noisy_re
        si = clean_im
        yi = noisy_im
        y_pow = noisy_power
        # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
        gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
        # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
        gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)

        gth_mask_re[gth_mask_re > 2] = 1
        gth_mask_re[gth_mask_re < -2] = -1
        gth_mask_im[gth_mask_im > 2] = 1
        gth_mask_im[gth_mask_im < -2] = -1

        mask_re = est_mask[:, :self.freq_bins, :]
        mask_im = est_mask[:, self.freq_bins:, :]

        loss_re = F.mse_loss(gth_mask_re, mask_re)
        loss_im = F.mse_loss(gth_mask_im, mask_im)

        loss = loss_re + loss_im
        return loss


class DfNetPretrainedModel(DfNet):
    def __init__(self,
                 config: DfNetConfig,
                 ):
        super(DfNetPretrainedModel, self).__init__(
            config=config,
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        config = DfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        model = cls(config)

        if os.path.isdir(pretrained_model_name_or_path):
            ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
        else:
            ckpt_file = pretrained_model_name_or_path

        with open(ckpt_file, "rb") as f:
            state_dict = torch.load(f, map_location="cpu", weights_only=True)
        model.load_state_dict(state_dict, strict=True)
        return model

    def save_pretrained(self,
                        save_directory: Union[str, os.PathLike],
                        state_dict: Optional[dict] = None,
                        ):

        model = self

        if state_dict is None:
            state_dict = model.state_dict()

        os.makedirs(save_directory, exist_ok=True)

        # save state dict
        model_file = os.path.join(save_directory, MODEL_FILE)
        torch.save(state_dict, model_file)

        # save config
        config_file = os.path.join(save_directory, CONFIG_FILE)
        self.config.to_yaml_file(config_file)
        return save_directory


def main():

    config = DfNetConfig()
    model = DfNetPretrainedModel(config=config)

    noisy = torch.randn(size=(1, 16000), dtype=torch.float32)

    output = model.forward(noisy)
    print(output[1].shape)
    return


if __name__ == "__main__":
    main()