|
from typing import Dict, Optional, Tuple, Type |
|
from pydantic import BaseModel, Field |
|
|
|
import skimage.io |
|
import torch |
|
import torchvision |
|
import torchxrayvision as xrv |
|
|
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForToolRun, |
|
CallbackManagerForToolRun, |
|
) |
|
from langchain_core.tools import BaseTool |
|
|
|
|
|
class ChestXRayInput(BaseModel): |
|
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images.""" |
|
|
|
image_path: str = Field( |
|
..., description="Path to the radiology image file, only supports JPG or PNG images" |
|
) |
|
|
|
|
|
class ChestXRayClassifierTool(BaseTool): |
|
"""Tool that classifies chest X-ray images for multiple pathologies. |
|
|
|
This tool uses a pre-trained DenseNet model to analyze chest X-ray images and |
|
predict the likelihood of various pathologies. The model can classify the following 18 conditions: |
|
|
|
Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, |
|
Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration, |
|
Lung Lesion, Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, Pneumothorax |
|
|
|
The output values represent the probability (from 0 to 1) of each condition being present in the image. |
|
A higher value indicates a higher likelihood of the condition being present. |
|
""" |
|
|
|
name: str = "chest_xray_classifier" |
|
description: str = ( |
|
"A tool that analyzes chest X-ray images and classifies them for 18 different pathologies. " |
|
"Input should be the path to a chest X-ray image file. " |
|
"Output is a dictionary of pathologies and their predicted probabilities (0 to 1). " |
|
"Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, " |
|
"Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration, Lung Lesion, " |
|
"Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. " |
|
"Higher values indicate a higher likelihood of the condition being present." |
|
) |
|
args_schema: Type[BaseModel] = ChestXRayInput |
|
model: xrv.models.DenseNet = None |
|
device: Optional[torch.device] = torch.device("cpu") |
|
transform: torchvision.transforms.Compose = None |
|
|
|
def __init__(self, model_name: str = "densenet121-res224-all", device: Optional[str] = None): |
|
super().__init__() |
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = xrv.models.DenseNet(weights=model_name) |
|
self.model.eval() |
|
|
|
|
|
self.device = torch.device(device) |
|
self.model = self.model.to(self.device) |
|
self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()]) |
|
|
|
def _process_image(self, image_path: str) -> torch.Tensor: |
|
""" |
|
Process the input chest X-ray image for model inference. |
|
|
|
This method loads the image, normalizes it, applies necessary transformations, |
|
and prepares it as a torch.Tensor for model input. |
|
|
|
Args: |
|
image_path (str): The file path to the chest X-ray image. |
|
|
|
Returns: |
|
torch.Tensor: A processed image tensor ready for model inference. |
|
|
|
Raises: |
|
FileNotFoundError: If the specified image file does not exist. |
|
ValueError: If the image cannot be properly loaded or processed. |
|
""" |
|
img = skimage.io.imread(image_path) |
|
img = xrv.datasets.normalize(img, 255) |
|
|
|
if len(img.shape) > 2: |
|
img = img[:, :, 0] |
|
|
|
img = img[None, :, :] |
|
img = self.transform(img) |
|
img = torch.from_numpy(img).unsqueeze(0) |
|
|
|
img = img.to(self.device) |
|
|
|
return img |
|
|
|
def _run( |
|
self, |
|
image_path: str, |
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, float], Dict]: |
|
"""Classify the chest X-ray image for multiple pathologies. |
|
|
|
Args: |
|
image_path (str): The path to the chest X-ray image file. |
|
run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run. |
|
|
|
Returns: |
|
Tuple[Dict[str, float], Dict]: A tuple containing the classification results |
|
(pathologies and their probabilities from 0 to 1) |
|
and any additional metadata. |
|
|
|
Raises: |
|
Exception: If there's an error processing the image or during classification. |
|
""" |
|
try: |
|
img = self._process_image(image_path) |
|
|
|
with torch.inference_mode(): |
|
preds = self.model(img).cpu()[0] |
|
|
|
output = dict(zip(xrv.datasets.default_pathologies, preds.numpy())) |
|
metadata = { |
|
"image_path": image_path, |
|
"analysis_status": "completed", |
|
"note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.", |
|
} |
|
return output, metadata |
|
except Exception as e: |
|
return {"error": str(e)}, { |
|
"image_path": image_path, |
|
"analysis_status": "failed", |
|
} |
|
|
|
async def _arun( |
|
self, |
|
image_path: str, |
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, float], Dict]: |
|
"""Asynchronously classify the chest X-ray image for multiple pathologies. |
|
|
|
This method currently calls the synchronous version, as the model inference |
|
is not inherently asynchronous. For true asynchronous behavior, consider |
|
using a separate thread or process. |
|
|
|
Args: |
|
image_path (str): The path to the chest X-ray image file. |
|
run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run. |
|
|
|
Returns: |
|
Tuple[Dict[str, float], Dict]: A tuple containing the classification results |
|
(pathologies and their probabilities from 0 to 1) |
|
and any additional metadata. |
|
|
|
Raises: |
|
Exception: If there's an error processing the image or during classification. |
|
""" |
|
return self._run(image_path) |
|
|