VisionScout / visualization_helper.py
DawnC's picture
Upload 6 files
611206a verified
raw
history blame
5.39 kB
import cv2
import numpy as np
import matplotlib.pyplot as plt
from typing import Any, List, Dict, Tuple, Optional
import io
from PIL import Image
class VisualizationHelper:
"""Helper class for visualizing detection results"""
@staticmethod
def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
figsize: Tuple[int, int] = (12, 12),
return_pil: bool = False) -> Optional[Image.Image]:
"""
Visualize detection results on a single image
Args:
image: Image path or numpy array
result: Detection result object
color_mapper: ColorMapper instance for consistent colors
figsize: Figure size
return_pil: If True, returns a PIL Image object
Returns:
PIL Image if return_pil is True, otherwise displays the plot
"""
if result is None:
print('No data for visualization')
return None
# Read image if path is provided
if isinstance(image, str):
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
else:
img = image
if len(img.shape) == 3 and img.shape[2] == 3:
# Check if BGR format (OpenCV) and convert to RGB if needed
if isinstance(img, np.ndarray):
# Assuming BGR format from OpenCV
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Create figure
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(img)
# Get bounding boxes, classes and confidences
boxes = result.boxes.xyxy.cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
confs = result.boxes.conf.cpu().numpy()
# Get class names
names = result.names
# Create a default color mapper if none is provided
if color_mapper is None:
# For backward compatibility, fallback to a simple color function
from matplotlib import colormaps
cmap = colormaps['tab10']
def get_color(class_id):
return cmap(class_id % 10)
else:
# Use the provided color mapper
def get_color(class_id):
hex_color = color_mapper.get_color(class_id)
# Convert hex to RGB float values for matplotlib
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
# Draw detection results
for box, cls, conf in zip(boxes, classes, confs):
x1, y1, x2, y2 = box
cls_id = int(cls)
cls_name = names[cls_id]
# Get color for this class
box_color = get_color(cls_id)
# Add text label with colored background
ax.text(x1, y1 - 5, f'{cls_name}: {conf:.2f}',
color='white', fontsize=10,
bbox=dict(facecolor=box_color[:3], alpha=0.7))
# Add bounding box
ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
fill=False, edgecolor=box_color[:3], linewidth=2))
ax.axis('off')
# ax.set_title('Detection Result')
plt.tight_layout()
if return_pil:
# Convert plot to PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
pil_img = Image.open(buf)
plt.close(fig)
return pil_img
else:
plt.show()
return None
@staticmethod
def create_summary(result: Any) -> Dict:
"""
Create a summary of detection results
Args:
result: Detection result object
Returns:
Dictionary with detection summary statistics
"""
if result is None:
return {"error": "No detection result provided"}
# Get classes and confidences
classes = result.boxes.cls.cpu().numpy().astype(int)
confidences = result.boxes.conf.cpu().numpy()
names = result.names
# Count detections by class
class_counts = {}
for cls, conf in zip(classes, confidences):
cls_name = names[int(cls)]
if cls_name not in class_counts:
class_counts[cls_name] = {"count": 0, "confidences": []}
class_counts[cls_name]["count"] += 1
class_counts[cls_name]["confidences"].append(float(conf))
# Calculate average confidence for each class
for cls_name, stats in class_counts.items():
if stats["confidences"]:
stats["average_confidence"] = float(np.mean(stats["confidences"]))
stats.pop("confidences") # Remove detailed confidences list to keep summary concise
# Prepare summary
summary = {
"total_objects": len(classes),
"class_counts": class_counts,
"unique_classes": len(class_counts)
}
return summary