|
import json |
|
import argparse |
|
import random |
|
from typing import List, Dict, Any, Tuple |
|
import re |
|
from collections import defaultdict |
|
|
|
|
|
CATEGORY_ORDER = [ |
|
"detection", |
|
"classification", |
|
"localization", |
|
"comparison", |
|
"relationship", |
|
"diagnosis", |
|
"characterization", |
|
] |
|
|
|
|
|
def extract_letter_answer(answer: str) -> str: |
|
"""Extract just the letter answer from various answer formats. |
|
|
|
Args: |
|
answer: The answer string to extract a letter from |
|
|
|
Returns: |
|
str: The extracted letter in uppercase, or empty string if no letter found |
|
""" |
|
if not answer: |
|
return "" |
|
|
|
|
|
answer = str(answer).strip() |
|
|
|
|
|
if len(answer) == 1 and answer.upper() in "ABCDEF": |
|
return answer.upper() |
|
|
|
|
|
match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE) |
|
if match: |
|
return match.group(1).upper() |
|
|
|
|
|
|
|
matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE) |
|
if matches: |
|
return matches[0].upper() |
|
|
|
|
|
letters = re.findall(r"[A-F]", answer, re.IGNORECASE) |
|
if letters: |
|
return letters[0].upper() |
|
|
|
|
|
return answer.strip().upper() |
|
|
|
|
|
def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]: |
|
"""Parse JSON Lines file and extract valid predictions. |
|
|
|
Args: |
|
file_path: Path to the JSON Lines file to parse |
|
|
|
Returns: |
|
Tuple containing: |
|
- str: Model name or file path if model name not found |
|
- List[Dict[str, Any]]: List of valid prediction entries |
|
""" |
|
valid_predictions = [] |
|
model_name = None |
|
|
|
|
|
try: |
|
with open(file_path, "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
if data.get("model") == "llava-med-v1.5-mistral-7b": |
|
model_name = data["model"] |
|
for result in data.get("results", []): |
|
if all(k in result for k in ["case_id", "question_id", "correct_answer"]): |
|
|
|
model_answer = ( |
|
result.get("model_answer") |
|
or result.get("validated_answer") |
|
or result.get("raw_output", "") |
|
) |
|
|
|
|
|
prediction = { |
|
"case_id": result["case_id"], |
|
"question_id": result["question_id"], |
|
"model_answer": model_answer, |
|
"correct_answer": result["correct_answer"], |
|
"input": { |
|
"question_data": { |
|
"metadata": { |
|
"categories": [ |
|
"detection", |
|
"classification", |
|
"localization", |
|
"comparison", |
|
"relationship", |
|
"diagnosis", |
|
"characterization", |
|
] |
|
} |
|
} |
|
}, |
|
} |
|
valid_predictions.append(prediction) |
|
return model_name, valid_predictions |
|
except (json.JSONDecodeError, KeyError): |
|
pass |
|
|
|
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
for line in f: |
|
if line.startswith("HTTP Request:"): |
|
continue |
|
try: |
|
data = json.loads(line.strip()) |
|
if "model" in data: |
|
model_name = data["model"] |
|
if all( |
|
k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"] |
|
): |
|
valid_predictions.append(data) |
|
except json.JSONDecodeError: |
|
continue |
|
|
|
return model_name if model_name else file_path, valid_predictions |
|
|
|
|
|
def filter_common_questions( |
|
predictions_list: List[List[Dict[str, Any]]] |
|
) -> List[List[Dict[str, Any]]]: |
|
"""Ensure only questions that exist across all models are evaluated. |
|
|
|
Args: |
|
predictions_list: List of prediction lists from different models |
|
|
|
Returns: |
|
List[List[Dict[str, Any]]]: Filtered predictions containing only common questions |
|
""" |
|
question_sets = [ |
|
set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list |
|
] |
|
common_questions = set.intersection(*question_sets) |
|
|
|
return [ |
|
[p for p in preds if (p["case_id"], p["question_id"]) in common_questions] |
|
for preds in predictions_list |
|
] |
|
|
|
|
|
def calculate_accuracy( |
|
predictions: List[Dict[str, Any]] |
|
) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]: |
|
"""Compute overall and category-level accuracy. |
|
|
|
Args: |
|
predictions: List of prediction entries to analyze |
|
|
|
Returns: |
|
Tuple containing: |
|
- float: Overall accuracy percentage |
|
- int: Number of correct predictions |
|
- int: Total number of predictions |
|
- Dict[str, Dict[str, float]]: Category-level accuracy statistics |
|
""" |
|
if not predictions: |
|
return 0.0, 0, 0, {} |
|
|
|
category_performance = defaultdict(lambda: {"total": 0, "correct": 0}) |
|
correct = 0 |
|
total = 0 |
|
sample_size = min(5, len(predictions)) |
|
sampled_indices = random.sample(range(len(predictions)), sample_size) |
|
|
|
print("\nSample extracted answers:") |
|
for i in sampled_indices: |
|
pred = predictions[i] |
|
model_ans = extract_letter_answer(pred["model_answer"]) |
|
correct_ans = extract_letter_answer(pred["correct_answer"]) |
|
print(f"QID: {pred['question_id']}") |
|
print(f" Raw Model Answer: {pred['model_answer']}") |
|
print(f" Extracted Model Answer: {model_ans}") |
|
print(f" Raw Correct Answer: {pred['correct_answer']}") |
|
print(f" Extracted Correct Answer: {correct_ans}") |
|
print("-" * 80) |
|
|
|
for pred in predictions: |
|
try: |
|
model_ans = extract_letter_answer(pred["model_answer"]) |
|
correct_ans = extract_letter_answer(pred["correct_answer"]) |
|
categories = ( |
|
pred.get("input", {}) |
|
.get("question_data", {}) |
|
.get("metadata", {}) |
|
.get("categories", []) |
|
) |
|
|
|
if model_ans and correct_ans: |
|
total += 1 |
|
is_correct = model_ans == correct_ans |
|
if is_correct: |
|
correct += 1 |
|
|
|
for category in categories: |
|
category_performance[category]["total"] += 1 |
|
if is_correct: |
|
category_performance[category]["correct"] += 1 |
|
|
|
except KeyError: |
|
continue |
|
|
|
category_accuracies = { |
|
category: { |
|
"accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0, |
|
"total": stats["total"], |
|
"correct": stats["correct"], |
|
} |
|
for category, stats in category_performance.items() |
|
} |
|
|
|
return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies) |
|
|
|
|
|
def compare_models(file_paths: List[str]) -> None: |
|
"""Compare accuracy between multiple model prediction files. |
|
|
|
Args: |
|
file_paths: List of paths to model prediction files to compare |
|
""" |
|
|
|
parsed_results = [parse_json_lines(file_path) for file_path in file_paths] |
|
model_names, predictions_list = zip(*parsed_results) |
|
|
|
|
|
print(f"\n📊 **Initial Accuracy**:") |
|
results = [] |
|
category_results = [] |
|
|
|
for preds, name in zip(predictions_list, model_names): |
|
acc, correct, total, category_acc = calculate_accuracy(preds) |
|
results.append((acc, correct, total, name)) |
|
category_results.append(category_acc) |
|
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)") |
|
|
|
|
|
filtered_predictions = filter_common_questions(predictions_list) |
|
print( |
|
f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}" |
|
) |
|
|
|
|
|
print(f"\n📊 **Accuracy on Common Questions**:") |
|
filtered_results = [] |
|
filtered_category_results = [] |
|
|
|
for preds, name in zip(filtered_predictions, model_names): |
|
acc, correct, total, category_acc = calculate_accuracy(preds) |
|
filtered_results.append((acc, correct, total, name)) |
|
filtered_category_results.append(category_acc) |
|
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)") |
|
|
|
|
|
print("\nCategory Performance (Common Questions):") |
|
for category in CATEGORY_ORDER: |
|
print(f"\n{category.capitalize()}:") |
|
for model_name, category_acc in zip(model_names, filtered_category_results): |
|
stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0}) |
|
print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description="Compare accuracy across multiple model prediction files" |
|
) |
|
parser.add_argument("files", nargs="+", help="Paths to model prediction files") |
|
parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling") |
|
|
|
args = parser.parse_args() |
|
random.seed(args.seed) |
|
|
|
compare_models(args.files) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|