|
import os |
|
import warnings |
|
from typing import * |
|
from dotenv import load_dotenv |
|
from transformers import logging |
|
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langchain_openai import ChatOpenAI |
|
|
|
from interface import create_demo |
|
from medrax.agent import * |
|
from medrax.tools import * |
|
from medrax.utils import * |
|
|
|
warnings.filterwarnings("ignore") |
|
logging.set_verbosity_error() |
|
load_dotenv() |
|
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
base_url = os.getenv("OPENAI_BASE_URL") |
|
|
|
if not api_key: |
|
raise ValueError("OPENAI_API_KEY not found in environment variables") |
|
if not base_url: |
|
raise ValueError("OPENAI_BASE_URL not found in environment variables") |
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = api_key |
|
os.environ["OPENAI_BASE_URL"] = base_url |
|
|
|
def initialize_agent( |
|
prompt_file, |
|
tools_to_use=None, |
|
model_dir="./model-weights", |
|
temp_dir="temp", |
|
device="cuda", |
|
model="qwen/qwen2.5-vl-3b-instruct:free", |
|
temperature=0.7, |
|
top_p=0.95 |
|
): |
|
"""Initialize the MedRAX agent with specified tools and configuration.""" |
|
|
|
prompts = load_prompts_from_file(prompt_file) |
|
prompt = prompts["MEDICAL_ASSISTANT"] |
|
|
|
all_tools = { |
|
"ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device), |
|
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device), |
|
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True), |
|
"XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device), |
|
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool( |
|
cache_dir=model_dir, device=device |
|
), |
|
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool( |
|
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device |
|
), |
|
"ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool( |
|
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device |
|
), |
|
"ImageVisualizerTool": lambda: ImageVisualizerTool(), |
|
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir), |
|
} |
|
|
|
tools_dict = {} |
|
tools_to_use = tools_to_use or all_tools.keys() |
|
for tool_name in tools_to_use: |
|
if tool_name in all_tools: |
|
tools_dict[tool_name] = all_tools[tool_name]() |
|
|
|
checkpointer = MemorySaver() |
|
|
|
|
|
model = ChatOpenAI( |
|
model_name=model, |
|
api_key=api_key, |
|
base_url=base_url, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
|
|
agent = Agent( |
|
model, |
|
tools=list(tools_dict.values()), |
|
log_tools=True, |
|
log_dir="logs", |
|
system_prompt=prompt, |
|
checkpointer=checkpointer, |
|
) |
|
|
|
print("Agent initialized") |
|
return agent, tools_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Starting server...") |
|
|
|
selected_tools = [ |
|
"ImageVisualizerTool", |
|
"DicomProcessorTool", |
|
"ChestXRayClassifierTool", |
|
"ChestXRaySegmentationTool", |
|
"ChestXRayReportGeneratorTool", |
|
"XRayVQATool", |
|
|
|
|
|
|
|
] |
|
|
|
agent, tools_dict = initialize_agent( |
|
"medrax/docs/system_prompts.txt", |
|
tools_to_use=selected_tools, |
|
model_dir="./model-weights", |
|
temp_dir="temp", |
|
device="cuda", |
|
model="qwen/qwen2.5-vl-3b-instruct:free", |
|
temperature=0.7, |
|
top_p=0.95 |
|
) |
|
|
|
demo = create_demo(agent, tools_dict) |
|
|
|
|
|
demo.launch(debug=True, ssr_mode=False) |