Spaces:
Sleeping
Sleeping
File size: 3,821 Bytes
a72a7d4 1855cc2 a72a7d4 1855cc2 a72a7d4 1855cc2 a72a7d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
from torch import Tensor
import torch.distributed as dist
import torch
import torch.nn.functional as F
class SimpleContrastiveLoss:
def __init__(self, temperature: float = 0.02):
self.temperature = temperature
def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor:
if target is None:
target_per_qry = y.size(0) // x.size(0)
target = torch.arange(
0, x.size(0) * target_per_qry, target_per_qry, device=x.device, dtype=torch.long)
logits = torch.matmul(x, y.transpose(0, 1))
loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
return loss
class DistributedContrastiveLoss(SimpleContrastiveLoss):
def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02):
assert dist.is_initialized(), "Distributed training has not been properly initialized."
super().__init__()
self.word_size = dist.get_world_size()
self.rank = dist.get_rank()
self.scale_loss = scale_loss
self.temperature = temperature
def __call__(self, x: Tensor, y: Tensor, **kwargs):
dist_x = self.gather_tensor(x)
dist_y = self.gather_tensor(y)
loss = super().__call__(dist_x, dist_y, **kwargs)
if self.scale_loss:
loss = loss * self.word_size
return loss
def gather_tensor(self, t):
gathered = [torch.empty_like(t) for _ in range(self.word_size)]
dist.all_gather(gathered, t)
gathered[self.rank] = t
return torch.cat(gathered, dim=0)
class HardNegativeContrastiveLoss:
def __init__(self, temperature: float = 0.02):
self.temperature = temperature
def __call__(self, x: Tensor, y: Tensor, z: Tensor = None, reduction: str = 'mean') -> Tensor:
# x: query embeddings
# y: positive embeddings
# z: negative embeddings (optional)
if z is None:
target_per_qry = y.size(0) // x.size(0)
target = torch.arange(
0, x.size(0) * target_per_qry, target_per_qry,
device=x.device, dtype=torch.long)
logits = torch.matmul(x, y.transpose(0, 1))
loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
return loss
pos_logits = torch.matmul(x, y.transpose(0, 1))
neg_logits = torch.matmul(x, z.transpose(0, 1))
logits = torch.cat([pos_logits, neg_logits], dim=1)
target = torch.arange(x.size(0), device=x.device)
loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
return loss
class DistributedHardNegativeContrastiveLoss(HardNegativeContrastiveLoss):
def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02):
assert dist.is_initialized(), "Distributed training has not been properly initialized."
super().__init__(temperature=temperature)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
self.scale_loss = scale_loss
def __call__(self, x: Tensor, y: Tensor, z: Tensor = None, **kwargs):
dist_x = self.gather_tensor(x)
dist_y = self.gather_tensor(y)
dist_z = self.gather_tensor(z) if z is not None else None
loss = super().__call__(dist_x, dist_y, dist_z, **kwargs)
if self.scale_loss:
loss = loss * self.world_size
return loss
def gather_tensor(self, t):
if t is None:
return None
gathered = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(gathered, t)
gathered[self.rank] = t
return torch.cat(gathered, dim=0)
|