|
|
|
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
|
|
class WingLoss(nn.Module):
|
|
def __init__(self, omega=0.01, epsilon=2):
|
|
super(WingLoss, self).__init__()
|
|
self.omega = omega
|
|
self.epsilon = epsilon
|
|
|
|
def forward(self, pred, target):
|
|
y = target
|
|
y_hat = pred
|
|
delta_2 = (y - y_hat).pow(2).sum(dim=-1, keepdim=False)
|
|
|
|
delta = delta_2.clamp(min=1e-6).sqrt()
|
|
C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon)
|
|
loss = torch.where(
|
|
delta < self.omega,
|
|
self.omega * torch.log(1 + delta / self.epsilon),
|
|
delta - C
|
|
)
|
|
return loss.mean()
|
|
|