Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -359,64 +359,64 @@ def main():
|
|
359 |
|
360 |
})
|
361 |
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
|
421 |
return
|
422 |
|
|
|
359 |
|
360 |
})
|
361 |
|
362 |
+
# save path
|
363 |
+
save_dir = serialization_dir / "steps-{}".format(total_steps)
|
364 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
365 |
+
|
366 |
+
# save models
|
367 |
+
model.save_pretrained(save_dir.as_posix())
|
368 |
+
|
369 |
+
model_list.append(save_dir)
|
370 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
371 |
+
model_to_delete: Path = model_list.pop(0)
|
372 |
+
shutil.rmtree(model_to_delete.as_posix())
|
373 |
+
|
374 |
+
# save optim
|
375 |
+
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
|
376 |
+
|
377 |
+
# save metric
|
378 |
+
if best_metric is None:
|
379 |
+
best_idx_epoch = idx_epoch
|
380 |
+
best_metric = average_pesq_score
|
381 |
+
elif average_pesq_score > best_metric:
|
382 |
+
# great is better.
|
383 |
+
best_idx_epoch = idx_epoch
|
384 |
+
best_metric = average_pesq_score
|
385 |
+
else:
|
386 |
+
pass
|
387 |
+
|
388 |
+
metrics = {
|
389 |
+
"idx_epoch": idx_epoch,
|
390 |
+
"best_idx_epoch": best_idx_epoch,
|
391 |
+
"pesq_score": average_pesq_score,
|
392 |
+
"loss": average_loss,
|
393 |
+
"ae_loss": average_ae_loss,
|
394 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
395 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
396 |
+
}
|
397 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
398 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
399 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
400 |
+
|
401 |
+
# save best
|
402 |
+
best_dir = serialization_dir / "best"
|
403 |
+
if best_idx_epoch == idx_epoch:
|
404 |
+
if best_dir.exists():
|
405 |
+
shutil.rmtree(best_dir)
|
406 |
+
shutil.copytree(save_dir, best_dir)
|
407 |
+
|
408 |
+
# early stop
|
409 |
+
early_stop_flag = False
|
410 |
+
if best_idx_epoch == idx_epoch:
|
411 |
+
patience_count = 0
|
412 |
+
else:
|
413 |
+
patience_count += 1
|
414 |
+
if patience_count >= args.patience:
|
415 |
+
early_stop_flag = True
|
416 |
+
|
417 |
+
# early stop
|
418 |
+
if early_stop_flag:
|
419 |
+
break
|
420 |
|
421 |
return
|
422 |
|
toolbox/torchaudio/losses/spectral.py
CHANGED
@@ -39,8 +39,8 @@ class LSDLoss(nn.Module):
|
|
39 |
|
40 |
def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor):
|
41 |
"""
|
42 |
-
:param denoise_power:
|
43 |
-
:param clean_power:
|
44 |
:return:
|
45 |
"""
|
46 |
denoise_power = denoise_power + self.eps
|
|
|
39 |
|
40 |
def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor):
|
41 |
"""
|
42 |
+
:param denoise_power: power spectrum of the estimated signal power spectrum (batch_size, ...)
|
43 |
+
:param clean_power: power spectrum of the target signal (batch_size, ...)
|
44 |
:return:
|
45 |
"""
|
46 |
denoise_power = denoise_power + self.eps
|