File size: 3,058 Bytes
d66c48f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch.nn as nn

from modules.general.utils import Conv1d, zero_module
from .residual_block import ResidualBlock


class BiDilConv(nn.Module):
    r"""Dilated CNN architecture with residual connections, default diffusion decoder.



    Args:

        input_channel: The number of input channels.

        base_channel: The number of base channels.

        n_res_block: The number of residual blocks.

        conv_kernel_size: The kernel size of convolutional layers.

        dilation_cycle_length: The cycle length of dilation.

        conditioner_size: The size of conditioner.

    """

    def __init__(

        self,

        input_channel,

        base_channel,

        n_res_block,

        conv_kernel_size,

        dilation_cycle_length,

        conditioner_size,

        output_channel: int = -1,

    ):
        super().__init__()

        self.input_channel = input_channel
        self.base_channel = base_channel
        self.n_res_block = n_res_block
        self.conv_kernel_size = conv_kernel_size
        self.dilation_cycle_length = dilation_cycle_length
        self.conditioner_size = conditioner_size
        self.output_channel = output_channel if output_channel > 0 else input_channel

        self.input = nn.Sequential(
            Conv1d(
                input_channel,
                base_channel,
                1,
            ),
            nn.ReLU(),
        )

        self.residual_blocks = nn.ModuleList(
            [
                ResidualBlock(
                    channels=base_channel,
                    kernel_size=conv_kernel_size,
                    dilation=2 ** (i % dilation_cycle_length),
                    d_context=conditioner_size,
                )
                for i in range(n_res_block)
            ]
        )

        self.out_proj = nn.Sequential(
            Conv1d(
                base_channel,
                base_channel,
                1,
            ),
            nn.ReLU(),
            zero_module(
                Conv1d(
                    base_channel,
                    self.output_channel,
                    1,
                ),
            ),
        )

    def forward(self, x, y, context=None):
        """

        Args:

            x: Noisy mel-spectrogram [B x ``n_mel`` x L]

            y: FILM embeddings with the shape of (B, ``base_channel``)

            context: Context with the shape of [B x ``d_context`` x L], default to None.

        """

        h = self.input(x)

        skip = None
        for i in range(self.n_res_block):
            h, skip_connection = self.residual_blocks[i](h, y, context)
            skip = skip_connection if skip is None else skip_connection + skip

        out = skip / math.sqrt(self.n_res_block)

        out = self.out_proj(out)

        return out