File size: 1,557 Bytes
0667c13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from torch import nn
from torchvision import models
from einops import rearrange
from torchvision.models._utils import IntermediateLayerGetter
class Vgg(nn.Module):
def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
super(Vgg, self).__init__()
if name == 'vgg11_bn':
cnn = models.vgg11_bn(weights='DEFAULT')
elif name == 'vgg19_bn':
cnn = models.vgg19_bn(weights='DEFAULT')
pool_idx = 0
for i, layer in enumerate(cnn.features):
if isinstance(layer, torch.nn.MaxPool2d):
cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0)
pool_idx += 1
self.features = cnn.features
self.dropout = nn.Dropout(dropout)
self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
def forward(self, x):
"""
Shape:
- x: (N, C, H, W)
- output: (W, N, C)
"""
conv = self.features(x)
conv = self.dropout(conv)
conv = self.last_conv_1x1(conv)
# conv = rearrange(conv, 'b d h w -> b d (w h)')
conv = conv.transpose(-1, -2)
conv = conv.flatten(2)
conv = conv.permute(-1, 0, 1)
return conv
def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout)
def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout)
|