File size: 6,854 Bytes
05e3595 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import logging
import os
import torch
from PIL import Image
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
class XRayImageAnalyzer:
"""
A class for analyzing medical X-ray images using pre-trained models from Hugging Face.
This analyzer uses the DeiT (Data-efficient image Transformers) model fine-tuned
on chest X-ray images to detect abnormalities.
"""
def __init__(
self, model_name="codewithdark/vit-chest-xray", device=None
):
"""
Initialize the X-ray image analyzer with a specific pre-trained model.
Args:
model_name (str): The Hugging Face model name to use
device (str, optional): Device to run the model on ('cuda' or 'cpu')
"""
self.logger = logging.getLogger(__name__)
# Determine device (CPU or GPU)
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.logger.info(f"Using device: {self.device}")
# Load model and feature extractor
try:
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
self.model = AutoModelForImageClassification.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval() # Set to evaluation mode
self.logger.info(f"Successfully loaded model: {model_name}")
# Map labels to more informative descriptions
self.labels = self.model.config.id2label
except Exception as e:
self.logger.error(f"Failed to load model: {e}")
raise
def preprocess_image(self, image_path):
"""
Preprocess an X-ray image for model input.
Args:
image_path (str or PIL.Image): Path to image or PIL Image object
Returns:
dict: Processed inputs ready for the model
"""
try:
# Load image if path is provided
if isinstance(image_path, str):
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
image = Image.open(image_path).convert("RGB")
else:
# Assume it's already a PIL Image
image = image_path.convert("RGB")
# Apply feature extraction
inputs = self.feature_extractor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
return inputs, image
except Exception as e:
self.logger.error(f"Error in preprocessing image: {e}")
raise
def analyze(self, image_path, threshold=0.5):
"""
Analyze an X-ray image and detect abnormalities.
Args:
image_path (str or PIL.Image): Path to the X-ray image or PIL Image object
threshold (float): Classification threshold for positive findings
Returns:
dict: Analysis results including:
- predictions: List of (label, probability) tuples
- primary_finding: The most likely abnormality
- has_abnormality: Boolean indicating if abnormalities were detected
- confidence: Confidence score for the primary finding
"""
try:
# Preprocess the image
inputs, original_image = self.preprocess_image(image_path)
# Run inference
with torch.no_grad():
outputs = self.model(**inputs)
# Process predictions
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
probabilities = probabilities.cpu().numpy()
# Get predictions sorted by probability
predictions = []
for i, p in enumerate(probabilities):
label = self.labels[i]
predictions.append((label, float(p)))
# Sort by probability (descending)
predictions.sort(key=lambda x: x[1], reverse=True)
# Determine if there's an abnormality and the primary finding
normal_idx = [
i
for i, (label, _) in enumerate(predictions)
if label.lower() == "normal" or label.lower() == "no finding"
]
if normal_idx and predictions[normal_idx[0]][1] > threshold:
has_abnormality = False
primary_finding = "No abnormalities detected"
confidence = predictions[normal_idx[0]][1]
else:
has_abnormality = True
primary_finding = predictions[0][0]
confidence = predictions[0][1]
return {
"predictions": predictions,
"primary_finding": primary_finding,
"has_abnormality": has_abnormality,
"confidence": confidence,
}
except Exception as e:
self.logger.error(f"Error analyzing image: {e}")
raise
def get_explanation(self, results):
"""
Generate a human-readable explanation of the analysis results.
Args:
results (dict): The results returned by the analyze method
Returns:
str: A text explanation of the findings
"""
if not results["has_abnormality"]:
explanation = (
f"The X-ray appears normal with {results['confidence']:.1%} confidence."
)
else:
explanation = (
f"The primary finding is {results['primary_finding']} "
f"with {results['confidence']:.1%} confidence.\n\n"
f"Other potential findings include:\n"
)
# Add top 3 other findings (skipping the first one which is primary)
for label, prob in results["predictions"][1:4]:
if prob > 0.05: # Only include if probability > 5%
explanation += f"- {label}: {prob:.1%}\n"
return explanation
# Example usage
if __name__ == "__main__":
# Set up logging
logging.basicConfig(level=logging.INFO)
# Test on a sample image if available
analyzer = XRayImageAnalyzer()
# Check if sample data directory exists
sample_dir = "../data/sample"
if os.path.exists(sample_dir) and os.listdir(sample_dir):
sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0])
print(f"Analyzing sample image: {sample_image}")
results = analyzer.analyze(sample_image)
explanation = analyzer.get_explanation(results)
print("\nAnalysis Results:")
print(explanation)
else:
print("No sample images found in ../data/sample directory")
|