Spaces:
Sleeping
Sleeping
import torch | |
import math | |
from torch.nn.utils import clip_grad_norm_, clip_grad_value_ | |
from typing import Union, Iterable, Tuple, Callable, List | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import pdb | |
import numpy as np | |
import copy | |
import random | |
inf = math.inf | |
def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float: | |
""" | |
Overview: | |
calculate grad norm of the parameters whose grad norms are not None in the model. | |
Arguments: | |
- model: torch.nn.Module | |
- norm_type (:obj:`int` or `inf`) | |
""" | |
parameters = list(filter(lambda p: p.grad is not None, model.parameters())) | |
if parameters == []: | |
parameters = 0 | |
return 0 | |
if norm_type == 'inf': | |
total_norm = max(p.grad.data.abs().max() for p in parameters) | |
return float(total_norm) | |
else: | |
total_norm = 0 | |
for p in parameters: | |
param_norm = p.grad.data.norm(norm_type) | |
total_norm += param_norm.item() ** norm_type | |
total_norm = total_norm ** (1. / norm_type) | |
return float(total_norm) | |
def calculate_grad_norm_without_bias_two_norm(model: torch.nn.Module) -> float: | |
""" | |
Overview: | |
calculate grad norm of the parameters whose grad norms are not None in the model. | |
Arguments: | |
- model: torch.nn.Module | |
""" | |
_list = [] | |
for name, param in model.named_parameters(): | |
if 'bias' not in name and param.requires_grad: | |
if param.grad is None: | |
return 0 | |
_list.append(param.grad.data.norm(2).item() ** 2) | |
return float(sum(_list) ** (1. / 2)) | |
def grad_ignore_norm(parameters, max_norm, norm_type=2): | |
""" | |
Overview: | |
Clip the gradient norm of an iterable of parameters. | |
Arguments: | |
- parameters (:obj:`Iterable`): an iterable of torch.Tensor | |
- max_norm (:obj:`float`): the max norm of the gradients | |
- norm_type (:obj:`float`): 2.0 means use norm2 to clip | |
""" | |
if isinstance(parameters, torch.Tensor): | |
parameters = [parameters] | |
parameters = list(filter(lambda p: p.grad is not None, parameters)) | |
max_norm = float(max_norm) | |
norm_type = float(norm_type) | |
if norm_type == inf: | |
total_norm = max(p.grad.data.abs().max() for p in parameters) | |
else: | |
total_norm = 0 | |
for p in parameters: | |
param_norm = p.grad.data.norm(norm_type) | |
total_norm += param_norm.item() ** norm_type | |
total_norm = total_norm ** (1. / norm_type) | |
clip_coef = max_norm / (total_norm + 1e-6) | |
if clip_coef < 1: | |
for p in parameters: | |
p.grad.zero_() | |
return total_norm | |
def grad_ignore_value(parameters, clip_value): | |
""" | |
Overview: | |
Clip the gradient value of an iterable of parameters. | |
Arguments: | |
- parameters (:obj:`Iterable`): an iterable of torch.Tensor | |
- clip_value (:obj:`float`): the value to start clipping | |
""" | |
if isinstance(parameters, torch.Tensor): | |
parameters = [parameters] | |
clip_value = float(clip_value) | |
flag = False | |
for p in filter(lambda p: p.grad is not None, parameters): | |
val = p.grad.data.abs().max() | |
if val >= clip_value: | |
flag = True | |
break | |
if flag: | |
for p in filter(lambda p: p.grad is not None, parameters): | |
p.grad.data.zero_() | |
class Adam(torch.optim.Adam): | |
""" | |
Overview: | |
Rewrited Adam optimizer to support more features. | |
Interfaces: | |
``__init__``, ``step``, ``_state_init``, ``get_grad`` | |
""" | |
def __init__( | |
self, | |
params: Iterable, | |
lr: float = 1e-3, | |
betas: Tuple[float, float] = (0.9, 0.999), | |
eps: float = 1e-8, | |
weight_decay: float = 0, | |
amsgrad: bool = False, | |
optim_type: str = 'adam', | |
grad_clip_type: str = None, | |
clip_value: Union[float, None] = None, | |
clip_coef: float = 5, | |
clip_norm_type: float = 2.0, | |
clip_momentum_timestep: int = 100, | |
grad_norm_type: str = None, | |
grad_ignore_type: str = None, | |
ignore_value: Union[float, None] = None, | |
ignore_coef: float = 5, | |
ignore_norm_type: float = 2.0, | |
ignore_momentum_timestep: int = 100, | |
): | |
""" | |
Overview: | |
init method of refactored Adam class | |
Arguments: | |
- params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \ | |
Specifies what Tensors should be optimized | |
- lr (:obj:`float`): learning rate, default set to 1e-3 | |
- betas (:obj:`Tuple[float, float]`): coefficients used for computing running averages of gradient and its\ | |
square, default set to (0.9, 0.999)) | |
- eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8 | |
- weight_decay (:obj:`float`): weight decay coefficient, deault set to 0 | |
- amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\ | |
On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237> | |
- optim_type (:obj:str): support ["adam", "adamw"] | |
- grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \ | |
'clip_momentum_norm'] | |
- clip_value (:obj:`float`): the value to start clipping | |
- clip_coef (:obj:`float`): the cliping coefficient | |
- clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip | |
- clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping | |
- grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \ | |
'ignore_momentum_norm'] | |
- ignore_value (:obj:`float`): the value to start ignoring | |
- ignore_coef (:obj:`float`): the ignoreing coefficient | |
- ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore | |
- ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring | |
""" | |
self._support_type = { | |
'optim': ['adam', 'adamw'], | |
'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'], | |
'grad_norm': [None], | |
'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'], | |
} | |
assert optim_type in self._support_type['optim'] | |
assert grad_clip_type in self._support_type['grad_clip'] | |
assert grad_norm_type in self._support_type['grad_norm'] | |
assert grad_ignore_type in self._support_type['grad_ignore'] | |
if grad_clip_type: | |
assert clip_value is not None | |
if grad_ignore_type: | |
assert ignore_value is not None | |
self._optim_type = optim_type | |
self._grad_clip_type = grad_clip_type | |
self._grad_norm_type = grad_norm_type | |
self._grad_ignore_type = grad_ignore_type | |
self._clip_value = clip_value | |
self._clip_norm_type = clip_norm_type | |
self._clip_coef = clip_coef | |
self._ignore_value = ignore_value | |
self._ignore_norm_type = ignore_norm_type | |
self._ignore_coef = ignore_coef | |
self._clip_momentum_timestep = clip_momentum_timestep | |
self._ignore_momentum_timestep = ignore_momentum_timestep | |
if self._optim_type == 'adamw': | |
self._weight_decay = weight_decay | |
super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=0, amsgrad=amsgrad) | |
elif self._optim_type == 'adam': | |
super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) | |
else: | |
raise NotImplementedError( | |
"optimizer type {} is not implemented, support type is {}".format( | |
self._optim_type, self._support_type['optim'] | |
) | |
) | |
def _state_init(self, p, amsgrad): | |
""" | |
Overview: | |
Initialize the state of the optimizer | |
Arguments: | |
- p (:obj:`torch.Tensor`): the parameter to be optimized | |
- amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\ | |
On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237> | |
""" | |
state = self.state[p] | |
state['thre_exp_avg_sq'] = torch.zeros_like(p.data, device=p.data.device) | |
# others | |
if torch.__version__ < "1.12.0": | |
state['step'] = 0 | |
# TODO | |
# wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0 | |
else: | |
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ | |
if self.defaults['capturable'] else torch.tensor(0.) | |
state['exp_avg'] = torch.zeros_like(p.data) | |
# Exponential moving average of squared gradient values | |
state['exp_avg_sq'] = torch.zeros_like(p.data) | |
if amsgrad: | |
# Maintains max of all exp. moving avg. of sq. grad. values | |
state['max_exp_avg_sq'] = torch.zeros_like(p.data) | |
def step(self, closure: Union[Callable, None] = None): | |
""" | |
Overview: | |
Performs a single optimization step | |
Arguments: | |
- closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None | |
""" | |
# clipping | |
new_params = [ | |
t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None | |
] | |
if self._grad_clip_type == 'clip_value': | |
clip_grad_value_(new_params, self._clip_value) | |
elif self._grad_clip_type == 'clip_norm': | |
clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type) | |
elif self._grad_clip_type == 'clip_momentum': | |
''' | |
This is the implimentation mimic the clip used in OPENAI, quote: | |
'Gradients are additionally clipped per parameter to be within between ±5√v | |
where v is the running estimate of the second moment of the (unclipped) gradient' | |
''' | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self._state_init(p, group['amsgrad']) | |
grad = p.grad.data | |
# should we use same beta group? | |
beta1, beta2 = group['betas'] | |
bias_correction2 = 1 - beta2 ** state['step'] | |
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate | |
flag = grad.abs( | |
) > (state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * self._clip_coef | |
grad.mul_(~flag).add_( | |
((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * | |
self._clip_coef).mul_(flag) | |
) | |
elif self._grad_clip_type == 'clip_momentum_norm': | |
# might have multi param_group, we should calculate each group differently. | |
for group in self.param_groups: | |
total_norm = 0 | |
total_momentum_norm = 0 | |
step = inf | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self._state_init(p, group['amsgrad']) | |
grad = p.grad.data | |
# should we use same beta group? | |
beta1, beta2 = group['betas'] | |
bias_correction2 = 1 - beta2 ** state['step'] | |
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
# sum total_norm | |
param_norm = grad.norm(self._clip_norm_type) | |
total_norm += param_norm.item() ** self._clip_norm_type | |
# sum momentum_norm | |
momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * | |
self._clip_coef).norm(self._clip_norm_type) | |
total_momentum_norm += momentum.item() ** self._clip_norm_type | |
step = min(step, state['step']) | |
if step > self._clip_momentum_timestep: | |
total_norm = total_norm ** (1. / self._clip_norm_type) | |
total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type) | |
clip_coef = total_momentum_norm / (total_norm + 1e-6) | |
if clip_coef < 1: | |
for p in group['params']: | |
p.grad.data.mul_(clip_coef) | |
if self._grad_ignore_type == 'ignore_value': | |
grad_ignore_value(new_params, self._ignore_value) | |
elif self._grad_ignore_type == 'ignore_norm': | |
grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type) | |
elif self._grad_ignore_type == 'ignore_momentum': | |
flag = False | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self._state_init(p, group['amsgrad']) | |
grad = p.grad.data | |
# should we use same beta group? | |
beta1, beta2 = group['betas'] | |
bias_correction2 = 1 - beta2 ** state['step'] | |
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate | |
if grad.abs() > (state['thre_exp_avg_sq'].sqrt() / | |
math.sqrt(bias_correction2)) * self._ignore_coef: | |
flag = True | |
break | |
else: | |
continue | |
break | |
if flag: | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
p.grad.zero_() | |
elif self._grad_ignore_type == 'ignore_momentum_norm': | |
# might have multi param_group, we should calculate each group differently. | |
step = inf | |
for group in self.param_groups: | |
total_norm = 0 | |
total_momentum_norm = 0 | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self._state_init(p, group['amsgrad']) | |
grad = p.grad.data | |
# should we use same beta group? | |
beta1, beta2 = group['betas'] | |
bias_correction2 = 1 - beta2 ** state['step'] | |
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
# sum total_norm | |
param_norm = grad.norm(self._ignore_norm_type) | |
total_norm += param_norm.item() ** self._ignore_norm_type | |
# sum momentum_norm | |
momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * | |
self._ignore_coef).norm(self._ignore_norm_type) | |
total_momentum_norm += momentum.item() ** self._ignore_norm_type | |
step = min(step, state['step']) | |
if step > self._ignore_momentum_timestep: | |
total_norm = total_norm ** (1. / self._ignore_norm_type) | |
total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type) | |
ignore_coef = total_momentum_norm / (total_norm + 1e-6) | |
if ignore_coef < 1: | |
for p in group['params']: | |
p.grad.zero_() | |
# Adam optim type | |
if self._optim_type == 'adamw': | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
p.data = p.data.add(-self._weight_decay * group['lr'], p.data) | |
return super().step(closure=closure) | |
elif self._optim_type == 'adam': | |
return super().step(closure=closure) | |
def get_grad(self) -> float: | |
total_norm = 0. | |
params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None] | |
for p in params: | |
param_norm = p.grad.data.norm(self._clip_norm_type) | |
total_norm += param_norm.item() ** self._clip_norm_type | |
return total_norm | |
class RMSprop(torch.optim.RMSprop): | |
r""" | |
Overview: | |
Rewrited RMSprop optimizer to support more features. | |
Interfaces: | |
``__init__``, ``step``, ``_state_init``, ``get_grad`` | |
""" | |
def __init__( | |
self, | |
params: Iterable, | |
lr: float = 1e-2, | |
alpha: float = 0.99, | |
eps: float = 1e-8, | |
weight_decay: float = 0, | |
momentum: float = 0, | |
centered: bool = False, | |
grad_clip_type: str = None, | |
clip_value: Union[float, None] = None, | |
clip_coef: float = 5, | |
clip_norm_type: float = 2.0, | |
clip_momentum_timestep: int = 100, | |
grad_norm_type: str = None, | |
grad_ignore_type: str = None, | |
ignore_value: Union[float, None] = None, | |
ignore_coef: float = 5, | |
ignore_norm_type: float = 2.0, | |
ignore_momentum_timestep: int = 100, | |
): | |
""" | |
Overview: | |
init method of refactored Adam class | |
Arguments: | |
- params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \ | |
Specifies what Tensors should be optimized | |
- lr (:obj:`float`): learning rate, default set to 1e-3 | |
- alpha (:obj:`float`): smoothing constant, default set to 0.99 | |
- eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8 | |
- weight_decay (:obj:`float`): weight decay coefficient, deault set to 0 | |
- centred (:obj:`bool`): if True, compute the centered RMSprop, \ | |
the gradient is normalized by an estimation of its variance | |
- grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \ | |
'clip_momentum_norm'] | |
- clip_value (:obj:`float`): the value to start clipping | |
- clip_coef (:obj:`float`): the cliping coefficient | |
- clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip | |
- clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping | |
- grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \ | |
'ignore_momentum_norm'] | |
- ignore_value (:obj:`float`): the value to start ignoring | |
- ignore_coef (:obj:`float`): the ignoreing coefficient | |
- ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore | |
- ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring | |
""" | |
self._support_type = { | |
'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'], | |
'grad_norm': [None], | |
'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'], | |
} | |
assert grad_clip_type in self._support_type['grad_clip'] | |
assert grad_norm_type in self._support_type['grad_norm'] | |
assert grad_ignore_type in self._support_type['grad_ignore'] | |
if grad_clip_type: | |
assert clip_value is not None | |
if grad_ignore_type: | |
assert ignore_value is not None | |
self._grad_clip_type = grad_clip_type | |
self._grad_norm_type = grad_norm_type | |
self._grad_ignore_type = grad_ignore_type | |
self._clip_value = clip_value | |
self._clip_norm_type = clip_norm_type | |
self._clip_coef = clip_coef | |
self._ignore_value = ignore_value | |
self._ignore_norm_type = ignore_norm_type | |
self._ignore_coef = ignore_coef | |
self._clip_momentum_timestep = clip_momentum_timestep | |
self._ignore_momentum_timestep = ignore_momentum_timestep | |
super(RMSprop, self).__init__( | |
params, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum, centered=centered | |
) | |
def _state_init(self, p, momentum, centered): | |
""" | |
Overview: | |
Initialize the state of the optimizer | |
Arguments: | |
- p (:obj:`torch.Tensor`): the parameter to be optimized | |
- momentum (:obj:`float`): the momentum coefficient | |
- centered (:obj:`bool`): if True, compute the centered RMSprop, \ | |
the gradient is normalized by an estimation of its variance | |
""" | |
state = self.state[p] | |
state['step'] = 0 | |
state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device) | |
state['square_avg'] = torch.zeros_like(p.data, device=p.data.device) | |
if momentum: | |
state['momentum_buffer'] = torch.zeros_like(p.data, device=p.data.device) | |
if centered: | |
state['grad_avg'] = torch.zeros_like(p.data, device=p.data.device) | |
def step(self, closure: Union[Callable, None] = None): | |
""" | |
Overview: | |
Performs a single optimization step | |
Arguments: | |
- closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None | |
""" | |
# clipping | |
new_params = [ | |
t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None | |
] | |
if self._grad_clip_type == 'clip_value': | |
clip_grad_value_(new_params, self._clip_value) | |
elif self._grad_clip_type == 'clip_norm': | |
clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type) | |
elif self._grad_clip_type == 'clip_momentum': | |
''' | |
This implementation mimics the clip used in OPENAI, quote: | |
'Gradients are additionally clipped per parameter to be within between ±5√v | |
where v is the running estimate of the second moment of the (unclipped) gradient' | |
''' | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self._state_init(p, group['momentum'], group['centered']) | |
grad = p.grad.data | |
# beta1, beta2 = group['betas'] | |
alpha = group['alpha'] | |
state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) | |
if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate | |
flag = grad.abs() > state['thre_square_avg'].sqrt() * self._clip_coef | |
grad.mul_(~flag).add_((state['thre_square_avg'].sqrt() * self._clip_coef).mul_(flag)) | |
elif self._grad_clip_type == 'clip_momentum_norm': | |
# might have multi param_group, we should calculate each group differently. | |
for group in self.param_groups: | |
total_norm = 0 | |
total_momentum_norm = 0 | |
step = inf | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self._state_init(p, group['momentum'], group['centered']) | |
grad = p.grad.data | |
alpha = group['alpha'] | |
state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) | |
# sum total_norm | |
param_norm = grad.norm(self._clip_norm_type) | |
total_norm += param_norm.item() ** self._clip_norm_type | |
# sum momentum_norm | |
momentum = (state['thre_square_avg'].sqrt() * self._clip_coef).norm(self._clip_norm_type) | |
total_momentum_norm += momentum.item() ** self._clip_norm_type | |
step = min(step, state['step']) | |
if step > self._clip_momentum_timestep: | |
total_norm = total_norm ** (1. / self._clip_norm_type) | |
total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type) | |
clip_coef = total_momentum_norm / (total_norm + 1e-6) | |
if clip_coef < 1: | |
for p in group['params']: | |
p.grad.data.mul_(clip_coef) | |
if self._grad_ignore_type == 'ignore_value': | |
grad_ignore_value(new_params, self._ignore_value) | |
elif self._grad_ignore_type == 'ignore_norm': | |
grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type) | |
elif self._grad_ignore_type == 'ignore_momentum': | |
flag = False | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self._state_init(p, group['momentum'], group['centered']) | |
grad = p.grad.data | |
alpha = group['alpha'] | |
state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) | |
if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate | |
if grad.abs() > state['thre_square_avg'].sqrt() * self._ignore_coef: | |
flag = True | |
break | |
else: | |
continue | |
break | |
if flag: | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
p.grad.zero_() | |
elif self._grad_ignore_type == 'ignore_momentum_norm': | |
# might have multi param_group, we should calculate each group differently. | |
step = inf | |
for group in self.param_groups: | |
total_norm = 0 | |
total_momentum_norm = 0 | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
self._state_init(p, group['momentum'], group['centered']) | |
grad = p.grad.data | |
alpha = group['alpha'] | |
state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad) | |
# sum total_norm | |
param_norm = grad.norm(self._ignore_norm_type) | |
total_norm += param_norm.item() ** self._ignore_norm_type | |
# sum momentum_norm | |
momentum = (state['thre_square_avg'].sqrt() * self._ignore_coef).norm(self._ignore_norm_type) | |
total_momentum_norm += momentum.item() ** self._ignore_norm_type | |
step = min(step, state['step']) | |
if step > self._ignore_momentum_timestep: | |
total_norm = total_norm ** (1. / self._ignore_norm_type) | |
total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type) | |
ignore_coef = total_momentum_norm / (total_norm + 1e-6) | |
if ignore_coef < 1: | |
for p in group['params']: | |
p.grad.zero_() | |
return super().step(closure=closure) | |
def get_grad(self) -> float: | |
""" | |
Overview: | |
calculate grad norm of the parameters whose grad norms are not None in the model. | |
""" | |
total_norm = 0. | |
params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None] | |
for p in params: | |
param_norm = p.grad.data.norm(self._clip_norm_type) | |
total_norm += param_norm.item() ** self._clip_norm_type | |
return total_norm | |
class PCGrad(): | |
""" | |
Overview: | |
PCGrad optimizer to support multi-task. | |
you can view the paper in the following link https://arxiv.org/pdf/2001.06782.pdf | |
Interfaces: | |
``__init__``, ``zero_grad``, ``step``, ``pc_backward`` | |
Properties: | |
- optimizer (:obj:`torch.optim`): the optimizer to be used | |
""" | |
def __init__(self, optimizer, reduction='mean'): | |
""" | |
Overview: | |
Initialization of PCGrad optimizer | |
Arguments: | |
- optimizer (:obj:`torch.optim`): the optimizer to be used | |
- reduction (:obj:`str`): the reduction method, support ['mean', 'sum'] | |
""" | |
self._optim, self._reduction = optimizer, reduction | |
def optimizer(self): | |
""" | |
Overview: | |
get the optimizer | |
""" | |
return self._optim | |
def zero_grad(self): | |
""" | |
Overview: | |
clear the gradient of the parameters | |
""" | |
return self._optim.zero_grad(set_to_none=True) | |
def step(self): | |
""" | |
Overview: | |
update the parameters with the gradient | |
""" | |
return self._optim.step() | |
def pc_backward(self, objectives): | |
""" | |
Overview: | |
calculate the gradient of the parameters | |
Arguments: | |
- objectives: a list of objectives | |
""" | |
grads, shapes, has_grads = self._pack_grad(objectives) | |
pc_grad = self._project_conflicting(grads, has_grads) | |
pc_grad = self._unflatten_grad(pc_grad, shapes[0]) | |
self._set_grad(pc_grad) | |
return | |
def _project_conflicting(self, grads, has_grads, shapes=None): | |
""" | |
Overview: | |
project the conflicting gradient to the orthogonal space | |
Arguments: | |
- grads (:obj:`list`): a list of the gradient of the parameters | |
- has_grads (:obj:`list`): a list of mask represent whether the parameter has gradient | |
- shapes (:obj:`list`): a list of the shape of the parameters | |
""" | |
shared = torch.stack(has_grads).prod(0).bool() | |
pc_grad, num_task = copy.deepcopy(grads), len(grads) | |
for g_i in pc_grad: | |
random.shuffle(grads) | |
for g_j in grads: | |
g_i_g_j = torch.dot(g_i, g_j) | |
if g_i_g_j < 0: | |
g_i -= (g_i_g_j) * g_j / (g_j.norm() ** 2) | |
merged_grad = torch.zeros_like(grads[0]).to(grads[0].device) | |
if self._reduction: | |
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0) | |
elif self._reduction == 'sum': | |
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0) | |
else: | |
raise KeyError("invalid reduction method") | |
merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0) | |
return merged_grad | |
def _set_grad(self, grads): | |
""" | |
Overview: | |
set the modified gradients to the network | |
Arguments: | |
- grads (:obj:`list`): a list of the gradient of the parameters | |
""" | |
idx = 0 | |
for group in self._optim.param_groups: | |
for p in group['params']: | |
# if p.grad is None: continue | |
p.grad = grads[idx] | |
idx += 1 | |
return | |
def _pack_grad(self, objectives): | |
""" | |
Overview: | |
pack the gradient of the parameters of the network for each objective | |
Arguments: | |
- objectives: a list of objectives | |
Returns: | |
- grad: a list of the gradient of the parameters | |
- shape: a list of the shape of the parameters | |
- has_grad: a list of mask represent whether the parameter has gradient | |
""" | |
grads, shapes, has_grads = [], [], [] | |
for obj in objectives: | |
self._optim.zero_grad(set_to_none=True) | |
obj.backward(retain_graph=True) | |
grad, shape, has_grad = self._retrieve_grad() | |
grads.append(self._flatten_grad(grad, shape)) | |
has_grads.append(self._flatten_grad(has_grad, shape)) | |
shapes.append(shape) | |
return grads, shapes, has_grads | |
def _unflatten_grad(self, grads, shapes): | |
""" | |
Overview: | |
unflatten the gradient of the parameters of the network | |
Arguments: | |
- grads (:obj:`list`): a list of the gradient of the parameters | |
- shapes (:obj:`list`): a list of the shape of the parameters | |
""" | |
unflatten_grad, idx = [], 0 | |
for shape in shapes: | |
length = np.prod(shape) | |
unflatten_grad.append(grads[idx:idx + length].view(shape).clone()) | |
idx += length | |
return unflatten_grad | |
def _flatten_grad(self, grads, shapes): | |
""" | |
Overview: | |
flatten the gradient of the parameters of the network | |
Arguments: | |
- grads (:obj:`list`): a list of the gradient of the parameters | |
- shapes (:obj:`list`): a list of the shape of the parameters | |
""" | |
flatten_grad = torch.cat([g.flatten() for g in grads]) | |
return flatten_grad | |
def _retrieve_grad(self): | |
""" | |
Overview: | |
get the gradient of the parameters of the network with specific objective | |
Returns: | |
- grad: a list of the gradient of the parameters | |
- shape: a list of the shape of the parameters | |
- has_grad: a list of mask represent whether the parameter has gradient | |
""" | |
grad, shape, has_grad = [], [], [] | |
for group in self._optim.param_groups: | |
for p in group['params']: | |
# if p.grad is None: continue | |
# tackle the multi-head scenario | |
if p.grad is None: | |
shape.append(p.shape) | |
grad.append(torch.zeros_like(p).to(p.device)) | |
has_grad.append(torch.zeros_like(p).to(p.device)) | |
continue | |
shape.append(p.grad.shape) | |
grad.append(p.grad.clone()) | |
has_grad.append(torch.ones_like(p).to(p.device)) | |
return grad, shape, has_grad | |
def configure_weight_decay(model: nn.Module, weight_decay: float) -> List: | |
""" | |
Overview: | |
Separating out all parameters of the model into two buckets: those that will experience | |
weight decay for regularization and those that won't (biases, and layer-norm or embedding weights). | |
Arguments: | |
- model (:obj:`nn.Module`): the given PyTorch model. | |
- weight_decay (:obj:`float`): weight decay value for optimizer. | |
Returns: | |
- optim groups (:obj:`List`): the parameter groups to be set in the latter optimizer. | |
""" | |
decay = set() | |
no_decay = set() | |
whitelist_weight_modules = (torch.nn.Linear, ) | |
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
for mn, m in model.named_modules(): | |
for pn, p in m.named_parameters(): | |
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name | |
# Because named_modules and named_parameters are recursive | |
# we will see the same tensors p many times. But doing it this way | |
# allows us to know which parent module any tensor p belongs to. | |
if pn.endswith('bias'): | |
# all biases will not be decayed | |
no_decay.add(fpn) | |
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): | |
# weights of whitelist modules will be weight decayed | |
decay.add(fpn) | |
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): | |
# weights of blacklist modules will NOT be weight decayed | |
no_decay.add(fpn) | |
else: | |
decay.add(fpn) | |
decay = decay - no_decay | |
# validate that we considered every parameter | |
param_dict = {pn: p for pn, p in model.named_parameters()} | |
union_params = decay | no_decay | |
assert len( | |
param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ | |
% (str(param_dict.keys() - union_params),) | |
optim_groups = [ | |
{ | |
"params": [param_dict[pn] for pn in sorted(list(decay))], | |
"weight_decay": weight_decay | |
}, | |
{ | |
"params": [param_dict[pn] for pn in sorted(list(no_decay))], | |
"weight_decay": 0.0 | |
}, | |
] | |
return optim_groups | |