Spaces:
Sleeping
Sleeping
from typing import List, Union, Callable, Any, Dict | |
from contextlib import nullcontext | |
from itertools import repeat | |
from collections import UserDict | |
import logging | |
import torch | |
from torch import nn, Tensor | |
from torch.cuda.amp import GradScaler, autocast | |
from grad_cache.context_managers import RandContext | |
from src.model.biencoder import BiEncoder | |
from utils import dist_utils | |
logger = logging.getLogger(__name__) | |
def is_binary_tensor(tensor): | |
unique_elements = torch.unique(tensor) | |
return torch.equal(unique_elements, torch.tensor([0, 1], dtype=tensor.dtype).to(unique_elements.device)) | |
class BiEncoderGradCache(nn.Module): | |
""" | |
Gradient Cache class. Implements input chunking, first graph-less forward pass, Gradient Cache creation, second | |
forward & backward gradient computation. Optimizer step is not included. Native torch automatic mixed precision is | |
supported. User needs to handle gradient unscaling and scaler update after a gradeitn cache step. | |
""" | |
def __init__( | |
self, | |
models: List[nn.Module], | |
chunk_sizes: Union[int, List[int]], | |
loss_fns, | |
split_input_fn: Callable[[Any, int], Any] = None, | |
get_rep_fn: Callable[..., Tensor] = None, | |
fp16_or_bf16: bool = False, | |
dtype=torch.float32, | |
scaler: GradScaler = None, | |
): | |
""" | |
Initialize the Gradient Cache class instance. | |
:param models: A list of all encoder models to be updated by the current cache. | |
:param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model. | |
:param loss_fns: A dict of loss functions that takes arbitrary numbers of representation tensors and | |
arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations | |
in the autograd graph, which are later relied upon to create the gradient cache. | |
:param split_input_fn: An optional function that split generic model input into chunks. If not provided, this | |
class will try its best to split the inputs of supported types. See `split_inputs` function. | |
:param get_rep_fn: An optional function that takes generic model output and return representation tensors. If | |
not provided, the generic output is assumed to be the representation tensor. | |
:param fp16_or_bf16: If True, run mixed precision training, which requires scaler to also be set. | |
:param scaler: A GradScaler object for automatic mixed precision training. | |
""" | |
super(BiEncoderGradCache, self).__init__() | |
self.models = models | |
self.q_encoder = models[0] | |
self.k_encoder = models[1] | |
if isinstance(chunk_sizes, int): | |
self.chunk_sizes = [chunk_sizes for _ in range(len(models))] | |
else: | |
self.chunk_sizes = chunk_sizes | |
self.split_input_fn = split_input_fn | |
self.get_rep_fn = get_rep_fn | |
self.loss_fns = loss_fns | |
self.fp16_or_bf16 = fp16_or_bf16 | |
self.dtype = dtype | |
self.scaler = scaler | |
self._get_input_tensors_strict = False | |
def __call__(self, *args, **kwargs): | |
""" | |
Call the cache_step function. | |
:return: Current step loss. | |
""" | |
return self.cache_step(*args, **kwargs) | |
def split_inputs(self, model_input, chunk_size: int) -> List: | |
""" | |
Split input into chunks. Will call user provided `split_input_fn` if specified. Otherwise, | |
it can handle input types of tensor, list of tensors and dictionary of tensors. | |
:param model_input: Generic pytorch input. | |
:param chunk_size: Size of each chunk. | |
:return: A list of chunked pytorch input. | |
""" | |
# delegate splitting to user provided function | |
if self.split_input_fn is not None: | |
return self.split_input_fn(model_input, chunk_size) | |
if isinstance(model_input, (dict, UserDict)) and all(isinstance(x, Tensor) for x in model_input.values()): | |
keys = list(model_input.keys()) | |
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] | |
return [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] | |
elif isinstance(model_input, list) and all(isinstance(x, Tensor) for x in model_input): | |
chunked_x = [t.split(chunk_size, dim=0) for t in model_input] | |
return [list(s) for s in zip(*chunked_x)] | |
elif isinstance(model_input, Tensor): | |
return list(model_input.split(chunk_size, dim=0)) | |
elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]: | |
args_chunks = self.split_inputs(model_input[0], chunk_size) | |
kwargs_chunks = self.split_inputs(model_input[1], chunk_size) | |
return list(zip(args_chunks, kwargs_chunks)) | |
else: | |
raise NotImplementedError(f'Model input split not implemented for type {type(model_input)}') | |
def get_input_tensors(self, model_input) -> List[Tensor]: | |
""" | |
Recursively go through model input and grab all tensors, which are then used to record current device random | |
states. This method will do its best to parse types of Tensor, tuple, list, dict and UserDict. Other types will | |
be ignored unless self._get_input_tensors_strict is set to True, in which case an exception will be raised. | |
:param model_input: input to model | |
:return: all torch tensors in model_input | |
""" | |
if isinstance(model_input, Tensor): | |
return [model_input] | |
elif isinstance(model_input, (list, tuple)): | |
return sum((self.get_input_tensors(x) for x in model_input), []) | |
elif isinstance(model_input, (dict, UserDict)): | |
return sum((self.get_input_tensors(x) for x in model_input.values()), []) | |
elif self._get_input_tensors_strict: | |
raise NotImplementedError(f'get_input_tensors not implemented for type {type(model_input)}') | |
else: | |
return [] | |
def model_call(self, model: nn.Module, model_input): | |
""" | |
Literally call the model's __call__ method. | |
:param model: model to be called | |
:param model_input: input to the model call | |
:return: model output | |
""" | |
with autocast('cuda', dtype=self.dtype) if self.fp16_or_bf16 else nullcontext(): | |
if isinstance(model_input, Tensor): | |
return model(model_input) | |
elif isinstance(model_input, list): | |
return model(*model_input) | |
elif isinstance(model_input, (dict, UserDict)): | |
return model(**model_input) | |
elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]: | |
model_args, model_kwargs = model_input | |
return model(*model_args, **model_kwargs) | |
elif isinstance(model_input, tuple): | |
return model(*model_input) | |
else: | |
raise NotImplementedError | |
def get_reps(self, model_out) -> Tensor: | |
""" | |
Return representation tensor from generic model output | |
:param model_out: generic model output | |
:return: a single tensor corresponding to the model representation output | |
""" | |
if self.get_rep_fn is not None: | |
return self.get_rep_fn(model_out) | |
else: | |
return model_out | |
def compute_loss(self, loss_mapping=None, *reps: Tensor, **loss_kwargs) -> Tensor: | |
""" | |
Compute the loss based on the representation tensors. The tensors should be ordered same as the list of models | |
registered in this GradCache class instance. | |
:param reps: Representations for computing the loss. | |
reps[0]: query vector, shape=[B,H] | |
reps[1]: doc vector, shape=[B*num_neg,H] | |
:param loss_kwargs: Keyword arguments input to the loss function. | |
:return: the loss tensor. | |
""" | |
if loss_mapping is None: | |
loss_fn = self.loss_fns["distributed_inbatch_contrastive"] | |
loss, loss_details = loss_fn(*reps, **loss_kwargs) | |
else: | |
# print('start to compute loss') | |
bsz, hdim = reps[0].shape | |
loss, loss_details = 0.0, {} | |
preds = torch.zeros(bsz * dist_utils.get_world_size(), dtype=torch.long, device=reps[0].device) | |
labels = torch.zeros(bsz * dist_utils.get_world_size(), dtype=torch.long, device=reps[0].device) | |
for loss_name, data_idxs in loss_mapping.items(): | |
# print("get loss_name, data_indxs", loss_name, data_idxs) | |
data_idxs = torch.tensor(data_idxs).to(reps[0].device) | |
q = reps[0].index_select(0, index=data_idxs) | |
if len(reps[1].shape) == 1 or is_binary_tensor(reps[1]): | |
# in cases d is one-hot label for classification loss | |
d = reps[1] | |
else: | |
d = reps[1].view(bsz, -1, hdim).index_select(0, index=data_idxs) | |
d = d.view(-1, hdim) | |
# print_rank(f"loss_name={loss_name}, q.shape={q.shape}, d.shape={d.shape}") | |
_loss, _loss_details = self.loss_fns[loss_name](q, d, **loss_kwargs) | |
loss += _loss | |
# print("finish loss fns") | |
if "labels" in _loss_details: | |
# since we compute losses per group/loss-type (stored in loss_mapping), so the data is reordered by group and we need to gather preds/labels | |
if torch.distributed.is_initialized(): | |
data_idxs = data_idxs + bsz * dist_utils.get_rank() | |
# print('start to gather data index') | |
data_idxs = dist_utils.dist_gather(data_idxs) | |
# print('finish gather the data index') | |
# TODO, this might not work correctly for classification loss | |
preds.index_copy_(0, data_idxs, _loss_details["preds"]) | |
labels.index_copy_(0, data_idxs, _loss_details["labels"]) | |
loss_details["preds"] = preds | |
loss_details["labels"] = labels | |
# print('finish loss', data_idxs) | |
# print('finish to compute_loss') | |
return loss, loss_details | |
def forward_no_grad( | |
self, | |
model: nn.Module, | |
model_inputs, | |
) -> [Tensor, List[RandContext]]: | |
""" | |
The first forward pass without gradient computation. | |
:param model: Encoder model. | |
:param model_inputs: Model input already broken into chunks. A tuple of two lists (ids, masks) | |
:return: A tuple of a) representations and b) recorded random states. | |
""" | |
rnd_states = [] | |
model_reps = [] | |
with torch.no_grad(): | |
for x in zip(*model_inputs): | |
rnd_states.append(RandContext(*self.get_input_tensors(x))) | |
y = self.model_call(model, x) | |
model_reps.append(self.get_reps(y)) | |
# concatenate all sub-batch representations | |
model_reps = torch.cat(model_reps, dim=0) | |
return model_reps, rnd_states | |
def build_cache(self, deepspeed=None, loss_mapping=None, *reps: Tensor, **loss_kwargs) -> [List[Tensor], Tensor]: | |
""" | |
Compute the gradient cache | |
:param reps: Computed representations from all encoder models | |
:param loss_kwargs: Extra keyword arguments to the loss function | |
:return: A tuple of a) gradient cache for each encoder model, and b) loss tensor | |
""" | |
new_reps = [] | |
for r in reps: | |
if isinstance(r, torch.Tensor) and r.ndim == 2: | |
new_reps.append(r.detach().requires_grad_()) | |
elif isinstance(r, list): | |
new_reps.append(torch.cat(r, dim=0)) | |
# reps = [r.detach().requires_grad_() for r in reps] | |
reps = tuple(new_reps) | |
with autocast(dtype=self.dtype) if self.fp16_or_bf16 else nullcontext(): | |
loss, loss_details = self.compute_loss(loss_mapping, *reps, **loss_kwargs) | |
if deepspeed is None: | |
if self.scaler: | |
self.scaler.scale(loss).backward() | |
else: | |
loss.backward() | |
else: | |
deepspeed.backward(loss) | |
cache = [r.grad for r in reps if len(r.shape) > 1 and not is_binary_tensor(r[0])] | |
return cache, loss.detach(), loss_details | |
def forward_backward( | |
self, | |
model: nn.Module, | |
model_inputs, | |
cached_gradients: List[Tensor], | |
random_states: List[RandContext], | |
no_sync_except_last: bool = False, | |
deepspeed = None, | |
): | |
""" | |
Run the second forward and the backward pass to compute gradient for a model. | |
:param model: Encoder model. | |
:param model_inputs: Chunked input to the encoder model. | |
:param cached_gradients: Chunked gradient cache tensor for each input. | |
:param random_states: Each input's device random state during the first forward. | |
:param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes | |
for the last sub-batch's forward-backward pass. | |
""" | |
if no_sync_except_last and deepspeed is None: | |
sync_contexts = [model.no_sync for _ in range(len(model_inputs) - 1)] + [nullcontext] | |
else: | |
sync_contexts = [nullcontext for _ in range(len(model_inputs))] | |
for x, state, gradient, sync_context in zip(model_inputs, random_states, cached_gradients, sync_contexts): | |
with sync_context(): | |
with state: | |
y = self.model_call(model, x) | |
reps = self.get_reps(y) | |
surrogate = torch.dot(reps.flatten(), gradient.flatten()) | |
if deepspeed is None: | |
surrogate.backward() | |
else: | |
deepspeed.backward(surrogate) | |
def cache_step( | |
self, | |
inputs, | |
masks, | |
no_sync_except_last: bool = False, | |
deepspeed: object = None, | |
loss_mapping = None, | |
**loss_kwargs | |
) -> Tensor: | |
""" | |
Run a cached step to compute gradient over the inputs. | |
:param model_inputs: Input to each encoder model. Should be in similar order as the class's model. | |
:param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction | |
across processes for the last sub-batch's forward-backward pass. | |
:param loss_kwargs: Additional keyword arguments to the loss function. | |
:return: The current's loss. | |
""" | |
all_reps = [] | |
all_rnd_states = [] | |
inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(inputs, self.chunk_sizes)] | |
masks = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(masks, self.chunk_sizes)] | |
for model, input, mask in zip(self.models, inputs, masks): | |
if len(input[0].shape) == 1 or is_binary_tensor(input[0]): | |
# input is label | |
all_reps.append(input) | |
all_rnd_states.append(input) | |
else: | |
model_reps, rnd_states = self.forward_no_grad(model, model_inputs=(input, mask)) | |
all_reps.append(model_reps) | |
all_rnd_states.append(rnd_states) | |
# print('start to build cache') | |
cache, loss, loss_details = self.build_cache(deepspeed, loss_mapping, *all_reps, **loss_kwargs) | |
cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)] | |
for model, input, mask, model_cache, rnd_states in zip(self.models, inputs, masks, cache, all_rnd_states): | |
self.forward_backward(model, model_inputs=list(zip(input, mask)), | |
cached_gradients=model_cache, random_states=rnd_states, | |
no_sync_except_last=no_sync_except_last, | |
deepspeed=deepspeed, | |
) | |
# print('finish forward backward') | |
log_stats = BiEncoder._report_train_metrics(q=all_reps[0], k=all_reps[1], | |
preds=loss_details["preds"], labels=loss_details["labels"], | |
loss_details=loss_details) | |
return loss, log_stats | |