File size: 4,543 Bytes
cba47e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window


def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
    if win_type == "None" or win_type is None:
        window = np.ones(win_size)
    else:
        window = get_window(win_type, win_size, fftbins=True)**0.5

    fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
    real_kernel = np.real(fourier_basis)
    image_kernel = np.imag(fourier_basis)
    kernel = np.concatenate([real_kernel, image_kernel], 1).T

    if inverse:
        kernel = np.linalg.pinv(kernel).T

    kernel = kernel * window
    kernel = kernel[:, None, :]
    result = (
        torch.from_numpy(kernel.astype(np.float32)),
        torch.from_numpy(window[None, :, None].astype(np.float32))
    )
    return result


class ConvSTFT(nn.Module):

    def __init__(self,
                 nfft: int,
                 win_size: int,
                 hop_size: int,
                 win_type: str = "hamming",
                 feature_type: str = "real",
                 requires_grad: bool = False):
        super(ConvSTFT, self).__init__()

        if nfft is None:
            self.nfft = int(2**np.ceil(np.log2(win_size)))
        else:
            self.nfft = nfft

        kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
        self.weight = nn.Parameter(kernel, requires_grad=requires_grad)

        self.win_size = win_size
        self.hop_size = hop_size

        self.stride = hop_size
        self.dim = self.nfft
        self.feature_type = feature_type

    def forward(self, inputs: torch.Tensor):
        if inputs.dim() == 2:
            inputs = torch.unsqueeze(inputs, 1)

        outputs = F.conv1d(inputs, self.weight, stride=self.stride)

        if self.feature_type == "complex":
            return outputs
        else:
            dim = self.dim // 2 + 1
            real = outputs[:, :dim, :]
            imag = outputs[:, dim:, :]
            mags = torch.sqrt(real**2 + imag**2)
            phase = torch.atan2(imag, real)
            return mags, phase


class ConviSTFT(nn.Module):

    def __init__(self,
                 win_size: int,
                 hop_size: int,
                 nfft: int = None,
                 win_type: str = "hamming",
                 feature_type: str = "real",
                 requires_grad: bool = False):
        super(ConviSTFT, self).__init__()
        if nfft is None:
            self.nfft = int(2**np.ceil(np.log2(win_size)))
        else:
            self.nfft = nfft

        kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
        self.weight = nn.Parameter(kernel, requires_grad=requires_grad)

        self.win_size = win_size
        self.hop_size = hop_size
        self.win_type = win_type

        self.stride = hop_size
        self.dim = self.nfft
        self.feature_type = feature_type

        self.register_buffer("window", window)
        self.register_buffer("enframe", torch.eye(win_size)[:, None, :])

    def forward(self,
                inputs: torch.Tensor,
                phase: torch.Tensor = None):
        """
        :param inputs: torch.Tensor, shape: [b, n+2, t] (complex spec) or [b, n//2+1, t] (mags)
        :param phase: torch.Tensor, shape: [b, n//2+1, t]
        :return:
        """
        if phase is not None:
            real = inputs * torch.cos(phase)
            imag = inputs * torch.sin(phase)
            inputs = torch.cat([real, imag], 1)
        outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)

        # this is from torch-stft: https://github.com/pseeth/torch-stft
        t = self.window.repeat(1, 1, inputs.size(-1))**2
        coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
        outputs = outputs / (coff + 1e-8)
        return outputs


def main():
    stft = ConvSTFT(win_size=512, hop_size=200, feature_type="complex")
    istft = ConviSTFT(win_size=512, hop_size=200, feature_type="complex")

    mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)

    spec = stft.forward(mixture)
    # shape: [batch_size, freq_bins, time_steps]
    print(spec.shape)

    waveform = istft.forward(spec)
    # shape: [batch_size, channels, num_samples]
    print(waveform.shape)

    return


if __name__ == "__main__":
    main()