Spaces:
svjack
/
Runtime error

yuandong513
feat: init
17cd746
raw
history blame contribute delete
791 Bytes
# -*- coding: utf-8 -*-
import math
import torch
from torch import nn
# torch.log and math.log is e based
class WingLoss(nn.Module):
def __init__(self, omega=0.01, epsilon=2):
super(WingLoss, self).__init__()
self.omega = omega
self.epsilon = epsilon
def forward(self, pred, target):
y = target
y_hat = pred
delta_2 = (y - y_hat).pow(2).sum(dim=-1, keepdim=False)
# delta = delta_2.sqrt()
delta = delta_2.clamp(min=1e-6).sqrt()
C = self.omega - self.omega * math.log(1 + self.omega / self.epsilon)
loss = torch.where(
delta < self.omega,
self.omega * torch.log(1 + delta / self.epsilon),
delta - C
)
return loss.mean()