Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/losses/spectral.py
CHANGED
@@ -183,6 +183,10 @@ class SpectralConvergenceLoss(torch.nn.Module):
|
|
183 |
loss = torch.sum(batch_loss)
|
184 |
else:
|
185 |
raise AssertionError
|
|
|
|
|
|
|
|
|
186 |
return loss
|
187 |
|
188 |
|
@@ -209,7 +213,13 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
|
|
209 |
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
|
210 |
:return:
|
211 |
"""
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
|
215 |
class STFTLoss(torch.nn.Module):
|
|
|
183 |
loss = torch.sum(batch_loss)
|
184 |
else:
|
185 |
raise AssertionError
|
186 |
+
|
187 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
188 |
+
raise AssertionError("SpectralConvergenceLoss, nan or inf in loss")
|
189 |
+
|
190 |
return loss
|
191 |
|
192 |
|
|
|
213 |
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
|
214 |
:return:
|
215 |
"""
|
216 |
+
|
217 |
+
loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
|
218 |
+
|
219 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
220 |
+
raise AssertionError("SpectralConvergenceLoss, nan or inf in loss")
|
221 |
+
|
222 |
+
return loss
|
223 |
|
224 |
|
225 |
class STFTLoss(torch.nn.Module):
|