Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from typing import List | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
class CIRMLoss(nn.Module): | |
def __init__(self, | |
n_fft: int = 512, | |
win_size: int = 512, | |
hop_size: int = 256, | |
center: bool = True, | |
eps: float = 1e-8, | |
reduction: str = "mean", | |
): | |
super(CIRMLoss, self).__init__() | |
self.n_fft = n_fft | |
self.win_size = win_size | |
self.hop_size = hop_size | |
self.center = center | |
self.eps = eps | |
self.reduction = reduction | |
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) | |
if reduction not in ("sum", "mean"): | |
raise AssertionError(f"param reduction must be sum or mean.") | |
def forward(self, clean: torch.Tensor, noisy: torch.Tensor, mask_real: torch.Tensor, mask_imag: torch.Tensor): | |
""" | |
:param clean: waveform | |
:param noisy: waveform | |
:param mask_real: shape: [b, f, t] | |
:param mask_imag: shape: [b, f, t] | |
:return: | |
""" | |
if noisy.shape != clean.shape: | |
raise AssertionError("Input signals must have the same shape") | |
# clean_stft, noisy_stft shape: [b, f, t] | |
clean_stft = torch.stft( | |
clean, | |
n_fft=self.n_fft, | |
win_length=self.win_size, | |
hop_length=self.hop_size, | |
window=self.window, | |
center=self.center, | |
pad_mode="reflect", | |
normalized=False, | |
return_complex=True | |
) | |
noisy_stft = torch.stft( | |
noisy, | |
n_fft=self.n_fft, | |
win_length=self.win_size, | |
hop_length=self.hop_size, | |
window=self.window, | |
center=self.center, | |
pad_mode="reflect", | |
normalized=False, | |
return_complex=True | |
) | |
# [b, f, t] | |
clean_stft_spec_real = torch.real(clean_stft) | |
clean_stft_spec_imag = torch.imag(clean_stft) | |
noisy_stft_spec_real = torch.real(noisy_stft) | |
noisy_stft_spec_imag = torch.imag(noisy_stft) | |
noisy_power = noisy_stft_spec_real ** 2 + noisy_stft_spec_imag ** 2 | |
sr = clean_stft_spec_real | |
yr = noisy_stft_spec_real | |
si = clean_stft_spec_imag | |
yi = noisy_stft_spec_imag | |
y_pow = noisy_power | |
# (Sr * Yr + Si * Yi) / (Y_pow + 1e-8) | |
gth_mask_real = (sr * yr + si * yi) / (y_pow + self.eps) | |
# (Si * Yr - Sr * Yi) / (Y_pow + 1e-8) | |
gth_mask_imag = (sr * yr - si * yi) / (y_pow + self.eps) | |
gth_mask_real[gth_mask_real > 2] = 1 | |
gth_mask_real[gth_mask_real < -2] = -1 | |
gth_mask_imag[gth_mask_imag > 2] = 1 | |
gth_mask_imag[gth_mask_imag < -2] = -1 | |
amp_loss = F.mse_loss(gth_mask_real, mask_real) | |
phase_loss = F.mse_loss(gth_mask_imag, mask_imag) | |
loss = amp_loss + phase_loss | |
return loss | |
def main(): | |
batch_size = 2 | |
signal_length = 16000 | |
estimated_signal = torch.randn(batch_size, signal_length) | |
target_signal = torch.randn(batch_size, signal_length) | |
loss_fn = CIRMLoss() | |
loss = loss_fn.forward(estimated_signal, target_signal) | |
print(f"loss: {loss.item()}") | |
return | |
if __name__ == "__main__": | |
main() | |