CRAX / medrax /tools /report_generation.py
Dhruv-Ty's picture
resolved the PermissionError
e1ede20
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field
import torch
import os # Added to create local cache dir
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from PIL import Image
from transformers import (
BertTokenizer,
ViTImageProcessor,
VisionEncoderDecoderModel,
GenerationConfig,
)
class ChestXRayInput(BaseModel):
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
image_path: str = Field(
..., description="Path to the radiology image file, only supports JPG or PNG images"
)
class ChestXRayReportGeneratorTool(BaseTool):
name: str = "chest_xray_report_generator"
description: str = (
"A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
"containing both detailed findings and impression summaries. Input should be the path "
"to a chest X-ray image file. Output is a structured report with both detailed "
"observations and key clinical conclusions."
)
device: Optional[str] = "cpu"
args_schema: Type[BaseModel] = ChestXRayInput
findings_model: VisionEncoderDecoderModel = None
impression_model: VisionEncoderDecoderModel = None
findings_tokenizer: BertTokenizer = None
impression_tokenizer: BertTokenizer = None
findings_processor: ViTImageProcessor = None
impression_processor: ViTImageProcessor = None
generation_args: Dict[str, Any] = None
def __init__(self, cache_dir: str = "./model_weights", device: Optional[str] = "cpu"):
super().__init__()
os.makedirs(cache_dir, exist_ok=True) # ✅ Ensure local folder exists
self.device = torch.device(device) if device else torch.device("cpu")
# Load findings model
self.findings_model = VisionEncoderDecoderModel.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
).eval()
self.findings_tokenizer = BertTokenizer.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
)
self.findings_processor = ViTImageProcessor.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
)
# Load impression model
self.impression_model = VisionEncoderDecoderModel.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
).eval()
self.impression_tokenizer = BertTokenizer.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
)
self.impression_processor = ViTImageProcessor.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
)
# Move models to CPU
self.findings_model = self.findings_model.to(self.device)
self.impression_model = self.impression_model.to(self.device)
self.generation_args = {
"num_return_sequences": 1,
"max_length": 128,
"use_cache": True,
"beam_width": 2,
}
def _process_image(
self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
) -> torch.Tensor:
image = Image.open(image_path).convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
expected_size = model.config.encoder.image_size
actual_size = pixel_values.shape[-1]
if expected_size != actual_size:
pixel_values = torch.nn.functional.interpolate(
pixel_values,
size=(expected_size, expected_size),
mode="bilinear",
align_corners=False,
)
return pixel_values.to(self.device)
def _generate_report_section(
self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
) -> str:
generation_config = GenerationConfig(
**{
**self.generation_args,
"bos_token_id": model.config.bos_token_id,
"eos_token_id": model.config.eos_token_id,
"pad_token_id": model.config.pad_token_id,
"decoder_start_token_id": tokenizer.cls_token_id,
}
)
generated_ids = model.generate(pixel_values, generation_config=generation_config)
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
def _run(
self,
image_path: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
try:
findings_pixels = self._process_image(
image_path, self.findings_processor, self.findings_model
)
impression_pixels = self._process_image(
image_path, self.impression_processor, self.impression_model
)
with torch.inference_mode():
findings_text = self._generate_report_section(
findings_pixels, self.findings_model, self.findings_tokenizer
)
impression_text = self._generate_report_section(
impression_pixels, self.impression_model, self.impression_tokenizer
)
report = (
"CHEST X-RAY REPORT\n\n"
f"FINDINGS:\n{findings_text}\n\n"
f"IMPRESSION:\n{impression_text}"
)
metadata = {
"image_path": image_path,
"analysis_status": "completed",
"sections_generated": ["findings", "impression"],
}
return report, metadata
except Exception as e:
return f"Error generating report: {str(e)}", {
"image_path": image_path,
"analysis_status": "failed",
"error": str(e),
}
async def _arun(
self,
image_path: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
return self._run(image_path)