|
import asyncio |
|
import json |
|
|
|
from DABench import DABench |
|
|
|
from metagpt.logs import logger |
|
from metagpt.roles.di.data_interpreter import DataInterpreter |
|
|
|
|
|
async def get_prediction(agent, requirement): |
|
"""Helper function to obtain a prediction from a new instance of the agent. |
|
|
|
This function runs the agent with the provided requirement and extracts the prediction |
|
from the result. If an error occurs during processing, it logs the error and returns None. |
|
|
|
Args: |
|
agent: The agent instance used to generate predictions. |
|
requirement: The input requirement for which the prediction is to be made. |
|
|
|
Returns: |
|
The predicted result if successful, otherwise None. |
|
""" |
|
try: |
|
|
|
result = await agent.run(requirement) |
|
|
|
|
|
prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0]) |
|
prediction = prediction_json[-1]["result"] |
|
|
|
return prediction |
|
except Exception as e: |
|
|
|
logger.info(f"Error processing requirement: {requirement}. Error: {e}") |
|
return None |
|
|
|
|
|
async def evaluate_all(agent, k): |
|
"""Evaluate all tasks in DABench using the specified baseline agent. |
|
|
|
Tasks are divided into groups of size k and processed in parallel. |
|
|
|
Args: |
|
agent: The baseline agent used for making predictions. |
|
k (int): The number of tasks to process in each group concurrently. |
|
""" |
|
bench = DABench() |
|
id_list, predictions = [], [] |
|
tasks = [] |
|
|
|
|
|
for key, value in bench.answers.items(): |
|
requirement = bench.generate_formatted_prompt(key) |
|
tasks.append(get_prediction(agent, requirement)) |
|
id_list.append(key) |
|
|
|
|
|
for i in range(0, len(tasks), k): |
|
|
|
current_group = tasks[i : i + k] |
|
|
|
group_predictions = await asyncio.gather(*current_group) |
|
|
|
predictions.extend(pred for pred in group_predictions if pred is not None) |
|
|
|
|
|
logger.info(bench.eval_all(id_list, predictions)) |
|
|
|
|
|
def main(k=5): |
|
"""Main function to run the evaluation process.""" |
|
agent = DataInterpreter() |
|
asyncio.run(evaluate_all(agent, k)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|