|
from typing import Dict, Optional, Tuple, Type |
|
from pathlib import Path |
|
import uuid |
|
import tempfile |
|
import torch |
|
from pydantic import BaseModel, Field |
|
from diffusers import StableDiffusionPipeline |
|
from langchain_core.callbacks import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun |
|
from langchain_core.tools import BaseTool |
|
|
|
|
|
class ChestXRayGeneratorInput(BaseModel): |
|
"""Input schema for the Chest X-Ray Generator Tool.""" |
|
|
|
prompt: str = Field( |
|
..., |
|
description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')" |
|
) |
|
height: int = Field( |
|
512, |
|
description="Height of generated image in pixels" |
|
) |
|
width: int = Field( |
|
512, |
|
description="Width of generated image in pixels" |
|
) |
|
num_inference_steps: int = Field( |
|
75, |
|
description="Number of denoising steps (higher = better quality but slower)" |
|
) |
|
guidance_scale: float = Field( |
|
4.0, |
|
description="How closely to follow the prompt (higher = more faithful but less diverse)" |
|
) |
|
|
|
|
|
class ChestXRayGeneratorTool(BaseTool): |
|
"""Tool for generating synthetic chest X-ray images using a fine-tuned Stable Diffusion model.""" |
|
|
|
name: str = "chest_xray_generator" |
|
description: str = ( |
|
"Generates synthetic chest X-ray images from text descriptions of medical conditions. " |
|
"Input: Text description of the medical finding or condition to generate, " |
|
"along with optional parameters for image size (height, width), " |
|
"quality (num_inference_steps), and prompt adherence (guidance_scale). " |
|
"Output: Path to the generated X-ray image and generation metadata." |
|
) |
|
args_schema: Type[BaseModel] = ChestXRayGeneratorInput |
|
|
|
model: StableDiffusionPipeline = None |
|
device: torch.device = None |
|
temp_dir: Path = None |
|
|
|
def __init__( |
|
self, |
|
model_path: str = "/model-weights/roentgen", |
|
cache_dir: str = "/model-weights", |
|
temp_dir: Optional[str] = None, |
|
device: Optional[str] = "cuda", |
|
): |
|
"""Initialize the chest X-ray generator tool.""" |
|
super().__init__() |
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
self.device = torch.device(device) |
|
|
|
self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir) |
|
self.model = self.model.to(torch.float32).to(self.device) |
|
|
|
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp()) |
|
self.temp_dir.mkdir(exist_ok=True) |
|
|
|
def _run( |
|
self, |
|
prompt: str, |
|
num_inference_steps: int = 75, |
|
guidance_scale: float = 4.0, |
|
height: int = 512, |
|
width: int = 512, |
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, str], Dict]: |
|
"""Generate a chest X-ray image from a text description. |
|
|
|
Args: |
|
prompt: Text description of the medical condition to generate |
|
num_inference_steps: Number of denoising steps |
|
guidance_scale: How closely to follow the prompt |
|
height: Height of generated image in pixels |
|
width: Width of generated image in pixels |
|
run_manager: Optional callback manager |
|
|
|
Returns: |
|
Tuple[Dict, Dict]: Output dictionary with image path and metadata dictionary |
|
""" |
|
try: |
|
|
|
generation_output = self.model( |
|
[prompt], |
|
num_inference_steps=num_inference_steps, |
|
height=height, |
|
width=width, |
|
guidance_scale=guidance_scale |
|
) |
|
|
|
|
|
image_path = self.temp_dir / f"generated_xray_{uuid.uuid4().hex[:8]}.png" |
|
generation_output.images[0].save(image_path) |
|
|
|
output = { |
|
"image_path": str(image_path), |
|
} |
|
|
|
metadata = { |
|
"prompt": prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"guidance_scale": guidance_scale, |
|
"device": str(self.device), |
|
"image_size": (height, width), |
|
"analysis_status": "completed", |
|
} |
|
|
|
return output, metadata |
|
|
|
except Exception as e: |
|
return ( |
|
{"error": str(e)} , |
|
{ |
|
"prompt": prompt, |
|
"analysis_status": "failed", |
|
"error_details": str(e), |
|
} |
|
) |
|
|
|
async def _arun( |
|
self, |
|
prompt: str, |
|
num_inference_steps: int = 75, |
|
guidance_scale: float = 4.0, |
|
height: int = 512, |
|
width: int = 512, |
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, str], Dict]: |
|
"""Async version of _run.""" |
|
return self._run(prompt, num_inference_steps, guidance_scale, height, width) |