Spaces:
Sleeping
Sleeping
from torch import Tensor | |
import torch.distributed as dist | |
import torch | |
import torch.nn.functional as F | |
class SimpleContrastiveLoss: | |
def __init__(self, temperature: float = 0.02): | |
self.temperature = temperature | |
def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: | |
if target is None: | |
target_per_qry = y.size(0) // x.size(0) | |
target = torch.arange( | |
0, x.size(0) * target_per_qry, target_per_qry, device=x.device, dtype=torch.long) | |
logits = torch.matmul(x, y.transpose(0, 1)) | |
loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) | |
return loss | |
class DistributedContrastiveLoss(SimpleContrastiveLoss): | |
def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02): | |
assert dist.is_initialized(), "Distributed training has not been properly initialized." | |
super().__init__() | |
self.word_size = dist.get_world_size() | |
self.rank = dist.get_rank() | |
self.scale_loss = scale_loss | |
self.temperature = temperature | |
def __call__(self, x: Tensor, y: Tensor, **kwargs): | |
dist_x = self.gather_tensor(x) | |
dist_y = self.gather_tensor(y) | |
loss = super().__call__(dist_x, dist_y, **kwargs) | |
if self.scale_loss: | |
loss = loss * self.word_size | |
return loss | |
def gather_tensor(self, t): | |
gathered = [torch.empty_like(t) for _ in range(self.word_size)] | |
dist.all_gather(gathered, t) | |
gathered[self.rank] = t | |
return torch.cat(gathered, dim=0) | |
class HardNegativeContrastiveLoss: | |
def __init__(self, temperature: float = 0.02): | |
self.temperature = temperature | |
def __call__(self, x: Tensor, y: Tensor, z: Tensor = None, reduction: str = 'mean') -> Tensor: | |
# x: query embeddings | |
# y: positive embeddings | |
# z: negative embeddings (optional) | |
if z is None: | |
target_per_qry = y.size(0) // x.size(0) | |
target = torch.arange( | |
0, x.size(0) * target_per_qry, target_per_qry, | |
device=x.device, dtype=torch.long) | |
logits = torch.matmul(x, y.transpose(0, 1)) | |
loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) | |
return loss | |
pos_logits = torch.matmul(x, y.transpose(0, 1)) | |
neg_logits = torch.matmul(x, z.transpose(0, 1)) | |
logits = torch.cat([pos_logits, neg_logits], dim=1) | |
target = torch.arange(x.size(0), device=x.device) | |
loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) | |
return loss | |
class DistributedHardNegativeContrastiveLoss(HardNegativeContrastiveLoss): | |
def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02): | |
assert dist.is_initialized(), "Distributed training has not been properly initialized." | |
super().__init__(temperature=temperature) | |
self.world_size = dist.get_world_size() | |
self.rank = dist.get_rank() | |
self.scale_loss = scale_loss | |
def __call__(self, x: Tensor, y: Tensor, z: Tensor = None, **kwargs): | |
dist_x = self.gather_tensor(x) | |
dist_y = self.gather_tensor(y) | |
dist_z = self.gather_tensor(z) if z is not None else None | |
loss = super().__call__(dist_x, dist_y, dist_z, **kwargs) | |
if self.scale_loss: | |
loss = loss * self.world_size | |
return loss | |
def gather_tensor(self, t): | |
if t is None: | |
return None | |
gathered = [torch.empty_like(t) for _ in range(self.world_size)] | |
dist.all_gather(gathered, t) | |
gathered[self.rank] = t | |
return torch.cat(gathered, dim=0) | |