File size: 5,112 Bytes
cb3a670 0ffa584 cb3a670 0ffa584 cb3a670 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import Dict, Optional, Tuple, Type
from pathlib import Path
import uuid
import tempfile
import torch
from pydantic import BaseModel, Field
from diffusers import StableDiffusionPipeline
from langchain_core.callbacks import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun
from langchain_core.tools import BaseTool
class ChestXRayGeneratorInput(BaseModel):
"""Input schema for the Chest X-Ray Generator Tool."""
prompt: str = Field(
...,
description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
)
height: int = Field(
512,
description="Height of generated image in pixels"
)
width: int = Field(
512,
description="Width of generated image in pixels"
)
num_inference_steps: int = Field(
75,
description="Number of denoising steps (higher = better quality but slower)"
)
guidance_scale: float = Field(
4.0,
description="How closely to follow the prompt (higher = more faithful but less diverse)"
)
class ChestXRayGeneratorTool(BaseTool):
"""Tool for generating synthetic chest X-ray images using a fine-tuned Stable Diffusion model."""
name: str = "chest_xray_generator"
description: str = (
"Generates synthetic chest X-ray images from text descriptions of medical conditions. "
"Input: Text description of the medical finding or condition to generate, "
"along with optional parameters for image size (height, width), "
"quality (num_inference_steps), and prompt adherence (guidance_scale). "
"Output: Path to the generated X-ray image and generation metadata."
)
args_schema: Type[BaseModel] = ChestXRayGeneratorInput
model: StableDiffusionPipeline = None
device: torch.device = None
temp_dir: Path = None
def __init__(
self,
model_path: str = "/model-weights/roentgen",
cache_dir: str = "/model-weights",
temp_dir: Optional[str] = None,
device: Optional[str] = "cuda",
):
"""Initialize the chest X-ray generator tool."""
super().__init__()
# Automatically detect device (cuda if available, else cpu)
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device(device)
self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
self.model = self.model.to(torch.float32).to(self.device)
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
self.temp_dir.mkdir(exist_ok=True)
def _run(
self,
prompt: str,
num_inference_steps: int = 75,
guidance_scale: float = 4.0,
height: int = 512,
width: int = 512,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, str], Dict]:
"""Generate a chest X-ray image from a text description.
Args:
prompt: Text description of the medical condition to generate
num_inference_steps: Number of denoising steps
guidance_scale: How closely to follow the prompt
height: Height of generated image in pixels
width: Width of generated image in pixels
run_manager: Optional callback manager
Returns:
Tuple[Dict, Dict]: Output dictionary with image path and metadata dictionary
"""
try:
# Generate image
generation_output = self.model(
[prompt],
num_inference_steps=num_inference_steps,
height=height,
width=width,
guidance_scale=guidance_scale
)
# Save generated image
image_path = self.temp_dir / f"generated_xray_{uuid.uuid4().hex[:8]}.png"
generation_output.images[0].save(image_path)
output = {
"image_path": str(image_path),
}
metadata = {
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"device": str(self.device),
"image_size": (height, width),
"analysis_status": "completed",
}
return output, metadata
except Exception as e:
return (
{"error": str(e)} ,
{
"prompt": prompt,
"analysis_status": "failed",
"error_details": str(e),
}
)
async def _arun(
self,
prompt: str,
num_inference_steps: int = 75,
guidance_scale: float = 4.0,
height: int = 512,
width: int = 512,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, str], Dict]:
"""Async version of _run."""
return self._run(prompt, num_inference_steps, guidance_scale, height, width) |