File size: 4,295 Bytes
d66c48f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn, pow, sin
from torch.nn import Parameter


class Snake(nn.Module):
    r"""Implementation of a sine-based periodic activation function.

    Alpha is initialized to 1 by default, higher values means higher frequency.

    It will be trained along with the rest of your model.



    Args:

        in_features: shape of the input

        alpha: trainable parameter



    Shape:

        - Input: (B, C, T)

        - Output: (B, C, T), same shape as the input



    References:

        This activation function is from this paper by Liu Ziyin, Tilman Hartwig,

        Masahito Ueda: https://arxiv.org/abs/2006.08195



    Examples:

        >>> a1 = Snake(256)

        >>> x = torch.randn(256)

        >>> x = a1(x)

    """

    def __init__(

        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False

    ):
        super(Snake, self).__init__()
        self.in_features = in_features

        # initialize alpha
        self.alpha_logscale = alpha_logscale
        if self.alpha_logscale:  # log scale alphas initialized to zeros
            self.alpha = Parameter(torch.zeros(in_features) * alpha)
        else:  # linear scale alphas initialized to ones
            self.alpha = Parameter(torch.ones(in_features) * alpha)

        self.alpha.requires_grad = alpha_trainable

        self.no_div_by_zero = 0.000000001

    def forward(self, x):
        r"""Forward pass of the function. Applies the function to the input elementwise.

        Snake ∶= x + 1/a * sin^2 (ax)

        """

        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
        if self.alpha_logscale:
            alpha = torch.exp(alpha)
        x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)

        return x


class SnakeBeta(nn.Module):
    r"""A modified Snake function which uses separate parameters for the magnitude

    of the periodic components. Alpha is initialized to 1 by default,

    higher values means higher frequency. Beta is initialized to 1 by default,

    higher values means higher magnitude. Both will be trained along with the

    rest of your model.



    Args:

        in_features: shape of the input

        alpha: trainable parameter that controls frequency

        beta: trainable parameter that controls magnitude



    Shape:

        - Input: (B, C, T)

        - Output: (B, C, T), same shape as the input



    References:

        This activation function is a modified version based on this paper by Liu Ziyin,

        Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195



    Examples:

        >>> a1 = SnakeBeta(256)

        >>> x = torch.randn(256)

        >>> x = a1(x)

    """

    def __init__(

        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False

    ):
        super(SnakeBeta, self).__init__()
        self.in_features = in_features

        # initialize alpha
        self.alpha_logscale = alpha_logscale
        if self.alpha_logscale:  # log scale alphas initialized to zeros
            self.alpha = Parameter(torch.zeros(in_features) * alpha)
            self.beta = Parameter(torch.zeros(in_features) * alpha)
        else:  # linear scale alphas initialized to ones
            self.alpha = Parameter(torch.ones(in_features) * alpha)
            self.beta = Parameter(torch.ones(in_features) * alpha)

        self.alpha.requires_grad = alpha_trainable
        self.beta.requires_grad = alpha_trainable

        self.no_div_by_zero = 0.000000001

    def forward(self, x):
        r"""Forward pass of the function. Applies the function to the input elementwise.

        SnakeBeta ∶= x + 1/b * sin^2 (xa)

        """

        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
        beta = self.beta.unsqueeze(0).unsqueeze(-1)
        if self.alpha_logscale:
            alpha = torch.exp(alpha)
            beta = torch.exp(beta)
        x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)

        return x