|
import logging |
|
|
|
from .image_analyzer import XRayImageAnalyzer |
|
from .text_analyzer import MedicalReportAnalyzer |
|
|
|
|
|
class MultimodalFusion: |
|
""" |
|
A class for fusing insights from image analysis and text analysis of medical data. |
|
|
|
This fusion approach combines the strengths of both modalities: |
|
- Images provide visual evidence of abnormalities |
|
- Text reports provide context, history and radiologist interpretations |
|
|
|
The combined analysis provides a more comprehensive understanding than either modality alone. |
|
""" |
|
|
|
def __init__(self, image_model=None, text_model=None, device=None): |
|
""" |
|
Initialize the multimodal fusion module with image and text analyzers. |
|
|
|
Args: |
|
image_model (str, optional): Model to use for image analysis |
|
text_model (str, optional): Model to use for text analysis |
|
device (str, optional): Device to run models on ('cuda' or 'cpu') |
|
""" |
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
if device is None: |
|
import torch |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
else: |
|
self.device = device |
|
|
|
self.logger.info(f"Using device: {self.device}") |
|
|
|
|
|
try: |
|
self.image_analyzer = XRayImageAnalyzer( |
|
model_name=image_model |
|
if image_model |
|
else "codewithdark/vit-chest-xray", |
|
device=self.device, |
|
) |
|
self.logger.info("Successfully initialized image analyzer") |
|
except Exception as e: |
|
self.logger.error(f"Failed to initialize image analyzer: {e}") |
|
self.image_analyzer = None |
|
|
|
|
|
try: |
|
self.text_analyzer = MedicalReportAnalyzer( |
|
classifier_model=text_model if text_model else "medicalai/ClinicalBERT", |
|
device=self.device, |
|
) |
|
self.logger.info("Successfully initialized text analyzer") |
|
except Exception as e: |
|
self.logger.error(f"Failed to initialize text analyzer: {e}") |
|
self.text_analyzer = None |
|
|
|
def analyze_image(self, image_path): |
|
""" |
|
Analyze a medical image. |
|
|
|
Args: |
|
image_path (str): Path to the medical image |
|
|
|
Returns: |
|
dict: Image analysis results |
|
""" |
|
if not self.image_analyzer: |
|
self.logger.warning("Image analyzer not available") |
|
return {"error": "Image analyzer not available"} |
|
|
|
try: |
|
return self.image_analyzer.analyze(image_path) |
|
except Exception as e: |
|
self.logger.error(f"Error analyzing image: {e}") |
|
return {"error": str(e)} |
|
|
|
def analyze_text(self, text): |
|
""" |
|
Analyze medical report text. |
|
|
|
Args: |
|
text (str): Medical report text |
|
|
|
Returns: |
|
dict: Text analysis results |
|
""" |
|
if not self.text_analyzer: |
|
self.logger.warning("Text analyzer not available") |
|
return {"error": "Text analyzer not available"} |
|
|
|
try: |
|
return self.text_analyzer.analyze(text) |
|
except Exception as e: |
|
self.logger.error(f"Error analyzing text: {e}") |
|
return {"error": str(e)} |
|
|
|
def _calculate_agreement_score(self, image_results, text_results): |
|
""" |
|
Calculate agreement score between image and text analyses. |
|
|
|
Args: |
|
image_results (dict): Results from image analysis |
|
text_results (dict): Results from text analysis |
|
|
|
Returns: |
|
float: Agreement score (0-1, where 1 is perfect agreement) |
|
""" |
|
try: |
|
|
|
agreement = 0.5 |
|
|
|
|
|
image_abnormal = image_results.get("has_abnormality", False) |
|
|
|
|
|
text_severity = text_results.get("severity", {}).get("level", "Unknown") |
|
text_abnormal = text_severity not in ["Normal", "Unknown"] |
|
|
|
|
|
if image_abnormal == text_abnormal: |
|
agreement += 0.25 |
|
else: |
|
agreement -= 0.25 |
|
|
|
|
|
image_finding = image_results.get("primary_finding", "").lower() |
|
|
|
|
|
problems = text_results.get("entities", {}).get("problem", []) |
|
problem_text = " ".join(problems).lower() |
|
|
|
|
|
common_conditions = [ |
|
"pneumonia", |
|
"effusion", |
|
"nodule", |
|
"mass", |
|
"cardiomegaly", |
|
"opacity", |
|
"fracture", |
|
"tumor", |
|
"edema", |
|
] |
|
|
|
matching_conditions = 0 |
|
total_mentioned = 0 |
|
|
|
for condition in common_conditions: |
|
in_image = condition in image_finding |
|
in_text = condition in problem_text |
|
|
|
if in_image or in_text: |
|
total_mentioned += 1 |
|
|
|
if in_image and in_text: |
|
matching_conditions += 1 |
|
agreement += 0.05 |
|
|
|
|
|
if total_mentioned > 0: |
|
match_ratio = matching_conditions / total_mentioned |
|
agreement += match_ratio * 0.2 |
|
|
|
|
|
agreement = max(0, min(1, agreement)) |
|
|
|
return agreement |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error calculating agreement score: {e}") |
|
return 0.5 |
|
|
|
def _get_confidence_weighted_finding(self, image_results, text_results, agreement): |
|
""" |
|
Get the most confident finding weighted by modality confidence. |
|
|
|
Args: |
|
image_results (dict): Results from image analysis |
|
text_results (dict): Results from text analysis |
|
agreement (float): Agreement score between modalities |
|
|
|
Returns: |
|
str: Most confident finding |
|
""" |
|
try: |
|
image_finding = image_results.get("primary_finding", "") |
|
image_confidence = image_results.get("confidence", 0.5) |
|
|
|
|
|
problems = text_results.get("entities", {}).get("problem", []) |
|
|
|
text_confidence = text_results.get("severity", {}).get("confidence", 0.5) |
|
|
|
if not problems: |
|
|
|
if image_confidence > 0.7: |
|
return image_finding |
|
else: |
|
return "No significant findings" |
|
|
|
|
|
if image_confidence > text_confidence + 0.2: |
|
return image_finding |
|
elif problems and text_confidence > image_confidence + 0.2: |
|
return ( |
|
problems[0] |
|
if isinstance(problems, list) and problems |
|
else "Unknown finding" |
|
) |
|
else: |
|
|
|
if agreement > 0.7: |
|
|
|
for problem in problems: |
|
if problem.lower() in image_finding.lower(): |
|
return problem |
|
|
|
|
|
if image_confidence > 0.6: |
|
return image_finding |
|
elif problems: |
|
return problems[0] |
|
else: |
|
return image_finding |
|
else: |
|
|
|
if image_finding and problems: |
|
return f"{image_finding} (image) / {problems[0]} (report)" |
|
elif image_finding: |
|
return image_finding |
|
elif problems: |
|
return problems[0] |
|
else: |
|
return "Findings unclear - review recommended" |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error getting weighted finding: {e}") |
|
return "Unable to determine primary finding" |
|
|
|
def _merge_followup_recommendations(self, image_results, text_results): |
|
""" |
|
Merge follow-up recommendations from both modalities. |
|
|
|
Args: |
|
image_results (dict): Results from image analysis |
|
text_results (dict): Results from text analysis |
|
|
|
Returns: |
|
list: Combined follow-up recommendations |
|
""" |
|
try: |
|
|
|
text_recommendations = text_results.get("followup_recommendations", []) |
|
|
|
|
|
image_recommendations = [] |
|
|
|
if image_results.get("has_abnormality", False): |
|
primary = image_results.get("primary_finding", "") |
|
confidence = image_results.get("confidence", 0) |
|
|
|
if ( |
|
"nodule" in primary.lower() |
|
or "mass" in primary.lower() |
|
or "tumor" in primary.lower() |
|
): |
|
image_recommendations.append( |
|
f"Follow-up imaging recommended to further evaluate {primary}." |
|
) |
|
elif "pneumonia" in primary.lower(): |
|
image_recommendations.append( |
|
"Clinical correlation and follow-up imaging recommended." |
|
) |
|
elif confidence > 0.8: |
|
image_recommendations.append( |
|
f"Consider follow-up imaging to monitor {primary}." |
|
) |
|
elif confidence > 0.5: |
|
image_recommendations.append( |
|
"Consider clinical correlation and potential follow-up." |
|
) |
|
|
|
|
|
all_recommendations = text_recommendations + image_recommendations |
|
|
|
|
|
unique_recommendations = [] |
|
for rec in all_recommendations: |
|
if not any( |
|
self._is_similar_recommendation(rec, existing) |
|
for existing in unique_recommendations |
|
): |
|
unique_recommendations.append(rec) |
|
|
|
return unique_recommendations |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error merging follow-up recommendations: {e}") |
|
return ["Follow-up recommended based on findings."] |
|
|
|
def _is_similar_recommendation(self, rec1, rec2): |
|
"""Check if two recommendations are semantically similar.""" |
|
|
|
rec1_lower = rec1.lower() |
|
rec2_lower = rec2.lower() |
|
|
|
|
|
words1 = set(rec1_lower.split()) |
|
words2 = set(rec2_lower.split()) |
|
|
|
|
|
intersection = words1.intersection(words2) |
|
union = words1.union(words2) |
|
|
|
similarity = len(intersection) / len(union) if union else 0 |
|
|
|
|
|
return similarity > 0.6 |
|
|
|
def _get_final_severity(self, image_results, text_results, agreement): |
|
""" |
|
Determine final severity based on both modalities. |
|
|
|
Args: |
|
image_results (dict): Results from image analysis |
|
text_results (dict): Results from text analysis |
|
agreement (float): Agreement score between modalities |
|
|
|
Returns: |
|
dict: Final severity assessment |
|
""" |
|
try: |
|
|
|
text_severity = text_results.get("severity", {}) |
|
text_level = text_severity.get("level", "Unknown") |
|
text_score = text_severity.get("score", 0) |
|
text_confidence = text_severity.get("confidence", 0.5) |
|
|
|
|
|
image_abnormal = image_results.get("has_abnormality", False) |
|
image_confidence = image_results.get("confidence", 0.5) |
|
|
|
|
|
image_severity = "Normal" if not image_abnormal else "Moderate" |
|
image_score = 0 if not image_abnormal else 2.0 |
|
|
|
|
|
primary_finding = image_results.get("primary_finding", "").lower() |
|
|
|
|
|
severity_mapping = { |
|
"pneumonia": ("Moderate", 2.5), |
|
"pneumothorax": ("Severe", 3.0), |
|
"effusion": ("Moderate", 2.0), |
|
"pulmonary edema": ("Moderate", 2.5), |
|
"nodule": ("Mild", 1.5), |
|
"mass": ("Moderate", 2.5), |
|
"tumor": ("Severe", 3.0), |
|
"cardiomegaly": ("Mild", 1.5), |
|
"fracture": ("Moderate", 2.0), |
|
"consolidation": ("Moderate", 2.0), |
|
} |
|
|
|
|
|
for key, (severity, score) in severity_mapping.items(): |
|
if key in primary_finding: |
|
image_severity = severity |
|
image_score = score |
|
break |
|
|
|
|
|
if agreement > 0.7: |
|
|
|
final_score = (image_score + text_score) / 2 |
|
else: |
|
|
|
total_confidence = image_confidence + text_confidence |
|
if total_confidence > 0: |
|
image_weight = image_confidence / total_confidence |
|
text_weight = text_confidence / total_confidence |
|
final_score = (image_score * image_weight) + ( |
|
text_score * text_weight |
|
) |
|
else: |
|
final_score = (image_score + text_score) / 2 |
|
|
|
|
|
severity_levels = { |
|
0: "Normal", |
|
1: "Mild", |
|
2: "Moderate", |
|
3: "Severe", |
|
4: "Critical", |
|
} |
|
|
|
|
|
level_index = round(min(4, max(0, final_score))) |
|
final_level = severity_levels[level_index] |
|
|
|
return { |
|
"level": final_level, |
|
"score": round(final_score, 1), |
|
"confidence": round((image_confidence + text_confidence) / 2, 2), |
|
} |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error determining final severity: {e}") |
|
return {"level": "Unknown", "score": 0, "confidence": 0} |
|
|
|
def fuse_analyses(self, image_results, text_results): |
|
""" |
|
Fuse the results from image and text analyses. |
|
|
|
Args: |
|
image_results (dict): Results from image analysis |
|
text_results (dict): Results from text analysis |
|
|
|
Returns: |
|
dict: Fused analysis results |
|
""" |
|
try: |
|
|
|
agreement = self._calculate_agreement_score(image_results, text_results) |
|
self.logger.info(f"Agreement score between modalities: {agreement:.2f}") |
|
|
|
|
|
primary_finding = self._get_confidence_weighted_finding( |
|
image_results, text_results, agreement |
|
) |
|
|
|
|
|
followup = self._merge_followup_recommendations(image_results, text_results) |
|
|
|
|
|
severity = self._get_final_severity(image_results, text_results, agreement) |
|
|
|
|
|
findings = [] |
|
|
|
|
|
text_findings = text_results.get("findings", []) |
|
if text_findings: |
|
findings.extend(text_findings) |
|
|
|
|
|
image_finding = image_results.get("primary_finding", "") |
|
if image_finding and not any( |
|
image_finding.lower() in f.lower() for f in findings |
|
): |
|
findings.append(f"Image finding: {image_finding}") |
|
|
|
|
|
fused_result = { |
|
"agreement_score": round(agreement, 2), |
|
"primary_finding": primary_finding, |
|
"severity": severity, |
|
"findings": findings, |
|
"followup_recommendations": followup, |
|
"modality_results": {"image": image_results, "text": text_results}, |
|
} |
|
|
|
return fused_result |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error fusing analyses: {e}") |
|
return { |
|
"error": str(e), |
|
"modality_results": {"image": image_results, "text": text_results}, |
|
} |
|
|
|
def analyze(self, image_path, report_text): |
|
""" |
|
Perform multimodal analysis of medical image and report. |
|
|
|
Args: |
|
image_path (str): Path to the medical image |
|
report_text (str): Medical report text |
|
|
|
Returns: |
|
dict: Fused analysis results |
|
""" |
|
try: |
|
|
|
image_results = self.analyze_image(image_path) |
|
|
|
|
|
text_results = self.analyze_text(report_text) |
|
|
|
|
|
return self.fuse_analyses(image_results, text_results) |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error in multimodal analysis: {e}") |
|
return {"error": str(e)} |
|
|
|
def get_explanation(self, fused_results): |
|
""" |
|
Generate a human-readable explanation of the fused analysis. |
|
|
|
Args: |
|
fused_results (dict): Results from the fused analysis |
|
|
|
Returns: |
|
str: A text explanation of the fused analysis |
|
""" |
|
try: |
|
explanation = [] |
|
|
|
|
|
primary_finding = fused_results.get("primary_finding", "Unknown") |
|
severity = fused_results.get("severity", {}).get("level", "Unknown") |
|
|
|
explanation.append("# Medical Analysis Summary\n") |
|
explanation.append("## Overview\n") |
|
explanation.append(f"Primary finding: **{primary_finding}**\n") |
|
explanation.append(f"Severity level: **{severity}**\n") |
|
|
|
|
|
agreement = fused_results.get("agreement_score", 0) |
|
agreement_text = ( |
|
"High" if agreement > 0.7 else "Moderate" if agreement > 0.4 else "Low" |
|
) |
|
|
|
explanation.append( |
|
f"Image and text analysis agreement: **{agreement_text}** ({agreement:.0%})\n" |
|
) |
|
|
|
|
|
explanation.append("\n## Detailed Findings\n") |
|
findings = fused_results.get("findings", []) |
|
|
|
if findings: |
|
for finding in findings: |
|
explanation.append(f"- {finding}\n") |
|
else: |
|
explanation.append("No specific findings detailed.\n") |
|
|
|
|
|
explanation.append("\n## Recommended Follow-up\n") |
|
followups = fused_results.get("followup_recommendations", []) |
|
|
|
if followups: |
|
for followup in followups: |
|
explanation.append(f"- {followup}\n") |
|
else: |
|
explanation.append("No specific follow-up recommendations provided.\n") |
|
|
|
|
|
confidence = fused_results.get("severity", {}).get("confidence", 0) |
|
explanation.append( |
|
f"\n*Note: This analysis has a confidence level of {confidence:.0%}. " |
|
f"Please consult with healthcare professionals for official diagnosis.*" |
|
) |
|
|
|
return "\n".join(explanation) |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error generating explanation: {e}") |
|
return "Error generating analysis explanation." |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
import os |
|
|
|
fusion = MultimodalFusion() |
|
|
|
|
|
sample_report = """ |
|
CHEST X-RAY EXAMINATION |
|
|
|
CLINICAL HISTORY: 55-year-old male with cough and fever. |
|
|
|
FINDINGS: The heart size is at the upper limits of normal. The lungs are clear without focal consolidation, |
|
effusion, or pneumothorax. There is mild prominence of the pulmonary vasculature. No pleural effusion is seen. |
|
There is a small nodular opacity noted in the right lower lobe measuring approximately 8mm, which is suspicious |
|
and warrants further investigation. The mediastinum is unremarkable. The visualized bony structures show no acute abnormalities. |
|
|
|
IMPRESSION: |
|
1. Mild cardiomegaly. |
|
2. 8mm nodular opacity in the right lower lobe, recommend follow-up CT for further evaluation. |
|
3. No acute pulmonary parenchymal abnormality. |
|
|
|
RECOMMENDATIONS: Follow-up chest CT to further characterize the nodular opacity in the right lower lobe. |
|
""" |
|
|
|
|
|
sample_dir = "../data/sample" |
|
if os.path.exists(sample_dir) and os.listdir(sample_dir): |
|
sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0]) |
|
print(f"Analyzing sample image: {sample_image}") |
|
|
|
|
|
fused_results = fusion.analyze(sample_image, sample_report) |
|
explanation = fusion.get_explanation(fused_results) |
|
|
|
print("\nFused Analysis Results:") |
|
print(explanation) |
|
else: |
|
print("No sample images found. Only analyzing text report.") |
|
|
|
|
|
text_results = fusion.analyze_text(sample_report) |
|
|
|
print("\nText Analysis Results:") |
|
print( |
|
f"Severity: {text_results['severity']['level']} (Score: {text_results['severity']['score']})" |
|
) |
|
|
|
print("\nKey Findings:") |
|
for finding in text_results["findings"]: |
|
print(f"- {finding}") |
|
|
|
print("\nEntities:") |
|
for category, items in text_results["entities"].items(): |
|
if items: |
|
print(f"- {category.capitalize()}: {', '.join(items)}") |
|
|
|
print("\nFollow-up Recommendations:") |
|
for rec in text_results["followup_recommendations"]: |
|
print(f"- {rec}") |
|
|