File size: 6,362 Bytes
cb3a670 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from typing import Dict, List, Optional, Tuple, Type, Any
from pathlib import Path
from pydantic import BaseModel, Field
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
class XRayVQAToolInput(BaseModel):
"""Input schema for the CheXagent Tool."""
image_paths: List[str] = Field(
..., description="List of paths to chest X-ray images to analyze"
)
prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
max_new_tokens: int = Field(
512, description="Maximum number of tokens to generate in the response"
)
class XRayVQATool(BaseTool):
"""Tool that leverages CheXagent for comprehensive chest X-ray analysis."""
name: str = "chest_xray_expert"
description: str = (
"A versatile tool for analyzing chest X-rays. "
"Can perform multiple tasks including: visual question answering, report generation, "
"abnormality detection, comparative analysis, anatomical description, "
"and clinical interpretation. Input should be paths to X-ray images "
"and a natural language prompt describing the analysis needed."
)
args_schema: Type[BaseModel] = XRayVQAToolInput
return_direct: bool = True
cache_dir: Optional[str] = None
device: Optional[str] = None
dtype: torch.dtype = torch.bfloat16
tokenizer: Optional[AutoTokenizer] = None
model: Optional[AutoModelForCausalLM] = None
def __init__(
self,
model_name: str = "StanfordAIMI/CheXagent-2-3b",
device: Optional[str] = "cuda",
dtype: torch.dtype = torch.bfloat16,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize the XRayVQATool.
Args:
model_name: Name of the CheXagent model to use
device: Device to run model on (cuda/cpu)
dtype: Data type for model weights
cache_dir: Directory to cache downloaded models
**kwargs: Additional arguments
"""
super().__init__(**kwargs)
# Dangerous code, but works for now
import transformers
original_transformers_version = transformers.__version__
transformers.__version__ = "4.40.0"
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = dtype
self.cache_dir = cache_dir
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir=cache_dir,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=self.device,
trust_remote_code=True,
cache_dir=cache_dir,
)
self.model = self.model.to(dtype=self.dtype)
self.model.eval()
transformers.__version__ = original_transformers_version
def _generate_response(self, image_paths: List[str], prompt: str, max_new_tokens: int) -> str:
"""Generate response using CheXagent model.
Args:
image_paths: List of paths to chest X-ray images
prompt: Question or instruction about the images
max_new_tokens: Maximum number of tokens to generate
Returns:
str: Model's response
"""
query = self.tokenizer.from_list_format(
[*[{"image": path} for path in image_paths], {"text": prompt}]
)
conv = [
{"from": "system", "value": "You are a helpful assistant."},
{"from": "human", "value": query},
]
input_ids = self.tokenizer.apply_chat_template(
conv, add_generation_prompt=True, return_tensors="pt"
).to(device=self.device)
# Run inference
with torch.inference_mode():
output = self.model.generate(
input_ids,
do_sample=False,
num_beams=1,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_new_tokens,
)[0]
response = self.tokenizer.decode(output[input_ids.size(1) : -1])
return response
def _run(
self,
image_paths: List[str],
prompt: str,
max_new_tokens: int = 512,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Execute the chest X-ray analysis.
Args:
image_paths: List of paths to chest X-ray images
prompt: Question or instruction about the images
max_new_tokens: Maximum number of tokens to generate
run_manager: Optional callback manager
Returns:
Tuple[Dict[str, Any], Dict]: Output dictionary and metadata dictionary
"""
try:
# Verify image paths
for path in image_paths:
if not Path(path).is_file():
raise FileNotFoundError(f"Image file not found: {path}")
response = self._generate_response(image_paths, prompt, max_new_tokens)
output = {
"response": response,
}
metadata = {
"image_paths": image_paths,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"analysis_status": "completed",
}
return output, metadata
except Exception as e:
output = {"error": str(e)}
metadata = {
"image_paths": image_paths,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"analysis_status": "failed",
"error_details": str(e),
}
return output, metadata
async def _arun(
self,
image_paths: List[str],
prompt: str,
max_new_tokens: int = 512,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Async version of _run."""
return self._run(image_paths, prompt, max_new_tokens)
|