HoneyTian commited on
Commit
b0fda13
·
1 Parent(s): c3d4076
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -183,6 +183,10 @@ class SpectralConvergenceLoss(torch.nn.Module):
183
  loss = torch.sum(batch_loss)
184
  else:
185
  raise AssertionError
 
 
 
 
186
  return loss
187
 
188
 
@@ -209,7 +213,13 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
209
  :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
210
  :return:
211
  """
212
- return F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
 
 
 
 
 
 
213
 
214
 
215
  class STFTLoss(torch.nn.Module):
 
183
  loss = torch.sum(batch_loss)
184
  else:
185
  raise AssertionError
186
+
187
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
188
+ raise AssertionError("SpectralConvergenceLoss, nan or inf in loss")
189
+
190
  return loss
191
 
192
 
 
213
  :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
214
  :return:
215
  """
216
+
217
+ loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
218
+
219
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
220
+ raise AssertionError("SpectralConvergenceLoss, nan or inf in loss")
221
+
222
+ return loss
223
 
224
 
225
  class STFTLoss(torch.nn.Module):