import torch.nn as nn # FP16 utils from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors def make_master_params(model_params): """ Copy model parameters into a inflated tensor of full-precision parameters. """ master_params = _flatten_dense_tensors( [param.detach().float() for param in model_params] ) master_params = nn.Parameter(master_params) master_params.requires_grad = True return [master_params] def unflatten_master_params(model_params, master_params): """ Unflatten the master parameters to look like model_params. """ return _unflatten_dense_tensors(master_params[0].detach(), model_params) def model_params_to_master_params(model_params, master_params): """ Copy the model parameter data into the master parameters. """ master_params[0].detach().copy_( _flatten_dense_tensors([param.detach().float() for param in model_params]) ) def master_params_to_model_params(model_params, master_params): """ Copy the master parameter data back into the model parameters. """ for param, master_param in zip( model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params) ): param.detach().copy_(master_param) def model_grads_to_master_grads(model_params, master_params): """ Copy the gradients from the model parameters into the master parameters from make_master_params(). """ master_params[0].grad = _flatten_dense_tensors( [param.grad.data.detach().float() for param in model_params] ) def zero_grad(model_params): for param in model_params: if param.grad is not None: if param.grad.grad_fn is not None: param.grad.detach_() else: param.grad.requires_grad_(False) param.grad.zero_() # LR Schedulers from torch.optim.lr_scheduler import LambdaLR class LinearWarmupLRScheduler(LambdaLR): def __init__(self, optimizer, warmup_steps, last_epoch=-1): self.warmup_steps = warmup_steps super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) def lr_lambda(self, current_step): if current_step < self.warmup_steps: return float(current_step + 1) / self.warmup_steps return 1.0