Spaces:
Sleeping
Sleeping
from ding.entry import serial_pipeline_bc, serial_pipeline, collect_demo_data | |
from dizoo.mujoco.config.halfcheetah_td3_config import main_config, create_config | |
from copy import deepcopy | |
from typing import Union, Optional, List, Any, Tuple | |
import os | |
import torch | |
import logging | |
from functools import partial | |
from tensorboardX import SummaryWriter | |
import torch.nn as nn | |
from ding.envs import get_vec_env_setting, create_env_manager | |
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ | |
create_serial_collector | |
from ding.config import read_config, compile_config | |
from ding.policy import create_policy | |
from ding.utils import set_pkg_seed | |
from ding.entry.utils import random_collect | |
from ding.entry import collect_demo_data, collect_episodic_demo_data, episode_to_transitions | |
import pickle | |
def load_policy( | |
input_cfg: Union[str, Tuple[dict, dict]], | |
load_path: str, | |
seed: int = 0, | |
env_setting: Optional[List[Any]] = None, | |
model: Optional[torch.nn.Module] = None, | |
) -> 'Policy': # noqa | |
if isinstance(input_cfg, str): | |
cfg, create_cfg = read_config(input_cfg) | |
else: | |
cfg, create_cfg = input_cfg | |
create_cfg.policy.type = create_cfg.policy.type + '_command' | |
env_fn = None if env_setting is None else env_setting[0] | |
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) | |
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) | |
sd = torch.load(load_path, map_location='cpu') | |
policy.collect_mode.load_state_dict(sd) | |
return policy | |
def main(): | |
half_td3_config, half_td3_create_config = main_config, create_config | |
train_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] | |
exp_path = 'DI-engine/halfcheetah_td3_seed0/ckpt/ckpt_best.pth.tar' | |
expert_policy = load_policy(train_config, load_path=exp_path, seed=0) | |
# collect expert demo data | |
collect_count = 100 | |
expert_data_path = 'expert_data.pkl' | |
state_dict = expert_policy.collect_mode.state_dict() | |
collect_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] | |
collect_episodic_demo_data( | |
deepcopy(collect_config), | |
seed=0, | |
state_dict=state_dict, | |
expert_data_path=expert_data_path, | |
collect_count=collect_count | |
) | |
episode_to_transitions(expert_data_path, expert_data_path, nstep=1) | |
# il training 2 | |
il_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] | |
il_config[0].policy.learn.train_epoch = 1000000 | |
il_config[0].policy.type = 'bc' | |
il_config[0].policy.continuous = True | |
il_config[0].exp_name = "continuous_bc_seed0" | |
il_config[0].env.stop_value = 50000 | |
il_config[0].multi_agent = False | |
bc_policy, converge_stop_flag = serial_pipeline_bc(il_config, seed=314, data_path=expert_data_path, max_iter=4e6) | |
return bc_policy | |
if __name__ == '__main__': | |
policy = main() | |