Spaces:
Sleeping
Sleeping
import os | |
from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional | |
from easydict import EasyDict | |
from ditk import logging | |
import torch | |
from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type | |
from ding.data.buffer.middleware import PriorityExperienceReplay | |
from ding.framework import task | |
from ding.utils import get_rank | |
if TYPE_CHECKING: | |
from ding.framework import OnlineRLContext, OfflineRLContext | |
def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = None): | |
""" | |
Overview: | |
Push episodes or trajectories into the buffer. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Config. | |
- buffer (:obj:`Buffer`): Buffer to push the data in. | |
""" | |
if task.router.is_active and not task.has_role(task.role.LEARNER): | |
return task.void() | |
def _push(ctx: "OnlineRLContext"): | |
""" | |
Overview: | |
In ctx, either `ctx.trajectories` or `ctx.episodes` should not be None. | |
Input of ctx: | |
- trajectories (:obj:`List[Dict]`): Trajectories. | |
- episodes (:obj:`List[Dict]`): Episodes. | |
""" | |
if ctx.trajectories is not None: # each data in buffer is a transition | |
if group_by_env: | |
for i, t in enumerate(ctx.trajectories): | |
buffer_.push(t, {'env': t.env_data_id.item()}) | |
else: | |
for t in ctx.trajectories: | |
buffer_.push(t) | |
ctx.trajectories = None | |
elif ctx.episodes is not None: # each data in buffer is a episode | |
for t in ctx.episodes: | |
buffer_.push(t) | |
ctx.episodes = None | |
else: | |
raise RuntimeError("Either ctx.trajectories or ctx.episodes should be not None.") | |
return _push | |
def buffer_saver(cfg: EasyDict, buffer_: Buffer, every_envstep: int = 1000, replace: bool = False): | |
""" | |
Overview: | |
Save current buffer data. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Config. | |
- buffer (:obj:`Buffer`): Buffer to push the data in. | |
- every_envstep (:obj:`int`): save at every env step. | |
- replace (:obj:`bool`): Whether replace the last file. | |
""" | |
buffer_saver_env_counter = -every_envstep | |
def _save(ctx: "OnlineRLContext"): | |
""" | |
Overview: | |
In ctx, `ctx.env_step` should not be None. | |
Input of ctx: | |
- env_step (:obj:`int`): env step. | |
""" | |
nonlocal buffer_saver_env_counter | |
if ctx.env_step is not None: | |
if ctx.env_step >= every_envstep + buffer_saver_env_counter: | |
buffer_saver_env_counter = ctx.env_step | |
if replace: | |
buffer_.save_data(os.path.join(cfg.exp_name, "replaybuffer", "data_latest.hkl")) | |
else: | |
buffer_.save_data( | |
os.path.join(cfg.exp_name, "replaybuffer", "data_envstep_{}.hkl".format(ctx.env_step)) | |
) | |
else: | |
raise RuntimeError("buffer_saver only supports collecting data by step rather than episode.") | |
return _save | |
def offpolicy_data_fetcher( | |
cfg: EasyDict, | |
buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], | |
data_shortage_warning: bool = False, | |
) -> Callable: | |
""" | |
Overview: | |
The return function is a generator which meanly fetch a batch of data from a buffer, \ | |
a list of buffers, or a dict of buffers. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. | |
- buffer (:obj:`Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]`): \ | |
The buffer where the data is fetched from. \ | |
``Buffer`` type means a buffer.\ | |
``List[Tuple[Buffer, float]]`` type means a list of tuple. In each tuple there is a buffer and a float. \ | |
The float defines, how many batch_size is the size of the data \ | |
which is sampled from the corresponding buffer.\ | |
``Dict[str, Buffer]`` type means a dict in which the value of each element is a buffer. \ | |
For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer \ | |
and assigned to the same key of `ctx.train_data`. | |
- data_shortage_warning (:obj:`bool`): Whether to output warning when data shortage occurs in fetching. | |
""" | |
def _fetch(ctx: "OnlineRLContext"): | |
""" | |
Input of ctx: | |
- train_output (:obj:`Union[Dict, Deque[Dict]]`): This attribute should exist \ | |
if `buffer_` is of type Buffer and if `buffer_` use the middleware `PriorityExperienceReplay`. \ | |
The meta data `priority` of the sampled data in the `buffer_` will be updated \ | |
to the `priority` attribute of `ctx.train_output` if `ctx.train_output` is a dict, \ | |
or the `priority` attribute of `ctx.train_output`'s popped element \ | |
if `ctx.train_output` is a deque of dicts. | |
Output of ctx: | |
- train_data (:obj:`Union[List[Dict], Dict[str, List[Dict]]]`): The fetched data. \ | |
``List[Dict]`` type means a list of data. | |
`train_data` is of this type if the type of `buffer_` is Buffer or List. | |
``Dict[str, List[Dict]]]`` type means a dict, in which the value of each key-value pair | |
is a list of data. `train_data` is of this type if the type of `buffer_` is Dict. | |
""" | |
try: | |
unroll_len = cfg.policy.collect.unroll_len | |
if isinstance(buffer_, Buffer): | |
if unroll_len > 1: | |
buffered_data = buffer_.sample( | |
cfg.policy.learn.batch_size, groupby="env", unroll_len=unroll_len, replace=True | |
) | |
ctx.train_data = [[t.data for t in d] for d in buffered_data] # B, unroll_len | |
else: | |
buffered_data = buffer_.sample(cfg.policy.learn.batch_size) | |
ctx.train_data = [d.data for d in buffered_data] | |
elif isinstance(buffer_, List): # like sqil, r2d3 | |
assert unroll_len == 1, "not support" | |
buffered_data = [] | |
for buffer_elem, p in buffer_: | |
data_elem = buffer_elem.sample(int(cfg.policy.learn.batch_size * p)) | |
assert data_elem is not None | |
buffered_data.append(data_elem) | |
buffered_data = sum(buffered_data, []) | |
ctx.train_data = [d.data for d in buffered_data] | |
elif isinstance(buffer_, Dict): # like ppg_offpolicy | |
assert unroll_len == 1, "not support" | |
buffered_data = {k: v.sample(cfg.policy.learn.batch_size) for k, v in buffer_.items()} | |
ctx.train_data = {k: [d.data for d in v] for k, v in buffered_data.items()} | |
else: | |
raise TypeError("not support buffer argument type: {}".format(type(buffer_))) | |
assert buffered_data is not None | |
except (ValueError, AssertionError): | |
if data_shortage_warning: | |
# You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode. | |
# Fetcher will skip this this attempt. | |
logging.warning( | |
"Replay buffer's data is not enough to support training, so skip this training to wait more data." | |
) | |
ctx.train_data = None | |
return | |
yield | |
if isinstance(buffer_, Buffer): | |
if any([isinstance(m, PriorityExperienceReplay) for m in buffer_._middleware]): | |
index = [d.index for d in buffered_data] | |
meta = [d.meta for d in buffered_data] | |
# such as priority | |
if isinstance(ctx.train_output, List): | |
priority = ctx.train_output.pop()['priority'] | |
else: | |
priority = ctx.train_output['priority'] | |
for idx, m, p in zip(index, meta, priority): | |
m['priority'] = p | |
buffer_.update(index=idx, data=None, meta=m) | |
return _fetch | |
def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: | |
from threading import Thread | |
from queue import Queue | |
import time | |
stream = torch.cuda.Stream() | |
def producer(queue, dataset, batch_size, device): | |
torch.set_num_threads(4) | |
nonlocal stream | |
idx_iter = iter(range(len(dataset) - batch_size)) | |
if len(dataset) < batch_size: | |
logging.warning('batch_size is too large!!!!') | |
with torch.cuda.stream(stream): | |
while True: | |
if queue.full(): | |
time.sleep(0.1) | |
else: | |
try: | |
start_idx = next(idx_iter) | |
except StopIteration: | |
del idx_iter | |
idx_iter = iter(range(len(dataset) - batch_size)) | |
start_idx = next(idx_iter) | |
data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)] | |
data = [[i[j] for i in data] for j in range(len(data[0]))] | |
data = [torch.stack(x).to(device) for x in data] | |
queue.put(data) | |
queue = Queue(maxsize=50) | |
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' | |
producer_thread = Thread( | |
target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer' | |
) | |
def _fetch(ctx: "OfflineRLContext"): | |
nonlocal queue, producer_thread | |
if not producer_thread.is_alive(): | |
time.sleep(5) | |
producer_thread.start() | |
while queue.empty(): | |
time.sleep(0.001) | |
ctx.train_data = queue.get() | |
return _fetch | |
def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: | |
""" | |
Overview: | |
The outer function transforms a Pytorch `Dataset` to `DataLoader`. \ | |
The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\ | |
Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \ | |
and https://pytorch.org/docs/stable/data.html for more details. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. | |
- dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. | |
""" | |
# collate_fn is executed in policy now | |
dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) | |
dataloader = iter(dataloader) | |
def _fetch(ctx: "OfflineRLContext"): | |
""" | |
Overview: | |
Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \ | |
After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1. | |
Input of ctx: | |
- train_epoch (:obj:`int`): Number of `train_epoch`. | |
Output of ctx: | |
- train_data (:obj:`List[Tensor]`): The fetched data batch. | |
""" | |
nonlocal dataloader | |
try: | |
ctx.train_data = next(dataloader) # noqa | |
except StopIteration: | |
ctx.train_epoch += 1 | |
del dataloader | |
dataloader = DataLoader( | |
dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x | |
) | |
dataloader = iter(dataloader) | |
ctx.train_data = next(dataloader) | |
# TODO apply data update (e.g. priority) in offline setting when necessary | |
ctx.trained_env_step += len(ctx.train_data) | |
return _fetch | |
def offline_data_saver(data_path: str, data_type: str = 'hdf5') -> Callable: | |
""" | |
Overview: | |
Save the expert data of offline RL in a directory. | |
Arguments: | |
- data_path (:obj:`str`): File path where the expert data will be written into, which is usually ./expert.pkl'. | |
- data_type (:obj:`str`): Define the type of the saved data. \ | |
The type of saved data is pkl if `data_type == 'naive'`. \ | |
The type of saved data is hdf5 if `data_type == 'hdf5'`. | |
""" | |
def _save(ctx: "OnlineRLContext"): | |
""" | |
Input of ctx: | |
- trajectories (:obj:`List[Tensor]`): The expert data to be saved. | |
""" | |
data = ctx.trajectories | |
offline_data_save_type(data, data_path, data_type) | |
ctx.trajectories = None | |
return _save | |
def sqil_data_pusher(cfg: EasyDict, buffer_: Buffer, expert: bool) -> Callable: | |
""" | |
Overview: | |
Push trajectories into the buffer in sqil learning pipeline. | |
Arguments: | |
- cfg (:obj:`EasyDict`): Config. | |
- buffer (:obj:`Buffer`): Buffer to push the data in. | |
- expert (:obj:`bool`): Whether the pushed data is expert data or not. \ | |
In each element of the pushed data, the reward will be set to 1 if this attribute is `True`, otherwise 0. | |
""" | |
def _pusher(ctx: "OnlineRLContext"): | |
""" | |
Input of ctx: | |
- trajectories (:obj:`List[Dict]`): The trajectories to be pushed. | |
""" | |
for t in ctx.trajectories: | |
if expert: | |
t.reward = torch.ones_like(t.reward) | |
else: | |
t.reward = torch.zeros_like(t.reward) | |
buffer_.push(t) | |
ctx.trajectories = None | |
return _pusher | |