File size: 6,252 Bytes
cb3a670 e1ede20 cb3a670 e1ede20 cb3a670 e1ede20 cb3a670 e1ede20 cb3a670 e1ede20 cb3a670 e1ede20 cb3a670 e1ede20 cb3a670 e1ede20 cb3a670 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field
import torch
import os # Added to create local cache dir
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from PIL import Image
from transformers import (
BertTokenizer,
ViTImageProcessor,
VisionEncoderDecoderModel,
GenerationConfig,
)
class ChestXRayInput(BaseModel):
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
image_path: str = Field(
..., description="Path to the radiology image file, only supports JPG or PNG images"
)
class ChestXRayReportGeneratorTool(BaseTool):
name: str = "chest_xray_report_generator"
description: str = (
"A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
"containing both detailed findings and impression summaries. Input should be the path "
"to a chest X-ray image file. Output is a structured report with both detailed "
"observations and key clinical conclusions."
)
device: Optional[str] = "cpu"
args_schema: Type[BaseModel] = ChestXRayInput
findings_model: VisionEncoderDecoderModel = None
impression_model: VisionEncoderDecoderModel = None
findings_tokenizer: BertTokenizer = None
impression_tokenizer: BertTokenizer = None
findings_processor: ViTImageProcessor = None
impression_processor: ViTImageProcessor = None
generation_args: Dict[str, Any] = None
def __init__(self, cache_dir: str = "./model_weights", device: Optional[str] = "cpu"):
super().__init__()
os.makedirs(cache_dir, exist_ok=True) # ✅ Ensure local folder exists
self.device = torch.device(device) if device else torch.device("cpu")
# Load findings model
self.findings_model = VisionEncoderDecoderModel.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
).eval()
self.findings_tokenizer = BertTokenizer.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
)
self.findings_processor = ViTImageProcessor.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
)
# Load impression model
self.impression_model = VisionEncoderDecoderModel.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
).eval()
self.impression_tokenizer = BertTokenizer.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
)
self.impression_processor = ViTImageProcessor.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
)
# Move models to CPU
self.findings_model = self.findings_model.to(self.device)
self.impression_model = self.impression_model.to(self.device)
self.generation_args = {
"num_return_sequences": 1,
"max_length": 128,
"use_cache": True,
"beam_width": 2,
}
def _process_image(
self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
) -> torch.Tensor:
image = Image.open(image_path).convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
expected_size = model.config.encoder.image_size
actual_size = pixel_values.shape[-1]
if expected_size != actual_size:
pixel_values = torch.nn.functional.interpolate(
pixel_values,
size=(expected_size, expected_size),
mode="bilinear",
align_corners=False,
)
return pixel_values.to(self.device)
def _generate_report_section(
self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
) -> str:
generation_config = GenerationConfig(
**{
**self.generation_args,
"bos_token_id": model.config.bos_token_id,
"eos_token_id": model.config.eos_token_id,
"pad_token_id": model.config.pad_token_id,
"decoder_start_token_id": tokenizer.cls_token_id,
}
)
generated_ids = model.generate(pixel_values, generation_config=generation_config)
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
def _run(
self,
image_path: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
try:
findings_pixels = self._process_image(
image_path, self.findings_processor, self.findings_model
)
impression_pixels = self._process_image(
image_path, self.impression_processor, self.impression_model
)
with torch.inference_mode():
findings_text = self._generate_report_section(
findings_pixels, self.findings_model, self.findings_tokenizer
)
impression_text = self._generate_report_section(
impression_pixels, self.impression_model, self.impression_tokenizer
)
report = (
"CHEST X-RAY REPORT\n\n"
f"FINDINGS:\n{findings_text}\n\n"
f"IMPRESSION:\n{impression_text}"
)
metadata = {
"image_path": image_path,
"analysis_status": "completed",
"sections_generated": ["findings", "impression"],
}
return report, metadata
except Exception as e:
return f"Error generating report: {str(e)}", {
"image_path": image_path,
"analysis_status": "failed",
"error": str(e),
}
async def _arun(
self,
image_path: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
return self._run(image_path)
|