|
from typing import Dict, List, Optional, Tuple, Type, Any |
|
from pathlib import Path |
|
import uuid |
|
import tempfile |
|
import matplotlib.pyplot as plt |
|
import torch |
|
from PIL import Image |
|
from pydantic import BaseModel, Field |
|
|
|
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig |
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForToolRun, |
|
CallbackManagerForToolRun, |
|
) |
|
from langchain_core.tools import BaseTool |
|
|
|
|
|
class XRayPhraseGroundingInput(BaseModel): |
|
"""Input schema for the XRay Phrase Grounding Tool. Only supports JPG or PNG images.""" |
|
|
|
image_path: str = Field( |
|
..., |
|
description="Path to the frontal chest X-ray image file, only supports JPG or PNG images", |
|
) |
|
phrase: str = Field( |
|
..., |
|
description="Medical finding or condition to locate in the image (e.g., 'Pleural effusion')", |
|
) |
|
max_new_tokens: int = Field(default=300, description="Maximum number of new tokens to generate") |
|
|
|
|
|
class XRayPhraseGroundingTool(BaseTool): |
|
"""Tool for grounding medical findings in chest X-ray images using the MAIRA-2 model. |
|
|
|
This tool processes chest X-ray images and locates specific medical findings mentioned |
|
in the input phrase. It returns both the bounding box coordinates and a visualization |
|
of the finding's location in the image. |
|
""" |
|
|
|
name: str = "xray_phrase_grounding" |
|
description: str = ( |
|
"Locates and visualizes specific medical findings in chest X-ray images. " |
|
"Takes a chest X-ray image and medical phrase to locate (e.g., 'Pleural effusion', 'Cardiomegaly'). " |
|
"Returns bounding box coordinates in format [x_topleft, y_topleft, x_bottomright, y_bottomright] " |
|
"where each value is between 0-1 representing relative position in the image, " |
|
"a visualization of the finding's location, and confidence metadata. " |
|
"Example input: {'image_path': '/path/to/xray.png', 'phrase': 'Pleural effusion', 'max_new_tokens': 300}" |
|
) |
|
args_schema: Type[BaseModel] = XRayPhraseGroundingInput |
|
|
|
model: Any = None |
|
processor: Any = None |
|
device: torch.device = None |
|
temp_dir: Path = None |
|
|
|
def __init__( |
|
self, |
|
model_path: str = "microsoft/maira-2", |
|
cache_dir: Optional[str] = None, |
|
temp_dir: Optional[str] = None, |
|
load_in_4bit: bool = False, |
|
load_in_8bit: bool = False, |
|
device: Optional[str] = "cuda", |
|
): |
|
"""Initialize the XRay Phrase Grounding Tool.""" |
|
super().__init__() |
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
self.device = torch.device(device) |
|
|
|
|
|
if load_in_4bit: |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
) |
|
elif load_in_8bit: |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_8bit=True, |
|
) |
|
else: |
|
quantization_config = None |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map=self.device, |
|
cache_dir=cache_dir, |
|
trust_remote_code=True, |
|
quantization_config=quantization_config, |
|
) |
|
self.processor = AutoProcessor.from_pretrained( |
|
model_path, cache_dir=cache_dir, trust_remote_code=True |
|
) |
|
|
|
self.model = self.model.eval() |
|
|
|
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp()) |
|
self.temp_dir.mkdir(exist_ok=True) |
|
|
|
def _visualize_bboxes( |
|
self, image: Image.Image, bboxes: List[Tuple[float, float, float, float]], phrase: str |
|
) -> str: |
|
"""Create and save visualization of multiple bounding boxes on the image.""" |
|
plt.figure(figsize=(12, 12)) |
|
plt.imshow(image, cmap="gray") |
|
|
|
for bbox in bboxes: |
|
x1, y1, x2, y2 = bbox |
|
width = x2 - x1 |
|
height = y2 - y1 |
|
|
|
plt.gca().add_patch( |
|
plt.Rectangle( |
|
(x1 * image.width, y1 * image.height), |
|
width * image.width, |
|
height * image.height, |
|
fill=False, |
|
color="red", |
|
linewidth=2, |
|
) |
|
) |
|
|
|
plt.title(f"Located: {phrase}", pad=20) |
|
plt.axis("off") |
|
|
|
viz_path = self.temp_dir / f"grounding_{uuid.uuid4().hex[:8]}.png" |
|
plt.savefig(viz_path, bbox_inches="tight", dpi=150) |
|
plt.close() |
|
|
|
return str(viz_path) |
|
|
|
def _run( |
|
self, |
|
image_path: str, |
|
phrase: str, |
|
max_new_tokens: int = 300, |
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, Any], Dict]: |
|
"""Ground a medical finding phrase in an X-ray image. |
|
|
|
Args: |
|
image_path: Path to the chest X-ray image file |
|
phrase: Medical finding to locate in the image |
|
max_new_tokens: Maximum number of new tokens to generate |
|
run_manager: Optional callback manager |
|
|
|
Returns: |
|
Tuple[Dict, Dict]: Output dictionary and metadata dictionary |
|
""" |
|
try: |
|
image = Image.open(image_path) |
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
inputs = self.processor.format_and_preprocess_phrase_grounding_input( |
|
frontal_image=image, phrase=phrase, return_tensors="pt" |
|
) |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
output = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
use_cache=True, |
|
) |
|
|
|
prompt_length = inputs["input_ids"].shape[-1] |
|
decoded_text = self.processor.decode( |
|
output[0][prompt_length:], skip_special_tokens=True |
|
) |
|
predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence( |
|
decoded_text |
|
) |
|
|
|
metadata = { |
|
"image_path": image_path, |
|
"original_size": image.size, |
|
"model_input_size": tuple(inputs["pixel_values"].shape[-2:]), |
|
"device": str(self.device), |
|
"analysis_status": "completed", |
|
} |
|
|
|
if not predictions: |
|
output = { |
|
"predictions": [], |
|
"visualization_path": None, |
|
} |
|
metadata["analysis_status"] = "completed_no_finding" |
|
return output, metadata |
|
|
|
|
|
processed_predictions = [] |
|
for pred_phrase, pred_bboxes in predictions: |
|
if not pred_bboxes: |
|
continue |
|
|
|
|
|
model_bboxes = [list(bbox) for bbox in pred_bboxes] |
|
original_bboxes = [ |
|
self.processor.adjust_box_for_original_image_size( |
|
bbox, width=image.size[0], height=image.size[1] |
|
) |
|
for bbox in model_bboxes |
|
] |
|
|
|
processed_predictions.append( |
|
{ |
|
"phrase": pred_phrase, |
|
"bounding_boxes": { |
|
"model_coordinates": model_bboxes, |
|
"image_coordinates": original_bboxes, |
|
}, |
|
} |
|
) |
|
|
|
|
|
if processed_predictions: |
|
all_bboxes = [] |
|
for pred in processed_predictions: |
|
all_bboxes.extend(pred["bounding_boxes"]["image_coordinates"]) |
|
viz_path = self._visualize_bboxes(image, all_bboxes, phrase) |
|
else: |
|
viz_path = None |
|
metadata["analysis_status"] = "completed_no_finding" |
|
|
|
output = { |
|
"predictions": processed_predictions, |
|
"visualization_path": viz_path, |
|
} |
|
|
|
return output, metadata |
|
|
|
except Exception as e: |
|
output = {"error": str(e)} |
|
metadata = { |
|
"image_path": image_path, |
|
"analysis_status": "failed", |
|
"error_details": str(e), |
|
} |
|
return output, metadata |
|
|
|
async def _arun( |
|
self, |
|
image_path: str, |
|
phrase: str, |
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, Any], Dict]: |
|
"""Asynchronous version of _run.""" |
|
return self._run(image_path, phrase, run_manager) |
|
|