HoneyTian commited on
Commit
9de2c38
·
1 Parent(s): 35dd947
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -10,6 +10,9 @@ https://github.com/kaituoxu/Conv-TasNet/tree/master/src
10
  高要求场景(如医疗助听、语音识别):
11
  需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
12
 
 
 
 
13
  """
14
  import argparse
15
  import json
 
10
  高要求场景(如医疗助听、语音识别):
11
  需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
12
 
13
+ DeepFilterNet2 模型在 DNS4 数据集,超过500小时的音频上训练了 100 个 epoch。
14
+ https://arxiv.org/abs/2205.05474
15
+
16
  """
17
  import argparse
18
  import json
examples/conv_tasnet/yaml/config.yaml CHANGED
@@ -19,10 +19,10 @@ mask_nonlinear: "relu"
19
  min_snr_db: -10
20
  max_snr_db: 20
21
 
22
- lr: 0.001
23
  lr_scheduler: "CosineAnnealingLR"
24
  lr_scheduler_kwargs:
25
  T_max: 250000
26
- eta_min: 0.00001
27
 
28
  eval_steps: 25000
 
19
  min_snr_db: -10
20
  max_snr_db: 20
21
 
22
+ lr: 0.005
23
  lr_scheduler: "CosineAnnealingLR"
24
  lr_scheduler_kwargs:
25
  T_max: 250000
26
+ eta_min: 0.00005
27
 
28
  eval_steps: 25000
