File size: 3,870 Bytes
cb3a670
 
 
 
 
 
 
44d5f7f
cb3a670
 
 
 
 
 
 
 
0e203e7
d65a2ca
 
 
 
 
 
 
 
 
 
 
 
 
cb3a670
 
 
 
eb57a64
cb3a670
 
6fd4eef
cb3a670
3f298d8
cb3a670
3f298d8
 
cb3a670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d65a2ca
 
44d5f7f
d65a2ca
 
 
3f298d8
 
 
c044359
cb3a670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c044359
 
 
67192a5
cb3a670
3f298d8
cb3a670
 
c044359
fd7b76c
7a17e63
c29759d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import warnings
from typing import *
from dotenv import load_dotenv
from transformers import logging

from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI

from interface import create_demo
from medrax.agent import *
from medrax.tools import *
from medrax.utils import *

warnings.filterwarnings("ignore")
logging.set_verbosity_error()
load_dotenv() 

# Set environment variables explicitly to ensure they're available
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL")

if not api_key:
    raise ValueError("OPENAI_API_KEY not found in environment variables")
if not base_url:
    raise ValueError("OPENAI_BASE_URL not found in environment variables")

# Set them in environment for libraries that might read directly from os.environ
os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_BASE_URL"] = base_url

def initialize_agent(
    prompt_file,
    tools_to_use=None,
    model_dir="./model-weights",
    temp_dir="temp",
    device="cuda",
    model="qwen/qwen2.5-vl-3b-instruct:free",
    temperature=0.7,
    top_p=0.95
):
    """Initialize the MedRAX agent with specified tools and configuration."""

    prompts = load_prompts_from_file(prompt_file)
    prompt = prompts["MEDICAL_ASSISTANT"]

    all_tools = {
        "ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
        "ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
        "LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
        "XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
        "ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
            cache_dir=model_dir, device=device
        ),
        "XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
            cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
        ),
        "ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
            model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
        ),
        "ImageVisualizerTool": lambda: ImageVisualizerTool(),
        "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
    }

    tools_dict = {}
    tools_to_use = tools_to_use or all_tools.keys()
    for tool_name in tools_to_use:
        if tool_name in all_tools:
            tools_dict[tool_name] = all_tools[tool_name]()

    checkpointer = MemorySaver()
    
    # Explicitly pass the API key and base URL
    model = ChatOpenAI(
        model_name=model,
        api_key=api_key,
        base_url=base_url,
        temperature=temperature,
        top_p=top_p,
    )

    agent = Agent(
        model,
        tools=list(tools_dict.values()),
        log_tools=True,
        log_dir="logs",
        system_prompt=prompt,
        checkpointer=checkpointer,
    )

    print("Agent initialized")
    return agent, tools_dict


if __name__ == "__main__":
    print("Starting server...")

    selected_tools = [
        "ImageVisualizerTool",
        "DicomProcessorTool",
        "ChestXRayClassifierTool",
        "ChestXRaySegmentationTool",
        "ChestXRayReportGeneratorTool",
        "XRayVQATool",
        # "LlavaMedTool",
        # "XRayPhraseGroundingTool",
        # "ChestXRayGeneratorTool",
    ]

    agent, tools_dict = initialize_agent(
        "medrax/docs/system_prompts.txt",
        tools_to_use=selected_tools,
        model_dir="./model-weights",
        temp_dir="temp",
        device="cuda",
        model="qwen/qwen2.5-vl-3b-instruct:free",
        temperature=0.7,
        top_p=0.95
    )

    demo = create_demo(agent, tools_dict)
    # demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
    # demo.launch(debug=True, queue=True, ssr_mode=False)
    demo.launch(debug=True, ssr_mode=False)