|
import json |
|
import openai |
|
import os |
|
from datetime import datetime |
|
import base64 |
|
import logging |
|
from pathlib import Path |
|
import time |
|
from tqdm import tqdm |
|
from typing import Dict, List, Optional, Union, Any |
|
|
|
|
|
DEBUG_MODE = False |
|
OUTPUT_DIR = "results" |
|
MODEL_NAME = "gpt-4o-2024-05-13" |
|
TEMPERATURE = 0.2 |
|
SUBSET = "Visual Question Answering" |
|
|
|
|
|
logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO |
|
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_mime_type(file_path: str) -> str: |
|
""" |
|
Determine MIME type based on file extension. |
|
|
|
Args: |
|
file_path (str): Path to the file |
|
|
|
Returns: |
|
str: MIME type string for the file |
|
""" |
|
extension = os.path.splitext(file_path)[1].lower() |
|
mime_types = { |
|
".png": "image/png", |
|
".jpg": "image/jpeg", |
|
".jpeg": "image/jpeg", |
|
".gif": "image/gif", |
|
} |
|
return mime_types.get(extension, "application/octet-stream") |
|
|
|
|
|
def encode_image(image_path: str) -> str: |
|
""" |
|
Encode image to base64 with extensive error checking. |
|
|
|
Args: |
|
image_path (str): Path to the image file |
|
|
|
Returns: |
|
str: Base64 encoded image string |
|
|
|
Raises: |
|
FileNotFoundError: If image file does not exist |
|
ValueError: If image file is empty or too large |
|
Exception: For other image processing errors |
|
""" |
|
logger.debug(f"Attempting to read image from: {image_path}") |
|
if not os.path.exists(image_path): |
|
raise FileNotFoundError(f"Image file not found: {image_path}") |
|
|
|
|
|
file_size = os.path.getsize(image_path) |
|
if file_size > 20 * 1024 * 1024: |
|
raise ValueError("Image file size exceeds 20MB limit") |
|
if file_size == 0: |
|
raise ValueError("Image file is empty") |
|
logger.debug(f"Image file size: {file_size / 1024:.2f} KB") |
|
|
|
try: |
|
from PIL import Image |
|
|
|
|
|
with Image.open(image_path) as img: |
|
|
|
width, height = img.size |
|
format = img.format |
|
mode = img.mode |
|
logger.debug( |
|
f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}" |
|
) |
|
|
|
if format not in ["PNG", "JPEG", "GIF"]: |
|
raise ValueError(f"Unsupported image format: {format}") |
|
|
|
with open(image_path, "rb") as image_file: |
|
|
|
header = image_file.read(8) |
|
|
|
|
|
|
|
|
|
image_file.seek(0) |
|
encoded = base64.b64encode(image_file.read()).decode("utf-8") |
|
encoded_length = len(encoded) |
|
logger.debug(f"Base64 encoded length: {encoded_length} characters") |
|
|
|
|
|
if encoded_length == 0: |
|
raise ValueError("Base64 encoding produced empty string") |
|
if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"): |
|
logger.warning("Base64 string doesn't start with expected JPEG or PNG header") |
|
|
|
return encoded |
|
except Exception as e: |
|
logger.error(f"Error reading/encoding image: {str(e)}") |
|
raise |
|
|
|
|
|
def create_single_request( |
|
image_path: str, question: str, options: Dict[str, str] |
|
) -> List[Dict[str, Any]]: |
|
""" |
|
Create a single API request with image and question. |
|
|
|
Args: |
|
image_path (str): Path to the image file |
|
question (str): Question text |
|
options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1' |
|
|
|
Returns: |
|
List[Dict[str, Any]]: List of message dictionaries for the API request |
|
|
|
Raises: |
|
Exception: For errors in request creation |
|
""" |
|
if DEBUG_MODE: |
|
logger.debug("Creating API request...") |
|
|
|
prompt = f"""Given the following medical examination question: |
|
Please answer this multiple choice question: |
|
|
|
Question: {question} |
|
|
|
Options: |
|
A) {options['option_0']} |
|
B) {options['option_1']} |
|
|
|
Base your answer only on the provided image and select either A or B.""" |
|
|
|
try: |
|
encoded_image = encode_image(image_path) |
|
mime_type = get_mime_type(image_path) |
|
|
|
if DEBUG_MODE: |
|
logger.debug(f"Image encoded with MIME type: {mime_type}") |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.", |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
{ |
|
"type": "image_url", |
|
"image_url": {"url": f"data:{mime_type};base64,{encoded_image}"}, |
|
}, |
|
], |
|
}, |
|
] |
|
|
|
if DEBUG_MODE: |
|
log_messages = json.loads(json.dumps(messages)) |
|
log_messages[1]["content"][1]["image_url"][ |
|
"url" |
|
] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]" |
|
logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}") |
|
|
|
return messages |
|
|
|
except Exception as e: |
|
logger.error(f"Error creating request: {str(e)}") |
|
raise |
|
|
|
|
|
def check_answer(model_answer: str, correct_answer: int) -> bool: |
|
""" |
|
Check if the model's answer matches the correct answer. |
|
|
|
Args: |
|
model_answer (str): The model's answer (A or B) |
|
correct_answer (int): The correct answer index (0 for A, 1 for B) |
|
|
|
Returns: |
|
bool: True if answer is correct, False otherwise |
|
""" |
|
if not isinstance(model_answer, str): |
|
return False |
|
|
|
|
|
model_letter = model_answer.strip().upper() |
|
if model_letter.startswith("A"): |
|
model_index = 0 |
|
elif model_letter.startswith("B"): |
|
model_index = 1 |
|
else: |
|
return False |
|
|
|
return model_index == correct_answer |
|
|
|
|
|
def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str: |
|
""" |
|
Save results to a JSON file with timestamp. |
|
|
|
Args: |
|
results (List[Dict[str, Any]]): List of result dictionaries |
|
output_dir (str): Directory to save results |
|
|
|
Returns: |
|
str: Path to the saved file |
|
""" |
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json") |
|
|
|
with open(output_file, "w") as f: |
|
json.dump(results, f, indent=2) |
|
|
|
logger.info(f"Batch results saved to {output_file}") |
|
return output_file |
|
|
|
|
|
def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]: |
|
""" |
|
Calculate accuracy from results, handling error cases. |
|
|
|
Args: |
|
results (List[Dict[str, Any]]): List of result dictionaries |
|
|
|
Returns: |
|
tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total) |
|
""" |
|
if not results: |
|
return 0.0, 0, 0 |
|
|
|
total = len(results) |
|
valid_results = [r for r in results if "output" in r] |
|
correct = sum( |
|
1 for result in valid_results if result.get("output", {}).get("is_correct", False) |
|
) |
|
|
|
accuracy = (correct / total * 100) if total > 0 else 0 |
|
return accuracy, correct, total |
|
|
|
|
|
def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float: |
|
""" |
|
Calculate accuracy for the current batch. |
|
|
|
Args: |
|
results (List[Dict[str, Any]]): List of result dictionaries |
|
|
|
Returns: |
|
float: Accuracy percentage for the batch |
|
""" |
|
valid_results = [r for r in results if "output" in r] |
|
if not valid_results: |
|
return 0.0 |
|
return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100 |
|
|
|
|
|
def process_batch( |
|
data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50 |
|
) -> List[Dict[str, Any]]: |
|
""" |
|
Process a batch of examples and return results. |
|
|
|
Args: |
|
data (List[Dict[str, Any]]): List of data items to process |
|
client (openai.OpenAI): OpenAI client instance |
|
start_idx (int, optional): Starting index for batch. Defaults to 0 |
|
batch_size (int, optional): Size of batch to process. Defaults to 50 |
|
|
|
Returns: |
|
List[Dict[str, Any]]: List of processed results |
|
""" |
|
batch_results = [] |
|
end_idx = min(start_idx + batch_size, len(data)) |
|
|
|
pbar = tqdm( |
|
range(start_idx, end_idx), |
|
desc=f"Processing batch {start_idx//batch_size + 1}", |
|
unit="example", |
|
) |
|
|
|
for index in pbar: |
|
vqa_item = data[index] |
|
options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]} |
|
|
|
try: |
|
messages = create_single_request( |
|
image_path=vqa_item["image_path"], question=vqa_item["question"], options=options |
|
) |
|
|
|
response = client.chat.completions.create( |
|
model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE |
|
) |
|
|
|
model_answer = response.choices[0].message.content.strip() |
|
is_correct = check_answer(model_answer, vqa_item["answer"]) |
|
|
|
result = { |
|
"timestamp": datetime.now().isoformat(), |
|
"example_index": index, |
|
"input": { |
|
"question": vqa_item["question"], |
|
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, |
|
"image_path": vqa_item["image_path"], |
|
}, |
|
"output": { |
|
"model_answer": model_answer, |
|
"correct_answer": "A" if vqa_item["answer"] == 0 else "B", |
|
"is_correct": is_correct, |
|
"usage": { |
|
"prompt_tokens": response.usage.prompt_tokens, |
|
"completion_tokens": response.usage.completion_tokens, |
|
"total_tokens": response.usage.total_tokens, |
|
}, |
|
}, |
|
} |
|
batch_results.append(result) |
|
|
|
|
|
current_accuracy = calculate_batch_accuracy(batch_results) |
|
pbar.set_description( |
|
f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% " |
|
f"({len(batch_results)}/{index-start_idx+1} examples)" |
|
) |
|
|
|
except Exception as e: |
|
error_result = { |
|
"timestamp": datetime.now().isoformat(), |
|
"example_index": index, |
|
"error": str(e), |
|
"input": { |
|
"question": vqa_item["question"], |
|
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, |
|
"image_path": vqa_item["image_path"], |
|
}, |
|
} |
|
batch_results.append(error_result) |
|
if DEBUG_MODE: |
|
pbar.write(f"Error processing example {index}: {str(e)}") |
|
|
|
time.sleep(1) |
|
|
|
return batch_results |
|
|
|
|
|
def main() -> None: |
|
""" |
|
Main function to process the entire dataset. |
|
|
|
Raises: |
|
ValueError: If OPENAI_API_KEY is not set |
|
Exception: For other processing errors |
|
""" |
|
logger.info("Starting full dataset processing...") |
|
json_path = "../data/chexbench_updated.json" |
|
|
|
try: |
|
api_key = os.getenv("OPENAI_API_KEY") |
|
if not api_key: |
|
raise ValueError("OPENAI_API_KEY environment variable is not set.") |
|
client = openai.OpenAI(api_key=api_key) |
|
|
|
with open(json_path, "r") as f: |
|
data = json.load(f) |
|
|
|
subset_data = data[SUBSET] |
|
total_examples = len(subset_data) |
|
logger.info(f"Found {total_examples} examples in {SUBSET} subset") |
|
|
|
all_results = [] |
|
batch_size = 50 |
|
|
|
|
|
for start_idx in range(0, total_examples, batch_size): |
|
batch_results = process_batch(subset_data, client, start_idx, batch_size) |
|
all_results.extend(batch_results) |
|
|
|
|
|
output_file = save_results_to_json(all_results, OUTPUT_DIR) |
|
|
|
|
|
overall_accuracy, correct, total = calculate_accuracy(all_results) |
|
logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed") |
|
logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)") |
|
|
|
logger.info("Processing completed!") |
|
logger.info(f"Final results saved to: {output_file}") |
|
|
|
except Exception as e: |
|
logger.error(f"Fatal error: {str(e)}") |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|