Spaces:
Running
Running
update
Browse files- examples/conv_tasnet_gan/step_2_train_model.py +5 -5
- examples/conv_tasnet_gan/yaml/discriminator_config.yaml +1 -1
- toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator → waveform_metric_discriminator}/__init__.py +0 -0
- toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator/configuration_conv_tasnet_discriminator.py → waveform_metric_discriminator/configuration_waveform_metric_discriminator.py} +3 -3
- toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator/modeling_conv_tasnet_discriminator.py → waveform_metric_discriminator/modeling_waveform_metric_discriminator.py} +11 -11
- toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/yaml/discriminator_config.yaml +10 -0
examples/conv_tasnet_gan/step_2_train_model.py
CHANGED
@@ -39,8 +39,8 @@ from tqdm import tqdm
|
|
39 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
40 |
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
41 |
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
42 |
-
from toolbox.torchaudio.models.discriminators.
|
43 |
-
from toolbox.torchaudio.models.discriminators.
|
44 |
from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score
|
45 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
46 |
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
@@ -126,7 +126,7 @@ def main():
|
|
126 |
config = ConvTasNetConfig.from_pretrained(
|
127 |
pretrained_model_name_or_path=args.config_file,
|
128 |
)
|
129 |
-
discriminator_config =
|
130 |
pretrained_model_name_or_path=args.discriminator_config_file,
|
131 |
)
|
132 |
|
@@ -189,13 +189,13 @@ def main():
|
|
189 |
model.to(device)
|
190 |
model.train()
|
191 |
|
192 |
-
discriminator =
|
193 |
discriminator.to(device)
|
194 |
discriminator.train()
|
195 |
|
196 |
# optimizer
|
197 |
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
198 |
-
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
199 |
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
|
200 |
|
201 |
# resume training
|
|
|
39 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
40 |
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
41 |
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
42 |
+
from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.modeling_waveform_metric_discriminator import WaveformMetricDiscriminatorPretrainedModel
|
43 |
+
from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.configuration_waveform_metric_discriminator import WaveformMetricDiscriminatorConfig
|
44 |
from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score
|
45 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
46 |
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
|
|
126 |
config = ConvTasNetConfig.from_pretrained(
|
127 |
pretrained_model_name_or_path=args.config_file,
|
128 |
)
|
129 |
+
discriminator_config = WaveformMetricDiscriminatorConfig.from_pretrained(
|
130 |
pretrained_model_name_or_path=args.discriminator_config_file,
|
131 |
)
|
132 |
|
|
|
189 |
model.to(device)
|
190 |
model.train()
|
191 |
|
192 |
+
discriminator = WaveformMetricDiscriminatorPretrainedModel(discriminator_config).to(device)
|
193 |
discriminator.to(device)
|
194 |
discriminator.train()
|
195 |
|
196 |
# optimizer
|
197 |
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
198 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
|
199 |
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
|
200 |
|
201 |
# resume training
|
examples/conv_tasnet_gan/yaml/discriminator_config.yaml
CHANGED
@@ -6,5 +6,5 @@ n_fft: 512
|
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
|
9 |
-
discriminator_dim:
|
10 |
discriminator_in_channel: 2
|
|
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
|
9 |
+
discriminator_dim: 24
|
10 |
discriminator_in_channel: 2
|
toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator → waveform_metric_discriminator}/__init__.py
RENAMED
File without changes
|
toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator/configuration_conv_tasnet_discriminator.py → waveform_metric_discriminator/configuration_waveform_metric_discriminator.py}
RENAMED
@@ -3,13 +3,13 @@
|
|
3 |
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
|
5 |
|
6 |
-
class
|
7 |
"""
|
8 |
https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
|
9 |
"""
|
10 |
def __init__(self,
|
11 |
sample_rate: int = 8000,
|
12 |
-
segment_size: int =
|
13 |
n_fft: int = 512,
|
14 |
win_length: int = 200,
|
15 |
hop_length: int = 80,
|
@@ -19,7 +19,7 @@ class ConvTasNetDiscriminatorConfig(PretrainedConfig):
|
|
19 |
|
20 |
**kwargs
|
21 |
):
|
22 |
-
super(
|
23 |
self.sample_rate = sample_rate
|
24 |
self.segment_size = segment_size
|
25 |
self.n_fft = n_fft
|
|
|
3 |
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
|
5 |
|
6 |
+
class WaveformMetricDiscriminatorConfig(PretrainedConfig):
|
7 |
"""
|
8 |
https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
|
9 |
"""
|
10 |
def __init__(self,
|
11 |
sample_rate: int = 8000,
|
12 |
+
segment_size: int = 4,
|
13 |
n_fft: int = 512,
|
14 |
win_length: int = 200,
|
15 |
hop_length: int = 80,
|
|
|
19 |
|
20 |
**kwargs
|
21 |
):
|
22 |
+
super(WaveformMetricDiscriminatorConfig, self).__init__(**kwargs)
|
23 |
self.sample_rate = sample_rate
|
24 |
self.segment_size = segment_size
|
25 |
self.n_fft = n_fft
|
toolbox/torchaudio/models/discriminators/{conv_tasnet_discriminator/modeling_conv_tasnet_discriminator.py → waveform_metric_discriminator/modeling_waveform_metric_discriminator.py}
RENAMED
@@ -7,8 +7,7 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
import torchaudio
|
9 |
|
10 |
-
from toolbox.torchaudio.
|
11 |
-
from toolbox.torchaudio.models.discriminators.conv_tasnet_discriminator.configuration_conv_tasnet_discriminator import ConvTasNetDiscriminatorConfig
|
12 |
|
13 |
|
14 |
class LearnableSigmoid1d(nn.Module):
|
@@ -23,9 +22,9 @@ class LearnableSigmoid1d(nn.Module):
|
|
23 |
return self.beta * torch.sigmoid(self.slope * x)
|
24 |
|
25 |
|
26 |
-
class
|
27 |
-
def __init__(self, config:
|
28 |
-
super(
|
29 |
dim = config.discriminator_dim
|
30 |
self.in_channel = config.discriminator_in_channel
|
31 |
|
@@ -74,21 +73,22 @@ class ConvTasNetDiscriminator(nn.Module):
|
|
74 |
return self.layers(xy)
|
75 |
|
76 |
|
|
|
77 |
MODEL_FILE = "discriminator.pt"
|
78 |
|
79 |
|
80 |
-
class
|
81 |
def __init__(self,
|
82 |
-
config:
|
83 |
):
|
84 |
-
super(
|
85 |
config=config,
|
86 |
)
|
87 |
self.config = config
|
88 |
|
89 |
@classmethod
|
90 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
91 |
-
config =
|
92 |
|
93 |
model = cls(config)
|
94 |
|
@@ -125,8 +125,8 @@ class ConvTasNetDiscriminatorPretrainedModel(ConvTasNetDiscriminator):
|
|
125 |
|
126 |
|
127 |
def main():
|
128 |
-
config =
|
129 |
-
discriminator =
|
130 |
|
131 |
# shape: [batch_size, num_samples]
|
132 |
# x = torch.ones([4, int(4.5 * 16000)])
|
|
|
7 |
import torch.nn as nn
|
8 |
import torchaudio
|
9 |
|
10 |
+
from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.configuration_waveform_metric_discriminator import WaveformMetricDiscriminatorConfig
|
|
|
11 |
|
12 |
|
13 |
class LearnableSigmoid1d(nn.Module):
|
|
|
22 |
return self.beta * torch.sigmoid(self.slope * x)
|
23 |
|
24 |
|
25 |
+
class WaveformMetricDiscriminator(nn.Module):
|
26 |
+
def __init__(self, config: WaveformMetricDiscriminatorConfig):
|
27 |
+
super(WaveformMetricDiscriminator, self).__init__()
|
28 |
dim = config.discriminator_dim
|
29 |
self.in_channel = config.discriminator_in_channel
|
30 |
|
|
|
73 |
return self.layers(xy)
|
74 |
|
75 |
|
76 |
+
CONFIG_FILE = "discriminator_config.yaml"
|
77 |
MODEL_FILE = "discriminator.pt"
|
78 |
|
79 |
|
80 |
+
class WaveformMetricDiscriminatorPretrainedModel(WaveformMetricDiscriminator):
|
81 |
def __init__(self,
|
82 |
+
config: WaveformMetricDiscriminatorConfig,
|
83 |
):
|
84 |
+
super(WaveformMetricDiscriminatorPretrainedModel, self).__init__(
|
85 |
config=config,
|
86 |
)
|
87 |
self.config = config
|
88 |
|
89 |
@classmethod
|
90 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
91 |
+
config = WaveformMetricDiscriminatorPretrainedModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
92 |
|
93 |
model = cls(config)
|
94 |
|
|
|
125 |
|
126 |
|
127 |
def main():
|
128 |
+
config = WaveformMetricDiscriminatorConfig()
|
129 |
+
discriminator = WaveformMetricDiscriminator(config=config)
|
130 |
|
131 |
# shape: [batch_size, num_samples]
|
132 |
# x = torch.ones([4, int(4.5 * 16000)])
|
toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/yaml/discriminator_config.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "waveform_metric_discriminator"
|
2 |
+
|
3 |
+
sample_rate: 8000
|
4 |
+
segment_size: 4
|
5 |
+
n_fft: 512
|
6 |
+
win_size: 200
|
7 |
+
hop_size: 80
|
8 |
+
|
9 |
+
discriminator_dim: 16
|
10 |
+
discriminator_in_channel: 2
|