Spaces:
Sleeping
Sleeping
from asyncio import InvalidStateError | |
from asyncio.tasks import FIRST_EXCEPTION | |
from collections import OrderedDict | |
from threading import Lock | |
import time | |
import asyncio | |
import concurrent.futures | |
import fnmatch | |
import math | |
import enum | |
from types import GeneratorType | |
from typing import Any, Awaitable, Callable, Dict, Generator, Iterable, List, Optional, Set, Union | |
import inspect | |
from ding.framework.context import Context | |
from ding.framework.parallel import Parallel | |
from ding.framework.event_loop import EventLoop | |
from functools import wraps | |
def enable_async(func: Callable) -> Callable: | |
""" | |
Overview: | |
Empower the function with async ability. | |
Arguments: | |
- func (:obj:`Callable`): The original function. | |
Returns: | |
- runtime_handler (:obj:`Callable`): The wrap function. | |
""" | |
def runtime_handler(task: "Task", *args, async_mode: Optional[bool] = None, **kwargs) -> "Task": | |
""" | |
Overview: | |
If task's async mode is enabled, execute the step in current loop executor asyncly, | |
or execute the task sync. | |
Arguments: | |
- task (:obj:`Task`): The task instance. | |
- async_mode (:obj:`Optional[bool]`): Whether using async mode. | |
Returns: | |
- result (:obj:`Union[Any, Awaitable]`): The result or future object of middleware. | |
""" | |
if async_mode is None: | |
async_mode = task.async_mode | |
if async_mode: | |
assert not kwargs, "Should not use kwargs in async_mode, use position parameters, kwargs: {}".format(kwargs) | |
t = task._async_loop.run_in_executor(task._thread_pool, func, task, *args, **kwargs) | |
task._async_stack.append(t) | |
return task | |
else: | |
return func(task, *args, **kwargs) | |
return runtime_handler | |
class Role(str, enum.Enum): | |
LEARNER = "learner" | |
COLLECTOR = "collector" | |
EVALUATOR = "evaluator" | |
FETCHER = 'fetcher' | |
class VoidMiddleware: | |
def __call__(self, _): | |
return | |
class Task: | |
""" | |
Task will manage the execution order of the entire pipeline, register new middleware, | |
and generate new context objects. | |
""" | |
role = Role | |
def __init__(self) -> None: | |
self.router = Parallel() | |
self._finish = False | |
def start( | |
self, | |
async_mode: bool = False, | |
n_async_workers: int = 3, | |
ctx: Optional[Context] = None, | |
labels: Optional[Set[str]] = None | |
) -> "Task": | |
# This flag can be modified by external or associated processes | |
self._finish = False | |
# This flag can only be modified inside the class, it will be set to False in the end of stop | |
self._running = True | |
self._middleware = [] | |
self._wrappers = [] | |
self.ctx = ctx or Context() | |
self._backward_stack = OrderedDict() | |
self._roles = set() | |
# Bind event loop functions | |
self._event_loop = EventLoop("task_{}".format(id(self))) | |
# Async segment | |
self.async_mode = async_mode | |
self.n_async_workers = n_async_workers | |
self._async_stack = [] | |
self._async_loop = None | |
self._thread_pool = None | |
self._exception = None | |
self._thread_lock = Lock() | |
self.labels = labels or set() | |
# Parallel segment | |
if async_mode or self.router.is_active: | |
self._activate_async() | |
if self.router.is_active: | |
def sync_finish(value): | |
self._finish = value | |
self.on("finish", sync_finish) | |
self.init_labels() | |
return self | |
def add_role(self, role: Role): | |
self._roles.add(role) | |
def has_role(self, role: Role) -> bool: | |
if len(self._roles) == 0: | |
return True | |
return role in self._roles | |
def roles(self) -> Set[Role]: | |
return self._roles | |
def void(self): | |
return VoidMiddleware() | |
def init_labels(self): | |
if self.async_mode: | |
self.labels.add("async") | |
if self.router.is_active: | |
self.labels.add("distributed") | |
self.labels.add("node.{}".format(self.router.node_id)) | |
for label in self.router.labels: | |
self.labels.add(label) | |
else: | |
self.labels.add("standalone") | |
def use(self, fn: Callable, lock: Union[bool, Lock] = False) -> 'Task': | |
""" | |
Overview: | |
Register middleware to task. The middleware will be executed by it's registry order. | |
Arguments: | |
- fn (:obj:`Callable`): A middleware is a function with only one argument: ctx. | |
- lock (:obj:`Union[bool, Lock]`): There can only be one middleware execution under lock at any one time. | |
Returns: | |
- task (:obj:`Task`): The task. | |
""" | |
assert isinstance(fn, Callable), "Middleware function should be a callable object, current fn {}".format(fn) | |
if isinstance(fn, VoidMiddleware): # Skip void function | |
return self | |
for wrapper in self._wrappers: | |
fn = wrapper(fn) | |
self._middleware.append(self.wrap(fn, lock=lock)) | |
return self | |
def use_wrapper(self, fn: Callable) -> 'Task': | |
""" | |
Overview: | |
Register wrappers to task. A wrapper works like a decorator, but task will apply this \ | |
decorator on top of each middleware. | |
Arguments: | |
- fn (:obj:`Callable`): A wrapper is a decorator, so the first argument is a callable function. | |
Returns: | |
- task (:obj:`Task`): The task. | |
""" | |
# Wrap exist middlewares | |
for i, middleware in enumerate(self._middleware): | |
self._middleware[i] = fn(middleware) | |
self._wrappers.append(fn) | |
return self | |
def match_labels(self, patterns: Union[Iterable[str], str]) -> bool: | |
""" | |
Overview: | |
A list of patterns to match labels. | |
Arguments: | |
- patterns (:obj:`Union[Iterable[str], str]`): Glob like pattern, e.g. node.1, node.*. | |
""" | |
if isinstance(patterns, str): | |
patterns = [patterns] | |
return any([fnmatch.filter(self.labels, p) for p in patterns]) | |
def run(self, max_step: int = int(1e12)) -> None: | |
""" | |
Overview: | |
Execute the iterations, when reach the max_step or task.finish is true, | |
The loop will be break. | |
Arguments: | |
- max_step (:obj:`int`): Max step of iterations. | |
""" | |
assert self._running, "Please make sure the task is running before calling the this method, see the task.start" | |
if len(self._middleware) == 0: | |
return | |
for i in range(max_step): | |
for fn in self._middleware: | |
self.forward(fn) | |
# Sync should be called before backward, otherwise it is possible | |
# that some generators have not been pushed to backward_stack. | |
self.sync() | |
self.backward() | |
self.sync() | |
if i == max_step - 1: | |
self.finish = True | |
if self.finish: | |
break | |
self.renew() | |
def wrap(self, fn: Callable, lock: Union[bool, Lock] = False) -> Callable: | |
""" | |
Overview: | |
Wrap the middleware, make it can be called directly in other middleware. | |
Arguments: | |
- fn (:obj:`Callable`): The middleware. | |
- lock (:obj:`Union[bool, Lock]`): There can only be one middleware execution under lock at any one time. | |
Returns: | |
- fn_back (:obj:`Callable`): It will return a backward function, which will call the rest part of | |
the middleware after yield. If this backward function was not called, the rest part of the middleware | |
will be called in the global backward step. | |
""" | |
if lock is True: | |
lock = self._thread_lock | |
def forward(ctx: Context): | |
if lock: | |
with lock: | |
g = self.forward(fn, ctx, async_mode=False) | |
else: | |
g = self.forward(fn, ctx, async_mode=False) | |
def backward(): | |
backward_stack = OrderedDict() | |
key = id(g) | |
backward_stack[key] = self._backward_stack.pop(key) | |
if lock: | |
with lock: | |
self.backward(backward_stack, async_mode=False) | |
else: | |
self.backward(backward_stack, async_mode=False) | |
return backward | |
if hasattr(fn, "__name__"): | |
forward = wraps(fn)(forward) | |
else: | |
forward = wraps(fn.__class__)(forward) | |
return forward | |
def forward(self, fn: Callable, ctx: Optional[Context] = None) -> Optional[Generator]: | |
""" | |
Overview: | |
This function will execute the middleware until the first yield statment, | |
or the end of the middleware. | |
Arguments: | |
- fn (:obj:`Callable`): Function with contain the ctx argument in middleware. | |
- ctx (:obj:`Optional[Context]`): Replace global ctx with a customized ctx. | |
Returns: | |
- g (:obj:`Optional[Generator]`): The generator if the return value of fn is a generator. | |
""" | |
assert self._running, "Please make sure the task is running before calling the this method, see the task.start" | |
if not ctx: | |
ctx = self.ctx | |
g = fn(ctx) | |
if isinstance(g, GeneratorType): | |
try: | |
next(g) | |
self._backward_stack[id(g)] = g | |
return g | |
except StopIteration: | |
pass | |
def backward(self, backward_stack: Optional[Dict[str, Generator]] = None) -> None: | |
""" | |
Overview: | |
Execute the rest part of middleware, by the reversed order of registry. | |
Arguments: | |
- backward_stack (:obj:`Optional[Dict[str, Generator]]`): Replace global backward_stack with a customized \ | |
stack. | |
""" | |
assert self._running, "Please make sure the task is running before calling the this method, see the task.start" | |
if not backward_stack: | |
backward_stack = self._backward_stack | |
while backward_stack: | |
# FILO | |
_, g = backward_stack.popitem() | |
try: | |
next(g) | |
except StopIteration: | |
continue | |
def running(self): | |
return self._running | |
def serial(self, *fns: List[Callable]) -> Callable: | |
""" | |
Overview: | |
Wrap functions and keep them run in serial, Usually in order to avoid the confusion | |
of dependencies in async mode. | |
Arguments: | |
- fn (:obj:`Callable`): Chain a serial of middleware, wrap them into one middleware function. | |
""" | |
def _serial(ctx): | |
backward_keys = [] | |
for fn in fns: | |
g = self.forward(fn, ctx, async_mode=False) | |
if isinstance(g, GeneratorType): | |
backward_keys.append(id(g)) | |
yield | |
backward_stack = OrderedDict() | |
for k in backward_keys: | |
backward_stack[k] = self._backward_stack.pop(k) | |
self.backward(backward_stack=backward_stack, async_mode=False) | |
name = ",".join([fn.__name__ for fn in fns]) | |
_serial.__name__ = "serial<{}>".format(name) | |
return _serial | |
def parallel(self, *fns: List[Callable]) -> Callable: | |
""" | |
Overview: | |
Wrap functions and keep them run in parallel, should not use this funciton in async mode. | |
Arguments: | |
- fn (:obj:`Callable`): Parallelized middleware, wrap them into one middleware function. | |
""" | |
self._activate_async() | |
def _parallel(ctx): | |
backward_keys = [] | |
for fn in fns: | |
g = self.forward(fn, ctx, async_mode=True) | |
if isinstance(g, GeneratorType): | |
backward_keys.append(id(g)) | |
self.sync() | |
yield | |
backward_stack = OrderedDict() | |
for k in backward_keys: | |
backward_stack[k] = self._backward_stack.pop(k) | |
self.backward(backward_stack, async_mode=True) | |
self.sync() | |
name = ",".join([fn.__name__ for fn in fns]) | |
_parallel.__name__ = "parallel<{}>".format(name) | |
return _parallel | |
def renew(self) -> 'Task': | |
""" | |
Overview: | |
Renew the context instance, this function should be called after backward in the end of iteration. | |
""" | |
assert self._running, "Please make sure the task is running before calling the this method, see the task.start" | |
self.ctx = self.ctx.renew() | |
return self | |
def __enter__(self) -> "Task": | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.stop() | |
def stop(self) -> None: | |
""" | |
Overview: | |
Stop and cleanup every thing in the runtime of task. | |
""" | |
if self.router.is_active: | |
self.emit("finish", True) | |
if self._thread_pool: | |
self._thread_pool.shutdown() | |
self._event_loop.stop() | |
self.router.off(self._wrap_event_name("*")) | |
if self._async_loop: | |
self._async_loop.stop() | |
self._async_loop.close() | |
# The middleware and listeners may contain some methods that reference to task, | |
# If we do not clear them after the task exits, we may find that gc will not clean up the task object. | |
self._middleware.clear() | |
self._wrappers.clear() | |
self._backward_stack.clear() | |
self._async_stack.clear() | |
self._running = False | |
def sync(self) -> 'Task': | |
if self._async_loop: | |
self._async_loop.run_until_complete(self.sync_tasks()) | |
return self | |
async def sync_tasks(self) -> Awaitable[None]: | |
if self._async_stack: | |
await asyncio.wait(self._async_stack, return_when=FIRST_EXCEPTION) | |
while self._async_stack: | |
t = self._async_stack.pop(0) | |
try: | |
e = t.exception() | |
if e: | |
self._exception = e | |
raise e | |
except InvalidStateError: | |
# Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception | |
pass | |
def async_executor(self, fn: Callable, *args, **kwargs) -> None: | |
""" | |
Overview: | |
Execute task in background, then apppend the future instance in _async_stack. | |
Arguments: | |
- fn (:obj:`Callable`): Synchronization fuction. | |
""" | |
if not self._async_loop: | |
raise Exception("Event loop was not initialized, please call this function in async or parallel mode") | |
t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs) | |
self._async_stack.append(t) | |
def emit(self, event: str, *args, only_remote: bool = False, only_local: bool = False, **kwargs) -> None: | |
""" | |
Overview: | |
Emit an event, call listeners. | |
Arguments: | |
- event (:obj:`str`): Event name. | |
- only_remote (:obj:`bool`): Only broadcast the event to the connected nodes, default is False. | |
- only_local (:obj:`bool`): Only emit local event, default is False. | |
- args (:obj:`any`): Rest arguments for listeners. | |
""" | |
# Check if need to broadcast event to connected nodes, default is True | |
assert self._running, "Please make sure the task is running before calling the this method, see the task.start" | |
if only_local: | |
self._event_loop.emit(event, *args, **kwargs) | |
elif only_remote: | |
if self.router.is_active: | |
self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) | |
else: | |
if self.router.is_active: | |
self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) | |
self._event_loop.emit(event, *args, **kwargs) | |
def on(self, event: str, fn: Callable) -> None: | |
""" | |
Overview: | |
Subscribe to an event, execute this function every time the event is emitted. | |
Arguments: | |
- event (:obj:`str`): Event name. | |
- fn (:obj:`Callable`): The function. | |
""" | |
self._event_loop.on(event, fn) | |
if self.router.is_active: | |
self.router.on(self._wrap_event_name(event), self._event_loop.emit) | |
def once(self, event: str, fn: Callable) -> None: | |
""" | |
Overview: | |
Subscribe to an event, execute this function only once when the event is emitted. | |
Arguments: | |
- event (:obj:`str`): Event name. | |
- fn (:obj:`Callable`): The function. | |
""" | |
self._event_loop.once(event, fn) | |
if self.router.is_active: | |
self.router.on(self._wrap_event_name(event), self._event_loop.emit) | |
def off(self, event: str, fn: Optional[Callable] = None) -> None: | |
""" | |
Overview: | |
Unsubscribe an event | |
Arguments: | |
- event (:obj:`str`): Event name. | |
- fn (:obj:`Callable`): The function. | |
""" | |
self._event_loop.off(event, fn) | |
if self.router.is_active: | |
self.router.off(self._wrap_event_name(event)) | |
def wait_for(self, event: str, timeout: float = math.inf, ignore_timeout_exception: bool = True) -> Any: | |
""" | |
Overview: | |
Wait for an event and block the thread. | |
Arguments: | |
- event (:obj:`str`): Event name. | |
- timeout (:obj:`float`): Timeout in seconds. | |
- ignore_timeout_exception (:obj:`bool`): If this is False, an exception will occur when meeting timeout. | |
""" | |
assert self._running, "Please make sure the task is running before calling the this method, see the task.start" | |
received = False | |
result = None | |
def _receive_event(*args, **kwargs): | |
nonlocal result, received | |
result = (args, kwargs) | |
received = True | |
self.once(event, _receive_event) | |
start = time.time() | |
while time.time() - start < timeout: | |
if received or self._exception: | |
return result | |
time.sleep(0.01) | |
if ignore_timeout_exception: | |
return result | |
else: | |
raise TimeoutError("Timeout when waiting for event: {}".format(event)) | |
def finish(self): | |
return self._finish | |
def finish(self, value: bool): | |
self._finish = value | |
def _wrap_event_name(self, event: str) -> str: | |
""" | |
Overview: | |
Wrap the event name sent to the router. | |
Arguments: | |
- event (:obj:`str`): Event name | |
""" | |
return "task.{}".format(event) | |
def _activate_async(self): | |
if not self._thread_pool: | |
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.n_async_workers) | |
if not self._async_loop: | |
self._async_loop = asyncio.new_event_loop() | |
def get_attch_to_len(self) -> int: | |
""" | |
Overview: | |
Get the length of the 'attach_to' list in Parallel._mq. | |
Returns: | |
int: the length of the Parallel._mq. | |
""" | |
if self.router.is_active: | |
return self.router.get_attch_to_len() | |
else: | |
raise RuntimeError("The router is inactive, failed to be obtained the length of 'attch_to' list.") | |
task = Task() | |