SincVAD_Demo / model /tinyvad.py
jethrowang's picture
Upload 18 files
1423dc8 verified
import torch
import torch.nn as nn
from .sinc_conv import TimeSincExtractor, FreqSincExtractor
from .patchify import Patchify
from .csp_tiny_layer import CSPTinyLayer
class TinyVAD(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, patch_size, num_blocks, sinc_conv, ssm):
super(TinyVAD, self).__init__()
self.sinc_conv = sinc_conv
if self.sinc_conv:
# self.extractor = TimeSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
self.extractor = FreqSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2)
self.patchify = Patchify(in_channels, hidden_channels, patch_size)
self.csp_tiny_layer1 = CSPTinyLayer(hidden_channels, hidden_channels, num_blocks, ssm)
self.csp_tiny_layer2 = CSPTinyLayer(hidden_channels, hidden_channels, num_blocks, ssm)
self.csp_tiny_layer3 = CSPTinyLayer(hidden_channels, out_channels, num_blocks, ssm)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Linear(out_channels, 1),
# nn.Sigmoid()
)
def forward(self, x):
if self.sinc_conv:
x = self.extractor(x, None)
x = x[0] # Untuple
x = self.patchify(x)
x = self.csp_tiny_layer1(x)
x = self.csp_tiny_layer2(x)
x = self.csp_tiny_layer3(x)
x = self.avg_pool(x).view(x.size(0), -1)
x = self.classifier(x)
return x
def predict(self, inputs):
logits = self.forward(inputs)
probs = torch.sigmoid(logits)
return probs
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = TinyVAD(1, 32, 64, 2, 2, False, False).to(device)
print(model)
dummy_input = torch.randn(1, 1, 64, 16).to(device)
output = model(dummy_input)
print(output)