File size: 6,854 Bytes
05e3595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import logging
import os

import torch
from PIL import Image
from transformers import AutoFeatureExtractor, AutoModelForImageClassification


class XRayImageAnalyzer:
    """
    A class for analyzing medical X-ray images using pre-trained models from Hugging Face.

    This analyzer uses the DeiT (Data-efficient image Transformers) model fine-tuned
    on chest X-ray images to detect abnormalities.
    """

    def __init__(
        self, model_name="codewithdark/vit-chest-xray", device=None
    ):
        """
        Initialize the X-ray image analyzer with a specific pre-trained model.

        Args:
            model_name (str): The Hugging Face model name to use
            device (str, optional): Device to run the model on ('cuda' or 'cpu')
        """
        self.logger = logging.getLogger(__name__)

        # Determine device (CPU or GPU)
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        self.logger.info(f"Using device: {self.device}")

        # Load model and feature extractor
        try:
            self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
            self.model = AutoModelForImageClassification.from_pretrained(model_name)
            self.model.to(self.device)
            self.model.eval()  # Set to evaluation mode
            self.logger.info(f"Successfully loaded model: {model_name}")

            # Map labels to more informative descriptions
            self.labels = self.model.config.id2label

        except Exception as e:
            self.logger.error(f"Failed to load model: {e}")
            raise

    def preprocess_image(self, image_path):
        """
        Preprocess an X-ray image for model input.

        Args:
            image_path (str or PIL.Image): Path to image or PIL Image object

        Returns:
            dict: Processed inputs ready for the model
        """
        try:
            # Load image if path is provided
            if isinstance(image_path, str):
                if not os.path.exists(image_path):
                    raise FileNotFoundError(f"Image file not found: {image_path}")
                image = Image.open(image_path).convert("RGB")
            else:
                # Assume it's already a PIL Image
                image = image_path.convert("RGB")

            # Apply feature extraction
            inputs = self.feature_extractor(images=image, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            return inputs, image

        except Exception as e:
            self.logger.error(f"Error in preprocessing image: {e}")
            raise

    def analyze(self, image_path, threshold=0.5):
        """
        Analyze an X-ray image and detect abnormalities.

        Args:
            image_path (str or PIL.Image): Path to the X-ray image or PIL Image object
            threshold (float): Classification threshold for positive findings

        Returns:
            dict: Analysis results including:
                - predictions: List of (label, probability) tuples
                - primary_finding: The most likely abnormality
                - has_abnormality: Boolean indicating if abnormalities were detected
                - confidence: Confidence score for the primary finding
        """
        try:
            # Preprocess the image
            inputs, original_image = self.preprocess_image(image_path)

            # Run inference
            with torch.no_grad():
                outputs = self.model(**inputs)

            # Process predictions
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
            probabilities = probabilities.cpu().numpy()

            # Get predictions sorted by probability
            predictions = []
            for i, p in enumerate(probabilities):
                label = self.labels[i]
                predictions.append((label, float(p)))

            # Sort by probability (descending)
            predictions.sort(key=lambda x: x[1], reverse=True)

            # Determine if there's an abnormality and the primary finding
            normal_idx = [
                i
                for i, (label, _) in enumerate(predictions)
                if label.lower() == "normal" or label.lower() == "no finding"
            ]

            if normal_idx and predictions[normal_idx[0]][1] > threshold:
                has_abnormality = False
                primary_finding = "No abnormalities detected"
                confidence = predictions[normal_idx[0]][1]
            else:
                has_abnormality = True
                primary_finding = predictions[0][0]
                confidence = predictions[0][1]

            return {
                "predictions": predictions,
                "primary_finding": primary_finding,
                "has_abnormality": has_abnormality,
                "confidence": confidence,
            }

        except Exception as e:
            self.logger.error(f"Error analyzing image: {e}")
            raise

    def get_explanation(self, results):
        """
        Generate a human-readable explanation of the analysis results.

        Args:
            results (dict): The results returned by the analyze method

        Returns:
            str: A text explanation of the findings
        """
        if not results["has_abnormality"]:
            explanation = (
                f"The X-ray appears normal with {results['confidence']:.1%} confidence."
            )
        else:
            explanation = (
                f"The primary finding is {results['primary_finding']} "
                f"with {results['confidence']:.1%} confidence.\n\n"
                f"Other potential findings include:\n"
            )

            # Add top 3 other findings (skipping the first one which is primary)
            for label, prob in results["predictions"][1:4]:
                if prob > 0.05:  # Only include if probability > 5%
                    explanation += f"- {label}: {prob:.1%}\n"

        return explanation


# Example usage
if __name__ == "__main__":
    # Set up logging
    logging.basicConfig(level=logging.INFO)

    # Test on a sample image if available
    analyzer = XRayImageAnalyzer()

    # Check if sample data directory exists
    sample_dir = "../data/sample"
    if os.path.exists(sample_dir) and os.listdir(sample_dir):
        sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0])
        print(f"Analyzing sample image: {sample_image}")

        results = analyzer.analyze(sample_image)
        explanation = analyzer.get_explanation(results)

        print("\nAnalysis Results:")
        print(explanation)
    else:
        print("No sample images found in ../data/sample directory")