|
from typing import Dict, List, Optional, Tuple, Union, Any |
|
import json |
|
import os |
|
import sys |
|
import argparse |
|
from collections import defaultdict |
|
from tqdm import tqdm |
|
|
|
QUESTION_TYPES = { |
|
"Detailed Finding Analysis": ["detection", "localization", "characterization"], |
|
"Pattern Recognition & Relations": ["detection", "classification", "relationship"], |
|
"Spatial Understanding": ["localization", "comparison", "relationship"], |
|
"Clinical Decision Making": ["classification", "comparison", "diagnosis"], |
|
"Diagnostic Classification": ["classification", "characterization", "diagnosis"], |
|
} |
|
|
|
|
|
def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]: |
|
""" |
|
Extract just the letter from various answer formats. |
|
|
|
Args: |
|
answer: The answer text to extract letter from |
|
|
|
Returns: |
|
Optional[str]: The extracted letter in uppercase, or None if no letter found |
|
""" |
|
if not answer: |
|
return None |
|
|
|
|
|
answer = str(answer).strip() |
|
|
|
|
|
if len(answer) == 1 and answer.isalpha(): |
|
return answer.upper() |
|
|
|
|
|
if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ": |
|
return answer[0].upper() |
|
|
|
|
|
if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")): |
|
return answer[0].upper() |
|
|
|
return None |
|
|
|
|
|
def analyze_gpt4_results( |
|
results_file: str, max_questions: Optional[int] = None |
|
) -> Tuple[float, Dict, Dict, List[str], List[str]]: |
|
""" |
|
Analyze results in GPT-4 format. |
|
|
|
Args: |
|
results_file: Path to results file |
|
max_questions: Maximum number of questions to analyze |
|
|
|
Returns: |
|
Tuple containing: |
|
- overall_accuracy (float) |
|
- category_accuracies (Dict) |
|
- question_type_stats (Dict) |
|
- correct_ids (List[str]) |
|
- incorrect_ids (List[str]) |
|
""" |
|
category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) |
|
all_questions = 0 |
|
all_correct = 0 |
|
correct_ids = [] |
|
incorrect_ids = [] |
|
|
|
with open(results_file, "r") as f: |
|
lines = f.readlines() |
|
|
|
processed_questions = 0 |
|
|
|
for line in tqdm(lines, desc="Analyzing Benchmark Results"): |
|
|
|
if max_questions is not None and processed_questions >= max_questions: |
|
break |
|
if line.startswith("HTTP Request:"): |
|
continue |
|
|
|
try: |
|
entry = json.loads(line) |
|
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) |
|
question_id = entry.get("question_id") |
|
|
|
model_letter = extract_answer_letter(entry.get("model_answer")) |
|
correct_letter = extract_answer_letter(entry.get("correct_answer")) |
|
|
|
if model_letter and correct_letter: |
|
all_questions += 1 |
|
processed_questions += 1 |
|
is_correct = model_letter == correct_letter |
|
|
|
if is_correct: |
|
all_correct += 1 |
|
correct_ids.append(question_id) |
|
else: |
|
incorrect_ids.append(question_id) |
|
|
|
for category in metadata.get("categories", []): |
|
category_performance[category]["total"] += 1 |
|
if is_correct: |
|
category_performance[category]["correct"] += 1 |
|
|
|
except json.JSONDecodeError: |
|
continue |
|
|
|
return process_results( |
|
category_performance, all_questions, all_correct, correct_ids, incorrect_ids |
|
) |
|
|
|
|
|
def analyze_llama_results( |
|
results_file: str, max_questions: Optional[int] = None |
|
) -> Tuple[float, Dict, Dict, List[str], List[str]]: |
|
""" |
|
Analyze results in Llama format. |
|
|
|
Args: |
|
results_file: Path to results file |
|
max_questions: Maximum number of questions to analyze |
|
|
|
Returns: |
|
Tuple containing: |
|
- overall_accuracy (float) |
|
- category_accuracies (Dict) |
|
- question_type_stats (Dict) |
|
- correct_ids (List[str]) |
|
- incorrect_ids (List[str]) |
|
""" |
|
category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) |
|
all_questions = 0 |
|
all_correct = 0 |
|
correct_ids = [] |
|
incorrect_ids = [] |
|
|
|
with open(results_file, "r") as f: |
|
lines = f.readlines() |
|
|
|
|
|
if max_questions is not None: |
|
lines = lines[:max_questions] |
|
|
|
for line in tqdm(lines, desc="Analyzing Benchmark Results"): |
|
if line.startswith("HTTP Request:"): |
|
continue |
|
|
|
try: |
|
entry = json.loads(line) |
|
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) |
|
question_id = entry.get("question_id") |
|
|
|
model_letter = extract_answer_letter(entry.get("model_answer")) |
|
correct_letter = extract_answer_letter(entry.get("correct_answer")) |
|
|
|
if model_letter and correct_letter: |
|
all_questions += 1 |
|
is_correct = model_letter == correct_letter |
|
|
|
if is_correct: |
|
all_correct += 1 |
|
correct_ids.append(question_id) |
|
else: |
|
incorrect_ids.append(question_id) |
|
|
|
for category in metadata.get("categories", []): |
|
category_performance[category]["total"] += 1 |
|
if is_correct: |
|
category_performance[category]["correct"] += 1 |
|
|
|
except json.JSONDecodeError: |
|
continue |
|
|
|
return process_results( |
|
category_performance, all_questions, all_correct, correct_ids, incorrect_ids |
|
) |
|
|
|
|
|
def analyze_chexagent_results( |
|
results_file: str, max_questions: Optional[int] = None |
|
) -> Tuple[float, Dict, Dict, List[str], List[str]]: |
|
""" |
|
Analyze results in CheXagent format. |
|
|
|
Args: |
|
results_file: Path to results file |
|
max_questions: Maximum number of questions to analyze |
|
|
|
Returns: |
|
Tuple containing: |
|
- overall_accuracy (float) |
|
- category_accuracies (Dict) |
|
- question_type_stats (Dict) |
|
- correct_ids (List[str]) |
|
- incorrect_ids (List[str]) |
|
""" |
|
category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) |
|
all_questions = 0 |
|
all_correct = 0 |
|
correct_ids = [] |
|
incorrect_ids = [] |
|
|
|
with open(results_file, "r") as f: |
|
lines = f.readlines() |
|
|
|
|
|
if max_questions is not None: |
|
lines = lines[:max_questions] |
|
|
|
for line in tqdm(lines, desc="Analyzing Benchmark Results"): |
|
try: |
|
entry = json.loads(line) |
|
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {}) |
|
question_id = entry.get("question_id") |
|
|
|
model_letter = extract_answer_letter(entry.get("model_answer")) |
|
correct_letter = extract_answer_letter(entry.get("correct_answer")) |
|
|
|
if model_letter and correct_letter: |
|
all_questions += 1 |
|
is_correct = model_letter == correct_letter |
|
|
|
if is_correct: |
|
all_correct += 1 |
|
correct_ids.append(question_id) |
|
else: |
|
incorrect_ids.append(question_id) |
|
|
|
for category in metadata.get("categories", []): |
|
category_performance[category]["total"] += 1 |
|
if is_correct: |
|
category_performance[category]["correct"] += 1 |
|
|
|
except json.JSONDecodeError: |
|
continue |
|
|
|
return process_results( |
|
category_performance, all_questions, all_correct, correct_ids, incorrect_ids |
|
) |
|
|
|
|
|
def process_results( |
|
category_performance: Dict, |
|
all_questions: int, |
|
all_correct: int, |
|
correct_ids: Optional[List[str]] = None, |
|
incorrect_ids: Optional[List[str]] = None, |
|
) -> Tuple[float, Dict, Dict, List[str], List[str]]: |
|
""" |
|
Process raw results into final statistics. |
|
|
|
Args: |
|
category_performance: Dict containing performance by category |
|
all_questions: Total number of questions |
|
all_correct: Total number of correct answers |
|
correct_ids: List of IDs for correctly answered questions |
|
incorrect_ids: List of IDs for incorrectly answered questions |
|
|
|
Returns: |
|
Tuple containing: |
|
- overall_accuracy (float) |
|
- category_accuracies (Dict) |
|
- question_type_stats (Dict) |
|
- correct_ids (List[str]) |
|
- incorrect_ids (List[str]) |
|
""" |
|
category_accuracies = { |
|
category: { |
|
"accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0, |
|
"total": stats["total"], |
|
"correct": stats["correct"], |
|
} |
|
for category, stats in category_performance.items() |
|
} |
|
|
|
question_type_stats = {} |
|
for qtype, categories in QUESTION_TYPES.items(): |
|
total = sum( |
|
category_performance[cat]["total"] for cat in categories if cat in category_performance |
|
) |
|
correct = sum( |
|
category_performance[cat]["correct"] |
|
for cat in categories |
|
if cat in category_performance |
|
) |
|
|
|
question_type_stats[qtype] = { |
|
"accuracy": (correct / total * 100) if total > 0 else 0, |
|
"total": total, |
|
"correct": correct, |
|
} |
|
|
|
overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0 |
|
|
|
return ( |
|
overall_accuracy, |
|
category_accuracies, |
|
question_type_stats, |
|
correct_ids or [], |
|
incorrect_ids or [], |
|
) |
|
|
|
|
|
def print_analysis( |
|
overall_accuracy: float, |
|
category_accuracies: Dict, |
|
question_type_stats: Dict, |
|
correct_ids: List[str], |
|
incorrect_ids: List[str], |
|
model_name: str, |
|
) -> None: |
|
""" |
|
Print analysis results. |
|
|
|
Args: |
|
overall_accuracy: Overall accuracy percentage |
|
category_accuracies: Dict containing accuracy metrics by category |
|
question_type_stats: Dict containing stats by question type |
|
correct_ids: List of IDs for correctly answered questions |
|
incorrect_ids: List of IDs for incorrectly answered questions |
|
model_name: Name of the model being analyzed |
|
""" |
|
total_questions = len(correct_ids) + len(incorrect_ids) |
|
print( |
|
f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)" |
|
) |
|
|
|
print("\nCategory Performance:") |
|
sorted_categories = sorted( |
|
category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True |
|
) |
|
for category, metrics in sorted_categories: |
|
print(f"{category}:") |
|
print(f" Accuracy: {metrics['accuracy']:.2f}%") |
|
print(f" Total Questions: {metrics['total']}") |
|
print(f" Correct Questions: {metrics['correct']}") |
|
|
|
print("\nQuestion Type Performance:") |
|
sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True) |
|
for qtype, metrics in sorted_types: |
|
print(f"\n{qtype}:") |
|
print(f" Accuracy: {metrics['accuracy']:.2f}%") |
|
print(f" Total Questions: {metrics['total']}") |
|
print(f" Correct Questions: {metrics['correct']}") |
|
print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}") |
|
|
|
|
|
question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids} |
|
|
|
output_filename = f"{model_name}_question_ids.json" |
|
with open(output_filename, "w") as f: |
|
json.dump(question_ids, f, indent=2) |
|
|
|
print(f"\nQuestion IDs have been saved to {output_filename}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Analyze benchmark results") |
|
parser.add_argument("results_file", help="Path to results file") |
|
parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory") |
|
parser.add_argument( |
|
"--model", |
|
choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"], |
|
default="gpt4", |
|
help="Specify model format (default: gpt4)", |
|
) |
|
parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze") |
|
args = parser.parse_args() |
|
|
|
if args.model == "gpt4": |
|
results = analyze_gpt4_results(args.results_file, args.max_questions) |
|
elif args.model == "llama": |
|
results = analyze_llama_results(args.results_file, args.max_questions) |
|
elif args.model == "chexagent": |
|
results = analyze_chexagent_results(args.results_file, args.max_questions) |
|
elif args.model == "medrax": |
|
results = analyze_gpt4_results(args.results_file, args.max_questions) |
|
else: |
|
parser.error(f"Unsupported model: {args.model}") |
|
|
|
print_analysis(*results, args.model) |
|
|