Spaces:
Sleeping
Sleeping
# Code adapted from SimCSE (https://github.com/princeton-nlp/SimCSE) governed by MIT license. | |
# Copyright (c) 2023, Salesforce, Inc. | |
# All rights reserved. | |
# SPDX-License-Identifier: BSD-3-Clause | |
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
import torch | |
import torch.distributed as dist | |
class GatherLayer(torch.autograd.Function): | |
""" | |
Gather tensors from all process, supporting backward propagation. | |
https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py | |
""" | |
def forward(ctx, input): | |
ctx.save_for_backward(input) | |
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] | |
dist.all_gather(output, input) | |
return tuple(output) | |
def backward(ctx, *grads): | |
(input,) = ctx.saved_tensors | |
grad_out = torch.zeros_like(input) | |
grad_out[:] = grads[dist.get_rank()] | |
return grad_out | |
def dist_gather(x: torch.tensor): | |
if not dist.is_initialized(): return x | |
if len(x.shape) == 0: | |
x = x.reshape(1) | |
x_gather = GatherLayer.apply(x) | |
x_gather = torch.cat(x_gather, dim=0) | |
return x_gather | |
def dist_gather_nograd(x: torch.tensor): | |
if not dist.is_initialized(): return x | |
x_gather = [torch.ones_like(x) for _ in range(get_world_size())] | |
dist.all_gather(x_gather, x, async_op=False) | |
x_gather = torch.cat(x_gather, dim=0) | |
return x_gather | |
def get_rank(): | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_main(): | |
return get_rank() == 0 | |
def get_world_size(): | |
if not dist.is_initialized(): | |
return 1 | |
else: | |
return dist.get_world_size() | |
def barrier(): | |
if dist.is_initialized(): | |
dist.barrier() | |
def varsize_gather_nograd(x: torch.Tensor): | |
"""gather tensors of different sizes along the first dimension""" | |
if not dist.is_initialized(): | |
return x | |
# determine max size | |
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) | |
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] | |
dist.all_gather(allsizes, size) | |
max_size = max([size.cpu().max() for size in allsizes]) | |
padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) | |
padded[: x.shape[0]] = x | |
output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] | |
dist.all_gather(output, padded) | |
output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] | |
output = torch.cat(output, dim=0) | |
return output | |