assi1 / ELR /model /loss.py
uthurumella's picture
Upload 69 files
72fc481 verified
import torch.nn.functional as F
import torch
from parse_config import ConfigParser
import torch.nn as nn
def cross_entropy(output, target):
return F.cross_entropy(output, target)
class elr_loss(nn.Module):
def __init__(self, num_examp, num_classes=10, beta=0.3):
super(elr_loss, self).__init__()
self.num_classes = num_classes
self.config = ConfigParser.get_instance()
self.USE_CUDA = torch.cuda.is_available()
self.target = torch.zeros(num_examp, self.num_classes).cuda() if self.USE_CUDA else torch.zeros(num_examp, self.num_classes)
self.beta = beta
def forward(self, index, output, label):
y_pred = F.softmax(output,dim=1)
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
y_pred_ = y_pred.data.detach()
self.target[index] = self.beta * self.target[index] + (1-self.beta) * ((y_pred_)/(y_pred_).sum(dim=1,keepdim=True))
ce_loss = F.cross_entropy(output, label)
elr_reg = ((1-(self.target[index] * y_pred).sum(dim=1)).log()).mean()
final_loss = ce_loss + self.config['train_loss']['args']['lambda']*elr_reg
return final_loss