HoneyTian's picture
add frcrn model
cba47e4
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/layers/uni_deep_fsmn.py
https://huggingface.co./spaces/alibabasglab/ClearVoice/blob/main/models/mossformer2_se/fsmn.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class UniDeepFsmn(nn.Module):
def __init__(self,
input_dim: int,
hidden_size: int,
lorder: int = 1,
):
super(UniDeepFsmn, self).__init__()
self.input_dim = input_dim
self.hidden_size = hidden_size
self.lorder = lorder
self.linear = nn.Linear(input_dim, hidden_size)
self.project = nn.Linear(hidden_size, input_dim, bias=False)
self.conv1 = nn.Conv2d(
input_dim,
input_dim,
kernel_size=(lorder, 1),
stride=(1, 1),
groups=input_dim,
bias=False
)
def forward(self, inputs: torch.Tensor):
"""
:param inputs: torch.Tensor, shape: [b, t, h]
:return: torch.Tensor, shape: [b, t, h]
"""
x = F.relu(self.linear(inputs))
x = self.project(x)
x = torch.unsqueeze(x, 1)
# x shape: [b, 1, t, h]
x = x.permute(0, 3, 2, 1)
# x shape: [b, h, t, 1]
y = F.pad(x, [0, 0, self.lorder - 1, 0])
x = x + self.conv1(y)
x = x.permute(0, 3, 2, 1)
# x shape: [b, 1, t, h]
x = x.squeeze()
result = inputs + x
return result
def main():
x = torch.rand(size=(1, 200, 32))
fsmn = UniDeepFsmn(
input_dim=32,
hidden_size=64,
lorder=3,
)
result = fsmn.forward(x)
print(result.shape)
return
if __name__ == "__main__":
main()