|
import logging |
|
import os |
|
|
|
import torch |
|
from PIL import Image |
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
|
|
|
class XRayImageAnalyzer: |
|
""" |
|
A class for analyzing medical X-ray images using pre-trained models from Hugging Face. |
|
|
|
This analyzer uses the DeiT (Data-efficient image Transformers) model fine-tuned |
|
on chest X-ray images to detect abnormalities. |
|
""" |
|
|
|
def __init__( |
|
self, model_name="codewithdark/vit-chest-xray", device=None |
|
): |
|
""" |
|
Initialize the X-ray image analyzer with a specific pre-trained model. |
|
|
|
Args: |
|
model_name (str): The Hugging Face model name to use |
|
device (str, optional): Device to run the model on ('cuda' or 'cpu') |
|
""" |
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
if device is None: |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
else: |
|
self.device = device |
|
|
|
self.logger.info(f"Using device: {self.device}") |
|
|
|
|
|
try: |
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
|
self.model = AutoModelForImageClassification.from_pretrained(model_name) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
self.logger.info(f"Successfully loaded model: {model_name}") |
|
|
|
|
|
self.labels = self.model.config.id2label |
|
|
|
except Exception as e: |
|
self.logger.error(f"Failed to load model: {e}") |
|
raise |
|
|
|
def preprocess_image(self, image_path): |
|
""" |
|
Preprocess an X-ray image for model input. |
|
|
|
Args: |
|
image_path (str or PIL.Image): Path to image or PIL Image object |
|
|
|
Returns: |
|
dict: Processed inputs ready for the model |
|
""" |
|
try: |
|
|
|
if isinstance(image_path, str): |
|
if not os.path.exists(image_path): |
|
raise FileNotFoundError(f"Image file not found: {image_path}") |
|
image = Image.open(image_path).convert("RGB") |
|
else: |
|
|
|
image = image_path.convert("RGB") |
|
|
|
|
|
inputs = self.feature_extractor(images=image, return_tensors="pt") |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
return inputs, image |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error in preprocessing image: {e}") |
|
raise |
|
|
|
def analyze(self, image_path, threshold=0.5): |
|
""" |
|
Analyze an X-ray image and detect abnormalities. |
|
|
|
Args: |
|
image_path (str or PIL.Image): Path to the X-ray image or PIL Image object |
|
threshold (float): Classification threshold for positive findings |
|
|
|
Returns: |
|
dict: Analysis results including: |
|
- predictions: List of (label, probability) tuples |
|
- primary_finding: The most likely abnormality |
|
- has_abnormality: Boolean indicating if abnormalities were detected |
|
- confidence: Confidence score for the primary finding |
|
""" |
|
try: |
|
|
|
inputs, original_image = self.preprocess_image(image_path) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] |
|
probabilities = probabilities.cpu().numpy() |
|
|
|
|
|
predictions = [] |
|
for i, p in enumerate(probabilities): |
|
label = self.labels[i] |
|
predictions.append((label, float(p))) |
|
|
|
|
|
predictions.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
normal_idx = [ |
|
i |
|
for i, (label, _) in enumerate(predictions) |
|
if label.lower() == "normal" or label.lower() == "no finding" |
|
] |
|
|
|
if normal_idx and predictions[normal_idx[0]][1] > threshold: |
|
has_abnormality = False |
|
primary_finding = "No abnormalities detected" |
|
confidence = predictions[normal_idx[0]][1] |
|
else: |
|
has_abnormality = True |
|
primary_finding = predictions[0][0] |
|
confidence = predictions[0][1] |
|
|
|
return { |
|
"predictions": predictions, |
|
"primary_finding": primary_finding, |
|
"has_abnormality": has_abnormality, |
|
"confidence": confidence, |
|
} |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error analyzing image: {e}") |
|
raise |
|
|
|
def get_explanation(self, results): |
|
""" |
|
Generate a human-readable explanation of the analysis results. |
|
|
|
Args: |
|
results (dict): The results returned by the analyze method |
|
|
|
Returns: |
|
str: A text explanation of the findings |
|
""" |
|
if not results["has_abnormality"]: |
|
explanation = ( |
|
f"The X-ray appears normal with {results['confidence']:.1%} confidence." |
|
) |
|
else: |
|
explanation = ( |
|
f"The primary finding is {results['primary_finding']} " |
|
f"with {results['confidence']:.1%} confidence.\n\n" |
|
f"Other potential findings include:\n" |
|
) |
|
|
|
|
|
for label, prob in results["predictions"][1:4]: |
|
if prob > 0.05: |
|
explanation += f"- {label}: {prob:.1%}\n" |
|
|
|
return explanation |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
analyzer = XRayImageAnalyzer() |
|
|
|
|
|
sample_dir = "../data/sample" |
|
if os.path.exists(sample_dir) and os.listdir(sample_dir): |
|
sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0]) |
|
print(f"Analyzing sample image: {sample_image}") |
|
|
|
results = analyzer.analyze(sample_image) |
|
explanation = analyzer.get_explanation(results) |
|
|
|
print("\nAnalysis Results:") |
|
print(explanation) |
|
else: |
|
print("No sample images found in ../data/sample directory") |
|
|