File size: 3,648 Bytes
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Tuple

from toolbox.torchaudio.configuration_utils import PretrainedConfig


class DfNetConfig(PretrainedConfig):
    def __init__(self,
                 sample_rate: int = 8000,
                 nfft: int = 512,
                 win_size: int = 200,
                 hop_size: int = 80,
                 win_type: str = "hann",

                 spec_bins: int = 256,

                 conv_channels: int = 64,
                 conv_kernel_size_input: Tuple[int, int] = (3, 3),
                 conv_kernel_size_inner: Tuple[int, int] = (1, 3),
                 conv_lookahead: int = 0,

                 convt_kernel_size_inner: Tuple[int, int] = (1, 3),

                 embedding_hidden_size: int = 256,
                 encoder_combine_op: str = "concat",

                 encoder_emb_skip_op: str = "none",
                 encoder_emb_linear_groups: int = 16,
                 encoder_emb_hidden_size: int = 256,

                 encoder_linear_groups: int = 32,

                 lsnr_max: int = 30,
                 lsnr_min: int = -15,
                 norm_tau: float = 1.,

                 decoder_emb_num_layers: int = 3,
                 decoder_emb_skip_op: str = "none",
                 decoder_emb_linear_groups: int = 16,
                 decoder_emb_hidden_size: int = 256,

                 df_decoder_hidden_size: int = 256,
                 df_num_layers: int = 2,
                 df_order: int = 5,
                 df_bins: int = 96,
                 df_gru_skip: str =  "grouped_linear",
                 df_decoder_linear_groups: int =  16,
                 df_pathway_kernel_size_t: int = 5,
                 df_lookahead: int = 2,

                 use_post_filter: bool = False,
                 **kwargs
                 ):
        super(DfNetConfig, self).__init__(**kwargs)
        # transform
        self.sample_rate = sample_rate
        self.nfft = nfft
        self.win_size = win_size
        self.hop_size = hop_size
        self.win_type = win_type

        # spectrum
        self.spec_bins = spec_bins

        # conv
        self.conv_channels = conv_channels
        self.conv_kernel_size_input = conv_kernel_size_input
        self.conv_kernel_size_inner = conv_kernel_size_inner
        self.conv_lookahead = conv_lookahead

        self.convt_kernel_size_inner = convt_kernel_size_inner

        self.embedding_hidden_size = embedding_hidden_size

        # encoder
        self.encoder_emb_skip_op = encoder_emb_skip_op
        self.encoder_emb_linear_groups = encoder_emb_linear_groups
        self.encoder_emb_hidden_size = encoder_emb_hidden_size

        self.encoder_linear_groups = encoder_linear_groups
        self.encoder_combine_op = encoder_combine_op

        self.lsnr_max = lsnr_max
        self.lsnr_min = lsnr_min
        self.norm_tau = norm_tau

        # decoder
        self.decoder_emb_num_layers = decoder_emb_num_layers
        self.decoder_emb_skip_op = decoder_emb_skip_op
        self.decoder_emb_linear_groups = decoder_emb_linear_groups
        self.decoder_emb_hidden_size = decoder_emb_hidden_size

        # df decoder
        self.df_decoder_hidden_size = df_decoder_hidden_size
        self.df_num_layers = df_num_layers
        self.df_order = df_order
        self.df_bins = df_bins
        self.df_gru_skip = df_gru_skip
        self.df_decoder_linear_groups = df_decoder_linear_groups
        self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
        self.df_lookahead = df_lookahead

        # runtime
        self.use_post_filter = use_post_filter


if __name__ == "__main__":
    pass