HoneyTian commited on
Commit
7c192b8
·
1 Parent(s): 20fa6bf
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -31,7 +31,6 @@ import torch
31
  import torch.nn as nn
32
  from torch.nn import functional as F
33
  from torch.utils.data.dataloader import DataLoader
34
- from torch_pesq import PesqLoss
35
  from tqdm import tqdm
36
 
37
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
@@ -39,7 +38,7 @@ from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import Conv
39
  from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
40
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
41
  from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
42
- from toolbox.torchaudio.losses.perceptual import NegSTOILoss
43
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
44
 
45
 
@@ -283,7 +282,6 @@ def main():
283
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
284
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
285
  pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
286
- print(f"pesq_loss: {pesq_loss}")
287
 
288
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
289
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
 
31
  import torch.nn as nn
32
  from torch.nn import functional as F
33
  from torch.utils.data.dataloader import DataLoader
 
34
  from tqdm import tqdm
35
 
36
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
 
38
  from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
39
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
40
  from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
41
+ from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss
42
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
43
 
44
 
 
282
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
283
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
284
  pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
 
285
 
286
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
287
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
toolbox/torchaudio/losses/perceptual.py CHANGED
@@ -6,6 +6,7 @@ https://zhuanlan.zhihu.com/p/627039860
6
  import torch
7
  import torch.nn as nn
8
  from torch_stoi import NegSTOILoss as TorchNegSTOILoss
 
9
 
10
 
11
  class PMSQELoss(object):
@@ -55,6 +56,47 @@ class NegSTOILoss(nn.Module):
55
  return loss
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def main():
59
  sample_rate = 16000
60
 
 
6
  import torch
7
  import torch.nn as nn
8
  from torch_stoi import NegSTOILoss as TorchNegSTOILoss
9
+ from torch_pesq import PesqLoss as TorchPesqLoss
10
 
11
 
12
  class PMSQELoss(object):
 
56
  return loss
57
 
58
 
59
+ class PesqLoss(nn.Module):
60
+ def __init__(self,
61
+ factor: float,
62
+ sample_rate: int = 48000,
63
+ nbarks: int = 49,
64
+ win_length: int = 512,
65
+ n_fft: int = 512,
66
+ hop_length: int = 256,
67
+ reduction: str = "mean",
68
+ ):
69
+ super(PesqLoss, self).__init__()
70
+ self.factor = factor
71
+ self.sample_rate = sample_rate
72
+ self.nbarks = nbarks
73
+ self.win_length = win_length
74
+ self.n_fft = n_fft
75
+ self.hop_length = hop_length
76
+ self.reduction = reduction
77
+
78
+ self.loss_fn = TorchPesqLoss(
79
+ factor=factor,
80
+ sample_rate=sample_rate,
81
+ nbarks=nbarks,
82
+ win_length=win_length,
83
+ n_fft=n_fft,
84
+ hop_length=hop_length,
85
+ )
86
+
87
+ def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
88
+
89
+ batch_loss = self.loss_fn.forward(clean, denoise)
90
+
91
+ if self.reduction == "mean":
92
+ loss = torch.mean(batch_loss)
93
+ elif self.reduction == "sum":
94
+ loss = torch.sum(batch_loss)
95
+ else:
96
+ raise AssertionError
97
+ return loss
98
+
99
+
100
  def main():
101
  sample_rate = 16000
102