File size: 3,401 Bytes
cba47e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import List

import torch
import torch.nn as nn
from torch.nn import functional as F


class CIRMLoss(nn.Module):
    def __init__(self,
                 n_fft: int = 512,
                 win_size: int = 512,
                 hop_size: int = 256,
                 center: bool = True,
                 eps: float = 1e-8,
                 reduction: str = "mean",
                 ):
        super(CIRMLoss, self).__init__()
        self.n_fft = n_fft
        self.win_size = win_size
        self.hop_size = hop_size
        self.center = center
        self.eps = eps
        self.reduction = reduction

        self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)

        if reduction not in ("sum", "mean"):
            raise AssertionError(f"param reduction must be sum or mean.")

    def forward(self, clean: torch.Tensor, noisy: torch.Tensor, mask_real: torch.Tensor, mask_imag: torch.Tensor):
        """
        :param clean: waveform
        :param noisy: waveform
        :param mask_real: shape: [b, f, t]
        :param mask_imag: shape: [b, f, t]
        :return:
        """
        if noisy.shape != clean.shape:
            raise AssertionError("Input signals must have the same shape")

        # clean_stft, noisy_stft shape: [b, f, t]
        clean_stft = torch.stft(
            clean,
            n_fft=self.n_fft,
            win_length=self.win_size,
            hop_length=self.hop_size,
            window=self.window,
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            return_complex=True
        )
        noisy_stft = torch.stft(
            noisy,
            n_fft=self.n_fft,
            win_length=self.win_size,
            hop_length=self.hop_size,
            window=self.window,
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            return_complex=True
        )

        # [b, f, t]
        clean_stft_spec_real = torch.real(clean_stft)
        clean_stft_spec_imag = torch.imag(clean_stft)
        noisy_stft_spec_real = torch.real(noisy_stft)
        noisy_stft_spec_imag = torch.imag(noisy_stft)
        noisy_power = noisy_stft_spec_real ** 2 + noisy_stft_spec_imag ** 2

        sr = clean_stft_spec_real
        yr = noisy_stft_spec_real
        si = clean_stft_spec_imag
        yi = noisy_stft_spec_imag
        y_pow = noisy_power
        # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
        gth_mask_real = (sr * yr + si * yi) / (y_pow + self.eps)
        # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
        gth_mask_imag = (sr * yr - si * yi) / (y_pow + self.eps)

        gth_mask_real[gth_mask_real > 2] = 1
        gth_mask_real[gth_mask_real < -2] = -1
        gth_mask_imag[gth_mask_imag > 2] = 1
        gth_mask_imag[gth_mask_imag < -2] = -1

        amp_loss = F.mse_loss(gth_mask_real, mask_real)
        phase_loss = F.mse_loss(gth_mask_imag, mask_imag)

        loss = amp_loss + phase_loss
        return loss


def main():
    batch_size = 2
    signal_length = 16000
    estimated_signal = torch.randn(batch_size, signal_length)
    target_signal = torch.randn(batch_size, signal_length)

    loss_fn = CIRMLoss()

    loss = loss_fn.forward(estimated_signal, target_signal)
    print(f"loss: {loss.item()}")

    return


if __name__ == "__main__":
    main()