import torch import torch.nn as nn class Patchify(nn.Module): def __init__(self, in_channels, out_channels, patch_size): super(Patchify, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(8, patch_size), stride=(8, patch_size), padding=0, bias=False) def forward(self, x): # x.shape = (batch_size, channels, height, width) x = self.conv(x) return x if __name__ == "__main__": model = Patchify(1, 32, 2) print(model) dummy_input = torch.randn(1, 1, 64, 16) output = model(dummy_input) print(output.shape)