import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from .smoothL1Loss import SmoothL1Loss from .wingLoss import WingLoss def get_channel_sum(input): temp = torch.sum(input, dim=3) output = torch.sum(temp, dim=2) return output def expand_two_dimensions_at_end(input, dim1, dim2): input = input.unsqueeze(-1).unsqueeze(-1) input = input.expand(-1, -1, dim1, dim2) return input class STARLoss(nn.Module): def __init__(self, w=1, dist='smoothl1', num_dim_image=2, EPSILON=1e-5): super(STARLoss, self).__init__() self.w = w self.num_dim_image = num_dim_image self.EPSILON = EPSILON self.dist = dist if self.dist == 'smoothl1': self.dist_func = SmoothL1Loss() elif self.dist == 'l1': self.dist_func = F.l1_loss elif self.dist == 'l2': self.dist_func = F.mse_loss elif self.dist == 'wing': self.dist_func = WingLoss() else: raise NotImplementedError def __repr__(self): return "STARLoss()" def _make_grid(self, h, w): yy, xx = torch.meshgrid( torch.arange(h).float() / (h - 1) * 2 - 1, torch.arange(w).float() / (w - 1) * 2 - 1) return yy, xx def weighted_mean(self, heatmap): batch, npoints, h, w = heatmap.shape yy, xx = self._make_grid(h, w) yy = yy.view(1, 1, h, w).to(heatmap) xx = xx.view(1, 1, h, w).to(heatmap) yy_coord = (yy * heatmap).sum([2, 3]) # batch x npoints xx_coord = (xx * heatmap).sum([2, 3]) # batch x npoints coords = torch.stack([xx_coord, yy_coord], dim=-1) return coords def unbiased_weighted_covariance(self, htp, means, num_dim_image=2, EPSILON=1e-5): batch_size, num_points, height, width = htp.shape yv, xv = self._make_grid(height, width) xv = Variable(xv) yv = Variable(yv) if htp.is_cuda: xv = xv.cuda() yv = yv.cuda() xmean = means[:, :, 0] xv_minus_mean = xv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(xmean, height, width) # [batch_size, 68, 64, 64] ymean = means[:, :, 1] yv_minus_mean = yv.expand(batch_size, num_points, -1, -1) - expand_two_dimensions_at_end(ymean, height, width) # [batch_size, 68, 64, 64] wt_xv_minus_mean = xv_minus_mean wt_yv_minus_mean = yv_minus_mean wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096] wt_xv_minus_mean = wt_xv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096] wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, height * width) # [batch_size*68, 4096] wt_yv_minus_mean = wt_yv_minus_mean.view(batch_size * num_points, 1, height * width) # [batch_size*68, 1, 4096] vec_concat = torch.cat((wt_xv_minus_mean, wt_yv_minus_mean), 1) # [batch_size*68, 2, 4096] htp_vec = htp.view(batch_size * num_points, 1, height * width) htp_vec = htp_vec.expand(-1, 2, -1) covariance = torch.bmm(htp_vec * vec_concat, vec_concat.transpose(1, 2)) # [batch_size*68, 2, 2] covariance = covariance.view(batch_size, num_points, num_dim_image, num_dim_image) # [batch_size, 68, 2, 2] V_1 = htp.sum([2, 3]) + EPSILON # [batch_size, 68] V_2 = torch.pow(htp, 2).sum([2, 3]) + EPSILON # [batch_size, 68] denominator = V_1 - (V_2 / V_1) covariance = covariance / expand_two_dimensions_at_end(denominator, num_dim_image, num_dim_image) return covariance def ambiguity_guided_decompose(self, pts, eigenvalues, eigenvectors): batch_size, npoints = pts.shape[:2] rotate = torch.matmul(pts.view(batch_size, npoints, 1, 2), eigenvectors.transpose(-1, -2)) scale = rotate.view(batch_size, npoints, 2) / torch.sqrt(eigenvalues + self.EPSILON) return scale def eigenvalue_restriction(self, evalues, batch, npoints): eigen_loss = torch.abs(evalues.view(batch * npoints, 2)).sum(-1) return eigen_loss.mean() def forward(self, heatmap, groundtruth): """ heatmap: b x n x 64 x 64 groundtruth: b x n x 2 output: b x n x 1 => 1 """ # normalize bs, npoints, h, w = heatmap.shape heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6) heatmap = heatmap / heatmap_sum.view(bs, npoints, 1, 1) means = self.weighted_mean(heatmap) # [bs, 68, 2] covars = self.unbiased_weighted_covariance(heatmap, means) # covars [bs, 68, 2, 2] # TODO: GPU-based eigen-decomposition # https://github.com/pytorch/pytorch/issues/60537 _covars = covars.view(bs * npoints, 2, 2).cpu() evalues, evectors = _covars.symeig(eigenvectors=True) # evalues [bs * 68, 2], evectors [bs * 68, 2, 2] evalues = evalues.view(bs, npoints, 2).to(heatmap) evectors = evectors.view(bs, npoints, 2, 2).to(heatmap) # STAR Loss # Ambiguity-guided Decomposition error = self.ambiguity_guided_decompose(groundtruth - means, evalues, evectors) loss_trans = self.dist_func(torch.zeros_like(error).to(error), error) # Eigenvalue Restriction loss_eigen = self.eigenvalue_restriction(evalues, bs, npoints) star_loss = loss_trans + self.w * loss_eigen return star_loss