multimodalart's picture
Upload 83 files
38e20ed verified
# The implementation is adopted from TFace,made pubicly available under the Apache-2.0 license at
# https://github.com/Tencent/TFace/blob/master/recognition/torchkit/backbone/model_irse.py
from collections import namedtuple
from torch.nn import BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU, Sequential
from .common import Flatten, SEModule, initialize_weights
class BasicBlockIR(Module):
""" BasicBlock for IRNet
"""
def __init__(self, in_channel, depth, stride):
super(BasicBlockIR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class BottleneckIR(Module):
""" BasicBlock with bottleneck for IRNet
"""
def __init__(self, in_channel, depth, stride):
super(BottleneckIR, self).__init__()
reduction_channel = depth // 4
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(
in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
BatchNorm2d(reduction_channel), PReLU(reduction_channel),
Conv2d(
reduction_channel,
reduction_channel, (3, 3), (1, 1),
1,
bias=False), BatchNorm2d(reduction_channel),
PReLU(reduction_channel),
Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class BasicBlockIRSE(BasicBlockIR):
def __init__(self, in_channel, depth, stride):
super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module('se_block', SEModule(depth, 16))
class BottleneckIRSE(BottleneckIR):
def __init__(self, in_channel, depth, stride):
super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module('se_block', SEModule(depth, 16))
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
'''A named tuple describing a ResNet block.'''
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + \
[Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 18:
blocks = [
get_block(in_channel=64, depth=64, num_units=2),
get_block(in_channel=64, depth=128, num_units=2),
get_block(in_channel=128, depth=256, num_units=2),
get_block(in_channel=256, depth=512, num_units=2)
]
elif num_layers == 34:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=6),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=8),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
elif num_layers == 200:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=24),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
return blocks
class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir'):
""" Args:
input_size: input_size of backbone
num_layers: num_layers of backbone
mode: support ir or irse
"""
super(Backbone, self).__init__()
assert input_size[0] in [112, 224], \
'input_size should be [112, 112] or [224, 224]'
assert num_layers in [18, 34, 50, 100, 152, 200], \
'num_layers should be 18, 34, 50, 100 or 152'
assert mode in ['ir', 'ir_se'], \
'mode should be ir or ir_se'
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
PReLU(64))
blocks = get_blocks(num_layers)
if num_layers <= 100:
if mode == 'ir':
unit_module = BasicBlockIR
elif mode == 'ir_se':
unit_module = BasicBlockIRSE
output_channel = 512
else:
if mode == 'ir':
unit_module = BottleneckIR
elif mode == 'ir_se':
unit_module = BottleneckIRSE
output_channel = 2048
if input_size[0] == 112:
self.output_layer = Sequential(
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
Linear(output_channel * 7 * 7, 512),
BatchNorm1d(512, affine=False))
else:
self.output_layer = Sequential(
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
Linear(output_channel * 14 * 14, 512),
BatchNorm1d(512, affine=False))
modules = []
mid_layer_indices = [] # [2, 15, 45, 48], total 49 layers for IR101
for block in blocks:
if len(mid_layer_indices) == 0:
mid_layer_indices.append(len(block) - 1)
else:
mid_layer_indices.append(len(block) + mid_layer_indices[-1])
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
self.mid_layer_indices = mid_layer_indices[-4:]
# self.dtype = next(self.parameters()).dtype
initialize_weights(self.modules())
def device(self):
return next(self.parameters()).device
def dtype(self):
return next(self.parameters()).dtype
def forward(self, x, return_mid_feats=False):
x = self.input_layer(x)
if not return_mid_feats:
x = self.body(x)
x = self.output_layer(x)
return x
else:
out_feats = []
for idx, module in enumerate(self.body):
x = module(x)
if idx in self.mid_layer_indices:
out_feats.append(x)
x = self.output_layer(x)
return x, out_feats
def IR_18(input_size):
""" Constructs a ir-18 model.
"""
model = Backbone(input_size, 18, 'ir')
return model
def IR_34(input_size):
""" Constructs a ir-34 model.
"""
model = Backbone(input_size, 34, 'ir')
return model
def IR_50(input_size):
""" Constructs a ir-50 model.
"""
model = Backbone(input_size, 50, 'ir')
return model
def IR_101(input_size):
""" Constructs a ir-101 model.
"""
model = Backbone(input_size, 100, 'ir')
return model
def IR_152(input_size):
""" Constructs a ir-152 model.
"""
model = Backbone(input_size, 152, 'ir')
return model
def IR_200(input_size):
""" Constructs a ir-200 model.
"""
model = Backbone(input_size, 200, 'ir')
return model
def IR_SE_50(input_size):
""" Constructs a ir_se-50 model.
"""
model = Backbone(input_size, 50, 'ir_se')
return model
def IR_SE_101(input_size):
""" Constructs a ir_se-101 model.
"""
model = Backbone(input_size, 100, 'ir_se')
return model
def IR_SE_152(input_size):
""" Constructs a ir_se-152 model.
"""
model = Backbone(input_size, 152, 'ir_se')
return model
def IR_SE_200(input_size):
""" Constructs a ir_se-200 model.
"""
model = Backbone(input_size, 200, 'ir_se')
return model