Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -31,7 +31,6 @@ import torch
|
|
31 |
import torch.nn as nn
|
32 |
from torch.nn import functional as F
|
33 |
from torch.utils.data.dataloader import DataLoader
|
34 |
-
from torch_pesq import PesqLoss
|
35 |
from tqdm import tqdm
|
36 |
|
37 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
@@ -39,7 +38,7 @@ from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import Conv
|
|
39 |
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
40 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
41 |
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
42 |
-
from toolbox.torchaudio.losses.perceptual import NegSTOILoss
|
43 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
44 |
|
45 |
|
@@ -283,7 +282,6 @@ def main():
|
|
283 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
284 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
285 |
pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
|
286 |
-
print(f"pesq_loss: {pesq_loss}")
|
287 |
|
288 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
289 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
|
|
31 |
import torch.nn as nn
|
32 |
from torch.nn import functional as F
|
33 |
from torch.utils.data.dataloader import DataLoader
|
|
|
34 |
from tqdm import tqdm
|
35 |
|
36 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
|
|
38 |
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
39 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
40 |
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
41 |
+
from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss
|
42 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
43 |
|
44 |
|
|
|
282 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
283 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
284 |
pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
|
|
|
285 |
|
286 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
287 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
toolbox/torchaudio/losses/perceptual.py
CHANGED
@@ -6,6 +6,7 @@ https://zhuanlan.zhihu.com/p/627039860
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from torch_stoi import NegSTOILoss as TorchNegSTOILoss
|
|
|
9 |
|
10 |
|
11 |
class PMSQELoss(object):
|
@@ -55,6 +56,47 @@ class NegSTOILoss(nn.Module):
|
|
55 |
return loss
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def main():
|
59 |
sample_rate = 16000
|
60 |
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from torch_stoi import NegSTOILoss as TorchNegSTOILoss
|
9 |
+
from torch_pesq import PesqLoss as TorchPesqLoss
|
10 |
|
11 |
|
12 |
class PMSQELoss(object):
|
|
|
56 |
return loss
|
57 |
|
58 |
|
59 |
+
class PesqLoss(nn.Module):
|
60 |
+
def __init__(self,
|
61 |
+
factor: float,
|
62 |
+
sample_rate: int = 48000,
|
63 |
+
nbarks: int = 49,
|
64 |
+
win_length: int = 512,
|
65 |
+
n_fft: int = 512,
|
66 |
+
hop_length: int = 256,
|
67 |
+
reduction: str = "mean",
|
68 |
+
):
|
69 |
+
super(PesqLoss, self).__init__()
|
70 |
+
self.factor = factor
|
71 |
+
self.sample_rate = sample_rate
|
72 |
+
self.nbarks = nbarks
|
73 |
+
self.win_length = win_length
|
74 |
+
self.n_fft = n_fft
|
75 |
+
self.hop_length = hop_length
|
76 |
+
self.reduction = reduction
|
77 |
+
|
78 |
+
self.loss_fn = TorchPesqLoss(
|
79 |
+
factor=factor,
|
80 |
+
sample_rate=sample_rate,
|
81 |
+
nbarks=nbarks,
|
82 |
+
win_length=win_length,
|
83 |
+
n_fft=n_fft,
|
84 |
+
hop_length=hop_length,
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
|
88 |
+
|
89 |
+
batch_loss = self.loss_fn.forward(clean, denoise)
|
90 |
+
|
91 |
+
if self.reduction == "mean":
|
92 |
+
loss = torch.mean(batch_loss)
|
93 |
+
elif self.reduction == "sum":
|
94 |
+
loss = torch.sum(batch_loss)
|
95 |
+
else:
|
96 |
+
raise AssertionError
|
97 |
+
return loss
|
98 |
+
|
99 |
+
|
100 |
def main():
|
101 |
sample_rate = 16000
|
102 |
|