Spaces:
Sleeping
Sleeping
File size: 612 Bytes
1423dc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torch.nn as nn
class Patchify(nn.Module):
def __init__(self, in_channels, out_channels, patch_size):
super(Patchify, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(8, patch_size), stride=(8, patch_size), padding=0, bias=False)
def forward(self, x):
# x.shape = (batch_size, channels, height, width)
x = self.conv(x)
return x
if __name__ == "__main__":
model = Patchify(1, 32, 2)
print(model)
dummy_input = torch.randn(1, 1, 64, 16)
output = model(dummy_input)
print(output.shape)
|