File size: 4,596 Bytes
cb3a670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import Optional, Type, Dict, Tuple
from pydantic import BaseModel, Field
import matplotlib.pyplot as plt
import skimage.io
from pathlib import Path

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool


class ImageVisualizerInput(BaseModel):
    """Input schema for the Image Visualizer Tool. Only supports JPG or PNG images."""

    image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
    title: Optional[str] = Field(None, description="Optional title to display above the image")
    description: Optional[str] = Field(
        None, description="Optional description to display below the image"
    )
    figsize: Optional[tuple] = Field(
        (10, 10), description="Optional figure size as (width, height) in inches"
    )
    cmap: Optional[str] = Field(
        "rgb", description="Optional colormap to use for displaying the image"
    )


class ImageVisualizerTool(BaseTool):
    """Tool for displaying medical images to users with annotations."""

    name: str = "image_visualizer"
    description: str = (
        "Displays images to users with optional titles and descriptions. "
        "Input: Path to image file and optional display parameters. "
        "Output: Dict with image path and metadata."
    )
    args_schema: Type[BaseModel] = ImageVisualizerInput

    def _display_image(
        self,
        image_path: str,
        title: Optional[str] = None,
        description: Optional[str] = None,
        figsize: tuple = (10, 10),
        cmap: str = "rgb",
    ) -> None:
        """Display an image with optional annotations."""
        plt.figure(figsize=figsize)

        img = skimage.io.imread(image_path)
        if len(img.shape) > 2 and cmap != "rgb":
            img = img[..., 0]

        plt.imshow(img, cmap=None if cmap == "rgb" else cmap)
        plt.axis("off")

        if title:
            plt.title(title, pad=15, fontsize=12)

        # Add description if provided
        if description:
            plt.figtext(
                0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10
            )

        # Adjust margins to minimize whitespace while preventing overlap
        plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
        plt.show()

    def _run(
        self,
        image_path: str,
        title: Optional[str] = None,
        description: Optional[str] = None,
        figsize: tuple = (10, 10),
        cmap: str = "rgb",
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> dict:
        """
        Display an image to the user with optional annotations.

        Args:
            image_path: Path to the image file
            title: Optional title to display above image
            description: Optional description to display below image
            figsize: Optional figure size as (width, height)
            cmap: Optional colormap to use for displaying the image
            run_manager: Optional callback manager

        Returns:
            Dict containing display status and metadata
        """
        try:
            # Verify image path
            if not Path(image_path).is_file():
                raise FileNotFoundError(f"Image file not found: {image_path}")

            # Display image
            # self._display_image(image_path, title, description, figsize, cmap)

            output = {"image_path": image_path}
            metadata = {
                "image_path": image_path,
                "title": bool(title),
                "description": bool(description),
                "figsize": figsize,
                "cmap": cmap,
                "analysis_status": "completed",
            }
            return output, metadata

        except Exception as e:
            return (
                {"error": str(e)},
                {
                    "image_path": image_path,
                    "visualization_status": "failed",
                    "note": "An error occurred during image visualization",
                },
            )

    async def _arun(
        self,
        image_path: str,
        title: Optional[str] = None,
        description: Optional[str] = None,
        figsize: tuple = (10, 10),
        cmap: str = "rgb",
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> Tuple[Dict[str, any], Dict]:
        """Async version of _run."""
        return self._run(image_path, title, description, figsize, cmap)