Spaces:
Sleeping
Sleeping
from typing import TYPE_CHECKING, Optional, Callable, Dict, List, Union | |
from ditk import logging | |
from easydict import EasyDict | |
from matplotlib import pyplot as plt | |
from matplotlib import animation | |
import os | |
import numpy as np | |
import torch | |
import wandb | |
import pickle | |
import treetensor.numpy as tnp | |
from ding.framework import task | |
from ding.envs import BaseEnvManagerV2 | |
from ding.utils import DistributedWriter | |
from ding.torch_utils import to_ndarray | |
from ding.utils.default_helper import one_time_warning | |
if TYPE_CHECKING: | |
from ding.framework import OnlineRLContext, OfflineRLContext | |
def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: | |
""" | |
Overview: | |
Create an online RL tensorboard logger for recording training and evaluation metrics. | |
Arguments: | |
- record_train_iter (:obj:`bool`): Whether to record training iteration. Default is False. | |
- train_show_freq (:obj:`int`): Frequency of showing training logs. Default is 100. | |
Returns: | |
- _logger (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. | |
Raises: | |
- RuntimeError: If writer is None. | |
- NotImplementedError: If the key of train_output is not supported, such as "scalars". | |
Examples: | |
>>> task.use(online_logger(record_train_iter=False, train_show_freq=1000)) | |
""" | |
if task.router.is_active and not task.has_role(task.role.LEARNER): | |
return task.void() | |
writer = DistributedWriter.get_instance() | |
if writer is None: | |
raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") | |
last_train_show_iter = -1 | |
def _logger(ctx: "OnlineRLContext"): | |
if task.finish: | |
writer.close() | |
nonlocal last_train_show_iter | |
if not np.isinf(ctx.eval_value): | |
if record_train_iter: | |
writer.add_scalar('basic/eval_episode_return_mean-env_step', ctx.eval_value, ctx.env_step) | |
writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) | |
else: | |
writer.add_scalar('basic/eval_episode_return_mean', ctx.eval_value, ctx.env_step) | |
if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: | |
last_train_show_iter = ctx.train_iter | |
if isinstance(ctx.train_output, List): | |
output = ctx.train_output.pop() # only use latest output for some algorithms, like PPO | |
else: | |
output = ctx.train_output | |
for k, v in output.items(): | |
if k in ['priority', 'td_error_priority']: | |
continue | |
if "[scalars]" in k: | |
new_k = k.split(']')[-1] | |
raise NotImplementedError | |
elif "[histogram]" in k: | |
new_k = k.split(']')[-1] | |
writer.add_histogram(new_k, v, ctx.env_step) | |
if record_train_iter: | |
writer.add_histogram(new_k, v, ctx.train_iter) | |
else: | |
if record_train_iter: | |
writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) | |
writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step) | |
else: | |
writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step) | |
return _logger | |
def offline_logger(train_show_freq: int = 100) -> Callable: | |
""" | |
Overview: | |
Create an offline RL tensorboard logger for recording training and evaluation metrics. | |
Arguments: | |
- train_show_freq (:obj:`int`): Frequency of showing training logs. Defaults to 100. | |
Returns: | |
- _logger (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. | |
Raises: | |
- RuntimeError: If writer is None. | |
- NotImplementedError: If the key of train_output is not supported, such as "scalars". | |
Examples: | |
>>> task.use(offline_logger(train_show_freq=1000)) | |
""" | |
if task.router.is_active and not task.has_role(task.role.LEARNER): | |
return task.void() | |
writer = DistributedWriter.get_instance() | |
if writer is None: | |
raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") | |
last_train_show_iter = -1 | |
def _logger(ctx: "OfflineRLContext"): | |
nonlocal last_train_show_iter | |
if task.finish: | |
writer.close() | |
if not np.isinf(ctx.eval_value): | |
writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) | |
if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: | |
last_train_show_iter = ctx.train_iter | |
output = ctx.train_output | |
for k, v in output.items(): | |
if k in ['priority']: | |
continue | |
if "[scalars]" in k: | |
new_k = k.split(']')[-1] | |
raise NotImplementedError | |
elif "[histogram]" in k: | |
new_k = k.split(']')[-1] | |
writer.add_histogram(new_k, v, ctx.train_iter) | |
else: | |
writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) | |
return _logger | |
# four utility functions for wandb logger | |
def softmax(logit: np.ndarray) -> np.ndarray: | |
v = np.exp(logit) | |
return v / v.sum(axis=-1, keepdims=True) | |
def action_prob(num, action_prob, ln): | |
ax = plt.gca() | |
ax.set_ylim([0, 1]) | |
for rect, x in zip(ln, action_prob[num]): | |
rect.set_height(x) | |
return ln | |
def return_prob(num, return_prob, ln): | |
return ln | |
def return_distribution(episode_return): | |
num = len(episode_return) | |
max_return = max(episode_return) | |
min_return = min(episode_return) | |
hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6)) | |
gap = (max_return - min_return + 100) / 5 | |
x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)] | |
return hist / num, x_dim | |
def wandb_online_logger( | |
record_path: str = None, | |
cfg: Union[dict, EasyDict] = None, | |
exp_config: Union[dict, EasyDict] = None, | |
metric_list: Optional[List[str]] = None, | |
env: Optional[BaseEnvManagerV2] = None, | |
model: Optional[torch.nn.Module] = None, | |
anonymous: bool = False, | |
project_name: str = 'default-project', | |
run_name: str = None, | |
wandb_sweep: bool = False, | |
) -> Callable: | |
""" | |
Overview: | |
Wandb visualizer to track the experiment. | |
Arguments: | |
- record_path (:obj:`str`): The path to save the replay of simulation. | |
- cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: | |
- gradient_logger: boolean. Whether to track the gradient. | |
- plot_logger: boolean. Whether to track the metrics like reward and loss. | |
- video_logger: boolean. Whether to upload the rendering video replay. | |
- action_logger: boolean. `q_value` or `action probability`. | |
- return_logger: boolean. Whether to track the return value. | |
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. | |
- env (:obj:`BaseEnvManagerV2`): Evaluator environment. | |
- model (:obj:`nn.Module`): Policy neural network model. | |
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ | |
of data without wandb count. | |
- project_name (:obj:`str`): The name of wandb project. | |
- run_name (:obj:`str`): The name of wandb run. | |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep. | |
''' | |
Returns: | |
- _plot (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. | |
""" | |
if task.router.is_active and not task.has_role(task.role.LEARNER): | |
return task.void() | |
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] | |
if metric_list is None: | |
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] | |
# Initialize wandb with default settings | |
# Settings can be covered by calling wandb.init() at the top of the script | |
if exp_config: | |
if not wandb_sweep: | |
if run_name is not None: | |
if anonymous: | |
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") | |
else: | |
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) | |
else: | |
if anonymous: | |
wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") | |
else: | |
wandb.init(project=project_name, config=exp_config, reinit=True) | |
else: | |
if run_name is not None: | |
if anonymous: | |
wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") | |
else: | |
wandb.init(project=project_name, config=exp_config, name=run_name) | |
else: | |
if anonymous: | |
wandb.init(project=project_name, config=exp_config, anonymous="must") | |
else: | |
wandb.init(project=project_name, config=exp_config) | |
else: | |
if not wandb_sweep: | |
if run_name is not None: | |
if anonymous: | |
wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") | |
else: | |
wandb.init(project=project_name, reinit=True, name=run_name) | |
else: | |
if anonymous: | |
wandb.init(project=project_name, reinit=True, anonymous="must") | |
else: | |
wandb.init(project=project_name, reinit=True) | |
else: | |
if run_name is not None: | |
if anonymous: | |
wandb.init(project=project_name, name=run_name, anonymous="must") | |
else: | |
wandb.init(project=project_name, name=run_name) | |
else: | |
if anonymous: | |
wandb.init(project=project_name, anonymous="must") | |
else: | |
wandb.init(project=project_name) | |
plt.switch_backend('agg') | |
if cfg is None: | |
cfg = EasyDict( | |
dict( | |
gradient_logger=False, | |
plot_logger=True, | |
video_logger=False, | |
action_logger=False, | |
return_logger=False, | |
) | |
) | |
else: | |
if not isinstance(cfg, EasyDict): | |
cfg = EasyDict(cfg) | |
for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: | |
if key not in cfg.keys(): | |
cfg[key] = False | |
# The visualizer is called to save the replay of the simulation | |
# which will be uploaded to wandb later | |
if env is not None and cfg.video_logger is True and record_path is not None: | |
env.enable_save_replay(replay_path=record_path) | |
if cfg.gradient_logger: | |
wandb.watch(model, log="all", log_freq=100, log_graph=True) | |
else: | |
one_time_warning( | |
"If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." | |
) | |
first_plot = True | |
def _plot(ctx: "OnlineRLContext"): | |
nonlocal first_plot | |
if first_plot: | |
first_plot = False | |
ctx.wandb_url = wandb.run.get_project_url() | |
info_for_logging = {} | |
if cfg.plot_logger: | |
for metric in metric_list: | |
if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: | |
if isinstance(ctx.train_output[metric], torch.Tensor): | |
info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) | |
else: | |
info_for_logging.update({metric: ctx.train_output[metric]}) | |
elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: | |
metric_value_list = [] | |
for item in ctx.train_output: | |
if isinstance(item[metric], torch.Tensor): | |
metric_value_list.append(item[metric].cpu().detach().numpy()) | |
else: | |
metric_value_list.append(item[metric]) | |
metric_value = np.mean(metric_value_list) | |
info_for_logging.update({metric: metric_value}) | |
else: | |
one_time_warning( | |
"If you want to use wandb to visualize the result, please set plot_logger = True in the config." | |
) | |
if ctx.eval_value != -np.inf: | |
if hasattr(ctx, "eval_value_min"): | |
info_for_logging.update({ | |
"episode return min": ctx.eval_value_min, | |
}) | |
if hasattr(ctx, "eval_value_max"): | |
info_for_logging.update({ | |
"episode return max": ctx.eval_value_max, | |
}) | |
if hasattr(ctx, "eval_value_std"): | |
info_for_logging.update({ | |
"episode return std": ctx.eval_value_std, | |
}) | |
if hasattr(ctx, "eval_value"): | |
info_for_logging.update({ | |
"episode return mean": ctx.eval_value, | |
}) | |
if hasattr(ctx, "train_iter"): | |
info_for_logging.update({ | |
"train iter": ctx.train_iter, | |
}) | |
if hasattr(ctx, "env_step"): | |
info_for_logging.update({ | |
"env step": ctx.env_step, | |
}) | |
eval_output = ctx.eval_output['output'] | |
episode_return = ctx.eval_output['episode_return'] | |
episode_return = np.array(episode_return) | |
if len(episode_return.shape) == 2: | |
episode_return = episode_return.squeeze(1) | |
if cfg.video_logger: | |
if 'replay_video' in ctx.eval_output: | |
# save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format | |
# The numpy tensor must be either 4 dimensional or 5 dimensional. | |
# Channels should be (time, channel, height, width) or (batch, time, channel, height width) | |
video_images = ctx.eval_output['replay_video'] | |
video_images = video_images.astype(np.uint8) | |
info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) | |
elif record_path is not None: | |
file_list = [] | |
for p in os.listdir(record_path): | |
if os.path.splitext(p)[-1] == ".mp4": | |
file_list.append(p) | |
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) | |
video_path = os.path.join(record_path, file_list[-2]) | |
info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) | |
if cfg.action_logger: | |
action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif")) | |
if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): | |
if isinstance(eval_output, tnp.ndarray): | |
action_prob = softmax(eval_output.logit) | |
else: | |
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] | |
fig, ax = plt.subplots() | |
plt.ylim([-1, 1]) | |
action_dim = len(action_prob[1]) | |
x_range = [str(x + 1) for x in range(action_dim)] | |
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) | |
ani = animation.FuncAnimation( | |
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) | |
) | |
ani.save(action_path, writer='pillow') | |
info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) | |
elif all(['action' in v for v in eval_output[0]]): | |
for i, action_trajectory in enumerate(eval_output): | |
fig, ax = plt.subplots() | |
fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) | |
steps = fig_data[:, 0] | |
actions = fig_data[:, 1:] | |
plt.ylim([-1, 1]) | |
for j in range(actions.shape[1]): | |
ax.scatter(steps, actions[:, j]) | |
info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) | |
if cfg.return_logger: | |
return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif")) | |
fig, ax = plt.subplots() | |
ax = plt.gca() | |
ax.set_ylim([0, 1]) | |
hist, x_dim = return_distribution(episode_return) | |
assert len(hist) == len(x_dim) | |
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) | |
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) | |
ani.save(return_path, writer='pillow') | |
info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) | |
if bool(info_for_logging): | |
wandb.log(data=info_for_logging, step=ctx.env_step) | |
plt.clf() | |
return _plot | |
def wandb_offline_logger( | |
record_path: str = None, | |
cfg: Union[dict, EasyDict] = None, | |
exp_config: Union[dict, EasyDict] = None, | |
metric_list: Optional[List[str]] = None, | |
env: Optional[BaseEnvManagerV2] = None, | |
model: Optional[torch.nn.Module] = None, | |
anonymous: bool = False, | |
project_name: str = 'default-project', | |
run_name: str = None, | |
wandb_sweep: bool = False, | |
) -> Callable: | |
""" | |
Overview: | |
Wandb visualizer to track the experiment. | |
Arguments: | |
- record_path (:obj:`str`): The path to save the replay of simulation. | |
- cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: | |
- gradient_logger: boolean. Whether to track the gradient. | |
- plot_logger: boolean. Whether to track the metrics like reward and loss. | |
- video_logger: boolean. Whether to upload the rendering video replay. | |
- action_logger: boolean. `q_value` or `action probability`. | |
- return_logger: boolean. Whether to track the return value. | |
- vis_dataset: boolean. Whether to visualize the dataset. | |
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. | |
- env (:obj:`BaseEnvManagerV2`): Evaluator environment. | |
- model (:obj:`nn.Module`): Policy neural network model. | |
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ | |
of data without wandb count. | |
- project_name (:obj:`str`): The name of wandb project. | |
- run_name (:obj:`str`): The name of wandb run. | |
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep. | |
''' | |
Returns: | |
- _plot (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. | |
""" | |
if task.router.is_active and not task.has_role(task.role.LEARNER): | |
return task.void() | |
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] | |
if metric_list is None: | |
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] | |
# Initialize wandb with default settings | |
# Settings can be covered by calling wandb.init() at the top of the script | |
if exp_config: | |
if not wandb_sweep: | |
if run_name is not None: | |
if anonymous: | |
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") | |
else: | |
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) | |
else: | |
if anonymous: | |
wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") | |
else: | |
wandb.init(project=project_name, config=exp_config, reinit=True) | |
else: | |
if run_name is not None: | |
if anonymous: | |
wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") | |
else: | |
wandb.init(project=project_name, config=exp_config, name=run_name) | |
else: | |
if anonymous: | |
wandb.init(project=project_name, config=exp_config, anonymous="must") | |
else: | |
wandb.init(project=project_name, config=exp_config) | |
else: | |
if not wandb_sweep: | |
if run_name is not None: | |
if anonymous: | |
wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") | |
else: | |
wandb.init(project=project_name, reinit=True, name=run_name) | |
else: | |
if anonymous: | |
wandb.init(project=project_name, reinit=True, anonymous="must") | |
else: | |
wandb.init(project=project_name, reinit=True) | |
else: | |
if run_name is not None: | |
if anonymous: | |
wandb.init(project=project_name, name=run_name, anonymous="must") | |
else: | |
wandb.init(project=project_name, name=run_name) | |
else: | |
if anonymous: | |
wandb.init(project=project_name, anonymous="must") | |
else: | |
wandb.init(project=project_name) | |
plt.switch_backend('agg') | |
plt.switch_backend('agg') | |
if cfg is None: | |
cfg = EasyDict( | |
dict( | |
gradient_logger=False, | |
plot_logger=True, | |
video_logger=False, | |
action_logger=False, | |
return_logger=False, | |
vis_dataset=True, | |
) | |
) | |
else: | |
if not isinstance(cfg, EasyDict): | |
cfg = EasyDict(cfg) | |
for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: | |
if key not in cfg.keys(): | |
cfg[key] = False | |
# The visualizer is called to save the replay of the simulation | |
# which will be uploaded to wandb later | |
if env is not None and cfg.video_logger is True and record_path is not None: | |
env.enable_save_replay(replay_path=record_path) | |
if cfg.gradient_logger: | |
wandb.watch(model, log="all", log_freq=100, log_graph=True) | |
else: | |
one_time_warning( | |
"If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." | |
) | |
first_plot = True | |
def _vis_dataset(datasetpath: str): | |
try: | |
from sklearn.manifold import TSNE | |
except ImportError: | |
import sys | |
logging.warning("Please install sklearn first, such as `pip3 install scikit-learn`.") | |
sys.exit(1) | |
try: | |
import h5py | |
except ImportError: | |
import sys | |
logging.warning("Please install h5py first, such as `pip3 install h5py`.") | |
sys.exit(1) | |
assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5'] | |
if os.path.splitext(datasetpath)[-1] == '.pkl': | |
with open(datasetpath, 'rb') as f: | |
data = pickle.load(f) | |
obs = [] | |
action = [] | |
reward = [] | |
for i in range(len(data)): | |
obs.extend(data[i]['observations']) | |
action.extend(data[i]['actions']) | |
reward.extend(data[i]['rewards']) | |
elif os.path.splitext(datasetpath)[-1] in ['.h5', '.hdf5']: | |
with h5py.File(datasetpath, 'r') as f: | |
obs = f['obs'][()] | |
action = f['action'][()] | |
reward = f['reward'][()] | |
cmap = plt.cm.hsv | |
obs = np.array(obs) | |
reward = np.array(reward) | |
obs_action = np.hstack((obs, np.array(action))) | |
reward = reward / (max(reward) - min(reward)) | |
embedded_obs = TSNE(n_components=2).fit_transform(obs) | |
embedded_obs_action = TSNE(n_components=2).fit_transform(obs_action) | |
x_min, x_max = np.min(embedded_obs, 0), np.max(embedded_obs, 0) | |
embedded_obs = embedded_obs / (x_max - x_min) | |
x_min, x_max = np.min(embedded_obs_action, 0), np.max(embedded_obs_action, 0) | |
embedded_obs_action = embedded_obs_action / (x_max - x_min) | |
fig = plt.figure() | |
f, axes = plt.subplots(nrows=1, ncols=3) | |
axes[0].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(reward)) | |
axes[1].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(action)) | |
axes[2].scatter(embedded_obs_action[:, 0], embedded_obs_action[:, 1], c=cmap(reward)) | |
axes[0].set_title('state-reward') | |
axes[1].set_title('state-action') | |
axes[2].set_title('stateAction-reward') | |
plt.savefig('dataset.png') | |
wandb.log({"dataset": wandb.Image("dataset.png")}) | |
if cfg.vis_dataset is True: | |
_vis_dataset(exp_config.dataset_path) | |
def _plot(ctx: "OfflineRLContext"): | |
nonlocal first_plot | |
if first_plot: | |
first_plot = False | |
ctx.wandb_url = wandb.run.get_project_url() | |
info_for_logging = {} | |
if cfg.plot_logger: | |
for metric in metric_list: | |
if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: | |
if isinstance(ctx.train_output[metric], torch.Tensor): | |
info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) | |
else: | |
info_for_logging.update({metric: ctx.train_output[metric]}) | |
elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: | |
metric_value_list = [] | |
for item in ctx.train_output: | |
if isinstance(item[metric], torch.Tensor): | |
metric_value_list.append(item[metric].cpu().detach().numpy()) | |
else: | |
metric_value_list.append(item[metric]) | |
metric_value = np.mean(metric_value_list) | |
info_for_logging.update({metric: metric_value}) | |
else: | |
one_time_warning( | |
"If you want to use wandb to visualize the result, please set plot_logger = True in the config." | |
) | |
if ctx.eval_value != -np.inf: | |
if hasattr(ctx, "eval_value_min"): | |
info_for_logging.update({ | |
"episode return min": ctx.eval_value_min, | |
}) | |
if hasattr(ctx, "eval_value_max"): | |
info_for_logging.update({ | |
"episode return max": ctx.eval_value_max, | |
}) | |
if hasattr(ctx, "eval_value_std"): | |
info_for_logging.update({ | |
"episode return std": ctx.eval_value_std, | |
}) | |
if hasattr(ctx, "eval_value"): | |
info_for_logging.update({ | |
"episode return mean": ctx.eval_value, | |
}) | |
if hasattr(ctx, "train_iter"): | |
info_for_logging.update({ | |
"train iter": ctx.train_iter, | |
}) | |
if hasattr(ctx, "train_epoch"): | |
info_for_logging.update({ | |
"train_epoch": ctx.train_epoch, | |
}) | |
eval_output = ctx.eval_output['output'] | |
episode_return = ctx.eval_output['episode_return'] | |
episode_return = np.array(episode_return) | |
if len(episode_return.shape) == 2: | |
episode_return = episode_return.squeeze(1) | |
if cfg.video_logger: | |
if 'replay_video' in ctx.eval_output: | |
# save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format | |
# The numpy tensor must be either 4 dimensional or 5 dimensional. | |
# Channels should be (time, channel, height, width) or (batch, time, channel, height width) | |
video_images = ctx.eval_output['replay_video'] | |
video_images = video_images.astype(np.uint8) | |
info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) | |
elif record_path is not None: | |
file_list = [] | |
for p in os.listdir(record_path): | |
if os.path.splitext(p)[-1] == ".mp4": | |
file_list.append(p) | |
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) | |
video_path = os.path.join(record_path, file_list[-2]) | |
info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) | |
if cfg.action_logger: | |
action_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_action.gif")) | |
if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): | |
if isinstance(eval_output, tnp.ndarray): | |
action_prob = softmax(eval_output.logit) | |
else: | |
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] | |
fig, ax = plt.subplots() | |
plt.ylim([-1, 1]) | |
action_dim = len(action_prob[1]) | |
x_range = [str(x + 1) for x in range(action_dim)] | |
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) | |
ani = animation.FuncAnimation( | |
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) | |
) | |
ani.save(action_path, writer='pillow') | |
info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) | |
elif all(['action' in v for v in eval_output[0]]): | |
for i, action_trajectory in enumerate(eval_output): | |
fig, ax = plt.subplots() | |
fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) | |
steps = fig_data[:, 0] | |
actions = fig_data[:, 1:] | |
plt.ylim([-1, 1]) | |
for j in range(actions.shape[1]): | |
ax.scatter(steps, actions[:, j]) | |
info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) | |
if cfg.return_logger: | |
return_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_return.gif")) | |
fig, ax = plt.subplots() | |
ax = plt.gca() | |
ax.set_ylim([0, 1]) | |
hist, x_dim = return_distribution(episode_return) | |
assert len(hist) == len(x_dim) | |
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) | |
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) | |
ani.save(return_path, writer='pillow') | |
info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) | |
if bool(info_for_logging): | |
wandb.log(data=info_for_logging, step=ctx.trained_env_step) | |
plt.clf() | |
return _plot | |