|
from typing import Dict, List, Optional, Tuple, Type, Any |
|
from pathlib import Path |
|
import uuid |
|
import tempfile |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision |
|
import torchxrayvision as xrv |
|
import matplotlib.pyplot as plt |
|
import skimage.io |
|
import skimage.measure |
|
import skimage.transform |
|
import traceback |
|
|
|
from pydantic import BaseModel, Field |
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForToolRun, |
|
CallbackManagerForToolRun, |
|
) |
|
from langchain_core.tools import BaseTool |
|
|
|
|
|
class ChestXRaySegmentationInput(BaseModel): |
|
"""Input schema for the Chest X-ray Segmentation Tool.""" |
|
|
|
image_path: str = Field(..., description="Path to the chest X-ray image file to be segmented") |
|
organs: Optional[List[str]] = Field( |
|
None, |
|
description="List of organs to segment. If None, all available organs will be segmented. " |
|
"Available organs: Left/Right Clavicle, Left/Right Scapula, Left/Right Lung, " |
|
"Left/Right Hilus Pulmonis, Heart, Aorta, Facies Diaphragmatica, " |
|
"Mediastinum, Weasand, Spine", |
|
) |
|
|
|
|
|
class OrganMetrics(BaseModel): |
|
"""Detailed metrics for a segmented organ.""" |
|
|
|
|
|
area_pixels: int = Field(..., description="Area in pixels") |
|
area_cm2: float = Field(..., description="Approximate area in cm²") |
|
centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid") |
|
bbox: Tuple[int, int, int, int] = Field( |
|
..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)" |
|
) |
|
|
|
|
|
width: int = Field(..., description="Width of the organ in pixels") |
|
height: int = Field(..., description="Height of the organ in pixels") |
|
aspect_ratio: float = Field(..., description="Height/width ratio") |
|
|
|
|
|
relative_position: Dict[str, float] = Field( |
|
..., description="Position relative to image boundaries (0-1 scale)" |
|
) |
|
|
|
|
|
mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region") |
|
std_intensity: float = Field(..., description="Standard deviation of pixel intensity") |
|
confidence_score: float = Field(..., description="Model confidence score for this organ") |
|
|
|
|
|
class ChestXRaySegmentationTool(BaseTool): |
|
"""Tool for performing detailed segmentation analysis of chest X-ray images.""" |
|
|
|
name: str = "chest_xray_segmentation" |
|
description: str = ( |
|
"Segments chest X-ray images to specified anatomical structures. " |
|
"Available organs: Left/Right Clavicle (collar bones), Left/Right Scapula (shoulder blades), " |
|
"Left/Right Lung, Left/Right Hilus Pulmonis (lung roots), Heart, Aorta, " |
|
"Facies Diaphragmatica (diaphragm), Mediastinum (central cavity), Weasand (esophagus), " |
|
"and Spine. Returns segmentation visualization and comprehensive metrics. " |
|
"Let the user know the area is not accurate unless input has been DICOM." |
|
) |
|
args_schema: Type[BaseModel] = ChestXRaySegmentationInput |
|
|
|
model: Any = None |
|
device: Optional[str] = "cuda" |
|
transform: Any = None |
|
pixel_spacing_mm: float = 0.2 |
|
temp_dir: Path = Path("temp") |
|
organ_map: Dict[str, int] = None |
|
|
|
def __init__(self, device: Optional[str] = "cuda", temp_dir: Optional[Path] = Path("temp")): |
|
"""Initialize the segmentation tool with model and temporary directory.""" |
|
super().__init__() |
|
self.model = xrv.baseline_models.chestx_det.PSPNet() |
|
self.device = torch.device(device) if device else "cuda" |
|
self.model = self.model.to(self.device) |
|
self.model.eval() |
|
|
|
self.transform = torchvision.transforms.Compose( |
|
[xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)] |
|
) |
|
|
|
self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir) |
|
self.temp_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.organ_map = { |
|
"Left Clavicle": 0, |
|
"Right Clavicle": 1, |
|
"Left Scapula": 2, |
|
"Right Scapula": 3, |
|
"Left Lung": 4, |
|
"Right Lung": 5, |
|
"Left Hilus Pulmonis": 6, |
|
"Right Hilus Pulmonis": 7, |
|
"Heart": 8, |
|
"Aorta": 9, |
|
"Facies Diaphragmatica": 10, |
|
"Mediastinum": 11, |
|
"Weasand": 12, |
|
"Spine": 13, |
|
} |
|
|
|
def _align_mask_to_original( |
|
self, mask: np.ndarray, original_shape: Tuple[int, int] |
|
) -> np.ndarray: |
|
""" |
|
Align a mask from the transformed (cropped/resized) space back to the full original image. |
|
Assumes that the transform does a center crop to a square of side = min(original height, width) |
|
and then resizes to (512,512). |
|
""" |
|
orig_h, orig_w = original_shape |
|
crop_size = min(orig_h, orig_w) |
|
crop_top = (orig_h - crop_size) // 2 |
|
crop_left = (orig_w - crop_size) // 2 |
|
|
|
|
|
resized_mask = skimage.transform.resize( |
|
mask, (crop_size, crop_size), order=0, preserve_range=True, anti_aliasing=False |
|
) |
|
full_mask = np.zeros(original_shape) |
|
full_mask[crop_top : crop_top + crop_size, crop_left : crop_left + crop_size] = resized_mask |
|
return full_mask |
|
|
|
def _compute_organ_metrics( |
|
self, mask: np.ndarray, original_img: np.ndarray, confidence: float |
|
) -> Optional[OrganMetrics]: |
|
"""Compute comprehensive metrics for a single organ mask.""" |
|
|
|
if mask.shape != original_img.shape: |
|
mask = self._align_mask_to_original(mask, original_img.shape) |
|
|
|
props = skimage.measure.regionprops(mask.astype(int)) |
|
if not props: |
|
return None |
|
|
|
props = props[0] |
|
area_cm2 = mask.sum() * (self.pixel_spacing_mm / 10) ** 2 |
|
|
|
img_height, img_width = mask.shape |
|
cy, cx = props.centroid |
|
relative_pos = { |
|
"top": cy / img_height, |
|
"left": cx / img_width, |
|
"center_dist": np.sqrt(((cy / img_height - 0.5) ** 2 + (cx / img_width - 0.5) ** 2)), |
|
} |
|
|
|
organ_pixels = original_img[mask > 0] |
|
mean_intensity = organ_pixels.mean() if len(organ_pixels) > 0 else 0 |
|
std_intensity = organ_pixels.std() if len(organ_pixels) > 0 else 0 |
|
|
|
return OrganMetrics( |
|
area_pixels=int(mask.sum()), |
|
area_cm2=float(area_cm2), |
|
centroid=(float(cy), float(cx)), |
|
bbox=tuple(map(int, props.bbox)), |
|
width=int(props.bbox[3] - props.bbox[1]), |
|
height=int(props.bbox[2] - props.bbox[0]), |
|
aspect_ratio=float( |
|
(props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1]) |
|
), |
|
relative_position=relative_pos, |
|
mean_intensity=float(mean_intensity), |
|
std_intensity=float(std_intensity), |
|
confidence_score=float(confidence), |
|
) |
|
|
|
def _save_visualization( |
|
self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int] |
|
) -> str: |
|
"""Save visualization of original image with segmentation masks overlaid.""" |
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow( |
|
original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0] |
|
) |
|
|
|
|
|
colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices))) |
|
|
|
|
|
for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)): |
|
mask = pred_masks[0, organ_idx].cpu().numpy() |
|
if mask.sum() > 0: |
|
|
|
if mask.shape != original_img.shape: |
|
mask = self._align_mask_to_original(mask, original_img.shape) |
|
|
|
|
|
colored_mask = np.zeros((*original_img.shape, 4)) |
|
colored_mask[mask > 0] = (*color[:3], 0.3) |
|
plt.imshow( |
|
colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0] |
|
) |
|
|
|
|
|
organ_name = list(self.organ_map.keys())[ |
|
list(self.organ_map.values()).index(organ_idx) |
|
] |
|
plt.plot([], [], color=color, label=organ_name, linewidth=3) |
|
|
|
plt.title("Segmentation Overlay") |
|
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") |
|
plt.axis("off") |
|
|
|
save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png" |
|
plt.savefig(save_path, bbox_inches="tight", dpi=300) |
|
plt.close() |
|
|
|
return str(save_path) |
|
|
|
def _run( |
|
self, |
|
image_path: str, |
|
organs: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, Any], Dict]: |
|
"""Run segmentation analysis for specified organs.""" |
|
try: |
|
|
|
if organs: |
|
organs = [o.strip() for o in organs] |
|
invalid_organs = [o for o in organs if o not in self.organ_map] |
|
if invalid_organs: |
|
raise ValueError(f"Invalid organs specified: {invalid_organs}") |
|
organ_indices = [self.organ_map[o] for o in organs] |
|
else: |
|
organ_indices = list(self.organ_map.values()) |
|
organs = list(self.organ_map.keys()) |
|
|
|
|
|
original_img = skimage.io.imread(image_path) |
|
if len(original_img.shape) > 2: |
|
original_img = original_img[:, :, 0] |
|
|
|
img = xrv.datasets.normalize(original_img, 255) |
|
img = img[None, ...] |
|
img = self.transform(img) |
|
img = torch.from_numpy(img) |
|
img = img.to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
pred = self.model(img) |
|
pred_probs = torch.sigmoid(pred) |
|
pred_masks = (pred_probs > 0.5).float() |
|
|
|
|
|
viz_path = self._save_visualization(original_img, pred_masks, organ_indices) |
|
|
|
|
|
results = {} |
|
for idx, organ_name in zip(organ_indices, organs): |
|
mask = pred_masks[0, idx].cpu().numpy() |
|
if mask.sum() > 0: |
|
metrics = self._compute_organ_metrics( |
|
mask, original_img, float(pred_probs[0, idx].mean().cpu()) |
|
) |
|
if metrics: |
|
results[organ_name] = metrics |
|
|
|
output = { |
|
"segmentation_image_path": viz_path, |
|
"metrics": {organ: metrics.dict() for organ, metrics in results.items()}, |
|
} |
|
|
|
metadata = { |
|
"image_path": image_path, |
|
"segmentation_image_path": viz_path, |
|
"original_size": original_img.shape, |
|
"model_size": tuple(img.shape[-2:]), |
|
"pixel_spacing_mm": self.pixel_spacing_mm, |
|
"requested_organs": organs, |
|
"processed_organs": list(results.keys()), |
|
"analysis_status": "completed", |
|
} |
|
|
|
return output, metadata |
|
|
|
except Exception as e: |
|
error_output = {"error": str(e)} |
|
error_metadata = { |
|
"image_path": image_path, |
|
"analysis_status": "failed", |
|
"error_traceback": traceback.format_exc(), |
|
} |
|
return error_output, error_metadata |
|
|
|
async def _arun( |
|
self, |
|
image_path: str, |
|
organs: Optional[List[str]] = None, |
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, Any], Dict]: |
|
"""Async version of _run.""" |
|
return self._run(image_path, organs) |
|
|