|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class SmoothL1Loss(nn.Module):
|
|
def __init__(self, scale=0.01):
|
|
super(SmoothL1Loss, self).__init__()
|
|
self.scale = scale
|
|
self.EPSILON = 1e-10
|
|
|
|
def __repr__(self):
|
|
return "SmoothL1Loss()"
|
|
|
|
def forward(self, output: torch.Tensor, groundtruth: torch.Tensor, reduction='mean'):
|
|
"""
|
|
input: b x n x 2
|
|
output: b x n x 1 => 1
|
|
"""
|
|
if output.dim() == 4:
|
|
shape = output.shape
|
|
groundtruth = groundtruth.reshape(shape[0], shape[1], 1, shape[3])
|
|
|
|
delta_2 = (output - groundtruth).pow(2).sum(dim=-1, keepdim=False)
|
|
delta = delta_2.clamp(min=1e-6).sqrt()
|
|
|
|
loss = torch.where( \
|
|
delta_2 < self.scale * self.scale, \
|
|
0.5 / self.scale * delta_2, \
|
|
delta - 0.5 * self.scale)
|
|
|
|
if reduction == 'mean':
|
|
loss = loss.mean()
|
|
elif reduction == 'sum':
|
|
loss = loss.sum()
|
|
|
|
return loss
|
|
|