Amarthya7's picture
Upload 21 files
86a74e6 verified
import logging
import os
import re
import cv2
from PIL import Image
# Set up logging
logger = logging.getLogger(__name__)
def preprocess_image(image_path, target_size=(224, 224)):
"""
Preprocess X-ray image for model input.
Args:
image_path (str): Path to the X-ray image
target_size (tuple): Target size for resizing
Returns:
PIL.Image: Preprocessed image
"""
try:
# Check if file exists
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
# Load image
image = Image.open(image_path)
# Convert grayscale to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Resize image
image = image.resize(target_size, Image.LANCZOS)
return image
except Exception as e:
logger.error(f"Error preprocessing image: {e}")
raise
def enhance_xray_image(image_path, output_path=None, clahe_clip=2.0, clahe_grid=(8, 8)):
"""
Enhance X-ray image contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization).
Args:
image_path (str): Path to the X-ray image
output_path (str, optional): Path to save enhanced image
clahe_clip (float): Clip limit for CLAHE
clahe_grid (tuple): Grid size for CLAHE
Returns:
str or np.ndarray: Path to enhanced image or image array
"""
try:
# Read image
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if img is None:
raise ValueError(f"Failed to read image: {image_path}")
# Create CLAHE object
clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=clahe_grid)
# Apply CLAHE
enhanced = clahe.apply(img)
# Save enhanced image if output path is provided
if output_path:
cv2.imwrite(output_path, enhanced)
return output_path
else:
return enhanced
except Exception as e:
logger.error(f"Error enhancing X-ray image: {e}")
raise
def normalize_report_text(text):
"""
Normalize medical report text for consistent processing.
Args:
text (str): Medical report text
Returns:
str: Normalized text
"""
try:
# Remove multiple whitespaces
text = re.sub(r"\s+", " ", text)
# Standardize section headers
section_patterns = {
r"(?i)clinical\s*(?:history|indication)": "CLINICAL HISTORY:",
r"(?i)technique": "TECHNIQUE:",
r"(?i)comparison": "COMPARISON:",
r"(?i)findings": "FINDINGS:",
r"(?i)impression": "IMPRESSION:",
r"(?i)recommendation": "RECOMMENDATION:",
r"(?i)comment": "COMMENT:",
}
for pattern, replacement in section_patterns.items():
text = re.sub(pattern + r"\s*:", replacement, text)
# Standardize common abbreviations
abbrev_patterns = {
r"(?i)\bw\/\b": "with",
r"(?i)\bw\/o\b": "without",
r"(?i)\bs\/p\b": "status post",
r"(?i)\bc\/w\b": "consistent with",
r"(?i)\br\/o\b": "rule out",
r"(?i)\bhx\b": "history",
r"(?i)\bdx\b": "diagnosis",
r"(?i)\btx\b": "treatment",
}
for pattern, replacement in abbrev_patterns.items():
text = re.sub(pattern, replacement, text)
return text.strip()
except Exception as e:
logger.error(f"Error normalizing report text: {e}")
return text # Return original text if normalization fails
def extract_sections(text):
"""
Extract sections from a medical report.
Args:
text (str): Medical report text
Returns:
dict: Dictionary of extracted sections
"""
try:
# Normalize text first
normalized_text = normalize_report_text(text)
# Define section patterns
section_headers = [
"CLINICAL HISTORY:",
"TECHNIQUE:",
"COMPARISON:",
"FINDINGS:",
"IMPRESSION:",
"RECOMMENDATION:",
]
# Find all section headers in the text
sections = {}
current_section = "PREAMBLE" # For text before first section header
sections[current_section] = []
for line in normalized_text.split("\n"):
section_found = False
for header in section_headers:
if header in line:
current_section = header.rstrip(":")
sections[current_section] = []
section_found = True
# Add the rest of the line after the header
content = line.split(header, 1)[1].strip()
if content:
sections[current_section].append(content)
break
if not section_found and current_section:
sections[current_section].append(line)
# Join each section's lines
for section, lines in sections.items():
sections[section] = " ".join(lines).strip()
# Remove empty sections
sections = {k: v for k, v in sections.items() if v}
return sections
except Exception as e:
logger.error(f"Error extracting sections: {e}")
return {"FULL_TEXT": text} # Return full text if extraction fails
def extract_measurements(text):
"""
Extract measurements from medical text (sizes, volumes, etc.).
Args:
text (str): Medical text
Returns:
list: List of tuples containing (measurement, value, unit)
"""
try:
# Pattern for measurements like "5mm nodule" or "nodule measuring 5mm"
# or "8x10mm mass" or "mass of size 8x10mm"
size_pattern = r"(\d+(?:\.\d+)?(?:\s*[x×]\s*\d+(?:\.\d+)?)?(?:\s*[x×]\s*\d+(?:\.\d+)?)?)\s*(mm|cm|mm2|cm2|mm3|cm3|ml|cc)"
# Find measurements with context
context_pattern = (
r"([A-Za-z\s]+(?:mass|nodule|effusion|opacity|lesion|tumor|cyst|structure|area|region)[A-Za-z\s]*)"
+ size_pattern
)
context_measurements = []
for match in re.finditer(context_pattern, text, re.IGNORECASE):
context, size, unit = match.groups()
context_measurements.append((context.strip(), size, unit))
# For measurements without clear context, just extract size and unit
all_measurements = []
for match in re.finditer(size_pattern, text):
size, unit = match.groups()
all_measurements.append((size, unit))
return context_measurements
except Exception as e:
logger.error(f"Error extracting measurements: {e}")
return []
def prepare_sample_batch(image_paths, reports=None, target_size=(224, 224)):
"""
Prepare a batch of samples for model processing.
Args:
image_paths (list): List of paths to images
reports (list, optional): List of corresponding reports
target_size (tuple): Target image size
Returns:
tuple: Batch of preprocessed images and reports
"""
try:
processed_images = []
processed_reports = []
for i, image_path in enumerate(image_paths):
# Process image
image = preprocess_image(image_path, target_size)
processed_images.append(image)
# Process report if available
if reports and i < len(reports):
normalized_report = normalize_report_text(reports[i])
processed_reports.append(normalized_report)
return processed_images, processed_reports if reports else None
except Exception as e:
logger.error(f"Error preparing sample batch: {e}")
raise