CRAX / medrax /tools /llava_med.py
Dhruv-Ty's picture
gpu to cpu
0ffa584
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field
import torch
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from PIL import Image
from medrax.llava.conversation import conv_templates
from medrax.llava.model.builder import load_pretrained_model
from medrax.llava.mm_utils import tokenizer_image_token, process_images
from medrax.llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
class LlavaMedInput(BaseModel):
"""Input for the LLaVA-Med Visual QA tool. Only supports JPG or PNG images."""
question: str = Field(..., description="The question to ask about the medical image")
image_path: Optional[str] = Field(
None,
description="Path to the medical image file (optional), only supports JPG or PNG images",
)
class LlavaMedTool(BaseTool):
"""Tool that performs medical visual question answering using LLaVA-Med.
This tool uses a large language model fine-tuned on medical images to answer
questions about medical images. It can handle both image-based questions and
general medical questions without images.
"""
name: str = "llava_med_qa"
description: str = (
"A tool that answers questions about biomedical images and general medical questions using LLaVA-Med. "
"While it can process chest X-rays, it may not be as reliable for detailed chest X-ray analysis. "
"Input should be a question and optionally a path to a medical image file."
)
args_schema: Type[BaseModel] = LlavaMedInput
tokenizer: Any = None
model: Any = None
image_processor: Any = None
context_len: int = 200000
def __init__(
self,
model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
cache_dir: str = "/model-weights",
low_cpu_mem_usage: bool = True,
torch_dtype: torch.dtype = torch.bfloat16,
device: str = "cuda",
load_in_4bit: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
super().__init__()
# Set the device (cuda or cpu)
self.device = torch.device(device) if device else torch.device("cuda")
# Load the model and tokenizer
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=model_path,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
cache_dir=cache_dir,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=torch_dtype,
device=device,
**kwargs,
)
# Move the model to the desired device
self.model.to(self.device)
self.model.eval()
def _process_input(
self, question: str, image_path: Optional[str] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.model.config.mm_use_im_start_end:
question = (
DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_TOKEN
+ DEFAULT_IM_END_TOKEN
+ "\n"
+ question
)
else:
question = DEFAULT_IMAGE_TOKEN + "\n" + question
conv = conv_templates["vicuna_v1"].copy()
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.to(self.device) # Move to the correct device
)
image_tensor = None
if image_path:
image = Image.open(image_path)
image_tensor = process_images([image], self.image_processor, self.model.config)[0]
image_tensor = image_tensor.unsqueeze(0).to(self.device, dtype=self.model.dtype) # Move to device
return input_ids, image_tensor
def _run(
self,
question: str,
image_path: Optional[str] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
"""Answer a medical question, optionally based on an input image.
Args:
question (str): The medical question to answer.
image_path (Optional[str]): The path to the medical image file (if applicable).
run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
Returns:
Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
Raises:
Exception: If there's an error processing the input or generating the answer.
"""
try:
input_ids, image_tensor = self._process_input(question, image_path)
# Ensure that inputs are on the same device as the model
input_ids = input_ids.to(self.device)
image_tensor = image_tensor.to(self.device, dtype=self.model.dtype)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
do_sample=False,
temperature=0.2,
max_new_tokens=500,
use_cache=True,
)
output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
metadata = {
"question": question,
"image_path": image_path,
"analysis_status": "completed",
}
return output, metadata
except Exception as e:
return f"Error generating answer: {str(e)}", {
"question": question,
"image_path": image_path,
"analysis_status": "failed",
}
async def _arun(
self,
question: str,
image_path: Optional[str] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
"""Asynchronously answer a medical question, optionally based on an input image.
This method currently calls the synchronous version, as the model inference
is not inherently asynchronous. For true asynchronous behavior, consider
using a separate thread or process.
Args:
question (str): The medical question to answer.
image_path (Optional[str]): The path to the medical image file (if applicable).
run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
Returns:
Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
Raises:
Exception: If there's an error processing the input or generating the answer.
"""
return self._run(question, image_path)