IDMR-demo / src /dist_utils.py
liubangwei
init IDMR demo
a72a7d4
# Code adapted from SimCSE (https://github.com/princeton-nlp/SimCSE) governed by MIT license.
# Copyright (c) 2023, Salesforce, Inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
import torch
import torch.distributed as dist
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all process, supporting backward propagation.
https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py
"""
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
(input,) = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out
def dist_gather(x: torch.tensor):
if not dist.is_initialized(): return x
if len(x.shape) == 0:
x = x.reshape(1)
x_gather = GatherLayer.apply(x)
x_gather = torch.cat(x_gather, dim=0)
return x_gather
@torch.no_grad()
def dist_gather_nograd(x: torch.tensor):
if not dist.is_initialized(): return x
x_gather = [torch.ones_like(x) for _ in range(get_world_size())]
dist.all_gather(x_gather, x, async_op=False)
x_gather = torch.cat(x_gather, dim=0)
return x_gather
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def is_main():
return get_rank() == 0
def get_world_size():
if not dist.is_initialized():
return 1
else:
return dist.get_world_size()
def barrier():
if dist.is_initialized():
dist.barrier()
@torch.no_grad()
def varsize_gather_nograd(x: torch.Tensor):
"""gather tensors of different sizes along the first dimension"""
if not dist.is_initialized():
return x
# determine max size
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
dist.all_gather(allsizes, size)
max_size = max([size.cpu().max() for size in allsizes])
padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device)
padded[: x.shape[0]] = x
output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
dist.all_gather(output, padded)
output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)]
output = torch.cat(output, dim=0)
return output