Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from typing import Tuple, List | |
from ding.hpc_rl import hpc_wrapper | |
def shape_fn_scatter_connection(args, kwargs) -> List[int]: | |
""" | |
Overview: | |
Return the shape of scatter_connection for HPC. | |
Arguments: | |
- args (:obj:`Tuple`): The arguments passed to the scatter_connection function. | |
- kwargs (:obj:`Dict`): The keyword arguments passed to the scatter_connection function. | |
Returns: | |
- shape (:obj:`List[int]`): A list representing the shape of scatter_connection, \ | |
in the form of [B, M, N, H, W, scatter_type]. | |
""" | |
if len(args) <= 1: | |
tmp = list(kwargs['x'].shape) | |
else: | |
tmp = list(args[1].shape) # args[0] is __main__.ScatterConnection object | |
if len(args) <= 2: | |
tmp.extend(kwargs['spatial_size']) | |
else: | |
tmp.extend(args[2]) | |
tmp.append(args[0].scatter_type) | |
return tmp | |
class ScatterConnection(nn.Module): | |
""" | |
Overview: | |
Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor, | |
and these tensors are scattered into a feature map with map size. | |
Interfaces: | |
``__init__``, ``forward``, ``xy_forward`` | |
""" | |
def __init__(self, scatter_type: str) -> None: | |
""" | |
Overview: | |
Initialize the ScatterConnection object. | |
Arguments: | |
- scatter_type (:obj:`str`): The scatter type, which decides the behavior when two entities have the \ | |
same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the \ | |
second one. If 'cover', the first one will be covered by the second one. | |
""" | |
super(ScatterConnection, self).__init__() | |
self.scatter_type = scatter_type | |
assert self.scatter_type in ['cover', 'add'] | |
def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor: | |
""" | |
Overview: | |
Scatter input tensor 'x' into a spatial feature map. | |
Arguments: | |
- x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \ | |
is the number of entities, and `N` is the dimension of entity attributes. | |
- spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \ | |
will be scattered, where `H` is the height and `W` is the width. | |
- location (:obj:`torch.Tensor`): The tensor of locations of shape `(B, M, 2)`. \ | |
Each location should be (y, x). | |
Returns: | |
- output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`. | |
Note: | |
When there are some overlapping in locations, 'cover' mode will result in the loss of information. | |
'add' mode is used as a temporary substitute. | |
""" | |
device = x.device | |
B, M, N = x.shape | |
x = x.permute(0, 2, 1) | |
H, W = spatial_size | |
index = location[:, :, 1] + location[:, :, 0] * W | |
index = index.unsqueeze(dim=1).repeat(1, N, 1) | |
output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W) | |
if self.scatter_type == 'cover': | |
output.scatter_(dim=2, index=index, src=x) | |
elif self.scatter_type == 'add': | |
output.scatter_add_(dim=2, index=index, src=x) | |
output = output.view(B, N, H, W) | |
return output | |
def xy_forward( | |
self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y | |
) -> torch.Tensor: | |
""" | |
Overview: | |
Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates. | |
Arguments: | |
- x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \ | |
is the number of entities, and `N` is the dimension of entity attributes. | |
- spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \ | |
will be scattered, where `H` is the height and `W` is the width. | |
- coord_x (:obj:`torch.Tensor`): The x-coordinates tensor of shape `(B, M)`. | |
- coord_y (:obj:`torch.Tensor`): The y-coordinates tensor of shape `(B, M)`. | |
Returns: | |
- output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`. | |
Note: | |
When there are some overlapping in locations, 'cover' mode will result in the loss of information. | |
'add' mode is used as a temporary substitute. | |
""" | |
device = x.device | |
B, M, N = x.shape | |
x = x.permute(0, 2, 1) | |
H, W = spatial_size | |
index = (coord_x * W + coord_y).long() | |
index = index.unsqueeze(dim=1).repeat(1, N, 1) | |
output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W) | |
if self.scatter_type == 'cover': | |
output.scatter_(dim=2, index=index, src=x) | |
elif self.scatter_type == 'add': | |
output.scatter_add_(dim=2, index=index, src=x) | |
output = output.view(B, N, H, W) | |
return output | |