HORT / wilor /models /heads /refinement_net.py
zerchen's picture
init test without models
717b269
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
from typing import Optional
def make_linear_layers(feat_dims, relu_final=True, use_bn=False):
layers = []
for i in range(len(feat_dims)-1):
layers.append(nn.Linear(feat_dims[i], feat_dims[i+1]))
# Do not use ReLU for final estimation
if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final):
if use_bn:
layers.append(nn.BatchNorm1d(feat_dims[i+1]))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
layers = []
for i in range(len(feat_dims)-1):
layers.append(
nn.Conv2d(
in_channels=feat_dims[i],
out_channels=feat_dims[i+1],
kernel_size=kernel,
stride=stride,
padding=padding
))
# Do not use BN and ReLU for final estimation
if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
layers.append(nn.BatchNorm2d(feat_dims[i+1]))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def make_deconv_layers(feat_dims, bnrelu_final=True):
layers = []
for i in range(len(feat_dims)-1):
layers.append(
nn.ConvTranspose2d(
in_channels=feat_dims[i],
out_channels=feat_dims[i+1],
kernel_size=4,
stride=2,
padding=1,
output_padding=0,
bias=False))
# Do not use BN and ReLU for final estimation
if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
layers.append(nn.BatchNorm2d(feat_dims[i+1]))
layers.append(nn.ReLU(inplace=True))
return nn.Sequential(*layers)
def sample_joint_features(img_feat, joint_xy):
height, width = img_feat.shape[2:]
x = joint_xy[:, :, 0] / (width - 1) * 2 - 1
y = joint_xy[:, :, 1] / (height - 1) * 2 - 1
grid = torch.stack((x, y), 2)[:, :, None, :]
img_feat = F.grid_sample(img_feat, grid, align_corners=True)[:, :, :, 0] # batch_size, channel_dim, joint_num
img_feat = img_feat.permute(0, 2, 1).contiguous() # batch_size, joint_num, channel_dim
return img_feat
def perspective_projection(points: torch.Tensor,
translation: torch.Tensor,
focal_length: torch.Tensor,
camera_center: Optional[torch.Tensor] = None,
rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Computes the perspective projection of a set of 3D points.
Args:
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
Returns:
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
"""
batch_size = points.shape[0]
if rotation is None:
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
if camera_center is None:
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
# Populate intrinsic camera matrix K.
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
K[:,0,0] = focal_length[:,0]
K[:,1,1] = focal_length[:,1]
K[:,2,2] = 1.
K[:,:-1, -1] = camera_center
# Transform points
points = torch.einsum('bij,bkj->bki', rotation, points)
points = points + translation.unsqueeze(1)
# Apply perspective distortion
projected_points = points / points[:,:,-1].unsqueeze(-1)
# Apply camera intrinsics
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
return projected_points[:, :, :-1]
class DeConvNet(nn.Module):
def __init__(self, feat_dim=768, upscale=4):
super(DeConvNet, self).__init__()
self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False)
self.deconv = nn.ModuleList([])
for i in range(int(math.log2(upscale))+1):
if i==0:
self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4]))
elif i==1:
self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8]))
elif i==2:
self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8, feat_dim//8]))
def forward(self, img_feat):
face_img_feats = []
img_feat = self.first_conv(img_feat)
face_img_feats.append(img_feat)
for i, deconv in enumerate(self.deconv):
scale = 2**i
img_feat_i = deconv(img_feat)
face_img_feat = img_feat_i
face_img_feats.append(face_img_feat)
return face_img_feats[::-1] # high resolution -> low resolution
class DeConvNet_v2(nn.Module):
def __init__(self, feat_dim=768):
super(DeConvNet_v2, self).__init__()
self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False)
self.deconv = nn.Sequential(*[nn.ConvTranspose2d(in_channels=feat_dim//2, out_channels=feat_dim//4, kernel_size=4, stride=4, padding=0, output_padding=0, bias=False),
nn.BatchNorm2d(feat_dim//4),
nn.ReLU(inplace=True)])
def forward(self, img_feat):
face_img_feats = []
img_feat = self.first_conv(img_feat)
img_feat = self.deconv(img_feat)
return [img_feat]
class RefineNet(nn.Module):
def __init__(self, cfg, feat_dim=1280, upscale=3):
super(RefineNet, self).__init__()
#self.deconv = DeConvNet_v2(feat_dim=feat_dim)
#self.out_dim = feat_dim//4
self.deconv = DeConvNet(feat_dim=feat_dim, upscale=upscale)
self.out_dim = feat_dim//8 + feat_dim//4 + feat_dim//2
self.dec_pose = nn.Linear(self.out_dim, 96)
self.dec_cam = nn.Linear(self.out_dim, 3)
self.dec_shape = nn.Linear(self.out_dim, 10)
self.cfg = cfg
self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
def forward(self, img_feat, verts_3d, pred_cam, pred_mano_feats, focal_length):
B = img_feat.shape[0]
img_feats = self.deconv(img_feat)
img_feat_sizes = [img_feat.shape[2] for img_feat in img_feats]
temp_cams = [torch.stack([pred_cam[:, 1], pred_cam[:, 2],
2*focal_length[:, 0]/(img_feat_size * pred_cam[:, 0] +1e-9)],dim=-1) for img_feat_size in img_feat_sizes]
verts_2d = [perspective_projection(verts_3d,
translation=temp_cams[i],
focal_length=focal_length / img_feat_sizes[i]) for i in range(len(img_feat_sizes))]
vert_feats = [sample_joint_features(img_feats[i], verts_2d[i]).max(1).values for i in range(len(img_feat_sizes))]
vert_feats = torch.cat(vert_feats, dim=-1)
delta_pose = self.dec_pose(vert_feats)
delta_betas = self.dec_shape(vert_feats)
delta_cam = self.dec_cam(vert_feats)
pred_hand_pose = pred_mano_feats['hand_pose'] + delta_pose
pred_betas = pred_mano_feats['betas'] + delta_betas
pred_cam = pred_mano_feats['cam'] + delta_cam
joint_conversion_fn = {
'6d': rot6d_to_rotmat,
'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
}[self.joint_rep_type]
pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3)
pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
'hand_pose': pred_hand_pose[:, 1:],
'betas': pred_betas}
return pred_mano_params, pred_cam