Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ding.torch_utils.network import one_hot | |
class MultiLogitsLoss(nn.Module): | |
""" | |
Overview: | |
Base class for supervised learning on linklink, including basic processes. | |
Interfaces: | |
``__init__``, ``forward``. | |
""" | |
def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None: | |
""" | |
Overview: | |
Initialization method, use cross_entropy as default criterion. | |
Arguments: | |
- criterion (:obj:`str`): Criterion type, supports ['cross_entropy', 'label_smooth_ce']. | |
- smooth_ratio (:obs:`float`): Smoothing ratio for label smoothing. | |
""" | |
super(MultiLogitsLoss, self).__init__() | |
if criterion is None: | |
criterion = 'cross_entropy' | |
assert (criterion in ['cross_entropy', 'label_smooth_ce']) | |
self.criterion = criterion | |
if self.criterion == 'label_smooth_ce': | |
self.ratio = smooth_ratio | |
def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.LongTensor: | |
""" | |
Overview: | |
Process the label according to the criterion. | |
Arguments: | |
- logits (:obj:`torch.Tensor`): Predicted logits. | |
- labels (:obj:`torch.LongTensor`): Ground truth. | |
Returns: | |
- ret (:obj:`torch.LongTensor`): Processed label. | |
""" | |
N = logits.shape[1] | |
if self.criterion == 'cross_entropy': | |
return one_hot(labels, num=N) | |
elif self.criterion == 'label_smooth_ce': | |
val = float(self.ratio) / (N - 1) | |
ret = torch.full_like(logits, val) | |
ret.scatter_(1, labels.unsqueeze(1), 1 - val) | |
return ret | |
def _nll_loss(self, nlls: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: | |
""" | |
Overview: | |
Calculate the negative log likelihood loss. | |
Arguments: | |
- nlls (:obj:`torch.Tensor`): Negative log likelihood loss. | |
- labels (:obj:`torch.LongTensor`): Ground truth. | |
Returns: | |
- ret (:obj:`torch.Tensor`): Calculated loss. | |
""" | |
ret = (-nlls * (labels.detach())) | |
return ret.sum(dim=1) | |
def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: | |
""" | |
Overview: | |
Calculate the metric matrix. | |
Arguments: | |
- logits (:obj:`torch.Tensor`): Predicted logits. | |
- labels (:obj:`torch.LongTensor`): Ground truth. | |
Returns: | |
- metric (:obj:`torch.Tensor`): Calculated metric matrix. | |
""" | |
M, N = logits.shape | |
labels = self._label_process(logits, labels) | |
logits = F.log_softmax(logits, dim=1) | |
metric = [] | |
for i in range(M): | |
logit = logits[i] | |
logit = logit.repeat(M).reshape(M, N) | |
metric.append(self._nll_loss(logit, labels)) | |
return torch.stack(metric, dim=0) | |
def _match(self, matrix: torch.Tensor): | |
""" | |
Overview: | |
Match the metric matrix. | |
Arguments: | |
- matrix (:obj:`torch.Tensor`): Metric matrix. | |
Returns: | |
- index (:obj:`np.ndarray`): Matched index. | |
""" | |
mat = matrix.clone().detach().to('cpu').numpy() | |
mat = -mat # maximize | |
M = mat.shape[0] | |
index = np.full(M, -1, dtype=np.int32) # -1 note not find link | |
lx = mat.max(axis=1) | |
ly = np.zeros(M, dtype=np.float32) | |
visx = np.zeros(M, dtype=np.bool_) | |
visy = np.zeros(M, dtype=np.bool_) | |
def has_augmented_path(t, binary_distance_matrix): | |
# What is changed? visx, visy, distance_matrix, index | |
# What is changed within this function? visx, visy, index | |
visx[t] = True | |
for i in range(M): | |
if not visy[i] and binary_distance_matrix[t, i]: | |
visy[i] = True | |
if index[i] == -1 or has_augmented_path(index[i], binary_distance_matrix): | |
index[i] = t | |
return True | |
return False | |
for i in range(M): | |
while True: | |
visx.fill(False) | |
visy.fill(False) | |
distance_matrix = self._get_distance_matrix(lx, ly, mat, M) | |
binary_distance_matrix = np.abs(distance_matrix) < 1e-4 | |
if has_augmented_path(i, binary_distance_matrix): | |
break | |
masked_distance_matrix = distance_matrix[:, ~visy][visx] | |
if 0 in masked_distance_matrix.shape: # empty matrix | |
raise RuntimeError("match error, matrix: {}".format(matrix)) | |
else: | |
d = masked_distance_matrix.min() | |
lx[visx] -= d | |
ly[visy] += d | |
return index | |
def _get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray: | |
""" | |
Overview: | |
Get distance matrix. | |
Arguments: | |
- lx (:obj:`np.ndarray`): lx. | |
- ly (:obj:`np.ndarray`): ly. | |
- mat (:obj:`np.ndarray`): mat. | |
- M (:obj:`int`): M. | |
""" | |
nlx = np.broadcast_to(lx, [M, M]).T | |
nly = np.broadcast_to(ly, [M, M]) | |
nret = nlx + nly - mat | |
return nret | |
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: | |
""" | |
Overview: | |
Calculate multiple logits loss. | |
Arguments: | |
- logits (:obj:`torch.Tensor`): Predicted logits, whose shape must be 2-dim, like (B, N). | |
- labels (:obj:`torch.LongTensor`): Ground truth. | |
Returns: | |
- loss (:obj:`torch.Tensor`): Calculated loss. | |
""" | |
assert (len(logits.shape) == 2) | |
metric_matrix = self._get_metric_matrix(logits, labels) | |
index = self._match(metric_matrix) | |
loss = [] | |
for i in range(metric_matrix.shape[0]): | |
loss.append(metric_matrix[index[i], i]) | |
return sum(loss) / len(loss) | |