|
import argparse |
|
import json |
|
import requests |
|
import base64 |
|
from PIL import Image |
|
from io import BytesIO |
|
from llava.conversation import conv_templates |
|
import time |
|
import os |
|
import glob |
|
import logging |
|
from datetime import datetime |
|
from tqdm import tqdm |
|
import re |
|
from typing import Dict, List, Optional, Union, Any, Tuple |
|
|
|
|
|
def process_image(image_path: str, target_size: int = 640) -> Image.Image: |
|
"""Process and resize an image to match model requirements. |
|
|
|
Args: |
|
image_path: Path to the input image file |
|
target_size: Target size for both width and height in pixels |
|
|
|
Returns: |
|
PIL.Image: Processed and padded image with dimensions (target_size, target_size) |
|
""" |
|
image = Image.open(image_path) |
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
|
|
ratio = min(target_size / image.width, target_size / image.height) |
|
new_size = (int(image.width * ratio), int(image.height * ratio)) |
|
|
|
|
|
image = image.resize(new_size, Image.LANCZOS) |
|
|
|
|
|
new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0)) |
|
|
|
offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2) |
|
new_image.paste(image, offset) |
|
|
|
return new_image |
|
|
|
|
|
def validate_answer(response_text: str) -> Optional[str]: |
|
"""Extract and validate a single-letter response from the model's output. |
|
Handles multiple response formats and edge cases. |
|
|
|
Args: |
|
response_text: The full text output from the model |
|
|
|
Returns: |
|
A single letter answer (A-F) or None if no valid answer found |
|
""" |
|
if not response_text: |
|
return None |
|
|
|
|
|
cleaned = response_text.strip() |
|
|
|
|
|
extraction_patterns = [ |
|
|
|
r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b", |
|
|
|
r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b", |
|
r"\b(?:answer|option)\s*([A-F])[):]\s*", |
|
|
|
r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b", |
|
r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b", |
|
|
|
r"characteriz[e]?d?\s+by\s+([A-F])\b", |
|
r"indicat[e]?s?\s+([A-F])\b", |
|
|
|
r"Option\s*([A-F])\b", |
|
r"\b([A-F])\)\s*", |
|
|
|
r"^\s*([A-F])\s*$", |
|
] |
|
|
|
|
|
for pattern in extraction_patterns: |
|
matches = re.findall(pattern, cleaned, re.IGNORECASE) |
|
for match in matches: |
|
|
|
if isinstance(match, tuple): |
|
match = match[0] if match[0] in "ABCDEF" else None |
|
if match and match.upper() in "ABCDEF": |
|
return match.upper() |
|
|
|
|
|
context_matches = re.findall(r"\b([A-F])\b", cleaned.upper()) |
|
context_letters = [m for m in context_matches if m in "ABCDEF"] |
|
if context_letters: |
|
return context_letters[0] |
|
|
|
|
|
return None |
|
|
|
|
|
def load_benchmark_questions(case_id: str) -> List[str]: |
|
"""Find all question files for a given case ID. |
|
|
|
Args: |
|
case_id: The ID of the medical case |
|
|
|
Returns: |
|
List of paths to question JSON files |
|
""" |
|
benchmark_dir = "MedMAX/benchmark/questions" |
|
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json") |
|
|
|
|
|
def count_total_questions() -> Tuple[int, int]: |
|
"""Count total number of cases and questions in benchmark. |
|
|
|
Returns: |
|
Tuple containing (total_cases, total_questions) |
|
""" |
|
total_cases = len(glob.glob("MedMAX/benchmark/questions/*")) |
|
total_questions = sum( |
|
len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json")) |
|
for case_id in os.listdir("MedMAX/benchmark/questions") |
|
) |
|
return total_cases, total_questions |
|
|
|
|
|
def create_inference_request( |
|
question_data: Dict[str, Any], |
|
case_details: Dict[str, Any], |
|
case_id: str, |
|
question_id: str, |
|
worker_addr: str, |
|
model_name: str, |
|
raw_output: bool = False, |
|
) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]: |
|
"""Create and send inference request to worker. |
|
|
|
Args: |
|
question_data: Dictionary containing question details and figures |
|
case_details: Dictionary containing case information and figures |
|
case_id: Identifier for the medical case |
|
question_id: Identifier for the specific question |
|
worker_addr: Address of the worker endpoint |
|
model_name: Name of the model to use |
|
raw_output: Whether to return raw model output |
|
|
|
Returns: |
|
If raw_output is False: Tuple of (validated_answer, duration) |
|
If raw_output is True: Dictionary with full inference details |
|
""" |
|
system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'. |
|
""" |
|
|
|
prompt = f"""Given the following medical case: |
|
Please answer this multiple choice question: |
|
{question_data['question']} |
|
Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """ |
|
|
|
try: |
|
|
|
if isinstance(question_data["figures"], str): |
|
try: |
|
required_figures = json.loads(question_data["figures"]) |
|
except json.JSONDecodeError: |
|
required_figures = [question_data["figures"]] |
|
elif isinstance(question_data["figures"], list): |
|
required_figures = question_data["figures"] |
|
else: |
|
required_figures = [str(question_data["figures"])] |
|
except Exception as e: |
|
print(f"Error parsing figures: {e}") |
|
required_figures = [] |
|
|
|
required_figures = [ |
|
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures |
|
] |
|
|
|
|
|
image_paths = [] |
|
for figure in required_figures: |
|
base_figure_num = "".join(filter(str.isdigit, figure)) |
|
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None |
|
|
|
matching_figures = [ |
|
case_figure |
|
for case_figure in case_details.get("figures", []) |
|
if case_figure["number"] == f"Figure {base_figure_num}" |
|
] |
|
|
|
for case_figure in matching_figures: |
|
subfigures = [] |
|
if figure_letter: |
|
subfigures = [ |
|
subfig |
|
for subfig in case_figure.get("subfigures", []) |
|
if subfig.get("number", "").lower().endswith(figure_letter.lower()) |
|
or subfig.get("label", "").lower() == figure_letter.lower() |
|
] |
|
else: |
|
subfigures = case_figure.get("subfigures", []) |
|
|
|
for subfig in subfigures: |
|
if "local_path" in subfig: |
|
image_paths.append("MedMAX/data/" + subfig["local_path"]) |
|
|
|
if not image_paths: |
|
print(f"No local images found for case {case_id}, question {question_id}") |
|
return "skipped", 0.0 |
|
|
|
try: |
|
start_time = time.time() |
|
|
|
|
|
processed_images = [process_image(path) for path in image_paths] |
|
|
|
|
|
conv = conv_templates["mistral_instruct"].copy() |
|
|
|
|
|
if "<image>" not in prompt: |
|
text = prompt + "\n<image>" |
|
else: |
|
text = prompt |
|
|
|
message = (text, processed_images[0], "Default") |
|
conv.append_message(conv.roles[0], message) |
|
conv.append_message(conv.roles[1], None) |
|
|
|
prompt = conv.get_prompt() |
|
headers = {"User-Agent": "LLaVA-Med Client"} |
|
pload = { |
|
"model": model_name, |
|
"prompt": prompt, |
|
"max_new_tokens": 150, |
|
"temperature": 0.5, |
|
"stop": conv.sep2, |
|
"images": conv.get_images(), |
|
"top_p": 1, |
|
"frequency_penalty": 0.0, |
|
"presence_penalty": 0.0, |
|
} |
|
|
|
max_retries = 3 |
|
retry_delay = 5 |
|
response_text = None |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
response = requests.post( |
|
worker_addr + "/worker_generate_stream", |
|
headers=headers, |
|
json=pload, |
|
stream=True, |
|
timeout=30, |
|
) |
|
|
|
complete_output = "" |
|
for chunk in response.iter_lines( |
|
chunk_size=8192, decode_unicode=False, delimiter=b"\0" |
|
): |
|
if chunk: |
|
data = json.loads(chunk.decode("utf-8")) |
|
if data["error_code"] == 0: |
|
output = data["text"].split("[/INST]")[-1] |
|
complete_output = output |
|
else: |
|
print(f"\nError: {data['text']} (error_code: {data['error_code']})") |
|
if attempt < max_retries - 1: |
|
time.sleep(retry_delay) |
|
break |
|
return None, None |
|
|
|
if complete_output: |
|
response_text = complete_output |
|
break |
|
|
|
except (requests.exceptions.RequestException, json.JSONDecodeError) as e: |
|
if attempt < max_retries - 1: |
|
print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...") |
|
time.sleep(retry_delay) |
|
else: |
|
print(f"\nFailed after {max_retries} attempts: {str(e)}") |
|
return None, None |
|
|
|
duration = time.time() - start_time |
|
|
|
if raw_output: |
|
inference_details = { |
|
"raw_output": response_text, |
|
"validated_answer": validate_answer(response_text), |
|
"duration": duration, |
|
"prompt": prompt, |
|
"system_prompt": system_prompt, |
|
"image_paths": image_paths, |
|
"payload": pload, |
|
} |
|
return inference_details |
|
|
|
return validate_answer(response_text), duration |
|
|
|
except Exception as e: |
|
print(f"Error in inference request: {str(e)}") |
|
return None, None |
|
|
|
|
|
def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: |
|
"""Remove image-related and large data from the payload to keep the log lean. |
|
|
|
Args: |
|
payload: Original request payload dictionary |
|
|
|
Returns: |
|
Cleaned payload dictionary with large data removed |
|
""" |
|
if not payload: |
|
return None |
|
|
|
|
|
cleaned_payload = payload.copy() |
|
|
|
|
|
if "images" in cleaned_payload: |
|
del cleaned_payload["images"] |
|
|
|
return cleaned_payload |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--controller-address", type=str, default="http://localhost:21001") |
|
parser.add_argument("--worker-address", type=str) |
|
parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b") |
|
parser.add_argument("--output-dir", type=str, default="benchmark_results") |
|
parser.add_argument( |
|
"--raw-output", action="store_true", help="Return raw model output without validation" |
|
) |
|
parser.add_argument( |
|
"--num-cases", |
|
type=int, |
|
help="Number of cases to process if looking at raw outputs", |
|
default=2, |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json") |
|
final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json") |
|
|
|
|
|
with open(live_log_filename, "w") as live_log_file: |
|
live_log_file.write("[\n") |
|
|
|
|
|
logging.basicConfig( |
|
filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"), |
|
level=logging.INFO, |
|
format="%(message)s", |
|
) |
|
|
|
|
|
if args.worker_address: |
|
worker_addr = args.worker_address |
|
else: |
|
try: |
|
requests.post(args.controller_address + "/refresh_all_workers") |
|
ret = requests.post(args.controller_address + "/list_models") |
|
models = ret.json()["models"] |
|
ret = requests.post( |
|
args.controller_address + "/get_worker_address", json={"model": args.model_name} |
|
) |
|
worker_addr = ret.json()["address"] |
|
print(f"Worker address: {worker_addr}") |
|
except requests.exceptions.RequestException as e: |
|
print(f"Failed to connect to controller: {e}") |
|
return |
|
|
|
if worker_addr == "": |
|
print("No available worker") |
|
return |
|
|
|
|
|
with open("MedMAX/data/updated_cases.json", "r") as file: |
|
data = json.load(file) |
|
|
|
total_cases, total_questions = count_total_questions() |
|
print(f"\nStarting benchmark with {args.model_name}") |
|
print(f"Found {total_cases} cases with {total_questions} total questions") |
|
|
|
results = { |
|
"model": args.model_name, |
|
"timestamp": datetime.now().isoformat(), |
|
"total_cases": total_cases, |
|
"total_questions": total_questions, |
|
"results": [], |
|
} |
|
|
|
cases_processed = 0 |
|
questions_processed = 0 |
|
correct_answers = 0 |
|
skipped_questions = 0 |
|
total_processed_entries = 0 |
|
|
|
|
|
for case_id, case_details in tqdm(data.items(), desc="Processing cases"): |
|
question_files = load_benchmark_questions(case_id) |
|
if not question_files: |
|
continue |
|
|
|
cases_processed += 1 |
|
for question_file in tqdm( |
|
question_files, desc=f"Processing questions for case {case_id}", leave=False |
|
): |
|
with open(question_file, "r") as file: |
|
question_data = json.load(file) |
|
question_id = os.path.basename(question_file).split(".")[0] |
|
|
|
questions_processed += 1 |
|
|
|
|
|
inference_result = create_inference_request( |
|
question_data, |
|
case_details, |
|
case_id, |
|
question_id, |
|
worker_addr, |
|
args.model_name, |
|
raw_output=True, |
|
) |
|
|
|
|
|
if inference_result == ("skipped", 0.0): |
|
skipped_questions += 1 |
|
print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)") |
|
|
|
|
|
skipped_entry = { |
|
"case_id": case_id, |
|
"question_id": question_id, |
|
"status": "skipped", |
|
"reason": "No images found", |
|
} |
|
with open(live_log_filename, "a") as live_log_file: |
|
json.dump(skipped_entry, live_log_file, indent=2) |
|
live_log_file.write(",\n") |
|
|
|
continue |
|
|
|
|
|
answer = inference_result["validated_answer"] |
|
duration = inference_result["duration"] |
|
|
|
|
|
log_entry = { |
|
"case_id": case_id, |
|
"question_id": question_id, |
|
"question": question_data["question"], |
|
"correct_answer": question_data["answer"], |
|
"raw_output": inference_result["raw_output"], |
|
"validated_answer": answer, |
|
"model_answer": answer, |
|
"is_correct": answer == question_data["answer"] if answer else False, |
|
"duration": duration, |
|
"system_prompt": inference_result["system_prompt"], |
|
"input_prompt": inference_result["prompt"], |
|
"image_paths": inference_result["image_paths"], |
|
"payload": clean_payload(inference_result["payload"]), |
|
} |
|
|
|
|
|
with open(live_log_filename, "a") as live_log_file: |
|
json.dump(log_entry, live_log_file, indent=2) |
|
live_log_file.write(",\n") |
|
|
|
|
|
print(f"\nCase {case_id}, Question {question_id}") |
|
print(f"Model Answer: {answer}") |
|
print(f"Correct Answer: {question_data['answer']}") |
|
print(f"Time taken: {duration:.2f}s") |
|
|
|
|
|
if answer == question_data["answer"]: |
|
correct_answers += 1 |
|
|
|
|
|
results["results"].append(log_entry) |
|
total_processed_entries += 1 |
|
|
|
|
|
if args.raw_output and cases_processed == args.num_cases: |
|
break |
|
|
|
|
|
if args.raw_output and cases_processed == args.num_cases: |
|
break |
|
|
|
|
|
with open(live_log_filename, "a") as live_log_file: |
|
|
|
live_log_file.seek(live_log_file.tell() - 2, 0) |
|
live_log_file.write("\n]") |
|
|
|
|
|
results["summary"] = { |
|
"cases_processed": cases_processed, |
|
"questions_processed": questions_processed, |
|
"total_processed_entries": total_processed_entries, |
|
"correct_answers": correct_answers, |
|
"skipped_questions": skipped_questions, |
|
"accuracy": ( |
|
correct_answers / (questions_processed - skipped_questions) |
|
if (questions_processed - skipped_questions) > 0 |
|
else 0 |
|
), |
|
} |
|
|
|
|
|
with open(final_results_filename, "w") as f: |
|
json.dump(results, f, indent=2) |
|
|
|
print(f"\nBenchmark Summary:") |
|
print(f"Total Cases Processed: {cases_processed}") |
|
print(f"Total Questions Processed: {questions_processed}") |
|
print(f"Total Processed Entries: {total_processed_entries}") |
|
print(f"Correct Answers: {correct_answers}") |
|
print(f"Skipped Questions: {skipped_questions}") |
|
print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%") |
|
print(f"\nResults saved to {args.output_dir}") |
|
print(f"Live log: {live_log_filename}") |
|
print(f"Final results: {final_results_filename}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|