|
from typing import Any, Dict, Optional, Tuple, Type |
|
from pydantic import BaseModel, Field |
|
|
|
import torch |
|
import os |
|
|
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForToolRun, |
|
CallbackManagerForToolRun, |
|
) |
|
from langchain_core.tools import BaseTool |
|
|
|
from PIL import Image |
|
|
|
from transformers import ( |
|
BertTokenizer, |
|
ViTImageProcessor, |
|
VisionEncoderDecoderModel, |
|
GenerationConfig, |
|
) |
|
|
|
|
|
class ChestXRayInput(BaseModel): |
|
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images.""" |
|
|
|
image_path: str = Field( |
|
..., description="Path to the radiology image file, only supports JPG or PNG images" |
|
) |
|
|
|
|
|
class ChestXRayReportGeneratorTool(BaseTool): |
|
name: str = "chest_xray_report_generator" |
|
description: str = ( |
|
"A tool that analyzes chest X-ray images and generates comprehensive radiology reports " |
|
"containing both detailed findings and impression summaries. Input should be the path " |
|
"to a chest X-ray image file. Output is a structured report with both detailed " |
|
"observations and key clinical conclusions." |
|
) |
|
device: Optional[str] = "cpu" |
|
args_schema: Type[BaseModel] = ChestXRayInput |
|
findings_model: VisionEncoderDecoderModel = None |
|
impression_model: VisionEncoderDecoderModel = None |
|
findings_tokenizer: BertTokenizer = None |
|
impression_tokenizer: BertTokenizer = None |
|
findings_processor: ViTImageProcessor = None |
|
impression_processor: ViTImageProcessor = None |
|
generation_args: Dict[str, Any] = None |
|
|
|
def __init__(self, cache_dir: str = "./model_weights", device: Optional[str] = "cpu"): |
|
super().__init__() |
|
os.makedirs(cache_dir, exist_ok=True) |
|
self.device = torch.device(device) if device else torch.device("cpu") |
|
|
|
|
|
self.findings_model = VisionEncoderDecoderModel.from_pretrained( |
|
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir |
|
).eval() |
|
self.findings_tokenizer = BertTokenizer.from_pretrained( |
|
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir |
|
) |
|
self.findings_processor = ViTImageProcessor.from_pretrained( |
|
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir |
|
) |
|
|
|
|
|
self.impression_model = VisionEncoderDecoderModel.from_pretrained( |
|
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir |
|
).eval() |
|
self.impression_tokenizer = BertTokenizer.from_pretrained( |
|
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir |
|
) |
|
self.impression_processor = ViTImageProcessor.from_pretrained( |
|
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir |
|
) |
|
|
|
|
|
self.findings_model = self.findings_model.to(self.device) |
|
self.impression_model = self.impression_model.to(self.device) |
|
|
|
self.generation_args = { |
|
"num_return_sequences": 1, |
|
"max_length": 128, |
|
"use_cache": True, |
|
"beam_width": 2, |
|
} |
|
|
|
def _process_image( |
|
self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel |
|
) -> torch.Tensor: |
|
image = Image.open(image_path).convert("RGB") |
|
pixel_values = processor(image, return_tensors="pt").pixel_values |
|
expected_size = model.config.encoder.image_size |
|
actual_size = pixel_values.shape[-1] |
|
|
|
if expected_size != actual_size: |
|
pixel_values = torch.nn.functional.interpolate( |
|
pixel_values, |
|
size=(expected_size, expected_size), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
return pixel_values.to(self.device) |
|
|
|
def _generate_report_section( |
|
self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer |
|
) -> str: |
|
generation_config = GenerationConfig( |
|
**{ |
|
**self.generation_args, |
|
"bos_token_id": model.config.bos_token_id, |
|
"eos_token_id": model.config.eos_token_id, |
|
"pad_token_id": model.config.pad_token_id, |
|
"decoder_start_token_id": tokenizer.cls_token_id, |
|
} |
|
) |
|
generated_ids = model.generate(pixel_values, generation_config=generation_config) |
|
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
def _run( |
|
self, |
|
image_path: str, |
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
) -> Tuple[str, Dict]: |
|
try: |
|
findings_pixels = self._process_image( |
|
image_path, self.findings_processor, self.findings_model |
|
) |
|
impression_pixels = self._process_image( |
|
image_path, self.impression_processor, self.impression_model |
|
) |
|
|
|
with torch.inference_mode(): |
|
findings_text = self._generate_report_section( |
|
findings_pixels, self.findings_model, self.findings_tokenizer |
|
) |
|
impression_text = self._generate_report_section( |
|
impression_pixels, self.impression_model, self.impression_tokenizer |
|
) |
|
|
|
report = ( |
|
"CHEST X-RAY REPORT\n\n" |
|
f"FINDINGS:\n{findings_text}\n\n" |
|
f"IMPRESSION:\n{impression_text}" |
|
) |
|
metadata = { |
|
"image_path": image_path, |
|
"analysis_status": "completed", |
|
"sections_generated": ["findings", "impression"], |
|
} |
|
return report, metadata |
|
|
|
except Exception as e: |
|
return f"Error generating report: {str(e)}", { |
|
"image_path": image_path, |
|
"analysis_status": "failed", |
|
"error": str(e), |
|
} |
|
|
|
async def _arun( |
|
self, |
|
image_path: str, |
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
) -> Tuple[str, Dict]: |
|
return self._run(image_path) |
|
|