Spaces:
Sleeping
Sleeping
import numpy as np | |
import pytest | |
import torch | |
from lzero.mcts.buffer.game_segment import GameSegment | |
from lzero.mcts.utils import prepare_observation | |
from lzero.policy import select_action | |
# args = ['EfficientZero', 'MuZero'] | |
args = ["MuZero"] | |
def test_game_segment(test_algo): | |
# import different modules according to ``test_algo`` | |
if test_algo == 'EfficientZero': | |
from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree | |
from lzero.model.efficientzero_model import EfficientZeroModel as Model | |
from lzero.mcts.tests.config.atari_efficientzero_config_for_test import atari_efficientzero_config as config | |
from zoo.atari.envs.atari_lightzero_env import AtariLightZeroEnv | |
envs = [AtariLightZeroEnv(config.env) for _ in range(config.env.evaluator_env_num)] | |
elif test_algo == 'MuZero': | |
from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree | |
from lzero.model.muzero_model import MuZeroModel as Model | |
from lzero.mcts.tests.config.tictactoe_muzero_bot_mode_config_for_test import tictactoe_muzero_config as config | |
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv | |
envs = [TicTacToeEnv(config.env) for _ in range(config.env.evaluator_env_num)] | |
# create model | |
model = Model(**config.policy.model) | |
if config.policy.cuda and torch.cuda.is_available(): | |
config.policy.device = 'cuda' | |
else: | |
config.policy.device = 'cpu' | |
model.to(config.policy.device) | |
model.eval() | |
with torch.no_grad(): | |
# initializations | |
init_observations = [env.reset() for env in envs] | |
dones = np.array([False for _ in range(config.env.evaluator_env_num)]) | |
game_segments = [ | |
GameSegment( | |
envs[i].action_space, game_segment_length=config.policy.game_segment_length, config=config.policy | |
) for i in range(config.env.evaluator_env_num) | |
] | |
for i in range(config.env.evaluator_env_num): | |
game_segments[i].reset( | |
[init_observations[i]['observation'] for _ in range(config.policy.model.frame_stack_num)] | |
) | |
episode_rewards = np.zeros(config.env.evaluator_env_num) | |
while not dones.all(): | |
stack_obs = [game_segment.get_obs() for game_segment in game_segments] | |
stack_obs = prepare_observation(stack_obs, config.policy.model.model_type) | |
stack_obs = torch.from_numpy(np.array(stack_obs)).to(config.policy.device) | |
# ============================================================== | |
# the core initial_inference. | |
# ============================================================== | |
network_output = model.initial_inference(stack_obs) | |
# process the network output | |
policy_logits_pool = network_output.policy_logits.detach().cpu().numpy().tolist() | |
latent_state_roots = network_output.latent_state.detach().cpu().numpy() | |
if test_algo == 'EfficientZero': | |
reward_hidden_state_roots = network_output.reward_hidden_state | |
value_prefix_pool = network_output.value_prefix | |
reward_hidden_state_roots = ( | |
reward_hidden_state_roots[0].detach().cpu().numpy(), | |
reward_hidden_state_roots[1].detach().cpu().numpy() | |
) | |
# for atari env, all actions is legal_action | |
legal_actions_list = [ | |
[i for i in range(config.policy.model.action_space_size)] | |
for _ in range(config.env.evaluator_env_num) | |
] | |
elif test_algo == 'MuZero': | |
reward_pool = network_output.reward | |
# for board games, we use the all actions is legal_action | |
legal_actions_list = [ | |
[a for a, x in enumerate(init_observations[i]['action_mask']) if x == 1] | |
for i in range(config.env.evaluator_env_num) | |
] | |
# null padding for the atari games and board_games in vs_bot_mode | |
to_play = [-1 for _ in range(config.env.evaluator_env_num)] | |
if test_algo == 'EfficientZero': | |
roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list) | |
roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play) | |
MCTSCtree(config.policy).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) | |
elif test_algo == 'MuZero': | |
roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list) | |
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) | |
MCTSCtree(config.policy).search(roots, model, latent_state_roots, to_play) | |
roots_distributions = roots.get_distributions() | |
roots_values = roots.get_values() | |
for i in range(config.env.evaluator_env_num): | |
distributions, value, env = roots_distributions[i], roots_values[i], envs[i] | |
# ``deterministic=True`` indicates that we select the argmax action instead of sampling. | |
action, _ = select_action(distributions, temperature=1, deterministic=True) | |
# ============================================================== | |
# the core initial_inference. | |
# ============================================================== | |
obs, reward, done, info = env.step(action) | |
obs = obs['observation'] | |
game_segments[i].store_search_stats(distributions, value) | |
game_segments[i].append(action, obs, reward) | |
dones[i] = done | |
episode_rewards[i] += reward | |
if dones[i]: | |
continue | |
for env in envs: | |
env.close() | |