Spaces:
Sleeping
Sleeping
from os import path | |
import os | |
import copy | |
from easydict import EasyDict | |
from collections import deque | |
import pytest | |
import shutil | |
import wandb | |
import h5py | |
import torch.nn as nn | |
from unittest.mock import MagicMock | |
from unittest.mock import Mock, patch | |
from ding.utils import DistributedWriter | |
from ding.framework.middleware.tests import MockPolicy, CONFIG | |
from ding.framework import OnlineRLContext, OfflineRLContext | |
from ding.framework.middleware.functional import online_logger, offline_logger, wandb_online_logger, \ | |
wandb_offline_logger | |
test_folder = "test_exp" | |
test_path = path.join(os.getcwd(), test_folder) | |
cfg = EasyDict({"exp_name": "test_exp"}) | |
def get_online_ctx(): | |
ctx = OnlineRLContext() | |
ctx.eval_value = -10000 | |
ctx.train_iter = 34 | |
ctx.env_step = 78 | |
ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15} | |
return ctx | |
def online_ctx_output_dict(): | |
ctx = get_online_ctx() | |
return ctx | |
def online_ctx_output_deque(): | |
ctx = get_online_ctx() | |
ctx.train_output = deque([ctx.train_output]) | |
return ctx | |
def online_ctx_output_list(): | |
ctx = get_online_ctx() | |
ctx.train_output = [ctx.train_output] | |
return ctx | |
def online_scalar_ctx(): | |
ctx = get_online_ctx() | |
ctx.train_output = {'[scalars]': 1} | |
return ctx | |
class MockOnlineWriter: | |
def __init__(self): | |
self.ctx = get_online_ctx() | |
def add_scalar(self, tag, scalar_value, global_step): | |
if tag in ['basic/eval_episode_return_mean-env_step', 'basic/eval_episode_return_mean']: | |
assert scalar_value == self.ctx.eval_value | |
assert global_step == self.ctx.env_step | |
elif tag == 'basic/eval_episode_return_mean-train_iter': | |
assert scalar_value == self.ctx.eval_value | |
assert global_step == self.ctx.train_iter | |
elif tag in ['basic/train_td_error-env_step', 'basic/train_td_error']: | |
assert scalar_value == self.ctx.train_output['td_error'] | |
assert global_step == self.ctx.env_step | |
elif tag == 'basic/train_td_error-train_iter': | |
assert scalar_value == self.ctx.train_output['td_error'] | |
assert global_step == self.ctx.train_iter | |
else: | |
raise NotImplementedError('tag should be in the tags defined') | |
def add_histogram(self, tag, values, global_step): | |
assert tag == 'test_histogram' | |
assert values == [1, 2, 3, 4, 5, 6] | |
assert global_step in [self.ctx.train_iter, self.ctx.env_step] | |
def close(self): | |
pass | |
def mock_get_online_instance(): | |
return MockOnlineWriter() | |
class TestOnlineLogger: | |
def test_online_logger_output_dict(self, online_ctx_output_dict): | |
with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): | |
online_logger()(online_ctx_output_dict) | |
def test_online_logger_record_output_dict(self, online_ctx_output_dict): | |
with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): | |
online_logger(record_train_iter=True)(online_ctx_output_dict) | |
def test_online_logger_record_output_deque(self, online_ctx_output_deque): | |
with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): | |
online_logger()(online_ctx_output_deque) | |
def get_offline_ctx(): | |
ctx = OfflineRLContext() | |
ctx.eval_value = -10000000000 | |
ctx.train_iter = 3333 | |
ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15} | |
return ctx | |
def offline_ctx_output_dict(): | |
ctx = get_offline_ctx() | |
return ctx | |
def offline_scalar_ctx(): | |
ctx = get_offline_ctx() | |
ctx.train_output = {'[scalars]': 1} | |
return ctx | |
class MockOfflineWriter: | |
def __init__(self): | |
self.ctx = get_offline_ctx() | |
def add_scalar(self, tag, scalar_value, global_step): | |
assert global_step == self.ctx.train_iter | |
if tag == 'basic/eval_episode_return_mean-train_iter': | |
assert scalar_value == self.ctx.eval_value | |
elif tag == 'basic/train_td_error-train_iter': | |
assert scalar_value == self.ctx.train_output['td_error'] | |
else: | |
raise NotImplementedError('tag should be in the tags defined') | |
def add_histogram(self, tag, values, global_step): | |
assert tag == 'test_histogram' | |
assert values == [1, 2, 3, 4, 5, 6] | |
assert global_step == self.ctx.train_iter | |
def close(self): | |
pass | |
def mock_get_offline_instance(): | |
return MockOfflineWriter() | |
class TestOfflineLogger: | |
def test_offline_logger_no_scalars(self, offline_ctx_output_dict): | |
with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): | |
offline_logger()(offline_ctx_output_dict) | |
def test_offline_logger_scalars(self, offline_scalar_ctx): | |
with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): | |
with pytest.raises(NotImplementedError) as exc_info: | |
offline_logger()(offline_scalar_ctx) | |
class TheModelClass(nn.Module): | |
def state_dict(self): | |
return 'fake_state_dict' | |
class TheEnvClass(Mock): | |
def enable_save_replay(self, replay_path): | |
return | |
class TheObsDataClass(Mock): | |
def __getitem__(self, index): | |
return [[1, 1, 1]] * 50 | |
class The1DDataClass(Mock): | |
def __getitem__(self, index): | |
return [[1]] * 50 | |
def test_wandb_online_logger(): | |
record_path = './video_qbert_dqn' | |
cfg = EasyDict( | |
dict( | |
gradient_logger=True, | |
plot_logger=True, | |
action_logger=True, | |
return_logger=True, | |
video_logger=True, | |
) | |
) | |
env = TheEnvClass() | |
ctx = OnlineRLContext() | |
ctx.train_output = [{'reward': 1, 'q_value': [1.0]}] | |
model = TheModelClass() | |
wandb.init(config=cfg, anonymous="must") | |
def mock_metric_logger(data, step): | |
metric_list = [ | |
"q_value", | |
"target q_value", | |
"loss", | |
"lr", | |
"entropy", | |
"reward", | |
"q value", | |
"video", | |
"q value distribution", | |
"train iter", | |
"episode return mean", | |
"env step", | |
"action", | |
"actions_of_trajectory_0", | |
"actions_of_trajectory_1", | |
"actions_of_trajectory_2", | |
"actions_of_trajectory_3", | |
"return distribution", | |
] | |
assert set(data.keys()) <= set(metric_list) | |
def mock_gradient_logger(input_model, log, log_freq, log_graph): | |
assert input_model == model | |
def test_wandb_online_logger_metric(): | |
with patch.object(wandb, 'log', new=mock_metric_logger): | |
wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx) | |
def test_wandb_online_logger_gradient(): | |
with patch.object(wandb, 'watch', new=mock_gradient_logger): | |
wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx) | |
test_wandb_online_logger_metric() | |
test_wandb_online_logger_gradient() | |
def test_wandb_offline_logger(): | |
record_path = './video_pendulum_cql' | |
cfg = EasyDict(dict(gradient_logger=True, plot_logger=True, action_logger=True, vis_dataset=True)) | |
env = TheEnvClass() | |
ctx = OfflineRLContext() | |
ctx.train_output = [{'reward': 1, 'q_value': [1.0]}] | |
model = TheModelClass() | |
wandb.init(config=cfg, anonymous="must") | |
exp_config = EasyDict(dict(dataset_path='dataset.h5')) | |
def mock_metric_logger(data, step=None): | |
metric_list = [ | |
"q_value", "target q_value", "loss", "lr", "entropy", "reward", "q value", "video", "q value distribution", | |
"train iter", 'dataset' | |
] | |
assert set(data.keys()) < set(metric_list) | |
def mock_gradient_logger(input_model, log, log_freq, log_graph): | |
assert input_model == model | |
def mock_image_logger(imagepath): | |
assert os.path.splitext(imagepath)[-1] == '.png' | |
def test_wandb_offline_logger_gradient(): | |
cfg.vis_dataset = False | |
print(cfg) | |
with patch.object(wandb, 'watch', new=mock_gradient_logger): | |
wandb_offline_logger( | |
record_path=record_path, cfg=cfg, exp_config=exp_config, env=env, model=model, anonymous=True | |
)(ctx) | |
def test_wandb_offline_logger_dataset(): | |
cfg.vis_dataset = True | |
m = MagicMock() | |
m.__enter__.return_value = {'obs': TheObsDataClass(), 'action': The1DDataClass(), 'reward': The1DDataClass()} | |
with patch.object(wandb, 'log', new=mock_metric_logger): | |
with patch.object(wandb, 'Image', new=mock_image_logger): | |
with patch('h5py.File', return_value=m): | |
wandb_offline_logger( | |
record_path=record_path, cfg=cfg, exp_config=exp_config, env=env, model=model, anonymous=True | |
)(ctx) | |
test_wandb_offline_logger_gradient() | |
test_wandb_offline_logger_dataset() | |