Spaces:
Sleeping
Sleeping
""" | |
Implementation of ``POPART`` algorithm for reward rescale. | |
<link https://arxiv.org/abs/1602.07714 link> | |
POPART is an adaptive normalization algorithm to normalize the targets used in the learning updates. | |
The two main components in POPART are: | |
**ART**: to update scale and shift such that the return is appropriately normalized, | |
**POP**: to preserve the outputs of the unnormalized function when we change the scale and shift. | |
""" | |
from typing import Optional, Union, Dict | |
import math | |
import torch | |
import torch.nn as nn | |
class PopArt(nn.Module): | |
""" | |
Overview: | |
A linear layer with popart normalization. This class implements a linear transformation followed by | |
PopArt normalization, which is a method to automatically adapt the contribution of each task to the agent's | |
updates in multi-task learning, as described in the paper <https://arxiv.org/abs/1809.04474>. | |
Interfaces: | |
``__init__``, ``reset_parameters``, ``forward``, ``update_parameters`` | |
""" | |
def __init__( | |
self, | |
input_features: Union[int, None] = None, | |
output_features: Union[int, None] = None, | |
beta: float = 0.5 | |
) -> None: | |
""" | |
Overview: | |
Initialize the class with input features, output features, and the beta parameter. | |
Arguments: | |
- input_features (:obj:`Union[int, None]`): The size of each input sample. | |
- output_features (:obj:`Union[int, None]`): The size of each output sample. | |
- beta (:obj:`float`): The parameter for moving average. | |
""" | |
super(PopArt, self).__init__() | |
self.beta = beta | |
self.input_features = input_features | |
self.output_features = output_features | |
# Initialize the linear layer parameters, weight and bias. | |
self.weight = nn.Parameter(torch.Tensor(output_features, input_features)) | |
self.bias = nn.Parameter(torch.Tensor(output_features)) | |
# Register a buffer for normalization parameters which can not be considered as model parameters. | |
# The normalization parameters will be used later to save the target value's scale and shift. | |
self.register_buffer('mu', torch.zeros(output_features, requires_grad=False)) | |
self.register_buffer('sigma', torch.ones(output_features, requires_grad=False)) | |
self.register_buffer('v', torch.ones(output_features, requires_grad=False)) | |
self.reset_parameters() | |
def reset_parameters(self): | |
""" | |
Overview: | |
Reset the parameters including weights and bias using kaiming_uniform_ and uniform_ initialization. | |
""" | |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
if self.bias is not None: | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) | |
bound = 1 / math.sqrt(fan_in) | |
nn.init.uniform_(self.bias, -bound, bound) | |
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
""" | |
Overview: | |
Implement the forward computation of the linear layer and return both the output and the | |
normalized output of the layer. | |
Arguments: | |
- x (:obj:`torch.Tensor`): Input tensor which is to be normalized. | |
Returns: | |
- output (:obj:`Dict[str, torch.Tensor]`): A dictionary contains 'pred' and 'unnormalized_pred'. | |
""" | |
normalized_output = x.mm(self.weight.t()) | |
normalized_output += self.bias.unsqueeze(0).expand_as(normalized_output) | |
# The unnormalization of output | |
with torch.no_grad(): | |
output = normalized_output * self.sigma + self.mu | |
return {'pred': normalized_output.squeeze(1), 'unnormalized_pred': output.squeeze(1)} | |
def update_parameters(self, value: torch.Tensor) -> Dict[str, torch.Tensor]: | |
""" | |
Overview: | |
Update the normalization parameters based on the given value and return the new mean and | |
standard deviation after the update. | |
Arguments: | |
- value (:obj:`torch.Tensor`): The tensor to be used for updating parameters. | |
Returns: | |
- update_results (:obj:`Dict[str, torch.Tensor]`): A dictionary contains 'new_mean' and 'new_std'. | |
""" | |
# Tensor device conversion of the normalization parameters. | |
self.mu = self.mu.to(value.device) | |
self.sigma = self.sigma.to(value.device) | |
self.v = self.v.to(value.device) | |
old_mu = self.mu | |
old_std = self.sigma | |
# Calculate the first and second moments (mean and variance) of the target value: | |
batch_mean = torch.mean(value, 0) | |
batch_v = torch.mean(torch.pow(value, 2), 0) | |
batch_mean[torch.isnan(batch_mean)] = self.mu[torch.isnan(batch_mean)] | |
batch_v[torch.isnan(batch_v)] = self.v[torch.isnan(batch_v)] | |
batch_mean = (1 - self.beta) * self.mu + self.beta * batch_mean | |
batch_v = (1 - self.beta) * self.v + self.beta * batch_v | |
batch_std = torch.sqrt(batch_v - (batch_mean ** 2)) | |
# Clip the standard deviation to reject the outlier data. | |
batch_std = torch.clamp(batch_std, min=1e-4, max=1e+6) | |
# Replace the nan value with old value. | |
batch_std[torch.isnan(batch_std)] = self.sigma[torch.isnan(batch_std)] | |
self.mu = batch_mean | |
self.v = batch_v | |
self.sigma = batch_std | |
# Update weight and bias with mean and standard deviation to preserve unnormalised outputs | |
self.weight.data = (self.weight.data.t() * old_std / self.sigma).t() | |
self.bias.data = (old_std * self.bias.data + old_mu - self.mu) / self.sigma | |
return {'new_mean': batch_mean, 'new_std': batch_std} | |