|
import torch |
|
from torch import nn |
|
from torchvision import models |
|
from einops import rearrange |
|
from torchvision.models._utils import IntermediateLayerGetter |
|
|
|
|
|
class Vgg(nn.Module): |
|
def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5): |
|
super(Vgg, self).__init__() |
|
|
|
if name == 'vgg11_bn': |
|
cnn = models.vgg11_bn(weights='DEFAULT') |
|
elif name == 'vgg19_bn': |
|
cnn = models.vgg19_bn(weights='DEFAULT') |
|
|
|
pool_idx = 0 |
|
|
|
for i, layer in enumerate(cnn.features): |
|
if isinstance(layer, torch.nn.MaxPool2d): |
|
cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0) |
|
pool_idx += 1 |
|
|
|
self.features = cnn.features |
|
self.dropout = nn.Dropout(dropout) |
|
self.last_conv_1x1 = nn.Conv2d(512, hidden, 1) |
|
|
|
def forward(self, x): |
|
""" |
|
Shape: |
|
- x: (N, C, H, W) |
|
- output: (W, N, C) |
|
""" |
|
|
|
conv = self.features(x) |
|
conv = self.dropout(conv) |
|
conv = self.last_conv_1x1(conv) |
|
|
|
|
|
conv = conv.transpose(-1, -2) |
|
conv = conv.flatten(2) |
|
conv = conv.permute(-1, 0, 1) |
|
return conv |
|
|
|
def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5): |
|
return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout) |
|
|
|
def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5): |
|
return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout) |
|
|
|
|