yuandong513
feat: init
17cd746
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .core.coord_conv import CoordConvTh
from external.landmark_detection.lib.dataset import get_decoder
class Activation(nn.Module):
def __init__(self, kind: str = 'relu', channel=None):
super().__init__()
self.kind = kind
if '+' in kind:
norm_str, act_str = kind.split('+')
else:
norm_str, act_str = 'none', kind
self.norm_fn = {
'in': F.instance_norm,
'bn': nn.BatchNorm2d(channel),
'bn_noaffine': nn.BatchNorm2d(channel, affine=False, track_running_stats=True),
'none': None
}[norm_str]
self.act_fn = {
'relu': F.relu,
'softplus': nn.Softplus(),
'exp': torch.exp,
'sigmoid': torch.sigmoid,
'tanh': torch.tanh,
'none': None
}[act_str]
self.channel = channel
def forward(self, x):
if self.norm_fn is not None:
x = self.norm_fn(x)
if self.act_fn is not None:
x = self.act_fn(x)
return x
def extra_repr(self):
return f'kind={self.kind}, channel={self.channel}'
class ConvBlock(nn.Module):
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, groups=1):
super(ConvBlock, self).__init__()
self.inp_dim = inp_dim
self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size,
stride, padding=(kernel_size - 1) // 2, groups=groups, bias=True)
self.relu = None
self.bn = None
if relu:
self.relu = nn.ReLU()
if bn:
self.bn = nn.BatchNorm2d(out_dim)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ResBlock(nn.Module):
def __init__(self, inp_dim, out_dim, mid_dim=None):
super(ResBlock, self).__init__()
if mid_dim is None:
mid_dim = out_dim // 2
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2d(inp_dim)
self.conv1 = ConvBlock(inp_dim, mid_dim, 1, relu=False)
self.bn2 = nn.BatchNorm2d(mid_dim)
self.conv2 = ConvBlock(mid_dim, mid_dim, 3, relu=False)
self.bn3 = nn.BatchNorm2d(mid_dim)
self.conv3 = ConvBlock(mid_dim, out_dim, 1, relu=False)
self.skip_layer = ConvBlock(inp_dim, out_dim, 1, relu=False)
if inp_dim == out_dim:
self.need_skip = False
else:
self.need_skip = True
def forward(self, x):
if self.need_skip:
residual = self.skip_layer(x)
else:
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
out += residual
return out
class Hourglass(nn.Module):
def __init__(self, n, f, increase=0, up_mode='nearest',
add_coord=False, first_one=False, x_dim=64, y_dim=64):
super(Hourglass, self).__init__()
nf = f + increase
Block = ResBlock
if add_coord:
self.coordconv = CoordConvTh(x_dim=x_dim, y_dim=y_dim,
with_r=True, with_boundary=True,
relu=False, bn=False,
in_channels=f, out_channels=f,
first_one=first_one,
kernel_size=1,
stride=1, padding=0)
else:
self.coordconv = None
self.up1 = Block(f, f)
# Lower branch
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.low1 = Block(f, nf)
self.n = n
# Recursive hourglass
if self.n > 1:
self.low2 = Hourglass(n=n - 1, f=nf, increase=increase, up_mode=up_mode, add_coord=False)
else:
self.low2 = Block(nf, nf)
self.low3 = Block(nf, f)
self.up2 = nn.Upsample(scale_factor=2, mode=up_mode)
def forward(self, x, heatmap=None):
if self.coordconv is not None:
x = self.coordconv(x, heatmap)
up1 = self.up1(x)
pool1 = self.pool1(x)
low1 = self.low1(pool1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return up1 + up2
class E2HTransform(nn.Module):
def __init__(self, edge_info, num_points, num_edges):
super().__init__()
e2h_matrix = np.zeros([num_points, num_edges])
for edge_id, isclosed_indices in enumerate(edge_info):
is_closed, indices = isclosed_indices
for point_id in indices:
e2h_matrix[point_id, edge_id] = 1
e2h_matrix = torch.from_numpy(e2h_matrix).float()
# pn x en x 1 x 1.
self.register_buffer('weight', e2h_matrix.view(
e2h_matrix.size(0), e2h_matrix.size(1), 1, 1))
# some keypoints are not coverred by any edges,
# in these cases, we must add a constant bias to their heatmap weights.
bias = ((e2h_matrix @ torch.ones(e2h_matrix.size(1)).to(
e2h_matrix)) < 0.5).to(e2h_matrix)
# pn x 1.
self.register_buffer('bias', bias)
def forward(self, edgemaps):
# input: batch_size x en x hw x hh.
# output: batch_size x pn x hw x hh.
return F.conv2d(edgemaps, weight=self.weight, bias=self.bias)
class StackedHGNetV1(nn.Module):
def __init__(self, config, classes_num, edge_info,
nstack=4, nlevels=4, in_channel=256, increase=0,
add_coord=True, decoder_type='default'):
super(StackedHGNetV1, self).__init__()
self.cfg = config
self.coder_type = decoder_type
self.decoder = get_decoder(decoder_type=decoder_type)
self.nstack = nstack
self.add_coord = add_coord
self.num_heats = classes_num[0]
if self.add_coord:
convBlock = CoordConvTh(x_dim=self.cfg.width, y_dim=self.cfg.height,
with_r=True, with_boundary=False,
relu=True, bn=True,
in_channels=3, out_channels=64,
kernel_size=7,
stride=2, padding=3)
else:
convBlock = ConvBlock(3, 64, 7, 2, bn=True, relu=True)
pool = nn.MaxPool2d(kernel_size=2, stride=2)
Block = ResBlock
self.pre = nn.Sequential(
convBlock,
Block(64, 128),
pool,
Block(128, 128),
Block(128, in_channel)
)
self.hgs = nn.ModuleList(
[Hourglass(n=nlevels, f=in_channel, increase=increase, add_coord=self.add_coord, first_one=(_ == 0),
x_dim=int(self.cfg.width / self.nstack), y_dim=int(self.cfg.height / self.nstack))
for _ in range(nstack)])
self.features = nn.ModuleList([
nn.Sequential(
Block(in_channel, in_channel),
ConvBlock(in_channel, in_channel, 1, bn=True, relu=True)
) for _ in range(nstack)])
self.out_heatmaps = nn.ModuleList(
[ConvBlock(in_channel, self.num_heats, 1, relu=False, bn=False)
for _ in range(nstack)])
if self.cfg.use_AAM:
self.num_edges = classes_num[1]
self.num_points = classes_num[2]
self.e2h_transform = E2HTransform(edge_info, self.num_points, self.num_edges)
self.out_edgemaps = nn.ModuleList(
[ConvBlock(in_channel, self.num_edges, 1, relu=False, bn=False)
for _ in range(nstack)])
self.out_pointmaps = nn.ModuleList(
[ConvBlock(in_channel, self.num_points, 1, relu=False, bn=False)
for _ in range(nstack)])
self.merge_edgemaps = nn.ModuleList(
[ConvBlock(self.num_edges, in_channel, 1, relu=False, bn=False)
for _ in range(nstack - 1)])
self.merge_pointmaps = nn.ModuleList(
[ConvBlock(self.num_points, in_channel, 1, relu=False, bn=False)
for _ in range(nstack - 1)])
self.edgemap_act = Activation("sigmoid", self.num_edges)
self.pointmap_act = Activation("sigmoid", self.num_points)
self.merge_features = nn.ModuleList(
[ConvBlock(in_channel, in_channel, 1, relu=False, bn=False)
for _ in range(nstack - 1)])
self.merge_heatmaps = nn.ModuleList(
[ConvBlock(self.num_heats, in_channel, 1, relu=False, bn=False)
for _ in range(nstack - 1)])
self.nstack = nstack
self.heatmap_act = Activation("in+relu", self.num_heats)
self.inference = False
def set_inference(self, inference):
self.inference = inference
def forward(self, x):
x = self.pre(x)
y, fusionmaps = [], []
heatmaps = None
for i in range(self.nstack):
hg = self.hgs[i](x, heatmap=heatmaps)
feature = self.features[i](hg)
heatmaps0 = self.out_heatmaps[i](feature)
heatmaps = self.heatmap_act(heatmaps0)
if self.cfg.use_AAM:
pointmaps0 = self.out_pointmaps[i](feature)
pointmaps = self.pointmap_act(pointmaps0)
edgemaps0 = self.out_edgemaps[i](feature)
edgemaps = self.edgemap_act(edgemaps0)
mask = self.e2h_transform(edgemaps) * pointmaps
fusion_heatmaps = mask * heatmaps
else:
fusion_heatmaps = heatmaps
landmarks = self.decoder.get_coords_from_heatmap(fusion_heatmaps)
if i < self.nstack - 1:
x = x + self.merge_features[i](feature) + \
self.merge_heatmaps[i](heatmaps)
if self.cfg.use_AAM:
x += self.merge_pointmaps[i](pointmaps)
x += self.merge_edgemaps[i](edgemaps)
y.append(landmarks)
if self.cfg.use_AAM:
y.append(pointmaps)
y.append(edgemaps)
fusionmaps.append(fusion_heatmaps)
return y, fusionmaps, landmarks