File size: 6,252 Bytes
cb3a670
 
 
 
e1ede20
cb3a670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ede20
cb3a670
 
 
 
 
 
 
 
 
e1ede20
cb3a670
e1ede20
 
cb3a670
e1ede20
cb3a670
 
 
 
 
 
 
 
 
 
e1ede20
cb3a670
 
 
 
 
 
 
 
 
 
e1ede20
cb3a670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ede20
cb3a670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field

import torch
import os  # Added to create local cache dir

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool

from PIL import Image

from transformers import (
    BertTokenizer,
    ViTImageProcessor,
    VisionEncoderDecoderModel,
    GenerationConfig,
)


class ChestXRayInput(BaseModel):
    """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""

    image_path: str = Field(
        ..., description="Path to the radiology image file, only supports JPG or PNG images"
    )


class ChestXRayReportGeneratorTool(BaseTool):
    name: str = "chest_xray_report_generator"
    description: str = (
        "A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
        "containing both detailed findings and impression summaries. Input should be the path "
        "to a chest X-ray image file. Output is a structured report with both detailed "
        "observations and key clinical conclusions."
    )
    device: Optional[str] = "cpu"
    args_schema: Type[BaseModel] = ChestXRayInput
    findings_model: VisionEncoderDecoderModel = None
    impression_model: VisionEncoderDecoderModel = None
    findings_tokenizer: BertTokenizer = None
    impression_tokenizer: BertTokenizer = None
    findings_processor: ViTImageProcessor = None
    impression_processor: ViTImageProcessor = None
    generation_args: Dict[str, Any] = None

    def __init__(self, cache_dir: str = "./model_weights", device: Optional[str] = "cpu"):
        super().__init__()
        os.makedirs(cache_dir, exist_ok=True)  # ✅ Ensure local folder exists
        self.device = torch.device(device) if device else torch.device("cpu")

        # Load findings model
        self.findings_model = VisionEncoderDecoderModel.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        ).eval()
        self.findings_tokenizer = BertTokenizer.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        )
        self.findings_processor = ViTImageProcessor.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        )

        # Load impression model
        self.impression_model = VisionEncoderDecoderModel.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        ).eval()
        self.impression_tokenizer = BertTokenizer.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        )
        self.impression_processor = ViTImageProcessor.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        )

        # Move models to CPU
        self.findings_model = self.findings_model.to(self.device)
        self.impression_model = self.impression_model.to(self.device)

        self.generation_args = {
            "num_return_sequences": 1,
            "max_length": 128,
            "use_cache": True,
            "beam_width": 2,
        }

    def _process_image(
        self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
    ) -> torch.Tensor:
        image = Image.open(image_path).convert("RGB")
        pixel_values = processor(image, return_tensors="pt").pixel_values
        expected_size = model.config.encoder.image_size
        actual_size = pixel_values.shape[-1]

        if expected_size != actual_size:
            pixel_values = torch.nn.functional.interpolate(
                pixel_values,
                size=(expected_size, expected_size),
                mode="bilinear",
                align_corners=False,
            )

        return pixel_values.to(self.device)

    def _generate_report_section(
        self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
    ) -> str:
        generation_config = GenerationConfig(
            **{
                **self.generation_args,
                "bos_token_id": model.config.bos_token_id,
                "eos_token_id": model.config.eos_token_id,
                "pad_token_id": model.config.pad_token_id,
                "decoder_start_token_id": tokenizer.cls_token_id,
            }
        )
        generated_ids = model.generate(pixel_values, generation_config=generation_config)
        return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    def _run(
        self,
        image_path: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> Tuple[str, Dict]:
        try:
            findings_pixels = self._process_image(
                image_path, self.findings_processor, self.findings_model
            )
            impression_pixels = self._process_image(
                image_path, self.impression_processor, self.impression_model
            )

            with torch.inference_mode():
                findings_text = self._generate_report_section(
                    findings_pixels, self.findings_model, self.findings_tokenizer
                )
                impression_text = self._generate_report_section(
                    impression_pixels, self.impression_model, self.impression_tokenizer
                )

            report = (
                "CHEST X-RAY REPORT\n\n"
                f"FINDINGS:\n{findings_text}\n\n"
                f"IMPRESSION:\n{impression_text}"
            )
            metadata = {
                "image_path": image_path,
                "analysis_status": "completed",
                "sections_generated": ["findings", "impression"],
            }
            return report, metadata

        except Exception as e:
            return f"Error generating report: {str(e)}", {
                "image_path": image_path,
                "analysis_status": "failed",
                "error": str(e),
            }

    async def _arun(
        self,
        image_path: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> Tuple[str, Dict]:
        return self._run(image_path)