Spaces:
Sleeping
Sleeping
from typing import Iterable, Any, Optional, List | |
from collections.abc import Sequence | |
import numbers | |
import time | |
import copy | |
from threading import Thread | |
from queue import Queue | |
import numpy as np | |
import torch | |
import treetensor.torch as ttorch | |
from ding.utils.default_helper import get_shape0 | |
def to_device(item: Any, device: str, ignore_keys: list = []) -> Any: | |
""" | |
Overview: | |
Transfer data to certain device. | |
Arguments: | |
- item (:obj:`Any`): The item to be transferred. | |
- device (:obj:`str`): The device wanted. | |
- ignore_keys (:obj:`list`): The keys to be ignored in transfer, default set to empty. | |
Returns: | |
- item (:obj:`Any`): The transferred item. | |
Examples: | |
>>> setup_data_dict['module'] = nn.Linear(3, 5) | |
>>> device = 'cuda' | |
>>> cuda_d = to_device(setup_data_dict, device, ignore_keys=['module']) | |
>>> assert cuda_d['module'].weight.device == torch.device('cpu') | |
Examples: | |
>>> setup_data_dict['module'] = nn.Linear(3, 5) | |
>>> device = 'cuda' | |
>>> cuda_d = to_device(setup_data_dict, device) | |
>>> assert cuda_d['module'].weight.device == torch.device('cuda:0') | |
.. note: | |
Now supports item type: :obj:`torch.nn.Module`, :obj:`torch.Tensor`, :obj:`Sequence`, \ | |
:obj:`dict`, :obj:`numbers.Integral`, :obj:`numbers.Real`, :obj:`np.ndarray`, :obj:`str` and :obj:`None`. | |
""" | |
if isinstance(item, torch.nn.Module): | |
return item.to(device) | |
elif isinstance(item, ttorch.Tensor): | |
if 'prev_state' in item: | |
prev_state = to_device(item.prev_state, device) | |
del item.prev_state | |
item = item.to(device) | |
item.prev_state = prev_state | |
return item | |
else: | |
return item.to(device) | |
elif isinstance(item, torch.Tensor): | |
return item.to(device) | |
elif isinstance(item, Sequence): | |
if isinstance(item, str): | |
return item | |
else: | |
return [to_device(t, device) for t in item] | |
elif isinstance(item, dict): | |
new_item = {} | |
for k in item.keys(): | |
if k in ignore_keys: | |
new_item[k] = item[k] | |
else: | |
new_item[k] = to_device(item[k], device) | |
return new_item | |
elif isinstance(item, numbers.Integral) or isinstance(item, numbers.Real): | |
return item | |
elif isinstance(item, np.ndarray) or isinstance(item, np.bool_): | |
return item | |
elif item is None or isinstance(item, str): | |
return item | |
elif isinstance(item, torch.distributions.Distribution): # for compatibility | |
return item | |
else: | |
raise TypeError("not support item type: {}".format(type(item))) | |
def to_dtype(item: Any, dtype: type) -> Any: | |
""" | |
Overview: | |
Change data to certain dtype. | |
Arguments: | |
- item (:obj:`Any`): The item for changing the dtype. | |
- dtype (:obj:`type`): The type wanted. | |
Returns: | |
- item (:obj:`object`): The item with changed dtype. | |
Examples (tensor): | |
>>> t = torch.randint(0, 10, (3, 5)) | |
>>> tfloat = to_dtype(t, torch.float) | |
>>> assert tfloat.dtype == torch.float | |
Examples (list): | |
>>> tlist = [torch.randint(0, 10, (3, 5))] | |
>>> tlfloat = to_dtype(tlist, torch.float) | |
>>> assert tlfloat[0].dtype == torch.float | |
Examples (dict): | |
>>> tdict = {'t': torch.randint(0, 10, (3, 5))} | |
>>> tdictf = to_dtype(tdict, torch.float) | |
>>> assert tdictf['t'].dtype == torch.float | |
.. note: | |
Now supports item type: :obj:`torch.Tensor`, :obj:`Sequence`, :obj:`dict`. | |
""" | |
if isinstance(item, torch.Tensor): | |
return item.to(dtype=dtype) | |
elif isinstance(item, Sequence): | |
return [to_dtype(t, dtype) for t in item] | |
elif isinstance(item, dict): | |
return {k: to_dtype(item[k], dtype) for k in item.keys()} | |
else: | |
raise TypeError("not support item type: {}".format(type(item))) | |
def to_tensor( | |
item: Any, dtype: Optional[torch.dtype] = None, ignore_keys: list = [], transform_scalar: bool = True | |
) -> Any: | |
""" | |
Overview: | |
Convert ``numpy.ndarray`` object to ``torch.Tensor``. | |
Arguments: | |
- item (:obj:`Any`): The ``numpy.ndarray`` objects to be converted. It can be exactly a ``numpy.ndarray`` \ | |
object or a container (list, tuple or dict) that contains several ``numpy.ndarray`` objects. | |
- dtype (:obj:`torch.dtype`): The type of wanted tensor. If set to ``None``, its dtype will be unchanged. | |
- ignore_keys (:obj:`list`): If the ``item`` is a dict, values whose keys are in ``ignore_keys`` will not \ | |
be converted. | |
- transform_scalar (:obj:`bool`): If set to ``True``, a scalar will be also converted to a tensor object. | |
Returns: | |
- item (:obj:`Any`): The converted tensors. | |
Examples (scalar): | |
>>> i = 10 | |
>>> t = to_tensor(i) | |
>>> assert t.item() == i | |
Examples (dict): | |
>>> d = {'i': i} | |
>>> dt = to_tensor(d, torch.int) | |
>>> assert dt['i'].item() == i | |
Examples (named tuple): | |
>>> data_type = namedtuple('data_type', ['x', 'y']) | |
>>> inputs = data_type(np.random.random(3), 4) | |
>>> outputs = to_tensor(inputs, torch.float32) | |
>>> assert type(outputs) == data_type | |
>>> assert isinstance(outputs.x, torch.Tensor) | |
>>> assert isinstance(outputs.y, torch.Tensor) | |
>>> assert outputs.x.dtype == torch.float32 | |
>>> assert outputs.y.dtype == torch.float32 | |
.. note: | |
Now supports item type: :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. | |
""" | |
def transform(d): | |
if dtype is None: | |
return torch.as_tensor(d) | |
else: | |
return torch.tensor(d, dtype=dtype) | |
if isinstance(item, dict): | |
new_data = {} | |
for k, v in item.items(): | |
if k in ignore_keys: | |
new_data[k] = v | |
else: | |
new_data[k] = to_tensor(v, dtype, ignore_keys, transform_scalar) | |
return new_data | |
elif isinstance(item, list) or isinstance(item, tuple): | |
if len(item) == 0: | |
return [] | |
elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real): | |
return transform(item) | |
elif hasattr(item, '_fields'): # namedtuple | |
return type(item)(*[to_tensor(t, dtype) for t in item]) | |
else: | |
new_data = [] | |
for t in item: | |
new_data.append(to_tensor(t, dtype, ignore_keys, transform_scalar)) | |
return new_data | |
elif isinstance(item, np.ndarray): | |
if dtype is None: | |
if item.dtype == np.float64: | |
return torch.FloatTensor(item) | |
else: | |
return torch.from_numpy(item) | |
else: | |
return torch.from_numpy(item).to(dtype) | |
elif isinstance(item, bool) or isinstance(item, str): | |
return item | |
elif np.isscalar(item): | |
if transform_scalar: | |
if dtype is None: | |
return torch.as_tensor(item) | |
else: | |
return torch.as_tensor(item).to(dtype) | |
else: | |
return item | |
elif item is None: | |
return None | |
elif isinstance(item, torch.Tensor): | |
if dtype is None: | |
return item | |
else: | |
return item.to(dtype) | |
else: | |
raise TypeError("not support item type: {}".format(type(item))) | |
def to_ndarray(item: Any, dtype: np.dtype = None) -> Any: | |
""" | |
Overview: | |
Convert ``torch.Tensor`` to ``numpy.ndarray``. | |
Arguments: | |
- item (:obj:`Any`): The ``torch.Tensor`` objects to be converted. It can be exactly a ``torch.Tensor`` \ | |
object or a container (list, tuple or dict) that contains several ``torch.Tensor`` objects. | |
- dtype (:obj:`np.dtype`): The type of wanted array. If set to ``None``, its dtype will be unchanged. | |
Returns: | |
- item (:obj:`object`): The changed arrays. | |
Examples (ndarray): | |
>>> t = torch.randn(3, 5) | |
>>> tarray1 = to_ndarray(t) | |
>>> assert tarray1.shape == (3, 5) | |
>>> assert isinstance(tarray1, np.ndarray) | |
Examples (list): | |
>>> t = [torch.randn(5, ) for i in range(3)] | |
>>> tarray1 = to_ndarray(t, np.float32) | |
>>> assert isinstance(tarray1, list) | |
>>> assert tarray1[0].shape == (5, ) | |
>>> assert isinstance(tarray1[0], np.ndarray) | |
.. note: | |
Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. | |
""" | |
def transform(d): | |
if dtype is None: | |
return np.array(d) | |
else: | |
return np.array(d, dtype=dtype) | |
if isinstance(item, dict): | |
new_data = {} | |
for k, v in item.items(): | |
new_data[k] = to_ndarray(v, dtype) | |
return new_data | |
elif isinstance(item, list) or isinstance(item, tuple): | |
if len(item) == 0: | |
return None | |
elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real): | |
return transform(item) | |
elif hasattr(item, '_fields'): # namedtuple | |
return type(item)(*[to_ndarray(t, dtype) for t in item]) | |
else: | |
new_data = [] | |
for t in item: | |
new_data.append(to_ndarray(t, dtype)) | |
return new_data | |
elif isinstance(item, torch.Tensor): | |
if dtype is None: | |
return item.numpy() | |
else: | |
return item.numpy().astype(dtype) | |
elif isinstance(item, np.ndarray): | |
if dtype is None: | |
return item | |
else: | |
return item.astype(dtype) | |
elif isinstance(item, bool) or isinstance(item, str): | |
return item | |
elif np.isscalar(item): | |
if dtype is None: | |
return np.array(item) | |
else: | |
return np.array(item, dtype=dtype) | |
elif item is None: | |
return None | |
else: | |
raise TypeError("not support item type: {}".format(type(item))) | |
def to_list(item: Any) -> Any: | |
""" | |
Overview: | |
Convert ``torch.Tensor``, ``numpy.ndarray`` objects to ``list`` objects, and keep their dtypes unchanged. | |
Arguments: | |
- item (:obj:`Any`): The item to be converted. | |
Returns: | |
- item (:obj:`Any`): The list after conversion. | |
Examples: | |
>>> data = { \ | |
'tensor': torch.randn(4), \ | |
'list': [True, False, False], \ | |
'tuple': (4, 5, 6), \ | |
'bool': True, \ | |
'int': 10, \ | |
'float': 10., \ | |
'array': np.random.randn(4), \ | |
'str': "asdf", \ | |
'none': None, \ | |
} \ | |
>>> transformed_data = to_list(data) | |
.. note:: | |
Now supports item type: :obj:`torch.Tensor`, :obj:`numpy.ndarray`, :obj:`dict`, :obj:`list`, \ | |
:obj:`tuple` and :obj:`None`. | |
""" | |
if item is None: | |
return item | |
elif isinstance(item, torch.Tensor): | |
return item.tolist() | |
elif isinstance(item, np.ndarray): | |
return item.tolist() | |
elif isinstance(item, list) or isinstance(item, tuple): | |
return [to_list(t) for t in item] | |
elif isinstance(item, dict): | |
return {k: to_list(v) for k, v in item.items()} | |
elif np.isscalar(item): | |
return item | |
else: | |
raise TypeError("not support item type: {}".format(type(item))) | |
def tensor_to_list(item: Any) -> Any: | |
""" | |
Overview: | |
Convert ``torch.Tensor`` objects to ``list``, and keep their dtypes unchanged. | |
Arguments: | |
- item (:obj:`Any`): The item to be converted. | |
Returns: | |
- item (:obj:`Any`): The lists after conversion. | |
Examples (2d-tensor): | |
>>> t = torch.randn(3, 5) | |
>>> tlist1 = tensor_to_list(t) | |
>>> assert len(tlist1) == 3 | |
>>> assert len(tlist1[0]) == 5 | |
Examples (1d-tensor): | |
>>> t = torch.randn(3, ) | |
>>> tlist1 = tensor_to_list(t) | |
>>> assert len(tlist1) == 3 | |
Examples (list) | |
>>> t = [torch.randn(5, ) for i in range(3)] | |
>>> tlist1 = tensor_to_list(t) | |
>>> assert len(tlist1) == 3 | |
>>> assert len(tlist1[0]) == 5 | |
Examples (dict): | |
>>> td = {'t': torch.randn(3, 5)} | |
>>> tdlist1 = tensor_to_list(td) | |
>>> assert len(tdlist1['t']) == 3 | |
>>> assert len(tdlist1['t'][0]) == 5 | |
.. note:: | |
Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. | |
""" | |
if item is None: | |
return item | |
elif isinstance(item, torch.Tensor): | |
return item.tolist() | |
elif isinstance(item, list) or isinstance(item, tuple): | |
return [tensor_to_list(t) for t in item] | |
elif isinstance(item, dict): | |
return {k: tensor_to_list(v) for k, v in item.items()} | |
elif np.isscalar(item): | |
return item | |
else: | |
raise TypeError("not support item type: {}".format(type(item))) | |
def to_item(data: Any, ignore_error: bool = True) -> Any: | |
""" | |
Overview: | |
Convert data to python native scalar (i.e. data item), and keep their dtypes unchanged. | |
Arguments: | |
- data (:obj:`Any`): The data that needs to be converted. | |
- ignore_error (:obj:`bool`): Whether to ignore the error when the data type is not supported. That is to \ | |
say, only the data can be transformed into a python native scalar will be returned. | |
Returns: | |
- data (:obj:`Any`): Converted data. | |
Examples: | |
>>>> data = { \ | |
'tensor': torch.randn(1), \ | |
'list': [True, False, torch.randn(1)], \ | |
'tuple': (4, 5, 6), \ | |
'bool': True, \ | |
'int': 10, \ | |
'float': 10., \ | |
'array': np.random.randn(1), \ | |
'str': "asdf", \ | |
'none': None, \ | |
} | |
>>>> new_data = to_item(data) | |
>>>> assert np.isscalar(new_data['tensor']) | |
>>>> assert np.isscalar(new_data['array']) | |
>>>> assert np.isscalar(new_data['list'][-1]) | |
.. note:: | |
Now supports item type: :obj:`torch.Tensor`, :obj:`torch.Tensor`, :obj:`ttorch.Tensor`, \ | |
:obj:`bool`, :obj:`str`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. | |
""" | |
if data is None: | |
return data | |
elif isinstance(data, bool) or isinstance(data, str): | |
return data | |
elif np.isscalar(data): | |
return data | |
elif isinstance(data, np.ndarray) or isinstance(data, torch.Tensor) or isinstance(data, ttorch.Tensor): | |
return data.item() | |
elif isinstance(data, list) or isinstance(data, tuple): | |
return [to_item(d) for d in data] | |
elif isinstance(data, dict): | |
new_data = {} | |
for k, v in data.items(): | |
if ignore_error: | |
try: | |
new_data[k] = to_item(v) | |
except (ValueError, RuntimeError): | |
pass | |
else: | |
new_data[k] = to_item(v) | |
return new_data | |
else: | |
raise TypeError("not support data type: {}".format(data)) | |
def same_shape(data: list) -> bool: | |
""" | |
Overview: | |
Judge whether all data elements in a list have the same shapes. | |
Arguments: | |
- data (:obj:`list`): The list of data. | |
Returns: | |
- same (:obj:`bool`): Whether the list of data all have the same shape. | |
Examples: | |
>>> tlist = [torch.randn(3, 5) for i in range(5)] | |
>>> assert same_shape(tlist) | |
>>> tlist = [torch.randn(3, 5), torch.randn(4, 5)] | |
>>> assert not same_shape(tlist) | |
""" | |
assert (isinstance(data, list)) | |
shapes = [t.shape for t in data] | |
return len(set(shapes)) == 1 | |
class LogDict(dict): | |
""" | |
Overview: | |
Derived from ``dict``. Would convert ``torch.Tensor`` to ``list`` for convenient logging. | |
Interfaces: | |
``_transform``, ``__setitem__``, ``update``. | |
""" | |
def _transform(self, data: Any) -> None: | |
""" | |
Overview: | |
Convert tensor objects to lists for better logging. | |
Arguments: | |
- data (:obj:`Any`): The input data to be converted. | |
""" | |
if isinstance(data, torch.Tensor): | |
new_data = data.tolist() | |
else: | |
new_data = data | |
return new_data | |
def __setitem__(self, key: Any, value: Any) -> None: | |
""" | |
Overview: | |
Override the ``__setitem__`` function of built-in dict. | |
Arguments: | |
- key (:obj:`Any`): The key of the data item. | |
- value (:obj:`Any`): The value of the data item. | |
""" | |
new_value = self._transform(value) | |
super().__setitem__(key, new_value) | |
def update(self, data: dict) -> None: | |
""" | |
Overview: | |
Override the ``update`` function of built-in dict. | |
Arguments: | |
- data (:obj:`dict`): The dict for updating current object. | |
""" | |
for k, v in data.items(): | |
self.__setitem__(k, v) | |
def build_log_buffer() -> LogDict: | |
""" | |
Overview: | |
Build log buffer, a subclass of dict, which can convert the input data into log format. | |
Returns: | |
- log_buffer (:obj:`LogDict`): Log buffer dict. | |
Examples: | |
>>> log_buffer = build_log_buffer() | |
>>> log_buffer['not_tensor'] = torch.randn(3) | |
>>> assert isinstance(log_buffer['not_tensor'], list) | |
>>> assert len(log_buffer['not_tensor']) == 3 | |
>>> log_buffer.update({'not_tensor': 4, 'a': 5}) | |
>>> assert log_buffer['not_tensor'] == 4 | |
""" | |
return LogDict() | |
class CudaFetcher(object): | |
""" | |
Overview: | |
Fetch data from source, and transfer it to a specified device. | |
Interfaces: | |
``__init__``, ``__next__``, ``run``, ``close``. | |
""" | |
def __init__(self, data_source: Iterable, device: str, queue_size: int = 4, sleep: float = 0.1) -> None: | |
""" | |
Overview: | |
Initialize the CudaFetcher object using the given arguments. | |
Arguments: | |
- data_source (:obj:`Iterable`): The iterable data source. | |
- device (:obj:`str`): The device to put data to, such as "cuda:0". | |
- queue_size (:obj:`int`): The internal size of queue, such as 4. | |
- sleep (:obj:`float`): Sleeping time when the internal queue is full. | |
""" | |
self._source = data_source | |
self._queue = Queue(maxsize=queue_size) | |
self._stream = torch.cuda.Stream() | |
self._producer_thread = Thread(target=self._producer, args=(), name='cuda_fetcher_producer') | |
self._sleep = sleep | |
self._device = device | |
def __next__(self) -> Any: | |
""" | |
Overview: | |
Response to the request for data. Return one data item from the internal queue. | |
Returns: | |
- item (:obj:`Any`): The data item on the required device. | |
""" | |
return self._queue.get() | |
def run(self) -> None: | |
""" | |
Overview: | |
Start ``producer`` thread: Keep fetching data from source, change the device, and put into \ | |
``queue`` for request. | |
Examples: | |
>>> timer = EasyTimer() | |
>>> dataloader = iter([torch.randn(3, 3) for _ in range(10)]) | |
>>> dataloader = CudaFetcher(dataloader, device='cuda', sleep=0.1) | |
>>> dataloader.run() | |
>>> data = next(dataloader) | |
""" | |
self._end_flag = False | |
self._producer_thread.start() | |
def close(self) -> None: | |
""" | |
Overview: | |
Stop ``producer`` thread by setting ``end_flag`` to ``True`` . | |
""" | |
self._end_flag = True | |
def _producer(self) -> None: | |
""" | |
Overview: | |
Keep fetching data from source, change the device, and put into ``queue`` for request. | |
""" | |
with torch.cuda.stream(self._stream): | |
while not self._end_flag: | |
if self._queue.full(): | |
time.sleep(self._sleep) | |
else: | |
data = next(self._source) | |
data = to_device(data, self._device) | |
self._queue.put(data) | |
def get_tensor_data(data: Any) -> Any: | |
""" | |
Overview: | |
Get pure tensor data from the given data (without disturbing grad computation graph). | |
Arguments: | |
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). | |
Returns: | |
- output (:obj:`Any`): The output data. | |
Examples: | |
>>> a = { \ | |
'tensor': torch.tensor([1, 2, 3.], requires_grad=True), \ | |
'list': [torch.tensor([1, 2, 3.], requires_grad=True) for _ in range(2)], \ | |
'none': None \ | |
} | |
>>> tensor_a = get_tensor_data(a) | |
>>> assert not tensor_a['tensor'].requires_grad | |
>>> for t in tensor_a['list']: | |
>>> assert not t.requires_grad | |
""" | |
if isinstance(data, torch.Tensor): | |
return data.data.clone() | |
elif data is None: | |
return None | |
elif isinstance(data, Sequence): | |
return [get_tensor_data(d) for d in data] | |
elif isinstance(data, dict): | |
return {k: get_tensor_data(v) for k, v in data.items()} | |
else: | |
raise TypeError("not support type in get_tensor_data: {}".format(type(data))) | |
def unsqueeze(data: Any, dim: int = 0) -> Any: | |
""" | |
Overview: | |
Unsqueeze the tensor data. | |
Arguments: | |
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). | |
- dim (:obj:`int`): The dimension to be unsqueezed. | |
Returns: | |
- output (:obj:`Any`): The output data. | |
Examples (tensor): | |
>>> t = torch.randn(3, 3) | |
>>> tt = unsqueeze(t, dim=0) | |
>>> assert tt.shape == torch.Shape([1, 3, 3]) | |
Examples (list): | |
>>> t = [torch.randn(3, 3)] | |
>>> tt = unsqueeze(t, dim=0) | |
>>> assert tt[0].shape == torch.Shape([1, 3, 3]) | |
Examples (dict): | |
>>> t = {"t": torch.randn(3, 3)} | |
>>> tt = unsqueeze(t, dim=0) | |
>>> assert tt["t"].shape == torch.Shape([1, 3, 3]) | |
""" | |
if isinstance(data, torch.Tensor): | |
return data.unsqueeze(dim) | |
elif isinstance(data, Sequence): | |
return [unsqueeze(d) for d in data] | |
elif isinstance(data, dict): | |
return {k: unsqueeze(v, 0) for k, v in data.items()} | |
else: | |
raise TypeError("not support type in unsqueeze: {}".format(type(data))) | |
def squeeze(data: Any, dim: int = 0) -> Any: | |
""" | |
Overview: | |
Squeeze the tensor data. | |
Arguments: | |
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). | |
- dim (:obj:`int`): The dimension to be Squeezed. | |
Returns: | |
- output (:obj:`Any`): The output data. | |
Examples (tensor): | |
>>> t = torch.randn(1, 3, 3) | |
>>> tt = squeeze(t, dim=0) | |
>>> assert tt.shape == torch.Shape([3, 3]) | |
Examples (list): | |
>>> t = [torch.randn(1, 3, 3)] | |
>>> tt = squeeze(t, dim=0) | |
>>> assert tt[0].shape == torch.Shape([3, 3]) | |
Examples (dict): | |
>>> t = {"t": torch.randn(1, 3, 3)} | |
>>> tt = squeeze(t, dim=0) | |
>>> assert tt["t"].shape == torch.Shape([3, 3]) | |
""" | |
if isinstance(data, torch.Tensor): | |
return data.squeeze(dim) | |
elif isinstance(data, Sequence): | |
return [squeeze(d) for d in data] | |
elif isinstance(data, dict): | |
return {k: squeeze(v, 0) for k, v in data.items()} | |
else: | |
raise TypeError("not support type in squeeze: {}".format(type(data))) | |
def get_null_data(template: Any, num: int) -> List[Any]: | |
""" | |
Overview: | |
Get null data given an input template. | |
Arguments: | |
- template (:obj:`Any`): The template data. | |
- num (:obj:`int`): The number of null data items to generate. | |
Returns: | |
- output (:obj:`List[Any]`): The generated null data. | |
Examples: | |
>>> temp = {'obs': [1, 2, 3], 'action': 1, 'done': False, 'reward': torch.tensor(1.)} | |
>>> null_data = get_null_data(temp, 2) | |
>>> assert len(null_data) ==2 | |
>>> assert null_data[0]['null'] and null_data[0]['done'] | |
""" | |
ret = [] | |
for _ in range(num): | |
data = copy.deepcopy(template) | |
data['null'] = True | |
data['done'] = True | |
data['reward'].zero_() | |
ret.append(data) | |
return ret | |
def zeros_like(h: Any) -> Any: | |
""" | |
Overview: | |
Generate zero-tensors like the input data. | |
Arguments: | |
- h (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). | |
Returns: | |
- output (:obj:`Any`): The output zero-tensors. | |
Examples (tensor): | |
>>> t = torch.randn(3, 3) | |
>>> tt = zeros_like(t) | |
>>> assert tt.shape == torch.Shape([3, 3]) | |
>>> assert torch.sum(torch.abs(tt)) < 1e-8 | |
Examples (list): | |
>>> t = [torch.randn(3, 3)] | |
>>> tt = zeros_like(t) | |
>>> assert tt[0].shape == torch.Shape([3, 3]) | |
>>> assert torch.sum(torch.abs(tt[0])) < 1e-8 | |
Examples (dict): | |
>>> t = {"t": torch.randn(3, 3)} | |
>>> tt = zeros_like(t) | |
>>> assert tt["t"].shape == torch.Shape([3, 3]) | |
>>> assert torch.sum(torch.abs(tt["t"])) < 1e-8 | |
""" | |
if isinstance(h, torch.Tensor): | |
return torch.zeros_like(h) | |
elif isinstance(h, (list, tuple)): | |
return [zeros_like(t) for t in h] | |
elif isinstance(h, dict): | |
return {k: zeros_like(v) for k, v in h.items()} | |
else: | |
raise TypeError("not support type: {}".format(h)) | |