|
import asyncio |
|
from typing import AsyncGenerator, Awaitable, Callable |
|
|
|
from pydantic import BaseModel, ConfigDict, Field |
|
|
|
from metagpt.logs import logger |
|
from metagpt.roles import Role |
|
from metagpt.schema import Message |
|
|
|
|
|
class SubscriptionRunner(BaseModel): |
|
"""A simple wrapper to manage subscription tasks for different roles using asyncio. |
|
|
|
Example: |
|
>>> import asyncio |
|
>>> from metagpt.address import SubscriptionRunner |
|
>>> from metagpt.roles import Searcher |
|
>>> from metagpt.schema import Message |
|
|
|
>>> async def trigger(): |
|
... while True: |
|
... yield Message(content="the latest news about OpenAI") |
|
... await asyncio.sleep(3600 * 24) |
|
|
|
>>> async def callback(msg: Message): |
|
... print(msg.content) |
|
|
|
>>> async def main(): |
|
... pb = SubscriptionRunner() |
|
... await pb.subscribe(Searcher(), trigger(), callback) |
|
... await pb.run() |
|
|
|
>>> asyncio.run(main()) |
|
""" |
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True) |
|
|
|
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict) |
|
|
|
async def subscribe( |
|
self, |
|
role: Role, |
|
trigger: AsyncGenerator[Message, None], |
|
callback: Callable[ |
|
[ |
|
Message, |
|
], |
|
Awaitable[None], |
|
], |
|
): |
|
"""Subscribes a role to a trigger and sets up a callback to be called with the role's response. |
|
|
|
Args: |
|
role: The role to subscribe. |
|
trigger: An asynchronous generator that yields Messages to be processed by the role. |
|
callback: An asynchronous function to be called with the response from the role. |
|
""" |
|
loop = asyncio.get_running_loop() |
|
|
|
async def _start_role(): |
|
async for msg in trigger: |
|
resp = await role.run(msg) |
|
await callback(resp) |
|
|
|
self.tasks[role] = loop.create_task(_start_role(), name=f"Subscription-{role}") |
|
|
|
async def unsubscribe(self, role: Role): |
|
"""Unsubscribes a role from its trigger and cancels the associated task. |
|
|
|
Args: |
|
role: The role to unsubscribe. |
|
""" |
|
task = self.tasks.pop(role) |
|
task.cancel() |
|
|
|
async def run(self, raise_exception: bool = True): |
|
"""Runs all subscribed tasks and handles their completion or exception. |
|
|
|
Args: |
|
raise_exception: _description_. Defaults to True. |
|
|
|
Raises: |
|
task.exception: _description_ |
|
""" |
|
while True: |
|
for role, task in self.tasks.items(): |
|
if task.done(): |
|
if task.exception(): |
|
if raise_exception: |
|
raise task.exception() |
|
logger.opt(exception=task.exception()).error(f"Task {task.get_name()} run error") |
|
else: |
|
logger.warning( |
|
f"Task {task.get_name()} has completed. " |
|
"If this is unexpected behavior, please check the trigger function." |
|
) |
|
self.tasks.pop(role) |
|
break |
|
else: |
|
await asyncio.sleep(1) |
|
|