Spaces:
Sleeping
Sleeping
import numpy as np | |
import dataclasses | |
import treetensor.torch as ttorch | |
from typing import Union, Dict, List | |
class Context: | |
""" | |
Overview: | |
Context is an object that pass contextual data between middlewares, whose life cycle | |
is only one training iteration. It is a dict that reflect itself, so you can set | |
any properties as you wish. | |
Note that the initial value of the property must be equal to False. | |
""" | |
_kept_keys: set = dataclasses.field(default_factory=set) | |
total_step: int = 0 | |
def renew(self) -> 'Context': # noqa | |
""" | |
Overview: | |
Renew context from self, add total_step and shift kept properties to the new instance. | |
""" | |
total_step = self.total_step | |
ctx = type(self)() | |
for key in self._kept_keys: | |
if self.has_attr(key): | |
setattr(ctx, key, getattr(self, key)) | |
ctx.total_step = total_step + 1 | |
return ctx | |
def keep(self, *keys: str) -> None: | |
""" | |
Overview: | |
Keep this key/keys until next iteration. | |
""" | |
for key in keys: | |
self._kept_keys.add(key) | |
def has_attr(self, key): | |
return hasattr(self, key) | |
# TODO: Restrict data to specific types | |
class OnlineRLContext(Context): | |
# common | |
total_step: int = 0 | |
env_step: int = 0 | |
env_episode: int = 0 | |
train_iter: int = 0 | |
train_data: Union[Dict, List] = None | |
train_output: Union[Dict, List[Dict]] = None | |
# collect | |
collect_kwargs: Dict = dataclasses.field(default_factory=dict) | |
obs: ttorch.Tensor = None | |
action: List = None | |
inference_output: Dict[int, Dict] = None | |
trajectories: List = None | |
episodes: List = None | |
trajectory_end_idx: List = dataclasses.field(default_factory=list) | |
action: Dict = None | |
inference_output: Dict = None | |
# eval | |
eval_value: float = -np.inf | |
last_eval_iter: int = -1 | |
last_eval_value: int = -np.inf | |
eval_output: List = dataclasses.field(default_factory=dict) | |
# wandb | |
wandb_url: str = "" | |
def __post_init__(self): | |
# This method is called just after __init__ method. Here, concretely speaking, | |
# this method is called just after the object initialize its fields. | |
# We use this method here to keep the fields needed for each iteration. | |
self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url') | |
class OfflineRLContext(Context): | |
# common | |
total_step: int = 0 | |
trained_env_step: int = 0 | |
train_epoch: int = 0 | |
train_iter: int = 0 | |
train_data: Union[Dict, List] = None | |
train_output: Union[Dict, List[Dict]] = None | |
# eval | |
eval_value: float = -np.inf | |
last_eval_iter: int = -1 | |
last_eval_value: int = -np.inf | |
eval_output: List = dataclasses.field(default_factory=dict) | |
# wandb | |
wandb_url: str = "" | |
def __post_init__(self): | |
# This method is called just after __init__ method. Here, concretely speaking, | |
# this method is called just after the object initialize its fields. | |
# We use this method here to keep the fields needed for each iteration. | |
self.keep('trained_env_step', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url') | |