Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/losses/spectral.py
CHANGED
@@ -158,9 +158,11 @@ class SpectralConvergenceLoss(torch.nn.Module):
|
|
158 |
|
159 |
def __init__(self,
|
160 |
reduction: str = "mean",
|
|
|
161 |
):
|
162 |
super(SpectralConvergenceLoss, self).__init__()
|
163 |
self.reduction = reduction
|
|
|
164 |
|
165 |
if reduction not in ("sum", "mean"):
|
166 |
raise AssertionError(f"param reduction must be sum or mean.")
|
@@ -175,19 +177,9 @@ class SpectralConvergenceLoss(torch.nn.Module):
|
|
175 |
:return:
|
176 |
"""
|
177 |
error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
|
178 |
-
if torch.any(torch.isnan(error_norm)) or torch.any(torch.isinf(error_norm)):
|
179 |
-
raise AssertionError("SpectralConvergenceLoss, nan or inf in error_norm")
|
180 |
truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
|
181 |
-
if torch.any(torch.isnan(truth_norm)):
|
182 |
-
raise AssertionError("SpectralConvergenceLoss, nan in truth_norm")
|
183 |
-
if torch.any(torch.isinf(truth_norm)):
|
184 |
-
raise AssertionError("SpectralConvergenceLoss, inf in truth_norm")
|
185 |
|
186 |
-
batch_loss = error_norm / truth_norm
|
187 |
-
if torch.any(torch.isnan(batch_loss)):
|
188 |
-
raise AssertionError("SpectralConvergenceLoss, nan in batch_loss")
|
189 |
-
if torch.any(torch.isinf(batch_loss)):
|
190 |
-
raise AssertionError("SpectralConvergenceLoss, inf in batch_loss")
|
191 |
|
192 |
if self.reduction == "mean":
|
193 |
loss = torch.mean(batch_loss)
|
@@ -196,9 +188,6 @@ class SpectralConvergenceLoss(torch.nn.Module):
|
|
196 |
else:
|
197 |
raise AssertionError
|
198 |
|
199 |
-
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
200 |
-
raise AssertionError("SpectralConvergenceLoss, nan or inf in loss")
|
201 |
-
|
202 |
return loss
|
203 |
|
204 |
|
|
|
158 |
|
159 |
def __init__(self,
|
160 |
reduction: str = "mean",
|
161 |
+
eps: float = 1e-8,
|
162 |
):
|
163 |
super(SpectralConvergenceLoss, self).__init__()
|
164 |
self.reduction = reduction
|
165 |
+
self.eps = eps
|
166 |
|
167 |
if reduction not in ("sum", "mean"):
|
168 |
raise AssertionError(f"param reduction must be sum or mean.")
|
|
|
177 |
:return:
|
178 |
"""
|
179 |
error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
|
|
|
|
|
180 |
truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
batch_loss = error_norm / (truth_norm + self.eps)
|
|
|
|
|
|
|
|
|
183 |
|
184 |
if self.reduction == "mean":
|
185 |
loss = torch.mean(batch_loss)
|
|
|
188 |
else:
|
189 |
raise AssertionError
|
190 |
|
|
|
|
|
|
|
191 |
return loss
|
192 |
|
193 |
|