HoneyTian's picture
add frcrn model
cba47e4
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import List
import torch
import torch.nn as nn
from torch.nn import functional as F
class CIRMLoss(nn.Module):
def __init__(self,
n_fft: int = 512,
win_size: int = 512,
hop_size: int = 256,
center: bool = True,
eps: float = 1e-8,
reduction: str = "mean",
):
super(CIRMLoss, self).__init__()
self.n_fft = n_fft
self.win_size = win_size
self.hop_size = hop_size
self.center = center
self.eps = eps
self.reduction = reduction
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
if reduction not in ("sum", "mean"):
raise AssertionError(f"param reduction must be sum or mean.")
def forward(self, clean: torch.Tensor, noisy: torch.Tensor, mask_real: torch.Tensor, mask_imag: torch.Tensor):
"""
:param clean: waveform
:param noisy: waveform
:param mask_real: shape: [b, f, t]
:param mask_imag: shape: [b, f, t]
:return:
"""
if noisy.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
# clean_stft, noisy_stft shape: [b, f, t]
clean_stft = torch.stft(
clean,
n_fft=self.n_fft,
win_length=self.win_size,
hop_length=self.hop_size,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
return_complex=True
)
noisy_stft = torch.stft(
noisy,
n_fft=self.n_fft,
win_length=self.win_size,
hop_length=self.hop_size,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
return_complex=True
)
# [b, f, t]
clean_stft_spec_real = torch.real(clean_stft)
clean_stft_spec_imag = torch.imag(clean_stft)
noisy_stft_spec_real = torch.real(noisy_stft)
noisy_stft_spec_imag = torch.imag(noisy_stft)
noisy_power = noisy_stft_spec_real ** 2 + noisy_stft_spec_imag ** 2
sr = clean_stft_spec_real
yr = noisy_stft_spec_real
si = clean_stft_spec_imag
yi = noisy_stft_spec_imag
y_pow = noisy_power
# (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
gth_mask_real = (sr * yr + si * yi) / (y_pow + self.eps)
# (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
gth_mask_imag = (sr * yr - si * yi) / (y_pow + self.eps)
gth_mask_real[gth_mask_real > 2] = 1
gth_mask_real[gth_mask_real < -2] = -1
gth_mask_imag[gth_mask_imag > 2] = 1
gth_mask_imag[gth_mask_imag < -2] = -1
amp_loss = F.mse_loss(gth_mask_real, mask_real)
phase_loss = F.mse_loss(gth_mask_imag, mask_imag)
loss = amp_loss + phase_loss
return loss
def main():
batch_size = 2
signal_length = 16000
estimated_signal = torch.randn(batch_size, signal_length)
target_signal = torch.randn(batch_size, signal_length)
loss_fn = CIRMLoss()
loss = loss_fn.forward(estimated_signal, target_signal)
print(f"loss: {loss.item()}")
return
if __name__ == "__main__":
main()