File size: 3,870 Bytes
cb3a670 44d5f7f cb3a670 0e203e7 d65a2ca cb3a670 eb57a64 cb3a670 6fd4eef cb3a670 3f298d8 cb3a670 3f298d8 cb3a670 d65a2ca 44d5f7f d65a2ca 3f298d8 c044359 cb3a670 c044359 67192a5 cb3a670 3f298d8 cb3a670 c044359 fd7b76c 7a17e63 c29759d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import warnings
from typing import *
from dotenv import load_dotenv
from transformers import logging
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from interface import create_demo
from medrax.agent import *
from medrax.tools import *
from medrax.utils import *
warnings.filterwarnings("ignore")
logging.set_verbosity_error()
load_dotenv()
# Set environment variables explicitly to ensure they're available
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL")
if not api_key:
raise ValueError("OPENAI_API_KEY not found in environment variables")
if not base_url:
raise ValueError("OPENAI_BASE_URL not found in environment variables")
# Set them in environment for libraries that might read directly from os.environ
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_BASE_URL"] = base_url
def initialize_agent(
prompt_file,
tools_to_use=None,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
model="qwen/qwen2.5-vl-3b-instruct:free",
temperature=0.7,
top_p=0.95
):
"""Initialize the MedRAX agent with specified tools and configuration."""
prompts = load_prompts_from_file(prompt_file)
prompt = prompts["MEDICAL_ASSISTANT"]
all_tools = {
"ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
"XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
cache_dir=model_dir, device=device
),
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
),
"ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
),
"ImageVisualizerTool": lambda: ImageVisualizerTool(),
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
}
tools_dict = {}
tools_to_use = tools_to_use or all_tools.keys()
for tool_name in tools_to_use:
if tool_name in all_tools:
tools_dict[tool_name] = all_tools[tool_name]()
checkpointer = MemorySaver()
# Explicitly pass the API key and base URL
model = ChatOpenAI(
model_name=model,
api_key=api_key,
base_url=base_url,
temperature=temperature,
top_p=top_p,
)
agent = Agent(
model,
tools=list(tools_dict.values()),
log_tools=True,
log_dir="logs",
system_prompt=prompt,
checkpointer=checkpointer,
)
print("Agent initialized")
return agent, tools_dict
if __name__ == "__main__":
print("Starting server...")
selected_tools = [
"ImageVisualizerTool",
"DicomProcessorTool",
"ChestXRayClassifierTool",
"ChestXRaySegmentationTool",
"ChestXRayReportGeneratorTool",
"XRayVQATool",
# "LlavaMedTool",
# "XRayPhraseGroundingTool",
# "ChestXRayGeneratorTool",
]
agent, tools_dict = initialize_agent(
"medrax/docs/system_prompts.txt",
tools_to_use=selected_tools,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
model="qwen/qwen2.5-vl-3b-instruct:free",
temperature=0.7,
top_p=0.95
)
demo = create_demo(agent, tools_dict)
# demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
# demo.launch(debug=True, queue=True, ssr_mode=False)
demo.launch(debug=True, ssr_mode=False) |