Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
import functools | |
import torch.multiprocessing as mp | |
from multiprocessing.context import BaseContext | |
import threading | |
import queue | |
import platform | |
import traceback | |
import uuid | |
import time | |
from ditk import logging | |
from dataclasses import dataclass, field | |
from typing import Any, Callable, Dict, List, Optional, Union | |
from enum import Enum | |
def get_mp_ctx() -> BaseContext: | |
context = 'spawn' if platform.system().lower() == 'windows' else 'fork' | |
mp_ctx = mp.get_context(context) | |
return mp_ctx | |
class SendPayload: | |
proc_id: int | |
# Use uuid1 here to include the timestamp | |
req_id: str = field(default_factory=lambda: uuid.uuid1().hex) | |
method: str = None | |
args: List = field(default_factory=list) | |
kwargs: Dict = field(default_factory=dict) | |
class RecvPayload: | |
proc_id: int | |
req_id: str = None | |
method: str = None | |
data: Any = None | |
err: Exception = None | |
extra: Any = None | |
class ReserveMethod(Enum): | |
SHUTDOWN = "_shutdown" | |
GETATTR = "_getattr" | |
class ChildType(Enum): | |
PROCESS = "process" | |
THREAD = "thread" | |
class Child(ABC): | |
""" | |
Abstract class of child process/thread. | |
""" | |
def __init__(self, proc_id: int, init: Union[Callable, object], **kwargs) -> None: | |
self._proc_id = proc_id | |
self._init = init | |
self._recv_queue = None | |
self._send_queue = None | |
def start(self, recv_queue: Union[mp.Queue, queue.Queue]): | |
raise NotImplementedError | |
def restart(self): | |
self.shutdown() | |
self.start(self._recv_queue) | |
def shutdown(self, timeout: Optional[float] = None): | |
raise NotImplementedError | |
def send(self, payload: SendPayload): | |
raise NotImplementedError | |
def _target( | |
self, | |
proc_id: int, | |
init: Union[Callable, object], | |
send_queue: Union[mp.Queue, queue.Queue], | |
recv_queue: Union[mp.Queue, queue.Queue], | |
shm_buffer: Optional[Any] = None, | |
shm_callback: Optional[Callable] = None | |
): | |
send_payload = SendPayload(proc_id=proc_id) | |
if isinstance(init, Callable): | |
child_ins = init() | |
else: | |
child_ins = init | |
while True: | |
try: | |
send_payload: SendPayload = send_queue.get() | |
if send_payload.method == ReserveMethod.SHUTDOWN: | |
break | |
if send_payload.method == ReserveMethod.GETATTR: | |
data = getattr(child_ins, send_payload.args[0]) | |
else: | |
data = getattr(child_ins, send_payload.method)(*send_payload.args, **send_payload.kwargs) | |
recv_payload = RecvPayload( | |
proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, data=data | |
) | |
if shm_callback is not None and shm_buffer is not None: | |
shm_callback(recv_payload, shm_buffer) | |
recv_queue.put(recv_payload) | |
except Exception as e: | |
logging.warning(traceback.format_exc()) | |
logging.warning("Error in child process! id: {}, error: {}".format(self._proc_id, e)) | |
recv_payload = RecvPayload( | |
proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, err=e | |
) | |
recv_queue.put(recv_payload) | |
def __del__(self): | |
self.shutdown() | |
class ChildProcess(Child): | |
def __init__( | |
self, | |
proc_id: int, | |
init: Union[Callable, object], | |
shm_buffer: Optional[Any] = None, | |
shm_callback: Optional[Callable] = None, | |
mp_ctx: Optional[BaseContext] = None, | |
**kwargs | |
) -> None: | |
super().__init__(proc_id, init, **kwargs) | |
self._proc = None | |
self._mp_ctx = mp_ctx | |
self._shm_buffer = shm_buffer | |
self._shm_callback = shm_callback | |
def start(self, recv_queue: mp.Queue): | |
if self._proc is None: | |
self._recv_queue = recv_queue | |
ctx = self._mp_ctx or get_mp_ctx() | |
self._send_queue = ctx.Queue() | |
proc = ctx.Process( | |
target=self._target, | |
args=( | |
self._proc_id, self._init, self._send_queue, self._recv_queue, self._shm_buffer, self._shm_callback | |
), | |
name="supervisor_child_{}_{}".format(self._proc_id, time.time()), | |
daemon=True | |
) | |
proc.start() | |
self._proc = proc | |
def shutdown(self, timeout: Optional[float] = None): | |
if self._proc: | |
self._send_queue.put(SendPayload(proc_id=self._proc_id, method=ReserveMethod.SHUTDOWN)) | |
self._proc.terminate() | |
self._proc.join(timeout=timeout) | |
if hasattr(self._proc, "close"): # Compatible with 3.6 | |
self._proc.close() | |
self._proc = None | |
self._send_queue.close() | |
self._send_queue.join_thread() | |
self._send_queue = None | |
def send(self, payload: SendPayload): | |
if self._send_queue is None: | |
logging.warning("Child worker has been terminated or not started.") | |
return | |
self._send_queue.put(payload) | |
class ChildThread(Child): | |
def __init__(self, proc_id: int, init: Union[Callable, object], *args, **kwargs) -> None: | |
super().__init__(proc_id, init, *args, **kwargs) | |
self._thread = None | |
def start(self, recv_queue: queue.Queue): | |
if self._thread is None: | |
self._recv_queue = recv_queue | |
self._send_queue = queue.Queue() | |
thread = threading.Thread( | |
target=self._target, | |
args=(self._proc_id, self._init, self._send_queue, self._recv_queue), | |
name="supervisor_child_{}_{}".format(self._proc_id, time.time()), | |
daemon=True | |
) | |
thread.start() | |
self._thread = thread | |
def shutdown(self, timeout: Optional[float] = None): | |
if self._thread: | |
self._send_queue.put(SendPayload(proc_id=self._proc_id, method=ReserveMethod.SHUTDOWN)) | |
self._thread.join(timeout=timeout) | |
self._thread = None | |
self._send_queue = None | |
def send(self, payload: SendPayload): | |
if self._send_queue is None: | |
logging.warning("Child worker has been terminated or not started.") | |
return | |
self._send_queue.put(payload) | |
class Supervisor: | |
TYPE_MAPPING = {ChildType.PROCESS: ChildProcess, ChildType.THREAD: ChildThread} | |
def __init__(self, type_: ChildType, mp_ctx: Optional[BaseContext] = None) -> None: | |
self._children: List[Child] = [] | |
self._type = type_ | |
self._child_class = self.TYPE_MAPPING[self._type] | |
self._running = False | |
self.__queue = None | |
self._mp_ctx = mp_ctx or get_mp_ctx() | |
def register( | |
self, | |
init: Union[Callable, object], | |
shm_buffer: Optional[Any] = None, | |
shm_callback: Optional[Callable] = None | |
) -> None: | |
proc_id = len(self._children) | |
self._children.append( | |
self._child_class(proc_id, init, shm_buffer=shm_buffer, shm_callback=shm_callback, mp_ctx=self._mp_ctx) | |
) | |
def _recv_queue(self) -> Union[queue.Queue, mp.Queue]: | |
if not self.__queue: | |
if self._type is ChildType.PROCESS: | |
self.__queue = self._mp_ctx.Queue() | |
elif self._type is ChildType.THREAD: | |
self.__queue = queue.Queue() | |
return self.__queue | |
def _recv_queue(self, queue: Union[queue.Queue, mp.Queue]): | |
self.__queue = queue | |
def start_link(self) -> None: | |
if not self._running: | |
for child in self._children: | |
child.start(recv_queue=self._recv_queue) | |
self._running = True | |
def send(self, payload: SendPayload) -> None: | |
""" | |
Overview: | |
Send message to child process. | |
Arguments: | |
- payload (:obj:`SendPayload`): Send payload. | |
""" | |
if not self._running: | |
logging.warning("Please call start_link before sending any payload to child process.") | |
return | |
self._children[payload.proc_id].send(payload) | |
def recv(self, ignore_err: bool = False, timeout: float = None) -> RecvPayload: | |
""" | |
Overview: | |
Wait for message from child process | |
Arguments: | |
- ignore_err (:obj:`bool`): If ignore_err is True, put the err in the property of recv_payload. \ | |
Otherwise, an exception will be raised. | |
- timeout (:obj:`float`): Timeout for queue.get, will raise an Empty exception if timeout. | |
Returns: | |
- recv_payload (:obj:`RecvPayload`): Recv payload. | |
""" | |
recv_payload: RecvPayload = self._recv_queue.get(timeout=timeout) | |
if recv_payload.err and not ignore_err: | |
raise recv_payload.err | |
return recv_payload | |
def recv_all( | |
self, | |
send_payloads: List[SendPayload], | |
ignore_err: bool = False, | |
callback: Callable = None, | |
timeout: Optional[float] = None | |
) -> List[RecvPayload]: | |
""" | |
Overview: | |
Wait for messages with specific req ids until all ids are fulfilled. | |
Arguments: | |
- send_payloads (:obj:`List[SendPayload]`): Request payloads. | |
- ignore_err (:obj:`bool`): If ignore_err is True, \ | |
put the err in the property of recv_payload. Otherwise, an exception will be raised. \ | |
This option will also ignore timeout error. | |
- callback (:obj:`Callable`): Callback for each recv payload. | |
- timeout (:obj:`Optional[float]`): Timeout when wait for responses. | |
Returns: | |
- recv_payload (:obj:`List[RecvPayload]`): Recv payload, may contain timeout error. | |
""" | |
assert send_payloads, "Req payload is empty!" | |
recv_payloads = {} | |
remain_payloads = {payload.req_id: payload for payload in send_payloads} | |
unrelated_payloads = [] | |
try: | |
while remain_payloads: | |
try: | |
recv_payload: RecvPayload = self._recv_queue.get(block=True, timeout=timeout) | |
if recv_payload.req_id in remain_payloads: | |
del remain_payloads[recv_payload.req_id] | |
recv_payloads[recv_payload.req_id] = recv_payload | |
if recv_payload.err and not ignore_err: | |
raise recv_payload.err | |
if callback: | |
callback(recv_payload, remain_payloads) | |
else: | |
unrelated_payloads.append(recv_payload) | |
except queue.Empty: | |
if ignore_err: | |
req_ids = list(remain_payloads.keys()) | |
logging.warning("Timeout ({}s) when receving payloads! Req ids: {}".format(timeout, req_ids)) | |
for req_id in req_ids: | |
send_payload = remain_payloads.pop(req_id) | |
# If timeout error happens in timeout recover, there may not find any send_payload | |
# in the original indexed payloads. | |
recv_payload = RecvPayload( | |
proc_id=send_payload.proc_id, | |
req_id=send_payload.req_id, | |
method=send_payload.method, | |
err=TimeoutError("Timeout on req_id ({})".format(req_id)) | |
) | |
recv_payloads[req_id] = recv_payload | |
if callback: | |
callback(recv_payload, remain_payloads) | |
else: | |
raise TimeoutError("Timeout ({}s) when receving payloads!".format(timeout)) | |
finally: | |
# Put back the unrelated payload. | |
for payload in unrelated_payloads: | |
self._recv_queue.put(payload) | |
# Keep the original order of requests. | |
return [recv_payloads[p.req_id] for p in send_payloads] | |
def shutdown(self, timeout: Optional[float] = None) -> None: | |
if self._running: | |
for child in self._children: | |
child.shutdown(timeout=timeout) | |
self._cleanup_queue() | |
self._running = False | |
def _cleanup_queue(self): | |
while True: | |
while not self._recv_queue.empty(): | |
self._recv_queue.get() | |
time.sleep(0.1) # mp.Queue is not reliable. | |
if self._recv_queue.empty(): | |
break | |
if hasattr(self._recv_queue, "close"): | |
self._recv_queue.close() | |
self._recv_queue.join_thread() | |
self._recv_queue = None | |
def __getattr__(self, key: str) -> List[Any]: | |
assert self._running, "Supervisor is not running, please call start_link first!" | |
send_payloads = [] | |
for i, child in enumerate(self._children): | |
payload = SendPayload(proc_id=i, method=ReserveMethod.GETATTR, args=[key]) | |
send_payloads.append(payload) | |
child.send(payload) | |
return [payload.data for payload in self.recv_all(send_payloads)] | |
def get_child_attr(self, proc_id: str, key: str) -> Any: | |
""" | |
Overview: | |
Get attr of one child process instance. | |
Arguments: | |
- proc_id (:obj:`str`): Proc id. | |
- key (:obj:`str`): Attribute key. | |
Returns: | |
- attr (:obj:`Any`): Attribute of child. | |
""" | |
assert self._running, "Supervisor is not running, please call start_link first!" | |
payload = SendPayload(proc_id=proc_id, method=ReserveMethod.GETATTR, args=[key]) | |
self._children[proc_id].send(payload) | |
payloads = self.recv_all([payload]) | |
return payloads[0].data | |
def __del__(self) -> None: | |
self.shutdown(timeout=5) | |
self._children.clear() | |