HoneyTian commited on
Commit
2c1a5a6
·
1 Parent(s): 7b7acb0
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -359,64 +359,64 @@ def main():
359
 
360
  })
361
 
362
- # save path
363
- save_dir = serialization_dir / "steps-{}".format(total_steps)
364
- save_dir.mkdir(parents=True, exist_ok=False)
365
-
366
- # save models
367
- model.save_pretrained(save_dir.as_posix())
368
-
369
- model_list.append(save_dir)
370
- if len(model_list) >= args.num_serialized_models_to_keep:
371
- model_to_delete: Path = model_list.pop(0)
372
- shutil.rmtree(model_to_delete.as_posix())
373
-
374
- # save optim
375
- torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
376
-
377
- # save metric
378
- if best_metric is None:
379
- best_idx_epoch = idx_epoch
380
- best_metric = average_pesq_score
381
- elif average_pesq_score > best_metric:
382
- # great is better.
383
- best_idx_epoch = idx_epoch
384
- best_metric = average_pesq_score
385
- else:
386
- pass
387
-
388
- metrics = {
389
- "idx_epoch": idx_epoch,
390
- "best_idx_epoch": best_idx_epoch,
391
- "pesq_score": average_pesq_score,
392
- "loss": average_loss,
393
- "ae_loss": average_ae_loss,
394
- "neg_si_snr_loss": average_neg_si_snr_loss,
395
- "neg_stoi_loss": average_neg_stoi_loss,
396
- }
397
- metrics_filename = save_dir / "metrics_epoch.json"
398
- with open(metrics_filename, "w", encoding="utf-8") as f:
399
- json.dump(metrics, f, indent=4, ensure_ascii=False)
400
-
401
- # save best
402
- best_dir = serialization_dir / "best"
403
- if best_idx_epoch == idx_epoch:
404
- if best_dir.exists():
405
- shutil.rmtree(best_dir)
406
- shutil.copytree(save_dir, best_dir)
407
-
408
- # early stop
409
- early_stop_flag = False
410
- if best_idx_epoch == idx_epoch:
411
- patience_count = 0
412
- else:
413
- patience_count += 1
414
- if patience_count >= args.patience:
415
- early_stop_flag = True
416
-
417
- # early stop
418
- if early_stop_flag:
419
- break
420
 
421
  return
422
 
 
359
 
360
  })
361
 
362
+ # save path
363
+ save_dir = serialization_dir / "steps-{}".format(total_steps)
364
+ save_dir.mkdir(parents=True, exist_ok=False)
365
+
366
+ # save models
367
+ model.save_pretrained(save_dir.as_posix())
368
+
369
+ model_list.append(save_dir)
370
+ if len(model_list) >= args.num_serialized_models_to_keep:
371
+ model_to_delete: Path = model_list.pop(0)
372
+ shutil.rmtree(model_to_delete.as_posix())
373
+
374
+ # save optim
375
+ torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
376
+
377
+ # save metric
378
+ if best_metric is None:
379
+ best_idx_epoch = idx_epoch
380
+ best_metric = average_pesq_score
381
+ elif average_pesq_score > best_metric:
382
+ # great is better.
383
+ best_idx_epoch = idx_epoch
384
+ best_metric = average_pesq_score
385
+ else:
386
+ pass
387
+
388
+ metrics = {
389
+ "idx_epoch": idx_epoch,
390
+ "best_idx_epoch": best_idx_epoch,
391
+ "pesq_score": average_pesq_score,
392
+ "loss": average_loss,
393
+ "ae_loss": average_ae_loss,
394
+ "neg_si_snr_loss": average_neg_si_snr_loss,
395
+ "neg_stoi_loss": average_neg_stoi_loss,
396
+ }
397
+ metrics_filename = save_dir / "metrics_epoch.json"
398
+ with open(metrics_filename, "w", encoding="utf-8") as f:
399
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
400
+
401
+ # save best
402
+ best_dir = serialization_dir / "best"
403
+ if best_idx_epoch == idx_epoch:
404
+ if best_dir.exists():
405
+ shutil.rmtree(best_dir)
406
+ shutil.copytree(save_dir, best_dir)
407
+
408
+ # early stop
409
+ early_stop_flag = False
410
+ if best_idx_epoch == idx_epoch:
411
+ patience_count = 0
412
+ else:
413
+ patience_count += 1
414
+ if patience_count >= args.patience:
415
+ early_stop_flag = True
416
+
417
+ # early stop
418
+ if early_stop_flag:
419
+ break
420
 
421
  return
422
 
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -39,8 +39,8 @@ class LSDLoss(nn.Module):
39
 
40
  def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor):
41
  """
42
- :param denoise_power: The estimated signal (batch_size, signal_length)
43
- :param clean_power: The target signal (batch_size, signal_length)
44
  :return:
45
  """
46
  denoise_power = denoise_power + self.eps
 
39
 
40
  def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor):
41
  """
42
+ :param denoise_power: power spectrum of the estimated signal power spectrum (batch_size, ...)
43
+ :param clean_power: power spectrum of the target signal (batch_size, ...)
44
  :return:
45
  """
46
  denoise_power = denoise_power + self.eps