HoneyTian commited on
Commit
3f9acc2
·
1 Parent(s): aa6beca
examples/conv_tasnet_gan/step_2_train_model.py CHANGED
@@ -39,8 +39,8 @@ from tqdm import tqdm
39
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
40
  from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
41
  from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
42
- from toolbox.torchaudio.models.discriminators.conv_tasnet_discriminator.modeling_conv_tasnet_discriminator import ConvTasNetDiscriminatorPretrainedModel
43
- from toolbox.torchaudio.models.discriminators.conv_tasnet_discriminator.configuration_conv_tasnet_discriminator import ConvTasNetDiscriminatorConfig
44
  from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score
45
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
46
  from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
@@ -126,7 +126,7 @@ def main():
126
  config = ConvTasNetConfig.from_pretrained(
127
  pretrained_model_name_or_path=args.config_file,
128
  )
129
- discriminator_config = ConvTasNetDiscriminatorConfig.from_pretrained(
130
  pretrained_model_name_or_path=args.discriminator_config_file,
131
  )
132
 
@@ -189,13 +189,13 @@ def main():
189
  model.to(device)
190
  model.train()
191
 
192
- discriminator = ConvTasNetDiscriminatorPretrainedModel(discriminator_config).to(device)
193
  discriminator.to(device)
194
  discriminator.train()
195
 
196
  # optimizer
197
  logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
198
- optimizer = torch.optim.AdamW(model.parameters(), config.lr)
199
  discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
200
 
201
  # resume training
 
39
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
40
  from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
41
  from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
42
+ from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.modeling_waveform_metric_discriminator import WaveformMetricDiscriminatorPretrainedModel
43
+ from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.configuration_waveform_metric_discriminator import WaveformMetricDiscriminatorConfig
44
  from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score
45
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
46
  from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
 
126
  config = ConvTasNetConfig.from_pretrained(
127
  pretrained_model_name_or_path=args.config_file,
128
  )
129
+ discriminator_config = WaveformMetricDiscriminatorConfig.from_pretrained(
130
  pretrained_model_name_or_path=args.discriminator_config_file,
131
  )
132
 
 
189
  model.to(device)
190
  model.train()
191
 
192
+ discriminator = WaveformMetricDiscriminatorPretrainedModel(discriminator_config).to(device)
193
  discriminator.to(device)
194
  discriminator.train()
195
 
196
  # optimizer
197
  logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
198
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
199
  discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
200
 
201
  # resume training
examples/conv_tasnet_gan/yaml/discriminator_config.yaml CHANGED
@@ -6,5 +6,5 @@ n_fft: 512
6
  win_size: 200
7
  hop_size: 80
8
 
9
- discriminator_dim: 16
10
  discriminator_in_channel: 2
 
6
  win_size: 200
7
  hop_size: 80
8
 
9
+ discriminator_dim: 24
10
  discriminator_in_channel: 2
toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator → waveform_metric_discriminator}/__init__.py RENAMED
File without changes
toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator/configuration_conv_tasnet_discriminator.py → waveform_metric_discriminator/configuration_waveform_metric_discriminator.py} RENAMED
@@ -3,13 +3,13 @@
3
  from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
 
5
 
6
- class ConvTasNetDiscriminatorConfig(PretrainedConfig):
7
  """
8
  https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
  """
10
  def __init__(self,
11
  sample_rate: int = 8000,
12
- segment_size: int = 16000,
13
  n_fft: int = 512,
14
  win_length: int = 200,
15
  hop_length: int = 80,
@@ -19,7 +19,7 @@ class ConvTasNetDiscriminatorConfig(PretrainedConfig):
19
 
20
  **kwargs
21
  ):
22
- super(ConvTasNetDiscriminatorConfig, self).__init__(**kwargs)
23
  self.sample_rate = sample_rate
24
  self.segment_size = segment_size
25
  self.n_fft = n_fft
 
3
  from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
 
5
 
6
+ class WaveformMetricDiscriminatorConfig(PretrainedConfig):
7
  """
8
  https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
  """
10
  def __init__(self,
11
  sample_rate: int = 8000,
12
+ segment_size: int = 4,
13
  n_fft: int = 512,
14
  win_length: int = 200,
15
  hop_length: int = 80,
 
19
 
20
  **kwargs
21
  ):
