|
from typing import Dict, List, Optional, Tuple, Type, Any |
|
from pathlib import Path |
|
from pydantic import BaseModel, Field |
|
|
|
import torch |
|
import transformers |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForToolRun, |
|
CallbackManagerForToolRun, |
|
) |
|
from langchain_core.tools import BaseTool |
|
|
|
|
|
class XRayVQAToolInput(BaseModel): |
|
"""Input schema for the CheXagent Tool.""" |
|
|
|
image_paths: List[str] = Field( |
|
..., description="List of paths to chest X-ray images to analyze" |
|
) |
|
prompt: str = Field(..., description="Question or instruction about the chest X-ray images") |
|
max_new_tokens: int = Field( |
|
512, description="Maximum number of tokens to generate in the response" |
|
) |
|
|
|
|
|
class XRayVQATool(BaseTool): |
|
"""Tool that leverages CheXagent for comprehensive chest X-ray analysis.""" |
|
|
|
name: str = "chest_xray_expert" |
|
description: str = ( |
|
"A versatile tool for analyzing chest X-rays. " |
|
"Can perform multiple tasks including: visual question answering, report generation, " |
|
"abnormality detection, comparative analysis, anatomical description, " |
|
"and clinical interpretation. Input should be paths to X-ray images " |
|
"and a natural language prompt describing the analysis needed." |
|
) |
|
args_schema: Type[BaseModel] = XRayVQAToolInput |
|
return_direct: bool = True |
|
cache_dir: Optional[str] = None |
|
device: Optional[str] = None |
|
dtype: torch.dtype = torch.bfloat16 |
|
tokenizer: Optional[AutoTokenizer] = None |
|
model: Optional[AutoModelForCausalLM] = None |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "StanfordAIMI/CheXagent-2-3b", |
|
device: Optional[str] = "cuda", |
|
dtype: torch.dtype = torch.bfloat16, |
|
cache_dir: Optional[str] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
"""Initialize the XRayVQATool. |
|
|
|
Args: |
|
model_name: Name of the CheXagent model to use |
|
device: Device to run model on (cuda/cpu) |
|
dtype: Data type for model weights |
|
cache_dir: Directory to cache downloaded models |
|
**kwargs: Additional arguments |
|
""" |
|
super().__init__(**kwargs) |
|
|
|
|
|
import transformers |
|
|
|
original_transformers_version = transformers.__version__ |
|
transformers.__version__ = "4.40.0" |
|
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
self.dtype = dtype |
|
self.cache_dir = cache_dir |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
cache_dir=cache_dir, |
|
) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map=self.device, |
|
trust_remote_code=True, |
|
cache_dir=cache_dir, |
|
) |
|
self.model = self.model.to(dtype=self.dtype) |
|
self.model.eval() |
|
|
|
transformers.__version__ = original_transformers_version |
|
|
|
def _generate_response(self, image_paths: List[str], prompt: str, max_new_tokens: int) -> str: |
|
"""Generate response using CheXagent model. |
|
|
|
Args: |
|
image_paths: List of paths to chest X-ray images |
|
prompt: Question or instruction about the images |
|
max_new_tokens: Maximum number of tokens to generate |
|
Returns: |
|
str: Model's response |
|
""" |
|
query = self.tokenizer.from_list_format( |
|
[*[{"image": path} for path in image_paths], {"text": prompt}] |
|
) |
|
conv = [ |
|
{"from": "system", "value": "You are a helpful assistant."}, |
|
{"from": "human", "value": query}, |
|
] |
|
input_ids = self.tokenizer.apply_chat_template( |
|
conv, add_generation_prompt=True, return_tensors="pt" |
|
).to(device=self.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
output = self.model.generate( |
|
input_ids, |
|
do_sample=False, |
|
num_beams=1, |
|
temperature=1.0, |
|
top_p=1.0, |
|
use_cache=True, |
|
max_new_tokens=max_new_tokens, |
|
)[0] |
|
response = self.tokenizer.decode(output[input_ids.size(1) : -1]) |
|
|
|
return response |
|
|
|
def _run( |
|
self, |
|
image_paths: List[str], |
|
prompt: str, |
|
max_new_tokens: int = 512, |
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, Any], Dict]: |
|
"""Execute the chest X-ray analysis. |
|
|
|
Args: |
|
image_paths: List of paths to chest X-ray images |
|
prompt: Question or instruction about the images |
|
max_new_tokens: Maximum number of tokens to generate |
|
run_manager: Optional callback manager |
|
|
|
Returns: |
|
Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary |
|
""" |
|
try: |
|
|
|
for path in image_paths: |
|
if not Path(path).is_file(): |
|
raise FileNotFoundError(f"Image file not found: {path}") |
|
|
|
response = self._generate_response(image_paths, prompt, max_new_tokens) |
|
|
|
output = { |
|
"response": response, |
|
} |
|
|
|
metadata = { |
|
"image_paths": image_paths, |
|
"prompt": prompt, |
|
"max_new_tokens": max_new_tokens, |
|
"analysis_status": "completed", |
|
} |
|
|
|
return output, metadata |
|
|
|
except Exception as e: |
|
output = {"error": str(e)} |
|
metadata = { |
|
"image_paths": image_paths, |
|
"prompt": prompt, |
|
"max_new_tokens": max_new_tokens, |
|
"analysis_status": "failed", |
|
"error_details": str(e), |
|
} |
|
return output, metadata |
|
|
|
async def _arun( |
|
self, |
|
image_paths: List[str], |
|
prompt: str, |
|
max_new_tokens: int = 512, |
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, Any], Dict]: |
|
"""Async version of _run.""" |
|
return self._run(image_paths, prompt, max_new_tokens) |
|
|