Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .core.coord_conv import CoordConvTh | |
from external.landmark_detection.lib.dataset import get_decoder | |
class Activation(nn.Module): | |
def __init__(self, kind: str = 'relu', channel=None): | |
super().__init__() | |
self.kind = kind | |
if '+' in kind: | |
norm_str, act_str = kind.split('+') | |
else: | |
norm_str, act_str = 'none', kind | |
self.norm_fn = { | |
'in': F.instance_norm, | |
'bn': nn.BatchNorm2d(channel), | |
'bn_noaffine': nn.BatchNorm2d(channel, affine=False, track_running_stats=True), | |
'none': None | |
}[norm_str] | |
self.act_fn = { | |
'relu': F.relu, | |
'softplus': nn.Softplus(), | |
'exp': torch.exp, | |
'sigmoid': torch.sigmoid, | |
'tanh': torch.tanh, | |
'none': None | |
}[act_str] | |
self.channel = channel | |
def forward(self, x): | |
if self.norm_fn is not None: | |
x = self.norm_fn(x) | |
if self.act_fn is not None: | |
x = self.act_fn(x) | |
return x | |
def extra_repr(self): | |
return f'kind={self.kind}, channel={self.channel}' | |
class ConvBlock(nn.Module): | |
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, groups=1): | |
super(ConvBlock, self).__init__() | |
self.inp_dim = inp_dim | |
self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, | |
stride, padding=(kernel_size - 1) // 2, groups=groups, bias=True) | |
self.relu = None | |
self.bn = None | |
if relu: | |
self.relu = nn.ReLU() | |
if bn: | |
self.bn = nn.BatchNorm2d(out_dim) | |
def forward(self, x): | |
x = self.conv(x) | |
if self.bn is not None: | |
x = self.bn(x) | |
if self.relu is not None: | |
x = self.relu(x) | |
return x | |
class ResBlock(nn.Module): | |
def __init__(self, inp_dim, out_dim, mid_dim=None): | |
super(ResBlock, self).__init__() | |
if mid_dim is None: | |
mid_dim = out_dim // 2 | |
self.relu = nn.ReLU() | |
self.bn1 = nn.BatchNorm2d(inp_dim) | |
self.conv1 = ConvBlock(inp_dim, mid_dim, 1, relu=False) | |
self.bn2 = nn.BatchNorm2d(mid_dim) | |
self.conv2 = ConvBlock(mid_dim, mid_dim, 3, relu=False) | |
self.bn3 = nn.BatchNorm2d(mid_dim) | |
self.conv3 = ConvBlock(mid_dim, out_dim, 1, relu=False) | |
self.skip_layer = ConvBlock(inp_dim, out_dim, 1, relu=False) | |
if inp_dim == out_dim: | |
self.need_skip = False | |
else: | |
self.need_skip = True | |
def forward(self, x): | |
if self.need_skip: | |
residual = self.skip_layer(x) | |
else: | |
residual = x | |
out = self.bn1(x) | |
out = self.relu(out) | |
out = self.conv1(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn3(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out += residual | |
return out | |
class Hourglass(nn.Module): | |
def __init__(self, n, f, increase=0, up_mode='nearest', | |
add_coord=False, first_one=False, x_dim=64, y_dim=64): | |
super(Hourglass, self).__init__() | |
nf = f + increase | |
Block = ResBlock | |
if add_coord: | |
self.coordconv = CoordConvTh(x_dim=x_dim, y_dim=y_dim, | |
with_r=True, with_boundary=True, | |
relu=False, bn=False, | |
in_channels=f, out_channels=f, | |
first_one=first_one, | |
kernel_size=1, | |
stride=1, padding=0) | |
else: | |
self.coordconv = None | |
self.up1 = Block(f, f) | |
# Lower branch | |
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.low1 = Block(f, nf) | |
self.n = n | |
# Recursive hourglass | |
if self.n > 1: | |
self.low2 = Hourglass(n=n - 1, f=nf, increase=increase, up_mode=up_mode, add_coord=False) | |
else: | |
self.low2 = Block(nf, nf) | |
self.low3 = Block(nf, f) | |
self.up2 = nn.Upsample(scale_factor=2, mode=up_mode) | |
def forward(self, x, heatmap=None): | |
if self.coordconv is not None: | |
x = self.coordconv(x, heatmap) | |
up1 = self.up1(x) | |
pool1 = self.pool1(x) | |
low1 = self.low1(pool1) | |
low2 = self.low2(low1) | |
low3 = self.low3(low2) | |
up2 = self.up2(low3) | |
return up1 + up2 | |
class E2HTransform(nn.Module): | |
def __init__(self, edge_info, num_points, num_edges): | |
super().__init__() | |
e2h_matrix = np.zeros([num_points, num_edges]) | |
for edge_id, isclosed_indices in enumerate(edge_info): | |
is_closed, indices = isclosed_indices | |
for point_id in indices: | |
e2h_matrix[point_id, edge_id] = 1 | |
e2h_matrix = torch.from_numpy(e2h_matrix).float() | |
# pn x en x 1 x 1. | |
self.register_buffer('weight', e2h_matrix.view( | |
e2h_matrix.size(0), e2h_matrix.size(1), 1, 1)) | |
# some keypoints are not coverred by any edges, | |
# in these cases, we must add a constant bias to their heatmap weights. | |
bias = ((e2h_matrix @ torch.ones(e2h_matrix.size(1)).to( | |
e2h_matrix)) < 0.5).to(e2h_matrix) | |
# pn x 1. | |
self.register_buffer('bias', bias) | |
def forward(self, edgemaps): | |
# input: batch_size x en x hw x hh. | |
# output: batch_size x pn x hw x hh. | |
return F.conv2d(edgemaps, weight=self.weight, bias=self.bias) | |
class StackedHGNetV1(nn.Module): | |
def __init__(self, config, classes_num, edge_info, | |
nstack=4, nlevels=4, in_channel=256, increase=0, | |
add_coord=True, decoder_type='default'): | |
super(StackedHGNetV1, self).__init__() | |
self.cfg = config | |
self.coder_type = decoder_type | |
self.decoder = get_decoder(decoder_type=decoder_type) | |
self.nstack = nstack | |
self.add_coord = add_coord | |
self.num_heats = classes_num[0] | |
if self.add_coord: | |
convBlock = CoordConvTh(x_dim=self.cfg.width, y_dim=self.cfg.height, | |
with_r=True, with_boundary=False, | |
relu=True, bn=True, | |
in_channels=3, out_channels=64, | |
kernel_size=7, | |
stride=2, padding=3) | |
else: | |
convBlock = ConvBlock(3, 64, 7, 2, bn=True, relu=True) | |
pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
Block = ResBlock | |
self.pre = nn.Sequential( | |
convBlock, | |
Block(64, 128), | |
pool, | |
Block(128, 128), | |
Block(128, in_channel) | |
) | |
self.hgs = nn.ModuleList( | |
[Hourglass(n=nlevels, f=in_channel, increase=increase, add_coord=self.add_coord, first_one=(_ == 0), | |
x_dim=int(self.cfg.width / self.nstack), y_dim=int(self.cfg.height / self.nstack)) | |
for _ in range(nstack)]) | |
self.features = nn.ModuleList([ | |
nn.Sequential( | |
Block(in_channel, in_channel), | |
ConvBlock(in_channel, in_channel, 1, bn=True, relu=True) | |
) for _ in range(nstack)]) | |
self.out_heatmaps = nn.ModuleList( | |
[ConvBlock(in_channel, self.num_heats, 1, relu=False, bn=False) | |
for _ in range(nstack)]) | |
if self.cfg.use_AAM: | |
self.num_edges = classes_num[1] | |
self.num_points = classes_num[2] | |
self.e2h_transform = E2HTransform(edge_info, self.num_points, self.num_edges) | |
self.out_edgemaps = nn.ModuleList( | |
[ConvBlock(in_channel, self.num_edges, 1, relu=False, bn=False) | |
for _ in range(nstack)]) | |
self.out_pointmaps = nn.ModuleList( | |
[ConvBlock(in_channel, self.num_points, 1, relu=False, bn=False) | |
for _ in range(nstack)]) | |
self.merge_edgemaps = nn.ModuleList( | |
[ConvBlock(self.num_edges, in_channel, 1, relu=False, bn=False) | |
for _ in range(nstack - 1)]) | |
self.merge_pointmaps = nn.ModuleList( | |
[ConvBlock(self.num_points, in_channel, 1, relu=False, bn=False) | |
for _ in range(nstack - 1)]) | |
self.edgemap_act = Activation("sigmoid", self.num_edges) | |
self.pointmap_act = Activation("sigmoid", self.num_points) | |
self.merge_features = nn.ModuleList( | |
[ConvBlock(in_channel, in_channel, 1, relu=False, bn=False) | |
for _ in range(nstack - 1)]) | |
self.merge_heatmaps = nn.ModuleList( | |
[ConvBlock(self.num_heats, in_channel, 1, relu=False, bn=False) | |
for _ in range(nstack - 1)]) | |
self.nstack = nstack | |
self.heatmap_act = Activation("in+relu", self.num_heats) | |
self.inference = False | |
def set_inference(self, inference): | |
self.inference = inference | |
def forward(self, x): | |
x = self.pre(x) | |
y, fusionmaps = [], [] | |
heatmaps = None | |
for i in range(self.nstack): | |
hg = self.hgs[i](x, heatmap=heatmaps) | |
feature = self.features[i](hg) | |
heatmaps0 = self.out_heatmaps[i](feature) | |
heatmaps = self.heatmap_act(heatmaps0) | |
if self.cfg.use_AAM: | |
pointmaps0 = self.out_pointmaps[i](feature) | |
pointmaps = self.pointmap_act(pointmaps0) | |
edgemaps0 = self.out_edgemaps[i](feature) | |
edgemaps = self.edgemap_act(edgemaps0) | |
mask = self.e2h_transform(edgemaps) * pointmaps | |
fusion_heatmaps = mask * heatmaps | |
else: | |
fusion_heatmaps = heatmaps | |
landmarks = self.decoder.get_coords_from_heatmap(fusion_heatmaps) | |
if i < self.nstack - 1: | |
x = x + self.merge_features[i](feature) + \ | |
self.merge_heatmaps[i](heatmaps) | |
if self.cfg.use_AAM: | |
x += self.merge_pointmaps[i](pointmaps) | |
x += self.merge_edgemaps[i](edgemaps) | |
y.append(landmarks) | |
if self.cfg.use_AAM: | |
y.append(pointmaps) | |
y.append(edgemaps) | |
fusionmaps.append(fusion_heatmaps) | |
return y, fusionmaps, landmarks |