|
import torch |
|
from torch import nn |
|
|
|
import vietocr.model.backbone.vgg as vgg |
|
from vietocr.model.backbone.resnet import Resnet50 |
|
|
|
class CNN(nn.Module): |
|
def __init__(self, backbone, **kwargs): |
|
super(CNN, self).__init__() |
|
|
|
if backbone == 'vgg11_bn': |
|
self.model = vgg.vgg11_bn(**kwargs) |
|
elif backbone == 'vgg19_bn': |
|
self.model = vgg.vgg19_bn(**kwargs) |
|
elif backbone == 'resnet50': |
|
self.model = Resnet50(**kwargs) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
def freeze(self): |
|
for name, param in self.model.features.named_parameters(): |
|
if name != 'last_conv_1x1': |
|
param.requires_grad = False |
|
|
|
def unfreeze(self): |
|
for param in self.model.features.parameters(): |
|
param.requires_grad = True |
|
|