File size: 7,136 Bytes
cb3a670 0ffa584 cb3a670 0ffa584 cb3a670 0ffa584 cb3a670 0ffa584 cb3a670 0ffa584 cb3a670 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field
import torch
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from PIL import Image
from medrax.llava.conversation import conv_templates
from medrax.llava.model.builder import load_pretrained_model
from medrax.llava.mm_utils import tokenizer_image_token, process_images
from medrax.llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
class LlavaMedInput(BaseModel):
"""Input for the LLaVA-Med Visual QA tool. Only supports JPG or PNG images."""
question: str = Field(..., description="The question to ask about the medical image")
image_path: Optional[str] = Field(
None,
description="Path to the medical image file (optional), only supports JPG or PNG images",
)
class LlavaMedTool(BaseTool):
"""Tool that performs medical visual question answering using LLaVA-Med.
This tool uses a large language model fine-tuned on medical images to answer
questions about medical images. It can handle both image-based questions and
general medical questions without images.
"""
name: str = "llava_med_qa"
description: str = (
"A tool that answers questions about biomedical images and general medical questions using LLaVA-Med. "
"While it can process chest X-rays, it may not be as reliable for detailed chest X-ray analysis. "
"Input should be a question and optionally a path to a medical image file."
)
args_schema: Type[BaseModel] = LlavaMedInput
tokenizer: Any = None
model: Any = None
image_processor: Any = None
context_len: int = 200000
def __init__(
self,
model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
cache_dir: str = "/model-weights",
low_cpu_mem_usage: bool = True,
torch_dtype: torch.dtype = torch.bfloat16,
device: str = "cuda",
load_in_4bit: bool = False,
load_in_8bit: bool = False,
**kwargs,
):
super().__init__()
# Set the device (cuda or cpu)
self.device = torch.device(device) if device else torch.device("cuda")
# Load the model and tokenizer
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=model_path,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
cache_dir=cache_dir,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=torch_dtype,
device=device,
**kwargs,
)
# Move the model to the desired device
self.model.to(self.device)
self.model.eval()
def _process_input(
self, question: str, image_path: Optional[str] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.model.config.mm_use_im_start_end:
question = (
DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_TOKEN
+ DEFAULT_IM_END_TOKEN
+ "\n"
+ question
)
else:
question = DEFAULT_IMAGE_TOKEN + "\n" + question
conv = conv_templates["vicuna_v1"].copy()
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.to(self.device) # Move to the correct device
)
image_tensor = None
if image_path:
image = Image.open(image_path)
image_tensor = process_images([image], self.image_processor, self.model.config)[0]
image_tensor = image_tensor.unsqueeze(0).to(self.device, dtype=self.model.dtype) # Move to device
return input_ids, image_tensor
def _run(
self,
question: str,
image_path: Optional[str] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
"""Answer a medical question, optionally based on an input image.
Args:
question (str): The medical question to answer.
image_path (Optional[str]): The path to the medical image file (if applicable).
run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
Returns:
Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
Raises:
Exception: If there's an error processing the input or generating the answer.
"""
try:
input_ids, image_tensor = self._process_input(question, image_path)
# Ensure that inputs are on the same device as the model
input_ids = input_ids.to(self.device)
image_tensor = image_tensor.to(self.device, dtype=self.model.dtype)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
do_sample=False,
temperature=0.2,
max_new_tokens=500,
use_cache=True,
)
output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
metadata = {
"question": question,
"image_path": image_path,
"analysis_status": "completed",
}
return output, metadata
except Exception as e:
return f"Error generating answer: {str(e)}", {
"question": question,
"image_path": image_path,
"analysis_status": "failed",
}
async def _arun(
self,
question: str,
image_path: Optional[str] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
"""Asynchronously answer a medical question, optionally based on an input image.
This method currently calls the synchronous version, as the model inference
is not inherently asynchronous. For true asynchronous behavior, consider
using a separate thread or process.
Args:
question (str): The medical question to answer.
image_path (Optional[str]): The path to the medical image file (if applicable).
run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
Returns:
Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
Raises:
Exception: If there's an error processing the input or generating the answer.
"""
return self._run(question, image_path)
|