HoneyTian commited on
Commit
c797dfd
·
1 Parent(s): ce6b38e
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -158,9 +158,11 @@ class SpectralConvergenceLoss(torch.nn.Module):
158
 
159
  def __init__(self,
160
  reduction: str = "mean",
 
161
  ):
162
  super(SpectralConvergenceLoss, self).__init__()
163
  self.reduction = reduction
 
164
 
165
  if reduction not in ("sum", "mean"):
166
  raise AssertionError(f"param reduction must be sum or mean.")
@@ -175,19 +177,9 @@ class SpectralConvergenceLoss(torch.nn.Module):
175
  :return:
176
  """
177
  error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
178
- if torch.any(torch.isnan(error_norm)) or torch.any(torch.isinf(error_norm)):
179
- raise AssertionError("SpectralConvergenceLoss, nan or inf in error_norm")
180
  truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
181
- if torch.any(torch.isnan(truth_norm)):
182
- raise AssertionError("SpectralConvergenceLoss, nan in truth_norm")
183
- if torch.any(torch.isinf(truth_norm)):
184
- raise AssertionError("SpectralConvergenceLoss, inf in truth_norm")
185
 
186
- batch_loss = error_norm / truth_norm
187
- if torch.any(torch.isnan(batch_loss)):
188
- raise AssertionError("SpectralConvergenceLoss, nan in batch_loss")
189
- if torch.any(torch.isinf(batch_loss)):
190
- raise AssertionError("SpectralConvergenceLoss, inf in batch_loss")
191
 
192
  if self.reduction == "mean":
193
  loss = torch.mean(batch_loss)
@@ -196,9 +188,6 @@ class SpectralConvergenceLoss(torch.nn.Module):
196
  else:
197
  raise AssertionError
198
 
199
- if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
200
- raise AssertionError("SpectralConvergenceLoss, nan or inf in loss")
201
-
202
  return loss
203
 
204
 
 
158
 
159
  def __init__(self,
160
  reduction: str = "mean",
161
+ eps: float = 1e-8,
162
  ):
163
  super(SpectralConvergenceLoss, self).__init__()
164
  self.reduction = reduction
165
+ self.eps = eps
166
 
167
  if reduction not in ("sum", "mean"):
168
  raise AssertionError(f"param reduction must be sum or mean.")
 
177
  :return:
178
  """
179
  error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
 
 
180
  truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
 
 
 
 
181
 
182
+ batch_loss = error_norm / (truth_norm + self.eps)
 
 
 
 
183
 
184
  if self.reduction == "mean":
185
  loss = torch.mean(batch_loss)
 
188
  else:
189
  raise AssertionError
190
 
 
 
 
191
  return loss
192
 
193