|
from typing import Any, Dict, Optional, Tuple, Type |
|
from pydantic import BaseModel, Field |
|
|
|
import torch |
|
|
|
from langchain_core.callbacks import ( |
|
AsyncCallbackManagerForToolRun, |
|
CallbackManagerForToolRun, |
|
) |
|
from langchain_core.tools import BaseTool |
|
|
|
from PIL import Image |
|
|
|
from medrax.llava.conversation import conv_templates |
|
from medrax.llava.model.builder import load_pretrained_model |
|
from medrax.llava.mm_utils import tokenizer_image_token, process_images |
|
from medrax.llava.constants import ( |
|
IMAGE_TOKEN_INDEX, |
|
DEFAULT_IMAGE_TOKEN, |
|
DEFAULT_IM_START_TOKEN, |
|
DEFAULT_IM_END_TOKEN, |
|
) |
|
|
|
|
|
class LlavaMedInput(BaseModel): |
|
"""Input for the LLaVA-Med Visual QA tool. Only supports JPG or PNG images.""" |
|
|
|
question: str = Field(..., description="The question to ask about the medical image") |
|
image_path: Optional[str] = Field( |
|
None, |
|
description="Path to the medical image file (optional), only supports JPG or PNG images", |
|
) |
|
|
|
|
|
class LlavaMedTool(BaseTool): |
|
"""Tool that performs medical visual question answering using LLaVA-Med. |
|
|
|
This tool uses a large language model fine-tuned on medical images to answer |
|
questions about medical images. It can handle both image-based questions and |
|
general medical questions without images. |
|
""" |
|
|
|
name: str = "llava_med_qa" |
|
description: str = ( |
|
"A tool that answers questions about biomedical images and general medical questions using LLaVA-Med. " |
|
"While it can process chest X-rays, it may not be as reliable for detailed chest X-ray analysis. " |
|
"Input should be a question and optionally a path to a medical image file." |
|
) |
|
args_schema: Type[BaseModel] = LlavaMedInput |
|
tokenizer: Any = None |
|
model: Any = None |
|
image_processor: Any = None |
|
context_len: int = 200000 |
|
|
|
def __init__( |
|
self, |
|
model_path: str = "microsoft/llava-med-v1.5-mistral-7b", |
|
cache_dir: str = "/model-weights", |
|
low_cpu_mem_usage: bool = True, |
|
torch_dtype: torch.dtype = torch.bfloat16, |
|
device: str = "cuda", |
|
load_in_4bit: bool = False, |
|
load_in_8bit: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.device = torch.device(device) if device else torch.device("cuda") |
|
|
|
|
|
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( |
|
model_path=model_path, |
|
model_base=None, |
|
model_name=model_path, |
|
load_in_4bit=load_in_4bit, |
|
load_in_8bit=load_in_8bit, |
|
cache_dir=cache_dir, |
|
low_cpu_mem_usage=low_cpu_mem_usage, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
**kwargs, |
|
) |
|
|
|
|
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def _process_input( |
|
self, question: str, image_path: Optional[str] = None |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
if self.model.config.mm_use_im_start_end: |
|
question = ( |
|
DEFAULT_IM_START_TOKEN |
|
+ DEFAULT_IMAGE_TOKEN |
|
+ DEFAULT_IM_END_TOKEN |
|
+ "\n" |
|
+ question |
|
) |
|
else: |
|
question = DEFAULT_IMAGE_TOKEN + "\n" + question |
|
|
|
conv = conv_templates["vicuna_v1"].copy() |
|
conv.append_message(conv.roles[0], question) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
input_ids = ( |
|
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") |
|
.unsqueeze(0) |
|
.to(self.device) |
|
) |
|
|
|
image_tensor = None |
|
if image_path: |
|
image = Image.open(image_path) |
|
image_tensor = process_images([image], self.image_processor, self.model.config)[0] |
|
image_tensor = image_tensor.unsqueeze(0).to(self.device, dtype=self.model.dtype) |
|
|
|
return input_ids, image_tensor |
|
|
|
def _run( |
|
self, |
|
question: str, |
|
image_path: Optional[str] = None, |
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
) -> Tuple[str, Dict]: |
|
"""Answer a medical question, optionally based on an input image. |
|
|
|
Args: |
|
question (str): The medical question to answer. |
|
image_path (Optional[str]): The path to the medical image file (if applicable). |
|
run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run. |
|
|
|
Returns: |
|
Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata. |
|
|
|
Raises: |
|
Exception: If there's an error processing the input or generating the answer. |
|
""" |
|
try: |
|
input_ids, image_tensor = self._process_input(question, image_path) |
|
|
|
|
|
input_ids = input_ids.to(self.device) |
|
image_tensor = image_tensor.to(self.device, dtype=self.model.dtype) |
|
|
|
with torch.inference_mode(): |
|
output_ids = self.model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=False, |
|
temperature=0.2, |
|
max_new_tokens=500, |
|
use_cache=True, |
|
) |
|
|
|
output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
|
metadata = { |
|
"question": question, |
|
"image_path": image_path, |
|
"analysis_status": "completed", |
|
} |
|
return output, metadata |
|
except Exception as e: |
|
return f"Error generating answer: {str(e)}", { |
|
"question": question, |
|
"image_path": image_path, |
|
"analysis_status": "failed", |
|
} |
|
|
|
async def _arun( |
|
self, |
|
question: str, |
|
image_path: Optional[str] = None, |
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
) -> Tuple[str, Dict]: |
|
"""Asynchronously answer a medical question, optionally based on an input image. |
|
|
|
This method currently calls the synchronous version, as the model inference |
|
is not inherently asynchronous. For true asynchronous behavior, consider |
|
using a separate thread or process. |
|
|
|
Args: |
|
question (str): The medical question to answer. |
|
image_path (Optional[str]): The path to the medical image file (if applicable). |
|
run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run. |
|
|
|
Returns: |
|
Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata. |
|
|
|
Raises: |
|
Exception: If there's an error processing the input or generating the answer. |
|
""" |
|
return self._run(question, image_path) |
|
|