File size: 6,419 Bytes
cb3a670 0ffa584 cb3a670 0ffa584 cb3a670 0ffa584 cb3a670 0ffa584 cb3a670 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from typing import Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field
import skimage.io
import torch
import torchvision
import torchxrayvision as xrv
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
class ChestXRayInput(BaseModel):
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
image_path: str = Field(
..., description="Path to the radiology image file, only supports JPG or PNG images"
)
class ChestXRayClassifierTool(BaseTool):
"""Tool that classifies chest X-ray images for multiple pathologies.
This tool uses a pre-trained DenseNet model to analyze chest X-ray images and
predict the likelihood of various pathologies. The model can classify the following 18 conditions:
Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema,
Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration,
Lung Lesion, Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, Pneumothorax
The output values represent the probability (from 0 to 1) of each condition being present in the image.
A higher value indicates a higher likelihood of the condition being present.
"""
name: str = "chest_xray_classifier"
description: str = (
"A tool that analyzes chest X-ray images and classifies them for 18 different pathologies. "
"Input should be the path to a chest X-ray image file. "
"Output is a dictionary of pathologies and their predicted probabilities (0 to 1). "
"Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, "
"Enlarged Cardiomediastinum, Fibrosis, Fracture, Hernia, Infiltration, Lung Lesion, "
"Lung Opacity, Mass, Nodule, Pleural Thickening, Pneumonia, and Pneumothorax. "
"Higher values indicate a higher likelihood of the condition being present."
)
args_schema: Type[BaseModel] = ChestXRayInput
model: xrv.models.DenseNet = None
device: Optional[torch.device] = torch.device("cpu") # Default to CPU
transform: torchvision.transforms.Compose = None
def __init__(self, model_name: str = "densenet121-res224-all", device: Optional[str] = None):
super().__init__()
# If device is not specified, use CUDA if available, else fallback to CPU
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = xrv.models.DenseNet(weights=model_name)
self.model.eval()
# Assign device based on the passed or auto-detected option
self.device = torch.device(device)
self.model = self.model.to(self.device)
self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])
def _process_image(self, image_path: str) -> torch.Tensor:
"""
Process the input chest X-ray image for model inference.
This method loads the image, normalizes it, applies necessary transformations,
and prepares it as a torch.Tensor for model input.
Args:
image_path (str): The file path to the chest X-ray image.
Returns:
torch.Tensor: A processed image tensor ready for model inference.
Raises:
FileNotFoundError: If the specified image file does not exist.
ValueError: If the image cannot be properly loaded or processed.
"""
img = skimage.io.imread(image_path)
img = xrv.datasets.normalize(img, 255)
if len(img.shape) > 2:
img = img[:, :, 0]
img = img[None, :, :]
img = self.transform(img)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(self.device)
return img
def _run(
self,
image_path: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, float], Dict]:
"""Classify the chest X-ray image for multiple pathologies.
Args:
image_path (str): The path to the chest X-ray image file.
run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
Returns:
Tuple[Dict[str, float], Dict]: A tuple containing the classification results
(pathologies and their probabilities from 0 to 1)
and any additional metadata.
Raises:
Exception: If there's an error processing the image or during classification.
"""
try:
img = self._process_image(image_path)
with torch.inference_mode():
preds = self.model(img).cpu()[0]
output = dict(zip(xrv.datasets.default_pathologies, preds.numpy()))
metadata = {
"image_path": image_path,
"analysis_status": "completed",
"note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.",
}
return output, metadata
except Exception as e:
return {"error": str(e)}, {
"image_path": image_path,
"analysis_status": "failed",
}
async def _arun(
self,
image_path: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, float], Dict]:
"""Asynchronously classify the chest X-ray image for multiple pathologies.
This method currently calls the synchronous version, as the model inference
is not inherently asynchronous. For true asynchronous behavior, consider
using a separate thread or process.
Args:
image_path (str): The path to the chest X-ray image file.
run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
Returns:
Tuple[Dict[str, float], Dict]: A tuple containing the classification results
(pathologies and their probabilities from 0 to 1)
and any additional metadata.
Raises:
Exception: If there's an error processing the image or during classification.
"""
return self._run(image_path)
|