Spaces:
Sleeping
Sleeping
from typing import Optional, Union | |
import numpy as np | |
from ding.utils import PLAYER_REGISTRY | |
from .player import ActivePlayer, HistoricalPlayer | |
from .algorithm import pfsp | |
class MainPlayer(ActivePlayer): | |
""" | |
Overview: | |
Main player in league training. | |
Default branch (0.5 pfsp, 0.35 sp, 0.15 veri). | |
Default snapshot every 2e9 steps. | |
Default mutate prob = 0 (never mutate). | |
Interface: | |
__init__, is_trained_enough, snapshot, mutate, get_job | |
Property: | |
race, payoff, checkpoint_path, player_id, train_iteration | |
""" | |
_name = "MainPlayer" | |
def _pfsp_branch(self) -> HistoricalPlayer: | |
""" | |
Overview: | |
Select prioritized fictitious self-play opponent, should be a historical player. | |
Returns: | |
- player (:obj:`HistoricalPlayer`): the selected historical player | |
""" | |
historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer)) | |
win_rates = self._payoff[self, historical] | |
p = pfsp(win_rates, weighting='squared') | |
return self._get_opponent(historical, p) | |
def _sp_branch(self): | |
""" | |
Overview: | |
Select normal self-play opponent | |
""" | |
main_players = self._get_players(lambda p: isinstance(p, MainPlayer)) | |
main_opponent = self._get_opponent(main_players) | |
# TODO(nyz) if only one main_player, self-play win_rates are constantly equal to 0.5 | |
# main_opponent is not too strong | |
if self._payoff[self, main_opponent] > 1 - self._strong_win_rate: | |
return main_opponent | |
# if the main_opponent is too strong, select a past alternative | |
historical = self._get_players( | |
lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id | |
) | |
win_rates = self._payoff[self, historical] | |
p = pfsp(win_rates, weighting='variance') | |
return self._get_opponent(historical, p) | |
def _verification_branch(self): | |
""" | |
Overview: | |
Verify no strong historical main exploiter and no forgotten historical past main player | |
""" | |
# check exploitation | |
main_exploiters = self._get_players(lambda p: isinstance(p, MainExploiter)) | |
exp_historical = self._get_players( | |
lambda p: isinstance(p, HistoricalPlayer) and any([p.parent_id == m.player_id for m in main_exploiters]) | |
) | |
win_rates = self._payoff[self, exp_historical] | |
# TODO(nyz) why min win_rates 0.3 | |
if len(win_rates) and win_rates.min() < 1 - self._strong_win_rate: | |
p = pfsp(win_rates, weighting='squared') | |
return self._get_opponent(exp_historical, p) | |
# check forgotten | |
main_players = self._get_players(lambda p: isinstance(p, MainPlayer)) | |
main_opponent = self._get_opponent(main_players) # only one main player | |
main_historical = self._get_players( | |
lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id | |
) | |
win_rates = self._payoff[self, main_historical] | |
# TODO(nyz) whether the method `_get_players` should return players with some sequence(such as step) | |
# win_rates, historical = self._remove_monotonic_suffix(win_rates, historical) | |
if len(win_rates) and win_rates.min() < self._strong_win_rate: | |
p = pfsp(win_rates, weighting='squared') | |
return self._get_opponent(main_historical, p) | |
# no forgotten main players or strong main exploiters, use self-play instead | |
return self._sp_branch() | |
# def _remove_monotonic_suffix(self, win_rates, players): | |
# if not len(win_rates): | |
# return win_rates, players | |
# for i in range(len(win_rates) - 1, 0, -1): | |
# if win_rates[i - 1] < win_rates[i]: | |
# return win_rates[:i + 1], players[:i + 1] | |
# return np.array([]), [] | |
# override | |
def is_trained_enough(self) -> bool: | |
# ``_pfsp_branch`` and ``_verification_branch`` are played against historcial player | |
return super().is_trained_enough(select_fn=lambda p: isinstance(p, HistoricalPlayer)) | |
# override | |
def mutate(self, info: dict) -> None: | |
""" | |
Overview: | |
MainPlayer does not mutate | |
""" | |
pass | |
class MainExploiter(ActivePlayer): | |
""" | |
Overview: | |
Main exploiter in league training. Can identify weaknesses of main agents, and consequently make them | |
more robust. | |
Default branch (1.0 main_players). | |
Default snapshot when defeating all 3 main players in the league in more than 70% of games, | |
or timeout of 4e9 steps. | |
Default mutate prob = 1 (must mutate). | |
Interface: | |
__init__, is_trained_enough, snapshot, mutate, get_job | |
Property: | |
race, payoff, checkpoint_path, player_id, train_iteration | |
""" | |
_name = "MainExploiter" | |
def __init__(self, *args, **kwargs): | |
""" | |
Overview: | |
Initialize ``min_valid_win_rate`` additionally | |
Note: | |
- min_valid_win_rate (:obj:`float`): only when win rate against the main player is greater than this, \ | |
can the main player be regarded as able to produce valid training signals to be selected | |
""" | |
super(MainExploiter, self).__init__(*args, **kwargs) | |
self._min_valid_win_rate = self._cfg.min_valid_win_rate | |
def _main_players_branch(self): | |
""" | |
Overview: | |
Select main player or historical player snapshot from main player as opponent | |
Returns: | |
- player (:obj:`Player`): the selected main player (active/historical) | |
""" | |
# get the main player (only one) | |
main_players = self._get_players(lambda p: isinstance(p, MainPlayer)) | |
main_opponent = self._get_opponent(main_players) | |
# if this main_opponent can produce valid training signals | |
if self._payoff[self, main_opponent] >= self._min_valid_win_rate: | |
return main_opponent | |
# otherwise, curriculum learning, select a historical version | |
historical = self._get_players( | |
lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id | |
) | |
win_rates = self._payoff[self, historical] | |
p = pfsp(win_rates, weighting='variance') | |
return self._get_opponent(historical, p) | |
# override | |
def is_trained_enough(self): | |
# would play against main player, or historical main player (if main player is too strong) | |
return super().is_trained_enough(select_fn=lambda p: isinstance(p, MainPlayer)) | |
# override | |
def mutate(self, info: dict) -> str: | |
""" | |
Overview: | |
Main exploiter is sure to mutate(reset) to the supervised learning player | |
Returns: | |
- mutate_ckpt_path (:obj:`str`): mutation target checkpoint path | |
""" | |
return info['reset_checkpoint_path'] | |
class LeagueExploiter(ActivePlayer): | |
""" | |
Overview: | |
League exploiter in league training. Can identify global blind spots in the league (strategies that no player | |
in the league can beat, but that are not necessarily robust themselves). | |
Default branch (1.0 pfsp). | |
Default snapshot when defeating all players in the league in more than 70% of games, or timeout of 2e9 steps. | |
Default mutate prob = 0.25. | |
Interface: | |
__init__, is_trained_enough, snapshot, mutate, get_job | |
Property: | |
race, payoff, checkpoint_path, player_id, train_iteration | |
""" | |
_name = "LeagueExploiter" | |
def __init__(self, *args, **kwargs) -> None: | |
""" | |
Overview: | |
Initialize ``mutate_prob`` additionally | |
Note: | |
- mutate_prob (:obj:`float`): the mutation probability of league exploiter. should be in [0, 1] | |
""" | |
super(LeagueExploiter, self).__init__(*args, **kwargs) | |
assert 0 <= self._cfg.mutate_prob <= 1 | |
self.mutate_prob = self._cfg.mutate_prob | |
def _pfsp_branch(self) -> HistoricalPlayer: | |
""" | |
Overview: | |
Select prioritized fictitious self-play opponent | |
Returns: | |
- player (:obj:`HistoricalPlayer`): the selected historical player | |
Note: | |
This branch is the same as the psfp branch in MainPlayer | |
""" | |
historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer)) | |
win_rates = self._payoff[self, historical] | |
p = pfsp(win_rates, weighting='squared') | |
return self._get_opponent(historical, p) | |
# override | |
def is_trained_enough(self) -> bool: | |
# will only player against historical player | |
return super().is_trained_enough(select_fn=lambda p: isinstance(p, HistoricalPlayer)) | |
# override | |
def mutate(self, info) -> Union[str, None]: | |
""" | |
Overview: | |
League exploiter can mutate to the supervised learning player with 0.25 prob | |
Returns: | |
- ckpt_path (:obj:`Union[str, None]`): with ``mutate_prob`` prob returns the pretrained model's ckpt path, \ | |
with left 1 - ``mutate_prob`` prob returns None, which means no mutation | |
""" | |
p = np.random.uniform() | |
if p < self.mutate_prob: | |
return info['reset_checkpoint_path'] | |
return None | |