22
+ super(WaveformMetricDiscriminatorConfig, self).__init__(**kwargs)
23
  self.sample_rate = sample_rate
24
  self.segment_size = segment_size
25
  self.n_fft = n_fft
toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator/modeling_conv_tasnet_discriminator.py → waveform_metric_discriminator/modeling_waveform_metric_discriminator.py} RENAMED
@@ -7,8 +7,7 @@ import torch
7
  import torch.nn as nn
8
  import torchaudio
9
 
10
- from toolbox.torchaudio.configuration_utils import DISCRIMINATOR_CONFIG_FILE
11
- from toolbox.torchaudio.models.discriminators.conv_tasnet_discriminator.configuration_conv_tasnet_discriminator import ConvTasNetDiscriminatorConfig
12
 
13
 
14
  class LearnableSigmoid1d(nn.Module):
@@ -23,9 +22,9 @@ class LearnableSigmoid1d(nn.Module):
23
  return self.beta * torch.sigmoid(self.slope * x)
24
 
25
 
26
- class ConvTasNetDiscriminator(nn.Module):
27
- def __init__(self, config: ConvTasNetDiscriminatorConfig):
28
- super(ConvTasNetDiscriminator, self).__init__()
29
  dim = config.discriminator_dim
30
  self.in_channel = config.discriminator_in_channel
31
 
@@ -74,21 +73,22 @@ class ConvTasNetDiscriminator(nn.Module):
74
  return self.layers(xy)
75
 
76
 
 
77
  MODEL_FILE = "discriminator.pt"
78
 
79
 
80
- class ConvTasNetDiscriminatorPretrainedModel(ConvTasNetDiscriminator):
81
  def __init__(self,
82
- config: ConvTasNetDiscriminatorConfig,
83
  ):
84
- super(ConvTasNetDiscriminatorPretrainedModel, self).__init__(
85
  config=config,
86
  )
87
  self.config = config
88
 
89
  @classmethod
90
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
91
- config = ConvTasNetDiscriminatorConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
92
 
93
  model = cls(config)
94
 
@@ -125,8 +125,8 @@ class ConvTasNetDiscriminatorPretrainedModel(ConvTasNetDiscriminator):
125
 
126
 
127
  def main():
128
- config = ConvTasNetDiscriminatorConfig()
129
- discriminator = ConvTasNetDiscriminator(config=config)
130
 
131
  # shape: [batch_size, num_samples]
132
  # x = torch.ones([4, int(4.5 * 16000)])
 
7
  import torch.nn as nn
8
  import torchaudio
9
 
10
+ from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.configuration_waveform_metric_discriminator import WaveformMetricDiscriminatorConfig
 
11
 
12
 
13
  class LearnableSigmoid1d(nn.Module):
 
22
  return self.beta * torch.sigmoid(self.slope * x)
23
 
24
 
25
+ class WaveformMetricDiscriminator(nn.Module):
26
+ def __init__(self, config: WaveformMetricDiscriminatorConfig):
27
+ super(WaveformMetricDiscriminator, self).__init__()
28
  dim = config.discriminator_dim
29
  self.in_channel = config.discriminator_in_channel
30
 
 
73
  return self.layers(xy)
74
 
75
 
76
+ CONFIG_FILE = "discriminator_config.yaml"
77
  MODEL_FILE = "discriminator.pt"
78
 
79
 
80
+ class WaveformMetricDiscriminatorPretrainedModel(WaveformMetricDiscriminator):
81
  def __init__(self,
82
+ config: WaveformMetricDiscriminatorConfig,
83
  ):
84
+ super(WaveformMetricDiscriminatorPretrainedModel, self).__init__(
85
  config=config,
86
  )
87
  self.config = config
88
 
89
  @classmethod
90
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
91
+ config = WaveformMetricDiscriminatorPretrainedModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
92
 
93
  model = cls(config)
94
 
 
125
 
126
 
127
  def main():
128
+ config = WaveformMetricDiscriminatorConfig()
129
+ discriminator = WaveformMetricDiscriminator(config=config)
130
 
131
  # shape: [batch_size, num_samples]
132
  # x = torch.ones([4, int(4.5 * 16000)])
toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/yaml/discriminator_config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "waveform_metric_discriminator"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 4
5
+ n_fft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ discriminator_dim: 16
10
+ discriminator_in_channel: 2