examples/conv_tasnet_gan/run.sh ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
+ --max_epochs 400
10
+
11
+
12
+ END
13
+
14
+
15
+ # params
16
+ system_version="windows";
17
+ verbose=true;
18
+ stage=0 # start from 0 if you need to start from data preparation
19
+ stop_stage=9
20
+
21
+ work_dir="$(pwd)"
22
+ file_folder_name=file_folder_name
23
+ final_model_name=final_model_name
24
+ config_file="yaml/config.yaml"
25
+ discriminator_config_file="yaml/discriminator_config.yaml"
26
+ limit=10
27
+
28
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
29
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
30
+
31
+ max_count=10000000
32
+
33
+ nohup_name=nohup.out
34
+
35
+ # model params
36
+ batch_size=64
37
+ max_epochs=200
38
+ save_top_k=10
39
+ patience=5
40
+
41
+
42
+ # parse options
43
+ while true; do
44
+ [ -z "${1:-}" ] && break; # break if there are no arguments
45
+ case "$1" in
46
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
47
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
48
+ old_value="(eval echo \\$$name)";
49
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
50
+ was_bool=true;
51
+ else
52
+ was_bool=false;
53
+ fi
54
+
55
+ # Set the variable to the right value-- the escaped quotes make it work if
56
+ # the option had spaces, like --cmd "queue.pl -sync y"
57
+ eval "${name}=\"$2\"";
58
+
59
+ # Check that Boolean-valued arguments are really Boolean.
60
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
61
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
62
+ exit 1;
63
+ fi
64
+ shift 2;
65
+ ;;
66
+
67
+ *) break;
68
+ esac
69
+ done
70
+
71
+ file_dir="${work_dir}/${file_folder_name}"
72
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
73
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
74
+
75
+ train_dataset="${file_dir}/train.jsonl"
76
+ valid_dataset="${file_dir}/valid.jsonl"
77
+
78
+ $verbose && echo "system_version: ${system_version}"
79
+ $verbose && echo "file_folder_name: ${file_folder_name}"
80
+
81
+ if [ $system_version == "windows" ]; then
82
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
83
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
84
+ #source /data/local/bin/nx_denoise/bin/activate
85
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
86
+ fi
87
+
88
+
89
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
90
+ $verbose && echo "stage 1: prepare data"
91
+ cd "${work_dir}" || exit 1
92
+ python3 step_1_prepare_data.py \
93
+ --file_dir "${file_dir}" \
94
+ --noise_dir "${noise_dir}" \
95
+ --speech_dir "${speech_dir}" \
96
+ --train_dataset "${train_dataset}" \
97
+ --valid_dataset "${valid_dataset}" \
98
+ --max_count "${max_count}" \
99
+
100
+ fi
101
+
102
+
103
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
104
+ $verbose && echo "stage 2: train model"
105
+ cd "${work_dir}" || exit 1
106
+ python3 step_2_train_model.py \
107
+ --train_dataset "${train_dataset}" \
108
+ --valid_dataset "${valid_dataset}" \
109
+ --serialization_dir "${file_dir}" \
110
+ --config_file "${config_file}" \
111
+ --discriminator_config_file "${discriminator_config_file}" \
112
+
113
+ fi
114
+
115
+
116
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
+ $verbose && echo "stage 3: test model"
118
+ cd "${work_dir}" || exit 1
119
+ python3 step_3_evaluation.py \
120
+ --valid_dataset "${valid_dataset}" \
121
+ --model_dir "${file_dir}/best" \
122
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
123
+ --limit "${limit}" \
124
+
125
+ fi
126
+
127
+
128
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
129
+ $verbose && echo "stage 4: collect files"
130
+ cd "${work_dir}" || exit 1
131
+
132
+ mkdir -p ${final_model_dir}
133
+
134
+ cp "${file_dir}/best"/* "${final_model_dir}"
135
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
136
+
137
+ cd "${final_model_dir}/.." || exit 1;
138
+
139
+ if [ -e "${final_model_name}.zip" ]; then
140
+ rm -rf "${final_model_name}_backup.zip"
141
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
142
+ fi
143
+
144
+ zip -r "${final_model_name}.zip" "${final_model_name}"
145
+ rm -rf "${final_model_name}"
146
+
147
+ fi
148
+
149
+
150
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
151
+ $verbose && echo "stage 5: clear file_dir"
152
+ cd "${work_dir}" || exit 1
153
+
154
+ rm -rf "${file_dir}";
155
+
156
+ fi
examples/conv_tasnet_gan/step_1_prepare_data.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=4.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=10000, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset excel")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "noise_filename": noise_filename,
128
+ "noise_raw_duration": noise_raw_duration,
129
+ "noise_offset": noise_offset,
130
+ "noise_duration": noise_duration,
131
+
132
+ "speech_filename": speech_filename,
133
+ "speech_raw_duration": speech_raw_duration,
134
+ "speech_offset": speech_offset,
135
+ "speech_duration": speech_duration,
136
+
137
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
138
+
139
+ "random1": random1,
140
+ }
141
+ row = json.dumps(row, ensure_ascii=False)
142
+ if random2 < (1 / 300 / 1):
143
+ fvalid.write(f"{row}\n")
144
+ else:
145
+ ftrain.write(f"{row}\n")
146
+
147
+ count += 1
148
+ duration_seconds = count * args.duration
149
+ duration_hours = duration_seconds / 3600
150
+
151
+ process_bar.update(n=1)
152
+ process_bar.set_postfix({
153
+ # "duration_seconds": round(duration_seconds, 4),
154
+ "duration_hours": round(duration_hours, 4),
155
+
156
+ })
157
+
158
+ return
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
examples/conv_tasnet_gan/step_2_train_model.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/kaituoxu/Conv-TasNet/tree/master/src
5
+
6
+ 一般场景:
7
+
8
+ 目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。
9
+
10
+ 高要求场景(如医疗助听、语音识别):
11
+ 需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
12
+
13
+ DeepFilterNet2 模型在 DNS4 数据集,超过500小时的音频上训练了 100 个 epoch。
14
+ https://arxiv.org/abs/2205.05474
15
+
16
+ """
17
+ import argparse
18
+ import json
19
+ import logging
20
+ from logging.handlers import TimedRotatingFileHandler
21
+ import os
22
+ import platform
23
+ from pathlib import Path
24
+ import random
25
+ import sys
26
+ import shutil
27
+ from typing import List
28
+
29
+ pwd = os.path.abspath(os.path.dirname(__file__))
30
+ sys.path.append(os.path.join(pwd, "../../"))
31
+
32
+ import numpy as np
33
+ import torch
34
+ import torch.nn as nn
35
+ from torch.nn import functional as F
36
+ from torch.utils.data.dataloader import DataLoader
37
+ from tqdm import tqdm
38
+
39
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
40
+ from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
41
+ from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
42
+ from toolbox.torchaudio.models.discriminators.conv_tasnet_discriminator.modeling_conv_tasnet_discriminator import ConvTasNetDiscriminatorPretrainedModel
43
+ from toolbox.torchaudio.models.discriminators.conv_tasnet_discriminator.configuration_conv_tasnet_discriminator import ConvTasNetDiscriminatorConfig
44
+ from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score
45
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
46
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
47
+ from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss
48
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
49
+
50
+
51
+ def get_args():
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
54
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
55
+
56
+ parser.add_argument("--max_epochs", default=200, type=int)
57
+
58
+ parser.add_argument("--batch_size", default=8, type=int)
59
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
60
+ parser.add_argument("--patience", default=5, type=int)
61
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
62
+ parser.add_argument("--seed", default=1234, type=int)
63
+
64
+ parser.add_argument("--config_file", default="config.yaml", type=str)
65
+ parser.add_argument("--discriminator_config_file", default="discriminator_config.yaml", type=str)
66
+
67
+ args = parser.parse_args()
68
+ return args
69
+
70
+
71
+ def logging_config(file_dir: str):
72
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
73
+
74
+ logging.basicConfig(format=fmt,
75
+ datefmt="%m/%d/%Y %H:%M:%S",
76
+ level=logging.INFO)
77
+ file_handler = TimedRotatingFileHandler(
78
+ filename=os.path.join(file_dir, "main.log"),
79
+ encoding="utf-8",
80
+ when="D",
81
+ interval=1,
82
+ backupCount=7
83
+ )
84
+ file_handler.setLevel(logging.INFO)
85
+ file_handler.setFormatter(logging.Formatter(fmt))
86
+ logger = logging.getLogger(__name__)
87
+ logger.addHandler(file_handler)
88
+
89
+ return logger
90
+
91
+
92
+ class CollateFunction(object):
93
+ def __init__(self):
94
+ pass
95
+
96
+ def __call__(self, batch: List[dict]):
97
+ clean_audios = list()
98
+ noisy_audios = list()
99
+
100
+ for sample in batch:
101
+ # noise_wave: torch.Tensor = sample["noise_wave"]
102
+ clean_audio: torch.Tensor = sample["speech_wave"]
103
+ noisy_audio: torch.Tensor = sample["mix_wave"]
104
+ # snr_db: float = sample["snr_db"]
105
+
106
+ clean_audios.append(clean_audio)
107
+ noisy_audios.append(noisy_audio)
108
+
109
+ clean_audios = torch.stack(clean_audios)
110
+ noisy_audios = torch.stack(noisy_audios)
111
+
112
+ # assert
113
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
114
+ raise AssertionError("nan or inf in clean_audios")
115
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
116
+ raise AssertionError("nan or inf in noisy_audios")
117
+ return clean_audios, noisy_audios
118
+
119
+
120
+ collate_fn = CollateFunction()
121
+
122
+
123
+ def main():
124
+ args = get_args()
125
+
126
+ config = ConvTasNetConfig.from_pretrained(
127
+ pretrained_model_name_or_path=args.config_file,
128
+ )
129
+ discriminator_config = ConvTasNetDiscriminatorConfig.from_pretrained(
130
+ pretrained_model_name_or_path=args.discriminator_config_file,
131
+ )
132
+
133
+ serialization_dir = Path(args.serialization_dir)
134
+ serialization_dir.mkdir(parents=True, exist_ok=True)
135
+
136
+ logger = logging_config(serialization_dir)
137
+
138
+ random.seed(args.seed)
139
+ np.random.seed(args.seed)
140
+ torch.manual_seed(args.seed)
141
+ logger.info(f"set seed: {args.seed}")
142
+
143
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
+ n_gpu = torch.cuda.device_count()
145
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
146
+
147
+ # datasets
148
+ train_dataset = DenoiseJsonlDataset(
149
+ jsonl_file=args.train_dataset,
150
+ expected_sample_rate=config.sample_rate,
151
+ max_wave_value=32768.0,
152
+ min_snr_db=config.min_snr_db,
153
+ max_snr_db=config.max_snr_db,
154
+ # skip=825000,
155
+ )
156
+ valid_dataset = DenoiseJsonlDataset(
157
+ jsonl_file=args.valid_dataset,
158
+ expected_sample_rate=config.sample_rate,
159
+ max_wave_value=32768.0,
160
+ min_snr_db=config.min_snr_db,
161
+ max_snr_db=config.max_snr_db,
162
+ )
163
+ train_data_loader = DataLoader(
164
+ dataset=train_dataset,
165
+ batch_size=args.batch_size,
166
+ # shuffle=True,
167
+ sampler=None,
168
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
169
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
170
+ collate_fn=collate_fn,
171
+ pin_memory=False,
172
+ prefetch_factor=2,
173
+ )
174
+ valid_data_loader = DataLoader(
175
+ dataset=valid_dataset,
176
+ batch_size=args.batch_size,
177
+ # shuffle=True,
178
+ sampler=None,
179
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
180
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
181
+ collate_fn=collate_fn,
182
+ pin_memory=False,
183
+ prefetch_factor=2,
184
+ )
185
+
186
+ # models
187
+ logger.info(f"prepare models. config_file: {args.config_file}")
188
+ model = ConvTasNetPretrainedModel(config).to(device)
189
+ model.to(device)
190
+ model.train()
191
+
192
+ discriminator = ConvTasNetDiscriminatorPretrainedModel(discriminator_config).to(device)
193
+ discriminator.to(device)
194
+ discriminator.train()
195
+
196
+ # optimizer
197
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
198
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
199
+ discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
200
+
201
+ # resume training
202
+ last_step_idx = -1
203
+ last_epoch = -1
204
+ for step_idx_str in serialization_dir.glob("steps-*"):
205
+ step_idx_str = Path(step_idx_str)
206
+ step_idx = step_idx_str.stem.split("-")[1]
207
+ step_idx = int(step_idx)
208
+ if step_idx > last_step_idx:
209
+ last_step_idx = step_idx
210
+
211
+ if last_step_idx != -1:
212
+ logger.info(f"resume from steps-{last_step_idx}.")
213
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
214
+ discriminator_pt = serialization_dir / f"steps-{last_step_idx}/discriminator.pt"
215
+
216
+ optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
217
+ discriminator_optimizer_pth = serialization_dir / f"steps-{last_step_idx}/discriminator_optimizer.pth"
218
+
219
+ logger.info(f"load state dict for model.")
220
+ with open(model_pt.as_posix(), "rb") as f:
221
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
222
+ model.load_state_dict(state_dict, strict=True)
223
+
224
+ logger.info(f"load state dict for optimizer.")
225
+ with open(optimizer_pth.as_posix(), "rb") as f:
226
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
227
+ optimizer.load_state_dict(state_dict)
228
+
229
+ if discriminator_pt.exists():
230
+ logger.info(f"load state dict for discriminator.")
231
+ with open(model_pt.as_posix(), "rb") as f:
232
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
233
+ discriminator.load_state_dict(state_dict, strict=True)
234
+
235
+ if discriminator_optimizer_pth.exists():
236
+ logger.info(f"load state dict for discriminator_optimizer.")
237
+ with open(optimizer_pth.as_posix(), "rb") as f:
238
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
239
+ discriminator_optimizer.load_state_dict(state_dict)
240
+
241
+ if config.lr_scheduler == "CosineAnnealingLR":
242
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
243
+ optimizer,
244
+ last_epoch=last_epoch,
245
+ # T_max=10 * config.eval_steps,
246
+ # eta_min=0.01 * config.lr,
247
+ **config.lr_scheduler_kwargs,
248
+ )
249
+ discriminator_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
250
+ discriminator_optimizer,
251
+ last_epoch=last_epoch,
252
+ # T_max=10 * config.eval_steps,
253
+ # eta_min=0.01 * config.lr,
254
+ **config.lr_scheduler_kwargs,
255
+ )
256
+ elif config.lr_scheduler == "MultiStepLR":
257
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
258
+ optimizer,
259
+ last_epoch=last_epoch,
260
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
261
+ )
262
+ discriminator_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
263
+ discriminator_optimizer,
264
+ last_epoch=last_epoch,
265
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
266
+ )
267
+ else:
268
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
269
+
270
+ ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
271
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
272
+ neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device)
273
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
274
+ fft_size_list=[256, 512, 1024],
275
+ win_size_list=[120, 240, 480],
276
+ hop_size_list=[25, 50, 100],
277
+ factor_sc=1.5,
278
+ factor_mag=1.0,
279
+ reduction="mean"
280
+ ).to(device)
281
+ pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device)
282
+
283
+ # training loop
284
+
285
+ # state
286
+ average_pesq_score = 1000000000
287
+ average_loss = 1000000000
288
+ average_ae_loss = 1000000000
289
+ average_neg_si_snr_loss = 1000000000
290
+ average_neg_stoi_loss = 1000000000
291
+ average_mr_stft_loss = 1000000000
292
+ average_pesq_loss = 1000000000
293
+ average_discriminator_g_loss = 1000000000
294
+ average_discriminator_d_loss = 1000000000
295
+
296
+ model_list = list()
297
+ best_epoch_idx = None
298
+ best_step_idx = None
299
+ best_metric = None
300
+ patience_count = 0
301
+
302
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
303
+
304
+ logger.info("training")
305
+ for epoch_idx in range(max(0, last_epoch+1), args.max_epochs):
306
+ # train
307
+ model.train()
308
+
309
+ total_pesq_score = 0.
310
+ total_loss = 0.
311
+ total_ae_loss = 0.
312
+ total_neg_si_snr_loss = 0.
313
+ total_neg_stoi_loss = 0.
314
+ total_mr_stft_loss = 0.
315
+ total_pesq_loss = 0.
316
+ total_discriminator_g_loss = 0.
317
+ total_discriminator_d_loss = 0.
318
+ total_batches = 0.
319
+
320
+ progress_bar_train = tqdm(
321
+ initial=step_idx,
322
+ desc="Training; epoch-{}".format(epoch_idx),
323
+ )
324
+ for train_batch in train_data_loader:
325
+ clean_audios, noisy_audios = train_batch
326
+ clean_audios: torch.Tensor = clean_audios.to(device)
327
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
328
+ one_labels = torch.ones(clean_audios.shape[0]).to(device)
329
+
330
+ denoise_audios = model.forward(noisy_audios)
331
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
332
+
333
+ if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)):
334
+ raise AssertionError("nan or inf in denoise_audios")
335
+
336
+ # Discriminator
337
+ clean_audio_list = torch.split(clean_audios, 1, dim=0)
338
+ enhanced_audio_list = torch.split(denoise_audios, 1, dim=0)
339
+ clean_audio_list = [t.squeeze().detach().cpu().numpy() for t in clean_audio_list]
340
+ enhanced_audio_list = [t.squeeze().detach().cpu().numpy() for t in enhanced_audio_list]
341
+
342
+ pesq_score_list: List[float] = run_batch_pesq(clean_audio_list, enhanced_audio_list, sample_rate=config.sample_rate, mode="nb")
343
+
344
+ metric_r = discriminator.forward(clean_audios, clean_audios)
345
+ metric_g = discriminator.forward(clean_audios, denoise_audios.detach())
346
+ loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
347
+
348
+ if -1 in pesq_score_list:
349
+ # print("-1 in batch_pesq_score!")
350
+ loss_disc_g = 0
351
+ else:
352
+ pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
353
+ loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
354
+
355
+ discriminator_d_loss = loss_disc_r + loss_disc_g
356
+ discriminator_optimizer.zero_grad()
357
+ discriminator_d_loss.backward()
358
+ discriminator_optimizer.step()
359
+ discriminator_lr_scheduler.step()
360
+
361
+ # Generator
362
+ ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
363
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
364
+ neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
365
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
366
+ pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
367
+
368
+ metric_g = discriminator.forward(denoise_audios, clean_audios)
369
+ discriminator_g_loss = F.mse_loss(metric_g.flatten(), one_labels)
370
+
371
+ loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss + 0.2 * discriminator_g_loss
372
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
373
+ logger.info(f"find nan or inf in loss.")
374
+ continue
375
+
376
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
377
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
378
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
379
+
380
+ optimizer.zero_grad()
381
+ loss.backward()
382
+ optimizer.step()
383
+ lr_scheduler.step()
384
+
385
+ total_pesq_score += pesq_score
386
+ total_loss += loss.item()
387
+ total_ae_loss += ae_loss.item()
388
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
389
+ total_neg_stoi_loss += neg_stoi_loss.item()
390
+ total_mr_stft_loss += mr_stft_loss.item()
391
+ total_pesq_loss += pesq_loss.item()
392
+ total_discriminator_g_loss += discriminator_g_loss.item()
393
+ total_discriminator_d_loss += discriminator_d_loss.item()
394
+ total_batches += 1
395
+
396
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
397
+ average_loss = round(total_loss / total_batches, 4)
398
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
399
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
400
+ average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
401
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
402
+ average_pesq_loss = round(total_pesq_loss / total_batches, 4)
403
+ average_discriminator_g_loss = round(total_discriminator_g_loss / total_batches, 4)
404
+ average_discriminator_d_loss = round(total_discriminator_d_loss / total_batches, 4)
405
+
406
+ progress_bar_train.update(1)
407
+ progress_bar_train.set_postfix({
408
+ "lr": lr_scheduler.get_last_lr()[0],
409
+ "pesq_score": average_pesq_score,
410
+ "loss": average_loss,
411
+ "ae_loss": average_ae_loss,
412
+ "neg_si_snr_loss": average_neg_si_snr_loss,
413
+ "neg_stoi_loss": average_neg_stoi_loss,
414
+ "mr_stft_loss": average_mr_stft_loss,
415
+ "pesq_loss": average_pesq_loss,
416
+ "disc_g_loss": average_discriminator_g_loss,
417
+ "disc_d_loss": average_discriminator_d_loss,
418
+
419
+ })
420
+
421
+ # evaluation
422
+ step_idx += 1
423
+ if step_idx % config.eval_steps == 0:
424
+ with torch.no_grad():
425
+ torch.cuda.empty_cache()
426
+
427
+ total_pesq_score = 0.
428
+ total_loss = 0.
429
+ total_ae_loss = 0.
430
+ total_neg_si_snr_loss = 0.
431
+ total_neg_stoi_loss = 0.
432
+ total_mr_stft_loss = 0.
433
+ total_pesq_loss = 0.
434
+ total_batches = 0.
435
+
436
+ progress_bar_train.close()
437
+ progress_bar_eval = tqdm(
438
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
439
+ )
440
+ for eval_batch in valid_data_loader:
441
+ clean_audios, noisy_audios = eval_batch
442
+ clean_audios = clean_audios.to(device)
443
+ noisy_audios = noisy_audios.to(device)
444
+
445
+ denoise_audios = model.forward(noisy_audios)
446
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
447
+
448
+ # Generator
449
+ ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
450
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
451
+ neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
452
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
453
+ pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
454
+
455
+ loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
456
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
457
+ logger.info(f"find nan or inf in loss.")
458
+ continue
459
+
460
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
461
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
462
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
463
+
464
+ total_pesq_score += pesq_score
465
+ total_loss += loss.item()
466
+ total_ae_loss += ae_loss.item()
467
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
468
+ total_neg_stoi_loss += neg_stoi_loss.item()
469
+ total_mr_stft_loss += mr_stft_loss.item()
470
+ total_pesq_loss += pesq_loss.item()
471
+ total_batches += 1
472
+
473
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
474
+ average_loss = round(total_loss / total_batches, 4)
475
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
476
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
477
+ average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
478
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
479
+ average_pesq_loss = round(total_pesq_loss / total_batches, 4)
480
+
481
+ progress_bar_eval.update(1)
482
+ progress_bar_eval.set_postfix({
483
+ "lr": lr_scheduler.get_last_lr()[0],
484
+ "pesq_score": average_pesq_score,
485
+ "loss": average_loss,
486
+ "ae_loss": average_ae_loss,
487
+ "neg_si_snr_loss": average_neg_si_snr_loss,
488
+ "neg_stoi_loss": average_neg_stoi_loss,
489
+ "mr_stft_loss": average_mr_stft_loss,
490
+ "pesq_loss": average_pesq_loss,
491
+ })
492
+
493
+ total_pesq_score = 0.
494
+ total_loss = 0.
495
+ total_ae_loss = 0.
496
+ total_neg_si_snr_loss = 0.
497
+ total_neg_stoi_loss = 0.
498
+ total_mr_stft_loss = 0.
499
+ total_pesq_loss = 0.
500
+ total_batches = 0.
501
+
502
+ progress_bar_eval.close()
503
+ progress_bar_train = tqdm(
504
+ initial=progress_bar_train.n,
505
+ postfix=progress_bar_train.postfix,
506
+ desc=progress_bar_train.desc,
507
+ )
508
+
509
+ # save path
510
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
511
+ save_dir.mkdir(parents=True, exist_ok=False)
512
+
513
+ # save models
514
+ model.save_pretrained(save_dir.as_posix())
515
+ discriminator.save_pretrained(save_dir.as_posix())
516
+
517
+ # save optim
518
+ torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
519
+ torch.save(discriminator_optimizer.state_dict(), (save_dir / "discriminator_optimizer.pth").as_posix())
520
+
521
+ model_list.append(save_dir)
522
+ if len(model_list) >= args.num_serialized_models_to_keep:
523
+ model_to_delete: Path = model_list.pop(0)
524
+ shutil.rmtree(model_to_delete.as_posix())
525
+
526
+ # save metric
527
+ if best_metric is None:
528
+ best_epoch_idx = epoch_idx
529
+ best_step_idx = step_idx
530
+ best_metric = average_pesq_score
531
+ elif average_pesq_score > best_metric:
532
+ # great is better.
533
+ best_epoch_idx = epoch_idx
534
+ best_step_idx = step_idx
535
+ best_metric = average_pesq_score
536
+ else:
537
+ pass
538
+
539
+ metrics = {
540
+ "epoch_idx": epoch_idx,
541
+ "best_epoch_idx": best_epoch_idx,
542
+ "best_step_idx": best_step_idx,
543
+ "pesq_score": average_pesq_score,
544
+ "loss": average_loss,
545
+ "ae_loss": average_ae_loss,
546
+ "neg_si_snr_loss": average_neg_si_snr_loss,
547
+ "neg_stoi_loss": average_neg_stoi_loss,
548
+ "mr_stft_loss": average_mr_stft_loss,
549
+ "pesq_loss": average_pesq_loss,
550
+ }
551
+ metrics_filename = save_dir / "metrics_epoch.json"
552
+ with open(metrics_filename, "w", encoding="utf-8") as f:
553
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
554
+
555
+ # save best
556
+ best_dir = serialization_dir / "best"
557
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
558
+ if best_dir.exists():
559
+ shutil.rmtree(best_dir)
560
+ shutil.copytree(save_dir, best_dir)
561
+
562
+ # early stop
563
+ early_stop_flag = False
564
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
565
+ patience_count = 0
566
+ else:
567
+ patience_count += 1
568
+ if patience_count >= args.patience:
569
+ early_stop_flag = True
570
+
571
+ # early stop
572
+ if early_stop_flag:
573
+ break
574
+
575
+ return
576
+
577
+
578
+ if __name__ == "__main__":
579
+ main()
examples/conv_tasnet_gan/yaml/config.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "conv_tasnet_gan"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 4
5
+
6
+ win_size: 20
7
+ freq_bins: 256
8
+ bottleneck_channels: 256
9
+ num_speakers: 1
10
+ num_blocks: 4
11
+ num_sub_blocks: 8
12
+ sub_blocks_channels: 512
13
+ sub_blocks_kernel_size: 3
14
+
15
+ norm_type: "gLN"
16
+ causal: false
17
+ mask_nonlinear: "relu"
18
+
19
+ min_snr_db: -10
20
+ max_snr_db: 20
21
+
22
+ lr: 0.005
23
+ adam_b1: 0.8
24
+ adam_b2: 0.99
25
+
26
+ lr_scheduler: "CosineAnnealingLR"
27
+ lr_scheduler_kwargs:
28
+ T_max: 250000
29
+ eta_min: 0.00005
30
+
31
+ eval_steps: 25000
examples/conv_tasnet_gan/yaml/discriminator_config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "conv_tasnet_gan"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 16000
5
+ n_fft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ discriminator_dim: 32
10
+ discriminator_in_channel: 2
examples/nx_clean_unet/step_2_train_model.py CHANGED
@@ -285,7 +285,8 @@ def main():
285
  # Time Loss
286
  loss_time = F.l1_loss(clean_audios, audio_g)
287
  # Metric Loss
288
- metric_g = discriminator.forward(clean_audios, audio_g.detach())
 
289
  loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
290
 
291
  # loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_metric * 0.05 + loss_time * 0.2
 
285
  # Time Loss
286
  loss_time = F.l1_loss(clean_audios, audio_g)
287
  # Metric Loss
288
+ metric_g = discriminator.forward(clean_audios, audio_g)
289
+ # metric_g = discriminator.forward(clean_audios, audio_g.detach())
290
  loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
291
 
292
  # loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_metric * 0.05 + loss_time * 0.2
toolbox/torchaudio/configuration_utils.py CHANGED
@@ -8,6 +8,7 @@ import yaml
8
 
9
 
10
  CONFIG_FILE = "config.yaml"
 
11
 
12
 
13
  class PretrainedConfig(object):
 
8
 
9
 
10
  CONFIG_FILE = "config.yaml"
11
+ DISCRIMINATOR_CONFIG_FILE = "discriminator_config.yaml"
12
 
13
 
14
  class PretrainedConfig(object):
toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py CHANGED
@@ -31,6 +31,9 @@ class ConvTasNetConfig(PretrainedConfig):
31
  max_snr_db: float = 20,
32
 
33
  lr: float = 1e-3,
 
 
 
34
  lr_scheduler: str = "CosineAnnealingLR",
35
  lr_scheduler_kwargs: dict = None,
36
 
@@ -60,6 +63,9 @@ class ConvTasNetConfig(PretrainedConfig):
60
  self.max_snr_db = max_snr_db
61
 
62
  self.lr = lr
 
 
 
63
  self.lr_scheduler = lr_scheduler
64
  self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
65
 
 
31
  max_snr_db: float = 20,
32
 
33
  lr: float = 1e-3,
34
+ adam_b1: float = 0.8,
35
+ adam_b2: float = 0.99,
36
+
37
  lr_scheduler: str = "CosineAnnealingLR",
38
  lr_scheduler_kwargs: dict = None,
39
 
 
63
  self.max_snr_db = max_snr_db
64
 
65
  self.lr = lr
66
+ self.adam_b1 = adam_b1
67
+ self.adam_b2 = adam_b2
68
+
69
  self.lr_scheduler = lr_scheduler
70
  self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
71
 
toolbox/torchaudio/models/conv_tasnet/inference_conv_tasnet.py CHANGED
@@ -83,7 +83,7 @@ class InferenceConvTasNet(object):
83
 
84
 
85
  def main():
86
- model_zip_file = project_path / "trained_models/conv-tasnet-dns3-575k-steps.zip"
87
  infer_conv_tasnet = InferenceConvTasNet(model_zip_file)
88
 
89
  sample_rate = 8000
 
83
 
84
 
85
  def main():
86
+ model_zip_file = project_path / "trained_models/conv-tasnet-dns3-1025k-steps.zip"
87
  infer_conv_tasnet = InferenceConvTasNet(model_zip_file)
88
 
89
  sample_rate = 8000
toolbox/torchaudio/models/discriminators/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/discriminators/conv_tasnet_discriminator/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/discriminators/conv_tasnet_discriminator/configuration_conv_tasnet_discriminator.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class ConvTasNetDiscriminatorConfig(PretrainedConfig):
7
+ """
8
+ https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
+ """
10
+ def __init__(self,
11
+ sample_rate: int = 8000,
12
+ segment_size: int = 16000,
13
+ n_fft: int = 512,
14
+ win_length: int = 200,
15
+ hop_length: int = 80,
16
+
17
+ discriminator_dim: int = 16,
18
+ discriminator_in_channel: int = 2,
19
+
20
+ **kwargs
21
+ ):
22
+ super(ConvTasNetDiscriminatorConfig, self).__init__(**kwargs)
23
+ self.sample_rate = sample_rate
24
+ self.segment_size = segment_size
25
+ self.n_fft = n_fft
26
+ self.win_length = win_length
27
+ self.hop_length = hop_length
28
+
29
+ self.discriminator_dim = discriminator_dim
30
+ self.discriminator_in_channel = discriminator_in_channel
31
+
32
+
33
+ if __name__ == "__main__":
34
+ pass
toolbox/torchaudio/models/discriminators/conv_tasnet_discriminator/modeling_conv_tasnet_discriminator.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchaudio
9
+
10
+ from toolbox.torchaudio.configuration_utils import DISCRIMINATOR_CONFIG_FILE
11
+ from toolbox.torchaudio.models.discriminators.conv_tasnet_discriminator.configuration_conv_tasnet_discriminator import ConvTasNetDiscriminatorConfig
12
+
13
+
14
+ class LearnableSigmoid1d(nn.Module):
15
+ def __init__(self, in_features, beta=1):
16
+ super().__init__()
17
+ self.beta = beta
18
+ self.slope = nn.Parameter(torch.ones(in_features))
19
+ self.slope.requiresGrad = True
20
+
21
+ def forward(self, x):
22
+ # x shape: [batch_size, time_steps, spec_bins]
23
+ return self.beta * torch.sigmoid(self.slope * x)
24
+
25
+
26
+ class ConvTasNetDiscriminator(nn.Module):
27
+ def __init__(self, config: ConvTasNetDiscriminatorConfig):
28
+ super(ConvTasNetDiscriminator, self).__init__()
29
+ dim = config.discriminator_dim
30
+ self.in_channel = config.discriminator_in_channel
31
+
32
+ self.n_fft = config.n_fft
33
+ self.win_length = config.win_length
34
+ self.hop_length = config.hop_length
35
+
36
+ self.transform = torchaudio.transforms.Spectrogram(
37
+ n_fft=self.n_fft,
38
+ win_length=self.win_length,
39
+ hop_length=self.hop_length,
40
+ power=1.0,
41
+ window_fn=torch.hann_window,
42
+ # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
43
+ )
44
+
45
+ self.layers = nn.Sequential(
46
+ nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
47
+ nn.InstanceNorm2d(dim, affine=True),
48
+ nn.PReLU(dim),
49
+ nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
50
+ nn.InstanceNorm2d(dim*2, affine=True),
51
+ nn.PReLU(dim*2),
52
+ nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
53
+ nn.InstanceNorm2d(dim*4, affine=True),
54
+ nn.PReLU(dim*4),
55
+ nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
56
+ nn.InstanceNorm2d(dim*8, affine=True),
57
+ nn.PReLU(dim*8),
58
+ nn.AdaptiveMaxPool2d(1),
59
+ nn.Flatten(),
60
+ nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
61
+ nn.Dropout(0.3),
62
+ nn.PReLU(dim*4),
63
+ nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
64
+ LearnableSigmoid1d(1)
65
+ )
66
+
67
+ def forward(self, denoise_audios, clean_audios):
68
+ x = denoise_audios
69
+ y = clean_audios
70
+ x = self.transform.forward(x)
71
+ y = self.transform.forward(y)
72
+
73
+ xy = torch.stack((x, y), dim=1)
74
+ return self.layers(xy)
75
+
76
+
77
+ MODEL_FILE = "discriminator.pt"
78
+
79
+
80
+ class ConvTasNetDiscriminatorPretrainedModel(ConvTasNetDiscriminator):
81
+ def __init__(self,
82
+ config: ConvTasNetDiscriminatorConfig,
83
+ ):
84
+ super(ConvTasNetDiscriminatorPretrainedModel, self).__init__(
85
+ config=config,
86
+ )
87
+ self.config = config
88
+
89
+ @classmethod
90
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
91
+ config = ConvTasNetDiscriminatorConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
92
+
93
+ model = cls(config)
94
+
95
+ if os.path.isdir(pretrained_model_name_or_path):
96
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
97
+ else:
98
+ ckpt_file = pretrained_model_name_or_path
99
+
100
+ with open(ckpt_file, "rb") as f:
101
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
102
+ model.load_state_dict(state_dict, strict=True)
103
+ return model
104
+
105
+ def save_pretrained(self,
106
+ save_directory: Union[str, os.PathLike],
107
+ state_dict: Optional[dict] = None,
108
+ ):
109
+
110
+ model = self
111
+
112
+ if state_dict is None:
113
+ state_dict = model.state_dict()
114
+
115
+ os.makedirs(save_directory, exist_ok=True)
116
+
117
+ # save state dict
118
+ model_file = os.path.join(save_directory, MODEL_FILE)
119
+ torch.save(state_dict, model_file)
120
+
121
+ # save config
122
+ config_file = os.path.join(save_directory, CONFIG_FILE)
123
+ self.config.to_yaml_file(config_file)
124
+ return save_directory
125
+
126
+
127
+ def main():
128
+ config = ConvTasNetDiscriminatorConfig()
129
+ discriminator = ConvTasNetDiscriminator(config=config)
130
+
131
+ # shape: [batch_size, num_samples]
132
+ # x = torch.ones([4, int(4.5 * 16000)])
133
+ # y = torch.ones([4, int(4.5 * 16000)])
134
+ x = torch.ones([4, 16000])
135
+ y = torch.ones([4, 16000])
136
+
137
+ output = discriminator.forward(x, y)
138
+ print(output.shape)
139
+ print(output)
140
+
141
+ return
142
+
143
+
144
+ if __name__ == "__main__":
145
+ main()