File size: 11,542 Bytes
1b032b9
 
 
 
cba47e4
 
 
 
1b032b9
cba47e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2323d2
 
cba47e4
 
 
 
1d4c9c3
cba47e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d4c9c3
cba47e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d4c9c3
 
cba47e4
1d4c9c3
 
cba47e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b032b9
 
 
cba47e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://arxiv.org/abs/2206.07293

https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/frcrn.py
https://huggingface.co./spaces/alibabasglab/ClearVoice/blob/main/models/frcrn_se/frcrn.py

"""
import os
from typing import Optional, Union

import torch
import torch.nn as nn
from torch.nn import functional as F

from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
from toolbox.torchaudio.models.frcrn.conv_stft import ConviSTFT, ConvSTFT
from toolbox.torchaudio.models.frcrn.unet import UNet


class FRCRN(nn.Module):
    """ Frequency Recurrent CRN """

    def __init__(self,
                 use_complex_networks: bool = True,
                 model_complexity: int = 45,
                 model_depth: int = 14,
                 padding_mode: str = "zeros",
                 nfft: int = 640,
                 win_size: int = 640,
                 hop_size: int = 320,
                 win_type: str = "hann",
                 ):
        """
        :param use_complex_networks: bool, Whether to use complex networks.
        :param model_complexity: int, define the model complexity with the number of layers
        :param model_depth: int, Only two options are available : 10, 20
        :param padding_mode: str, Encoder's convolution filter. 'zeros', 'reflect'
        :param nfft: int, number of Short Time Fourier Transform (STFT) points
        :param win_size: int, length of window used for defining one frame of sample points
        :param hop_size: int, length of window shifting (equivalent to hop_size)
        :param win_type: str, windowing type used in STFT, eg. 'hanning', 'hamming'
        """
        super().__init__()
        self.freq_bins = nfft // 2 + 1

        self.nfft = nfft
        self.win_size = win_size
        self.hop_size = hop_size
        self.win_type = win_type

        self.eps = 1e-8

        self.stft = ConvSTFT(
            nfft=self.nfft,
            win_size=self.win_size,
            hop_size=self.hop_size,
            win_type=self.win_type,
            feature_type="complex",
            requires_grad=False
        )
        self.istft = ConviSTFT(
            nfft=self.nfft,
            win_size=self.win_size,
            hop_size=self.hop_size,
            win_type=self.win_type,
            feature_type="complex",
            requires_grad=False
        )
        self.unet = UNet(
            in_channels=1,
            use_complex_networks=use_complex_networks,
            model_complexity=model_complexity,
            model_depth=model_depth,
            padding_mode=padding_mode
        )
        self.unet2 = UNet(
            in_channels=1,
            use_complex_networks=use_complex_networks,
            model_complexity=model_complexity,
            model_depth=model_depth,
            padding_mode=padding_mode
        )

    def forward(self, noisy: torch.Tensor):
        """
        :param noisy: torch.Tensor, shape: [b, n_samples] or [b, c, n_samples]
        :return:
        """
        if noisy.dim() == 2:
            noisy = torch.unsqueeze(noisy, dim=1)
        _, _, n_samples = noisy.shape
        remainder = (n_samples - self.win_size) % self.hop_size
        if remainder > 0:
            n_samples_pad = self.hop_size - remainder
            noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)

        # [batch_size, freq_bins * 2, time_steps]
        cmp_spec = self.stft.forward(noisy)
        # [batch_size, 1, freq_bins * 2, time_steps]
        cmp_spec = torch.unsqueeze(cmp_spec, 1)

        # [batch_size, 2, freq_bins, time_steps]
        cmp_spec = torch.cat([
            cmp_spec[:, :, :self.freq_bins, :],
            cmp_spec[:, :, self.freq_bins:, :],
        ], dim=1)

        # [batch_size, 2, freq_bins, time_steps, 1]
        cmp_spec = torch.unsqueeze(cmp_spec, dim=4)

        cmp_spec = torch.transpose(cmp_spec, 1, 4)
        # [batch_size, 1, freq_bins, time_steps, 2]

        unet1_out = self.unet.forward(cmp_spec)
        cmp_mask1 = torch.tanh(unet1_out)
        unet2_out = self.unet2.forward(unet1_out)
        cmp_mask2 = torch.tanh(unet2_out)

        # est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1)

        cmp_mask2 = cmp_mask2 + cmp_mask1
        est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
        # est_wav shape: [b, n_samples]

        est_wav = est_wav[:, :n_samples]
        return est_spec, est_wav, est_mask

    def apply_mask(self,
                   cmp_spec: torch.Tensor,
                   cmp_mask: torch.Tensor,
                   ):
        """
        :param cmp_spec: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2]
        :param cmp_mask: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2]
        :return:
        """
        est_spec = torch.cat(
            tensors=[
                cmp_spec[..., 0] * cmp_mask[..., 0] - cmp_spec[..., 1] * cmp_mask[..., 1],
                cmp_spec[..., 0] * cmp_mask[..., 1] + cmp_spec[..., 1] * cmp_mask[..., 0]
            ], dim=1
        )
        # est_spec shape: [b, 2, n//2+1, t]
        est_spec = torch.cat(tensors=[est_spec[:, 0, :, :], est_spec[:, 1, :, :]], dim=1)
        # est_spec shape: [b, n+2, t]

        # cmp_mask shape: [b, 1, n//2+1, t, 2]
        cmp_mask = torch.squeeze(cmp_mask, dim=1)
        # cmp_mask shape: [b, n//2+1, t, 2]
        cmp_mask = torch.cat(tensors=[cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], dim=1)
        # cmp_mask shape: [b, n+2, t]

        # est_spec shape: [b, n+2, t]
        est_wav = self.istft(est_spec)
        # est_wav shape: [b, 1, n_samples]
        est_wav = torch.squeeze(est_wav, 1)
        # est_wav shape: [b, n_samples]
        return est_spec, est_wav, cmp_mask

    def get_params(self, weight_decay=0.0):
        """
        为可训练参数配置 weight_decay (权重衰减) 的作用是实现 L2 正则化。
        1. 防止过拟合: 通过向损失函数添加参数的 L2 范数 (平方和) 作为惩罚项, weight_decay 会限制模型权重的大小.
        这使得模型倾向于学习更小的权重值, 降低对训练数据的过度敏感, 从而提高泛化能力.
        2. 控制模型复杂度: 权重衰减直接作用于优化过程, 在梯度更新时对权重进行衰减,
        公式: weight = weight - lr * (gradient + weight_decay * weight).
        这相当于在梯度下降中额外引入了一个与当前权重值成正比的衰减力, 抑制权重快速增长.
        3. 与优化器的具体实现相关
        在 SGD 等传统优化器中, weight_decay 直接等价于 L2 正则化.
        在 Adam 优化器中, 权重衰减的实现与参数更新耦合, 可能因学习率调整而效果减弱.
        在 AdamW 优化器改进了这一点, 将权重衰减与学习率解耦, 使其更符合 L2 正则化的理论效果.

        注意:
        值过大会导致欠拟合, 过小则正则化效果弱, 常用范围是 1e-4到 1e-2.
        某些场景 (如 BatchNorm 层) 可能需要通过参数分组对不同层设置不同的 weight_decay.
        :param weight_decay:
        :return:
        """
        weights, biases = [], []
        for name, param in self.named_parameters():
            if "bias" in name:
                biases += [param]
            else:
                weights += [param]

        params = [{
            'params': weights,
            'weight_decay': weight_decay,
        }, {
            'params': biases,
            'weight_decay': 0.0,
        }]
        return params

    def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
        """

        :param est_mask: torch.Tensor, shape: [b, n+2, t]
        :param clean:
        :param noisy:
        :return:
        """
        clean_stft = self.stft(clean)
        clean_re = clean_stft[:, :self.freq_bins, :]
        clean_im = clean_stft[:, self.freq_bins:, :]

        noisy_stft = self.stft(noisy)
        noisy_re = noisy_stft[:, :self.freq_bins, :]
        noisy_im = noisy_stft[:, self.freq_bins:, :]

        noisy_power = noisy_re ** 2 + noisy_im ** 2

        sr = clean_re
        yr = noisy_re
        si = clean_im
        yi = noisy_im
        y_pow = noisy_power
        # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
        gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
        # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
        gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)

        gth_mask_re[gth_mask_re > 2] = 1
        gth_mask_re[gth_mask_re < -2] = -1
        gth_mask_im[gth_mask_im > 2] = 1
        gth_mask_im[gth_mask_im < -2] = -1

        mask_re = est_mask[:, :self.freq_bins, :]
        mask_im = est_mask[:, self.freq_bins:, :]

        loss_re = F.mse_loss(gth_mask_re, mask_re)
        loss_im = F.mse_loss(gth_mask_im, mask_im)

        loss = loss_re + loss_im
        return loss


MODEL_FILE = "model.pt"


class FRCRNPretrainedModel(FRCRN):
    def __init__(self,
                 config: FRCRNConfig,
                 ):
        super(FRCRNPretrainedModel, self).__init__(
            use_complex_networks=config.use_complex_networks,
            model_complexity=config.model_complexity,
            model_depth=config.model_depth,
            nfft=config.nfft,
            win_size=config.win_size,
            hop_size=config.hop_size,
            win_type=config.win_type,
        )
        self.config = config

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        config = FRCRNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

        model = cls(config)

        if os.path.isdir(pretrained_model_name_or_path):
            ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
        else:
            ckpt_file = pretrained_model_name_or_path

        with open(ckpt_file, "rb") as f:
            state_dict = torch.load(f, map_location="cpu", weights_only=True)
        model.load_state_dict(state_dict, strict=True)
        return model

    def save_pretrained(self,
                        save_directory: Union[str, os.PathLike],
                        state_dict: Optional[dict] = None,
                        ):

        model = self

        if state_dict is None:
            state_dict = model.state_dict()

        os.makedirs(save_directory, exist_ok=True)

        # save state dict
        model_file = os.path.join(save_directory, MODEL_FILE)
        torch.save(state_dict, model_file)

        # save config
        config_file = os.path.join(save_directory, CONFIG_FILE)
        self.config.to_yaml_file(config_file)
        return save_directory


def main():
    # model = FRCRN(
    #     use_complex_networks=True,
    #     model_complexity=45,
    #     model_depth=14,
    #     padding_mode="zeros",
    #     nfft=512,
    #     win_size=400,
    #     hop_size=200,
    #     win_type="hann",
    # )
    model = FRCRN(
        use_complex_networks=True,
        model_complexity=45,
        model_depth=14,
        padding_mode="zeros",
        nfft=640,
        win_size=640,
        hop_size=320,
        win_type="hann",
    )
    mixture = torch.rand(size=(1, 8000), dtype=torch.float32)

    est_spec, est_wav, est_mask = model.forward(mixture)
    print(est_spec.shape)
    print(est_wav.shape)
    print(est_mask.shape)

    return


if __name__ == "__main__":
    main()