File size: 612 Bytes
1423dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn

class Patchify(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size):
        super(Patchify, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(8, patch_size), stride=(8, patch_size), padding=0, bias=False)
        
    def forward(self, x):
        # x.shape = (batch_size, channels, height, width)
        x = self.conv(x)

        return x

if __name__ == "__main__":
    model = Patchify(1, 32, 2)
    print(model)
    dummy_input = torch.randn(1, 1, 64, 16)
    output = model(dummy_input)
    print(output.shape)