|
import asyncio |
|
import json |
|
import re |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Tuple, Union |
|
|
|
import nest_asyncio |
|
|
|
from examples.di.requirements_prompt import DABENCH |
|
from metagpt.const import DABENCH_PATH |
|
from metagpt.logs import logger |
|
from metagpt.utils.exceptions import handle_exception |
|
|
|
|
|
def evaluate_accuracy_by_question(results: dict) -> float: |
|
""" |
|
Calculate the accuracy of results based on complete correctness of each question. |
|
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py |
|
This function checks whether each result is entirely correct, meaning all sub-questions |
|
within that result are answered correctly. It computes the proportion of correct results |
|
by dividing the number of fully correct results by the total number of results. |
|
|
|
Args: |
|
results (dict): A collection of results where each result may contain a 'correctness' field. |
|
|
|
Returns: |
|
float: The proportion of correct results, rounded to four decimal places. |
|
Returns 0 if there are no results. |
|
""" |
|
correct = sum("correctness" in result and all(result["correctness"].values()) for result in results) |
|
total = len(results) |
|
return round(correct / total, 4) if total > 0 else 0 |
|
|
|
|
|
def evaluate_accuracy_by_sub_question(results: dict) -> float: |
|
""" |
|
Evaluate the correctness of all sub-questions across the results. |
|
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py |
|
This function calculates the total number of correct sub-questions and the overall |
|
number of sub-questions present in all results. It returns the ratio of correct |
|
sub-questions to the total number of sub-questions. |
|
|
|
Args: |
|
results (dict): A collection of results where each result may contain a 'correctness' field. |
|
|
|
Returns: |
|
float: The ratio of correct sub-questions, rounded to four decimal places. |
|
Returns 0 if there are no sub-questions. |
|
""" |
|
correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result) |
|
total = sum(len(result["correctness"]) for result in results if "correctness" in result) |
|
return round(correct / total, 4) if total > 0 else 0 |
|
|
|
|
|
def evaluate_accuracy_proportional_by_sub_question_adjusted(results: dict) -> float: |
|
""" |
|
Adjust the score based on the number of sub-questions in each result. |
|
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py |
|
This function calculates a score for each result by considering the number of sub-questions |
|
it contains. Each sub-question is assigned a score of 1 divided by the number of sub-questions. |
|
The total score for each result is computed as the sum of all correct sub-questions multiplied |
|
by the score per sub-question. Finally, it returns the average score across all results. |
|
|
|
Args: |
|
results (dict): A collection of results where each result may contain a 'correctness' field. |
|
|
|
Returns: |
|
float: The average score across all results, rounded to four decimal places. |
|
Returns 0 if there are no results. |
|
""" |
|
total_score = 0 |
|
for result in results: |
|
if "correctness" in result: |
|
sub_question_count = len(result["correctness"]) |
|
score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0 |
|
question_score = sum(result["correctness"].values()) * score_per_sub_question |
|
total_score += question_score |
|
return round(total_score / len(results), 4) if results else 0 |
|
|
|
|
|
async def reformat(question: str, format: str, response: str) -> str: |
|
""" |
|
Asynchronously reformats a given response based on specified formatting requirements. |
|
This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py |
|
This function constructs a prompt for the LLM (Large Language Model) to reformat |
|
the provided response according to the specified format. It includes a system prompt |
|
to guide the LLM's behavior and a template that outlines the expected output structure. |
|
|
|
Args: |
|
question (str): The original question posed by the user. |
|
format (str): The specific formatting requirements that the response must adhere to. |
|
response (str): The initial response from the LLM that needs to be reformatted. |
|
|
|
Returns: |
|
str: The reformatted response generated by the LLM based on the provided question |
|
and formatting requirements. |
|
""" |
|
system_prompt = "You are a helpful assistant." |
|
demons = """\Format{{ |
|
@shapiro_wilk_statistic[test_statistic] |
|
@shapiro_wilk_p_value[p_value] |
|
where "test_statistic" is a number between 0 and 1 representing the Shapiro-Wilk test statistic. Rounding off the answer to two decimal places. |
|
where "p_value" is a number between 0 and 1 representing the p-value from the Shapiro-Wilk test. Rounding off the answer to four decimal places. |
|
}} |
|
\Answer{{ |
|
@shapiro_wilk_statistic[0.56] |
|
@shapiro_wilk_p_value[0.0002] |
|
}} |
|
|
|
\Format{{ |
|
@total_votes_outliers_num[outlier_num] |
|
where "outlier_num" is an integer representing the number of values considered outliers in the 'total_votes' column. |
|
}} |
|
\Answer{{ |
|
@total_votes_outliers[10] |
|
}} |
|
""" |
|
reformat_template = """You should strictly follow the output requirements in the Format part. Here're some examples: {demons}. |
|
Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes. |
|
The format requirements of this question is: |
|
{format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:""" |
|
messages = [ |
|
{"role": "user", "content": question}, |
|
{"role": "assistant", "content": response}, |
|
{"role": "user", "content": reformat_template.format(demons=demons, format=format)}, |
|
] |
|
rsp = await ask(messages, system_prompt) |
|
return rsp |
|
|
|
|
|
def load_jsonl(file_path: Union[Path, str]) -> List[Dict[str, Any]]: |
|
""" |
|
Load data from a JSONL file into a list of dictionaries. |
|
|
|
Args: |
|
file_path (Union[Path, str]): The path to the JSONL file to be loaded. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries containing the data from the JSONL file. |
|
""" |
|
|
|
if isinstance(file_path, str): |
|
file_path = Path(file_path) |
|
|
|
data = [] |
|
with open(file_path, "r", encoding="utf-8") as file: |
|
for line in file: |
|
data.append(json.loads(line)) |
|
return data |
|
|
|
|
|
def compare_predictions(pred_dict: dict, true_label: list) -> bool: |
|
""" |
|
Compares each prediction against the corresponding true label. |
|
|
|
This function checks whether the predicted values match the true values for each |
|
metric. It sorts the true labels to ensure the comparison is made in the correct |
|
order. The function returns True if all predictions are accurate within a small |
|
tolerance for numerical values, or if string values match case-insensitively. |
|
|
|
Args: |
|
pred_dict (dict): A dictionary of predicted metrics and their values. |
|
true_label (list): A list of tuples containing true metrics and their values. |
|
|
|
Returns: |
|
bool: True if all predictions match the true labels, False otherwise. |
|
""" |
|
sorted_true_label = sorted(true_label, key=lambda x: x[0]) |
|
|
|
for metric, true_value in sorted_true_label: |
|
try: |
|
true_value = float(true_value) |
|
except ValueError: |
|
true_value = true_value.replace(",", "") |
|
|
|
|
|
if isinstance(true_value, (int, float)) and ( |
|
metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6 |
|
): |
|
return False |
|
|
|
|
|
if isinstance(true_value, str) and ( |
|
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower() |
|
): |
|
return False |
|
|
|
return True |
|
|
|
|
|
async def ask(question: str, system_prompt: str) -> str: |
|
""" |
|
Asynchronously sends a question to the LLM (Large Language Model) and retrieves the response. |
|
|
|
This function initializes an instance of the LLM and uses it to ask a question |
|
along with a system prompt. The response from the LLM is awaited and returned. |
|
|
|
Args: |
|
question (str): The question to be asked to the LLM. |
|
system_prompt (str): A prompt that provides context or instructions to the LLM. |
|
|
|
Returns: |
|
str: The response from the LLM based on the provided question and system prompt. |
|
""" |
|
from metagpt.llm import LLM |
|
|
|
llm = LLM() |
|
rsp = await llm.aask(question, system_msgs=[system_prompt]) |
|
return rsp |
|
|
|
|
|
def parse_prediction(prediction: str) -> dict: |
|
""" |
|
Parses a prediction string into a dictionary of metric-value pairs. |
|
|
|
This function takes a formatted string containing metrics and their corresponding |
|
values, separated by the "@" symbol. Each metric may be enclosed in brackets and |
|
may include commas. The function processes the input to extract and clean the |
|
metrics and their values, returning them in a structured dictionary format. |
|
|
|
Args: |
|
prediction (str): A string representation of metrics and their values. |
|
|
|
Returns: |
|
dict: A dictionary where each key is a metric name and each value is the |
|
corresponding value, either as a float or a string. |
|
""" |
|
pred_dict = {} |
|
for pred in prediction.split("@"): |
|
if pred == "": |
|
continue |
|
temp = re.split(r"[\[\]]", pred.strip()) |
|
temp = [s.replace(",", "") for s in temp] |
|
parts = [s for s in temp if s] |
|
metric = parts[0].strip().replace(",", "") |
|
value = parts[-1].replace(",", "").replace(":", "") |
|
|
|
try: |
|
value = float(value) |
|
except ValueError: |
|
pass |
|
|
|
pred_dict[metric] = value |
|
return pred_dict |
|
|
|
|
|
class DABench: |
|
def __init__( |
|
self, |
|
questions_file: Path = Path(DABENCH_PATH) / "da-dev-questions.jsonl", |
|
answers_file: Path = Path(DABENCH_PATH) / "da-dev-labels.jsonl", |
|
template: str = "", |
|
): |
|
""" |
|
Initializes the DABench instance with questions and answers. |
|
|
|
This constructor loads questions and answers from specified JSONL files. |
|
It also sets a template for formatting prompts. If no template is provided, |
|
a default template is used. |
|
|
|
Args: |
|
questions_file (Path): The path to the JSONL file containing questions. |
|
answers_file (Path): The path to the JSONL file containing answers. |
|
template (str): A string template for formatting prompts. |
|
""" |
|
|
|
self.questions = { |
|
int(line["id"]): line for line in load_jsonl(questions_file) |
|
} |
|
self.answers = { |
|
int(line["id"]): line for line in load_jsonl(answers_file) |
|
} |
|
self.template = template if template else DABENCH |
|
|
|
def get_question(self, question_id: str) -> dict: |
|
""" |
|
Retrieve the question associated with the given ID. |
|
|
|
This method looks up a question by its unique identifier. If the question |
|
is found, it returns the question data; otherwise, it returns a message |
|
indicating that the question was not found. |
|
|
|
Args: |
|
question_id (str): The unique identifier for the question. |
|
|
|
Returns: |
|
dict: The question data if found, otherwise a "Question not found." message. |
|
""" |
|
return self.questions.get(question_id, "Question not found.") |
|
|
|
def generate_formatted_prompt(self, question_id: str) -> str: |
|
""" |
|
Generate a formatted prompt for the specified question ID. |
|
|
|
This method retrieves the question data and formats it using the specified |
|
template. The formatted prompt includes the question, constraints, format, |
|
file name, and level, allowing for a structured output. |
|
|
|
Args: |
|
question_id (str): The unique identifier for the question. |
|
|
|
Returns: |
|
str: A formatted prompt string based on the question data. |
|
""" |
|
temp = self.get_question(question_id) |
|
return self.template.format( |
|
question=temp["question"], |
|
constraints=temp["constraints"], |
|
format=temp["format"], |
|
file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"], |
|
level=temp["level"], |
|
) |
|
|
|
def get_answer(self, answer_id: str) -> list: |
|
""" |
|
Retrieve the answer list associated with the given ID. |
|
|
|
This method looks up an answer by its unique identifier. If the answer |
|
is found, it returns the answer data; otherwise, it returns a message |
|
indicating that the answer was not found. |
|
|
|
Args: |
|
answer_id (str): The unique identifier for the answer. |
|
|
|
Returns: |
|
list: The answer data if found, otherwise an "Answer not found." message. |
|
""" |
|
return self.answers.get(answer_id, "Answer not found.") |
|
|
|
@handle_exception(exception_msg="Error parsing cleaned prediction", default_return=(None, False)) |
|
def parse_cleaned_prediction(self, cleaned_prediction: str, true_label: Any) -> Tuple[str, bool]: |
|
""" |
|
Parse the cleaned prediction and compare it with the true label. |
|
|
|
Args: |
|
cleaned_prediction (str): The cleaned prediction string. |
|
true_label (Any): The true label to compare against. |
|
|
|
Returns: |
|
Tuple[str, bool]: A tuple containing the cleaned prediction and a boolean indicating |
|
whether it matches the true label. |
|
""" |
|
if cleaned_prediction: |
|
pred_dict = parse_prediction(cleaned_prediction) |
|
if pred_dict is not None and compare_predictions(pred_dict, true_label): |
|
return cleaned_prediction, True |
|
return cleaned_prediction, False |
|
|
|
@handle_exception(exception_msg="Error during async reformat", default_return=(None, False)) |
|
def async_reformat_prediction(self, id: str, result: str) -> str: |
|
""" |
|
Reformat the prediction asynchronously and extract the answer. |
|
|
|
Args: |
|
id (str): The identifier for the question. |
|
result (str): The original prediction result. |
|
|
|
Returns: |
|
str: The reformatted prediction or the original prediction if extraction fails. |
|
""" |
|
question = self.get_question(id)["question"] |
|
question_format = self.get_question(id)["format"] |
|
prediction = asyncio.run(reformat(question, question_format, result)) |
|
|
|
|
|
answer_part = prediction.split("Answer{{") if "Answer{{" in prediction else [] |
|
if len(answer_part) > 1: |
|
return answer_part[1].split("}}")[0].strip() |
|
|
|
return prediction |
|
|
|
def eval(self, id: str, result: str) -> Tuple[str, bool]: |
|
""" |
|
Evaluate the prediction against the true label. |
|
|
|
Args: |
|
id (str): The identifier for the question. |
|
result (str): The original prediction result. |
|
|
|
Returns: |
|
Tuple[str, bool]: A tuple containing the final prediction and a boolean indicating |
|
whether it matches the true label. |
|
""" |
|
true_label = self.get_answer(id)["common_answers"] |
|
nest_asyncio.apply() |
|
result = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"].strip() |
|
cleaned_prediction = result.replace("{", "").replace("}", "").replace("'", "") |
|
|
|
|
|
parsed_result = self.parse_cleaned_prediction(cleaned_prediction, true_label) |
|
if parsed_result[1]: |
|
return parsed_result |
|
|
|
|
|
prediction = self.async_reformat_prediction(id, result) |
|
|
|
pred_dict = parse_prediction(prediction) |
|
if pred_dict is not None and compare_predictions(pred_dict, true_label): |
|
return prediction, True |
|
|
|
return prediction, False |
|
|
|
@handle_exception(exception_msg="Error evaluating single prediction", default_return={}) |
|
def single_eval(self, id: str, prediction: str) -> dict: |
|
""" |
|
Evaluate the prediction against the true label for a single question. |
|
just using in eval_all |
|
|
|
Args: |
|
id (str): The identifier for the question. |
|
prediction (str): The prediction string to evaluate. |
|
|
|
Returns: |
|
dict: A dictionary indicating the correctness of each metric. |
|
""" |
|
true_label = self.get_answer(id)["common_answers"] |
|
prediction = prediction.replace("{", "").replace("}", "").replace("'", "") |
|
pred_dict = parse_prediction(prediction) |
|
|
|
|
|
correctness = {metric: False for metric, _ in true_label} |
|
|
|
|
|
for metric, true_value in true_label: |
|
try: |
|
true_value = float(true_value) |
|
except ValueError: |
|
true_value = true_value.replace(",", "") |
|
|
|
if metric in pred_dict: |
|
|
|
if ( |
|
isinstance(true_value, (int, float)) |
|
and isinstance(pred_dict[metric], (int, float)) |
|
and abs(pred_dict[metric] - true_value) < 1e-6 |
|
): |
|
correctness[metric] = True |
|
|
|
if isinstance(true_value, str) and ( |
|
metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower() |
|
): |
|
correctness[metric] = True |
|
|
|
return correctness |
|
|
|
def eval_all(self, id_list: list, predictions: list) -> dict: |
|
""" |
|
Evaluate all predictions and calculate accuracy rates. |
|
|
|
Args: |
|
id_list (list): A list of question identifiers. |
|
predictions (list): A list of prediction strings corresponding to the questions. |
|
|
|
Returns: |
|
dict: A dictionary containing accuracy rates by question and sub-question. |
|
""" |
|
results = [] |
|
|
|
|
|
for id, prediction in zip(id_list, predictions): |
|
correct = self.single_eval(id, prediction) |
|
results.append({"id": id, "correctness": correct}) |
|
|
|
|
|
accuracy_by_question = evaluate_accuracy_by_question(results) |
|
accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results) |
|
proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results) |
|
|
|
return { |
|
"accuracy_by_question": accuracy_by_question, |
|
"accuracy_by_sub_question": accuracy_by_sub_question, |
|
"proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question, |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
bench = DABench() |
|
id = 0 |
|
prediction = "@mean_fare[34.65]" |
|
logger.info(bench.eval(id, prediction)) |
|
ids = [0, 5, 6] |
|
predictions = [ |
|
"@mean_fare[34.89]", |
|
"@correlation_coefficient[0.21]", |
|
"@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]", |
|
] |
|
logger.info(bench.eval_all(ids, predictions)) |
|
|