import base64 import io import logging import cv2 import matplotlib.pyplot as plt import numpy as np from PIL import Image # Set up logging logger = logging.getLogger(__name__) def plot_image_prediction(image, predictions, title=None, figsize=(10, 8)): """ Plot an image with its predictions. Args: image (PIL.Image or str): Image or path to image predictions (list): List of (label, probability) tuples title (str, optional): Plot title figsize (tuple): Figure size Returns: matplotlib.figure.Figure: The figure object """ try: # Load image if path is provided if isinstance(image, str): img = Image.open(image) else: img = image # Create figure fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) # Plot image ax1.imshow(img) ax1.set_title("X-ray Image") ax1.axis("off") # Plot predictions if predictions: # Sort predictions by probability sorted_pred = sorted(predictions, key=lambda x: x[1], reverse=True) # Get top 5 predictions top_n = min(5, len(sorted_pred)) labels = [pred[0] for pred in sorted_pred[:top_n]] probs = [pred[1] for pred in sorted_pred[:top_n]] # Plot horizontal bar chart y_pos = np.arange(top_n) ax2.barh(y_pos, probs, align="center") ax2.set_yticks(y_pos) ax2.set_yticklabels(labels) ax2.set_xlabel("Probability") ax2.set_title("Top Predictions") ax2.set_xlim(0, 1) # Annotate probabilities for i, prob in enumerate(probs): ax2.text(prob + 0.02, i, f"{prob:.1%}", va="center") # Set overall title if title: fig.suptitle(title, fontsize=16) fig.tight_layout() return fig except Exception as e: logger.error(f"Error plotting image prediction: {e}") # Create empty figure if error occurs fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center") return fig def create_heatmap_overlay(image, heatmap, alpha=0.4): """ Create a heatmap overlay on an X-ray image to highlight areas of interest. Args: image (PIL.Image or str): Image or path to image heatmap (numpy.ndarray): Heatmap array alpha (float): Transparency of the overlay Returns: PIL.Image: Image with heatmap overlay """ try: # Load image if path is provided if isinstance(image, str): img = cv2.imread(image) if img is None: raise ValueError(f"Could not load image: {image}") elif isinstance(image, Image.Image): img = np.array(image) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) else: img = image # Ensure image is in BGR format for OpenCV if len(img.shape) == 2: # Grayscale img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # Resize heatmap to match image dimensions heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # Normalize heatmap (0-1) heatmap = np.maximum(heatmap, 0) heatmap = np.minimum(heatmap / np.max(heatmap), 1) # Apply colormap (jet) to heatmap heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Create overlay overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0) # Convert back to PIL image overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) overlay_img = Image.fromarray(overlay) return overlay_img except Exception as e: logger.error(f"Error creating heatmap overlay: {e}") # Return original image if error occurs if isinstance(image, str): return Image.open(image) elif isinstance(image, Image.Image): return image else: return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) def plot_report_entities(text, entities, figsize=(12, 8)): """ Visualize entities extracted from a medical report. Args: text (str): Report text entities (dict): Dictionary of entities by category figsize (tuple): Figure size Returns: matplotlib.figure.Figure: The figure object """ try: fig, ax = plt.subplots(figsize=figsize) ax.axis("off") # Set background color fig.patch.set_facecolor("#f8f9fa") ax.set_facecolor("#f8f9fa") # Title ax.text( 0.5, 0.98, "Medical Report Analysis", ha="center", va="top", fontsize=18, fontweight="bold", color="#2c3e50", ) # Display entity counts y_pos = 0.9 ax.text( 0.05, y_pos, "Extracted Entities:", fontsize=14, fontweight="bold", color="#2c3e50", ) y_pos -= 0.05 # Define colors for different entity categories category_colors = { "problem": "#e74c3c", # Red "test": "#3498db", # Blue "treatment": "#2ecc71", # Green "anatomy": "#9b59b6", # Purple } # Display entities by category for category, items in entities.items(): if items: y_pos -= 0.05 ax.text( 0.1, y_pos, f"{category.capitalize()}:", fontsize=12, fontweight="bold", ) y_pos -= 0.05 ax.text( 0.15, y_pos, ", ".join(items), wrap=True, fontsize=11, color=category_colors.get(category, "black"), ) # Add the report text with highlighted entities y_pos -= 0.1 ax.text( 0.05, y_pos, "Report Text (with highlighted entities):", fontsize=14, fontweight="bold", color="#2c3e50", ) y_pos -= 0.05 # Get all entities to highlight all_entities = [] for category, items in entities.items(): for item in items: all_entities.append((item, category)) # Sort entities by length (longest first to avoid overlap issues) all_entities.sort(key=lambda x: len(x[0]), reverse=True) # Highlight entities in text highlighted_text = text for entity, category in all_entities: # Escape regex special characters entity_escaped = ( entity.replace("(", r"\(") .replace(")", r"\)") .replace("[", r"\[") .replace("]", r"\]") ) # Find entity in text (word boundary) pattern = r"\b" + entity_escaped + r"\b" color_code = category_colors.get(category, "black") replacement = f"\\textcolor{{{color_code}}}{{{entity}}}" highlighted_text = highlighted_text.replace(entity, replacement) # Display highlighted text ax.text(0.05, y_pos, highlighted_text, va="top", fontsize=10, wrap=True) fig.tight_layout(rect=[0, 0.03, 1, 0.97]) return fig except Exception as e: logger.error(f"Error plotting report entities: {e}") # Create empty figure if error occurs fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center") return fig def plot_multimodal_results( fused_results, image=None, report_text=None, figsize=(12, 10) ): """ Visualize the results of multimodal analysis. Args: fused_results (dict): Results from multimodal fusion image (PIL.Image or str, optional): Image or path to image report_text (str, optional): Report text figsize (tuple): Figure size Returns: matplotlib.figure.Figure: The figure object """ try: # Create figure with a grid layout fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(2, 2) # Add title fig.suptitle( "Multimodal Medical Analysis Results", fontsize=18, fontweight="bold", y=0.98, ) # 1. Overview panel (top left) ax_overview = fig.add_subplot(gs[0, 0]) ax_overview.axis("off") # Get severity info severity = fused_results.get("severity", {}) severity_level = severity.get("level", "Unknown") severity_score = severity.get("score", 0) # Get primary finding primary_finding = fused_results.get("primary_finding", "Unknown") # Get agreement score agreement = fused_results.get("agreement_score", 0) # Create overview text overview_text = [ "ANALYSIS OVERVIEW", f"Primary Finding: {primary_finding}", f"Severity Level: {severity_level} ({severity_score}/4)", f"Agreement Score: {agreement:.0%}", ] # Define severity colors severity_colors = { "Normal": "#2ecc71", # Green "Mild": "#3498db", # Blue "Moderate": "#f39c12", # Orange "Severe": "#e74c3c", # Red "Critical": "#c0392b", # Dark Red } # Add overview text to the panel y_pos = 0.9 ax_overview.text( 0.5, y_pos, overview_text[0], fontsize=14, fontweight="bold", ha="center", va="center", ) y_pos -= 0.15 ax_overview.text( 0.1, y_pos, overview_text[1], fontsize=12, ha="left", va="center" ) y_pos -= 0.1 # Severity with color severity_color = severity_colors.get(severity_level, "black") ax_overview.text( 0.1, y_pos, "Severity Level:", fontsize=12, ha="left", va="center" ) ax_overview.text( 0.4, y_pos, severity_level, fontsize=12, color=severity_color, fontweight="bold", ha="left", va="center", ) ax_overview.text( 0.6, y_pos, f"({severity_score}/4)", fontsize=10, ha="left", va="center" ) y_pos -= 0.1 # Agreement score with color agreement_color = ( "#2ecc71" if agreement > 0.7 else "#f39c12" if agreement > 0.4 else "#e74c3c" ) ax_overview.text( 0.1, y_pos, "Agreement Score:", fontsize=12, ha="left", va="center" ) ax_overview.text( 0.4, y_pos, f"{agreement:.0%}", fontsize=12, color=agreement_color, fontweight="bold", ha="left", va="center", ) # 2. Findings panel (top right) ax_findings = fig.add_subplot(gs[0, 1]) ax_findings.axis("off") # Get findings findings = fused_results.get("findings", []) # Add findings to the panel y_pos = 0.9 ax_findings.text( 0.5, y_pos, "KEY FINDINGS", fontsize=14, fontweight="bold", ha="center", va="center", ) y_pos -= 0.1 if findings: for i, finding in enumerate(findings[:5]): # Limit to 5 findings ax_findings.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center") ax_findings.text( 0.1, y_pos, finding, fontsize=11, ha="left", va="center", wrap=True ) y_pos -= 0.15 else: ax_findings.text( 0.1, y_pos, "No specific findings detailed.", fontsize=11, ha="left", va="center", ) # 3. Image panel (bottom left) ax_image = fig.add_subplot(gs[1, 0]) if image is not None: # Load image if path is provided if isinstance(image, str): img = Image.open(image) else: img = image # Display image ax_image.imshow(img) ax_image.set_title("X-ray Image", fontsize=12) else: ax_image.text(0.5, 0.5, "No image available", ha="center", va="center") ax_image.axis("off") # 4. Recommendation panel (bottom right) ax_rec = fig.add_subplot(gs[1, 1]) ax_rec.axis("off") # Get recommendations recommendations = fused_results.get("followup_recommendations", []) # Add recommendations to the panel y_pos = 0.9 ax_rec.text( 0.5, y_pos, "RECOMMENDATIONS", fontsize=14, fontweight="bold", ha="center", va="center", ) y_pos -= 0.1 if recommendations: for i, rec in enumerate(recommendations): ax_rec.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center") ax_rec.text( 0.1, y_pos, rec, fontsize=11, ha="left", va="center", wrap=True ) y_pos -= 0.15 else: ax_rec.text( 0.1, y_pos, "No specific recommendations provided.", fontsize=11, ha="left", va="center", ) # Add disclaimer fig.text( 0.5, 0.03, "DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice.", fontsize=9, style="italic", ha="center", ) fig.tight_layout(rect=[0, 0.05, 1, 0.95]) return fig except Exception as e: logger.error(f"Error plotting multimodal results: {e}") # Create empty figure if error occurs fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center") return fig def figure_to_base64(fig): """ Convert matplotlib figure to base64 string. Args: fig (matplotlib.figure.Figure): Figure object Returns: str: Base64 encoded string """ try: buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) img_str = base64.b64encode(buf.read()).decode("utf-8") return img_str except Exception as e: logger.error(f"Error converting figure to base64: {e}") return ""