Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,383 Bytes
cc0c59d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch.nn as nn
# FP16 utils
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def make_master_params(model_params):
"""
Copy model parameters into a inflated tensor of full-precision parameters.
"""
master_params = _flatten_dense_tensors(
[param.detach().float() for param in model_params]
)
master_params = nn.Parameter(master_params)
master_params.requires_grad = True
return [master_params]
def unflatten_master_params(model_params, master_params):
"""
Unflatten the master parameters to look like model_params.
"""
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
def model_params_to_master_params(model_params, master_params):
"""
Copy the model parameter data into the master parameters.
"""
master_params[0].detach().copy_(
_flatten_dense_tensors([param.detach().float() for param in model_params])
)
def master_params_to_model_params(model_params, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
for param, master_param in zip(
model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
):
param.detach().copy_(master_param)
def model_grads_to_master_grads(model_params, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params[0].grad = _flatten_dense_tensors(
[param.grad.data.detach().float() for param in model_params]
)
def zero_grad(model_params):
for param in model_params:
if param.grad is not None:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
# LR Schedulers
from torch.optim.lr_scheduler import LambdaLR
class LinearWarmupLRScheduler(LambdaLR):
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
self.warmup_steps = warmup_steps
super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, current_step):
if current_step < self.warmup_steps:
return float(current_step + 1) / self.warmup_steps
return 1.0
|