Spaces:
Running
on
Zero
Running
on
Zero
Upload 6 files
Browse files- app.py +598 -0
- color_mapper.py +270 -0
- detection_model.py +164 -0
- evaluation_metrics.py +360 -0
- style.py +282 -0
- visualization_helper.py +147 -0
app.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import cv2
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import gradio as gr
|
7 |
+
import io
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
import spaces
|
10 |
+
from typing import Dict, List, Any, Optional, Tuple
|
11 |
+
from ultralytics import YOLO
|
12 |
+
|
13 |
+
from detection_model import DetectionModel
|
14 |
+
from color_mapper import ColorMapper
|
15 |
+
from visualization_helper import VisualizationHelper
|
16 |
+
from evaluation_metrics import EvaluationMetrics
|
17 |
+
from style import Style
|
18 |
+
|
19 |
+
|
20 |
+
color_mapper = ColorMapper()
|
21 |
+
model_instances = {}
|
22 |
+
|
23 |
+
@spaces.GPU
|
24 |
+
def process_image(image, model_instance, confidence_threshold, filter_classes=None):
|
25 |
+
"""
|
26 |
+
Process an image for object detection
|
27 |
+
|
28 |
+
Args:
|
29 |
+
image: Input image (numpy array or PIL Image)
|
30 |
+
model_instance: DetectionModel instance to use
|
31 |
+
confidence_threshold: Confidence threshold for detection
|
32 |
+
filter_classes: Optional list of classes to filter results
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Tuple of (result_image, result_text, stats_data)
|
36 |
+
"""
|
37 |
+
# initialize key variables
|
38 |
+
result = None
|
39 |
+
stats = {}
|
40 |
+
temp_path = None
|
41 |
+
|
42 |
+
try:
|
43 |
+
# update confidence threshold
|
44 |
+
model_instance.confidence = confidence_threshold
|
45 |
+
|
46 |
+
# processing input image
|
47 |
+
if isinstance(image, np.ndarray):
|
48 |
+
# Convert BGR to RGB if needed
|
49 |
+
if image.shape[2] == 3:
|
50 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
51 |
+
else:
|
52 |
+
image_rgb = image
|
53 |
+
pil_image = Image.fromarray(image_rgb)
|
54 |
+
elif image is None:
|
55 |
+
return None, "No image provided. Please upload an image.", {}
|
56 |
+
else:
|
57 |
+
pil_image = image
|
58 |
+
|
59 |
+
# store temp files
|
60 |
+
import uuid
|
61 |
+
import tempfile
|
62 |
+
|
63 |
+
temp_dir = tempfile.gettempdir() # use system temp directory
|
64 |
+
temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
|
65 |
+
temp_path = os.path.join(temp_dir, temp_filename)
|
66 |
+
pil_image.save(temp_path)
|
67 |
+
|
68 |
+
# object detection
|
69 |
+
result = model_instance.detect(temp_path)
|
70 |
+
|
71 |
+
if result is None:
|
72 |
+
return None, "Detection failed. Please try again with a different image.", {}
|
73 |
+
|
74 |
+
# calculate stats
|
75 |
+
stats = EvaluationMetrics.calculate_basic_stats(result)
|
76 |
+
|
77 |
+
# add space calculation
|
78 |
+
spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
|
79 |
+
stats["spatial_metrics"] = spatial_metrics
|
80 |
+
|
81 |
+
if filter_classes and len(filter_classes) > 0:
|
82 |
+
# get classes, boxes, confidence
|
83 |
+
classes = result.boxes.cls.cpu().numpy().astype(int)
|
84 |
+
confs = result.boxes.conf.cpu().numpy()
|
85 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
86 |
+
|
87 |
+
mask = np.zeros_like(classes, dtype=bool)
|
88 |
+
for cls_id in filter_classes:
|
89 |
+
mask = np.logical_or(mask, classes == cls_id)
|
90 |
+
|
91 |
+
filtered_stats = {
|
92 |
+
"total_objects": int(np.sum(mask)),
|
93 |
+
"class_statistics": {},
|
94 |
+
"average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
|
95 |
+
"spatial_metrics": stats["spatial_metrics"]
|
96 |
+
}
|
97 |
+
|
98 |
+
# update stats
|
99 |
+
names = result.names
|
100 |
+
for cls, conf in zip(classes[mask], confs[mask]):
|
101 |
+
cls_name = names[int(cls)]
|
102 |
+
if cls_name not in filtered_stats["class_statistics"]:
|
103 |
+
filtered_stats["class_statistics"][cls_name] = {
|
104 |
+
"count": 0,
|
105 |
+
"average_confidence": 0
|
106 |
+
}
|
107 |
+
|
108 |
+
filtered_stats["class_statistics"][cls_name]["count"] += 1
|
109 |
+
filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
|
110 |
+
|
111 |
+
stats = filtered_stats
|
112 |
+
|
113 |
+
viz_data = EvaluationMetrics.generate_visualization_data(
|
114 |
+
result,
|
115 |
+
color_mapper.get_all_colors()
|
116 |
+
)
|
117 |
+
|
118 |
+
result_image = VisualizationHelper.visualize_detection(
|
119 |
+
temp_path, result, color_mapper=color_mapper, figsize=(12, 12), return_pil=True
|
120 |
+
)
|
121 |
+
|
122 |
+
result_text = EvaluationMetrics.format_detection_summary(viz_data)
|
123 |
+
|
124 |
+
return result_image, result_text, stats
|
125 |
+
|
126 |
+
except Exception as e:
|
127 |
+
error_message = f"Error Occurs: {str(e)}"
|
128 |
+
import traceback
|
129 |
+
traceback.print_exc()
|
130 |
+
print(error_message)
|
131 |
+
return None, error_message, {}
|
132 |
+
|
133 |
+
finally:
|
134 |
+
if temp_path and os.path.exists(temp_path):
|
135 |
+
try:
|
136 |
+
os.remove(temp_path)
|
137 |
+
except Exception as e:
|
138 |
+
print(f"Cannot delete temp files {temp_path}: {str(e)}")
|
139 |
+
|
140 |
+
def format_result_text(stats):
|
141 |
+
"""
|
142 |
+
Format detection statistics into readable text with improved spacing
|
143 |
+
|
144 |
+
Args:
|
145 |
+
stats: Dictionary containing detection statistics
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Formatted text summary
|
149 |
+
"""
|
150 |
+
if not stats or "total_objects" not in stats:
|
151 |
+
return "No objects detected."
|
152 |
+
|
153 |
+
# 減少不必要的空行
|
154 |
+
lines = [
|
155 |
+
f"Detected {stats['total_objects']} objects.",
|
156 |
+
f"Average confidence: {stats.get('average_confidence', 0):.2f}",
|
157 |
+
"Objects by class:"
|
158 |
+
]
|
159 |
+
|
160 |
+
if "class_statistics" in stats and stats["class_statistics"]:
|
161 |
+
# 按計數排序類別
|
162 |
+
sorted_classes = sorted(
|
163 |
+
stats["class_statistics"].items(),
|
164 |
+
key=lambda x: x[1]["count"],
|
165 |
+
reverse=True
|
166 |
+
)
|
167 |
+
|
168 |
+
for cls_name, cls_stats in sorted_classes:
|
169 |
+
count = cls_stats["count"]
|
170 |
+
conf = cls_stats.get("average_confidence", 0)
|
171 |
+
|
172 |
+
item_text = "item" if count == 1 else "items"
|
173 |
+
lines.append(f"• {cls_name}: {count} {item_text} (avg conf: {conf:.2f})")
|
174 |
+
else:
|
175 |
+
lines.append("No class information available.")
|
176 |
+
|
177 |
+
# 添加空間信息
|
178 |
+
if "spatial_metrics" in stats and "spatial_distribution" in stats["spatial_metrics"]:
|
179 |
+
lines.append("Object Distribution:")
|
180 |
+
|
181 |
+
dist = stats["spatial_metrics"]["spatial_distribution"]
|
182 |
+
x_mean = dist.get("x_mean", 0)
|
183 |
+
y_mean = dist.get("y_mean", 0)
|
184 |
+
|
185 |
+
# 描述物體的大致位置
|
186 |
+
if x_mean < 0.33:
|
187 |
+
h_pos = "on the left side"
|
188 |
+
elif x_mean < 0.67:
|
189 |
+
h_pos = "in the center"
|
190 |
+
else:
|
191 |
+
h_pos = "on the right side"
|
192 |
+
|
193 |
+
if y_mean < 0.33:
|
194 |
+
v_pos = "in the upper part"
|
195 |
+
elif y_mean < 0.67:
|
196 |
+
v_pos = "in the middle"
|
197 |
+
else:
|
198 |
+
v_pos = "in the lower part"
|
199 |
+
|
200 |
+
lines.append(f"• Most objects appear {h_pos} {v_pos} of the image")
|
201 |
+
|
202 |
+
return "\n".join(lines)
|
203 |
+
|
204 |
+
def format_json_for_display(stats):
|
205 |
+
"""
|
206 |
+
Format statistics JSON for better display
|
207 |
+
|
208 |
+
Args:
|
209 |
+
stats: Raw statistics dictionary
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
Formatted statistics structure for display
|
213 |
+
"""
|
214 |
+
# Create a cleaner copy of the stats for display
|
215 |
+
display_stats = {}
|
216 |
+
|
217 |
+
# Add summary section
|
218 |
+
display_stats["summary"] = {
|
219 |
+
"total_objects": stats.get("total_objects", 0),
|
220 |
+
"average_confidence": round(stats.get("average_confidence", 0), 3)
|
221 |
+
}
|
222 |
+
|
223 |
+
# Add class statistics in a more organized way
|
224 |
+
if "class_statistics" in stats and stats["class_statistics"]:
|
225 |
+
# Sort classes by count (descending)
|
226 |
+
sorted_classes = sorted(
|
227 |
+
stats["class_statistics"].items(),
|
228 |
+
key=lambda x: x[1].get("count", 0),
|
229 |
+
reverse=True
|
230 |
+
)
|
231 |
+
|
232 |
+
class_stats = {}
|
233 |
+
for cls_name, cls_data in sorted_classes:
|
234 |
+
class_stats[cls_name] = {
|
235 |
+
"count": cls_data.get("count", 0),
|
236 |
+
"average_confidence": round(cls_data.get("average_confidence", 0), 3)
|
237 |
+
}
|
238 |
+
|
239 |
+
display_stats["detected_objects"] = class_stats
|
240 |
+
|
241 |
+
# Simplify spatial metrics
|
242 |
+
if "spatial_metrics" in stats:
|
243 |
+
spatial = stats["spatial_metrics"]
|
244 |
+
|
245 |
+
# Simplify spatial distribution
|
246 |
+
if "spatial_distribution" in spatial:
|
247 |
+
dist = spatial["spatial_distribution"]
|
248 |
+
display_stats["spatial"] = {
|
249 |
+
"distribution": {
|
250 |
+
"x_mean": round(dist.get("x_mean", 0), 3),
|
251 |
+
"y_mean": round(dist.get("y_mean", 0), 3),
|
252 |
+
"x_std": round(dist.get("x_std", 0), 3),
|
253 |
+
"y_std": round(dist.get("y_std", 0), 3)
|
254 |
+
}
|
255 |
+
}
|
256 |
+
|
257 |
+
# Add simplified size information
|
258 |
+
if "size_distribution" in spatial:
|
259 |
+
size = spatial["size_distribution"]
|
260 |
+
display_stats["spatial"]["size"] = {
|
261 |
+
"mean_area": round(size.get("mean_area", 0), 3),
|
262 |
+
"min_area": round(size.get("min_area", 0), 3),
|
263 |
+
"max_area": round(size.get("max_area", 0), 3)
|
264 |
+
}
|
265 |
+
|
266 |
+
return display_stats
|
267 |
+
|
268 |
+
def get_all_classes():
|
269 |
+
"""
|
270 |
+
Get all available COCO classes from the currently active model or fallback to standard COCO classes
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
List of tuples (class_id, class_name)
|
274 |
+
"""
|
275 |
+
global model_instances
|
276 |
+
|
277 |
+
# Try to get class names from any loaded model
|
278 |
+
for model_name, model_instance in model_instances.items():
|
279 |
+
if model_instance and model_instance.is_model_loaded:
|
280 |
+
try:
|
281 |
+
class_names = model_instance.class_names
|
282 |
+
return [(idx, name) for idx, name in class_names.items()]
|
283 |
+
except Exception:
|
284 |
+
pass
|
285 |
+
|
286 |
+
# Fallback to standard COCO classes
|
287 |
+
return [
|
288 |
+
(0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'),
|
289 |
+
(5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'),
|
290 |
+
(10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'),
|
291 |
+
(14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'),
|
292 |
+
(20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'),
|
293 |
+
(25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'),
|
294 |
+
(30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'),
|
295 |
+
(35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'),
|
296 |
+
(39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'),
|
297 |
+
(44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'),
|
298 |
+
(49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'),
|
299 |
+
(54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'),
|
300 |
+
(59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'),
|
301 |
+
(64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'),
|
302 |
+
(69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'),
|
303 |
+
(74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'),
|
304 |
+
(79, 'toothbrush')
|
305 |
+
]
|
306 |
+
|
307 |
+
def create_interface():
|
308 |
+
"""創建 Gradio 界面,包含美化的視覺效果"""
|
309 |
+
css = Style.get_css()
|
310 |
+
|
311 |
+
# 獲取可用模型信息
|
312 |
+
available_models = DetectionModel.get_available_models()
|
313 |
+
model_choices = [model["model_file"] for model in available_models]
|
314 |
+
model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models]
|
315 |
+
|
316 |
+
# 可用類別過濾選項
|
317 |
+
available_classes = get_all_classes()
|
318 |
+
class_choices = [f"{id}: {name}" for id, name in available_classes]
|
319 |
+
|
320 |
+
# 創建 Gradio Blocks 界面
|
321 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="teal", secondary_hue="blue")) as demo:
|
322 |
+
# 頁面頂部標題
|
323 |
+
with gr.Group(elem_classes="app-header"):
|
324 |
+
gr.HTML("""
|
325 |
+
<div style="text-align: center; width: 100%;">
|
326 |
+
<h1 class="app-title">VisionScout</h1>
|
327 |
+
<h2 class="app-subtitle">Detect and identify objects in your images</h2>
|
328 |
+
<div class="app-divider"></div>
|
329 |
+
</div>
|
330 |
+
""")
|
331 |
+
|
332 |
+
current_model = gr.State("yolov8m.pt") # use medium size model as defualt
|
333 |
+
|
334 |
+
# 主要內容區 - 輸入和輸出面板
|
335 |
+
with gr.Row(equal_height=True):
|
336 |
+
# 左側 - 輸入控制區(可上傳圖片)
|
337 |
+
with gr.Column(scale=4, elem_classes="input-panel"):
|
338 |
+
with gr.Group():
|
339 |
+
gr.HTML('<div class="section-heading">Upload Image</div>')
|
340 |
+
image_input = gr.Image(type="pil", label="Upload an image", elem_classes="upload-box")
|
341 |
+
|
342 |
+
with gr.Accordion("Advanced Settings", open=False):
|
343 |
+
with gr.Row():
|
344 |
+
model_dropdown = gr.Dropdown(
|
345 |
+
choices=model_choices,
|
346 |
+
value="yolov8m.pt",
|
347 |
+
label="Select Model",
|
348 |
+
info="Choose different models based on your needs for speed vs. accuracy"
|
349 |
+
)
|
350 |
+
|
351 |
+
# display model info
|
352 |
+
model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt"))
|
353 |
+
|
354 |
+
confidence = gr.Slider(
|
355 |
+
minimum=0.1,
|
356 |
+
maximum=0.9,
|
357 |
+
value=0.25,
|
358 |
+
step=0.05,
|
359 |
+
label="Confidence Threshold",
|
360 |
+
info="Higher values show fewer but more confident detections"
|
361 |
+
)
|
362 |
+
|
363 |
+
with gr.Accordion("Filter Classes", open=False):
|
364 |
+
# 常見物件類別快速選擇按鈕
|
365 |
+
gr.HTML('<div class="section-heading" style="font-size: 1rem;">Common Categories</div>')
|
366 |
+
with gr.Row():
|
367 |
+
people_btn = gr.Button("People", size="sm")
|
368 |
+
vehicles_btn = gr.Button("Vehicles", size="sm")
|
369 |
+
animals_btn = gr.Button("Animals", size="sm")
|
370 |
+
objects_btn = gr.Button("Common Objects", size="sm")
|
371 |
+
|
372 |
+
# 類別選擇下拉框
|
373 |
+
class_filter = gr.Dropdown(
|
374 |
+
choices=class_choices,
|
375 |
+
multiselect=True,
|
376 |
+
label="Select Classes to Display",
|
377 |
+
info="Leave empty to show all detected objects"
|
378 |
+
)
|
379 |
+
|
380 |
+
# detect buttom
|
381 |
+
detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn")
|
382 |
+
|
383 |
+
# 使用說明區
|
384 |
+
with gr.Group(elem_classes="how-to-use"):
|
385 |
+
gr.HTML('<div class="section-heading">How to Use</div>')
|
386 |
+
gr.Markdown("""
|
387 |
+
1. Upload an image or use the camera
|
388 |
+
2. Adjust confidence threshold if needed
|
389 |
+
3. Optionally filter to specific object classes
|
390 |
+
4. Click "Detect Objects" button
|
391 |
+
|
392 |
+
The model will identify objects in your image and display them with bounding boxes.
|
393 |
+
|
394 |
+
**Note:** Detection quality depends on image clarity and object visibility. The model can detect up to 80 different types of common objects.
|
395 |
+
""")
|
396 |
+
|
397 |
+
# 右側 - 結果顯示區
|
398 |
+
with gr.Column(scale=6, elem_classes="output-panel"):
|
399 |
+
with gr.Tabs(elem_classes="tabs"):
|
400 |
+
with gr.Tab("Detection Result"):
|
401 |
+
result_image = gr.Image(type="pil", label="Detection Result")
|
402 |
+
|
403 |
+
# 文本框的格式
|
404 |
+
with gr.Group(elem_classes="result-details-box"):
|
405 |
+
gr.HTML('<div class="section-heading">Detection Details</div>')
|
406 |
+
# 文本框設置,讓顯示會更寬
|
407 |
+
result_text = gr.Textbox(
|
408 |
+
label=None,
|
409 |
+
lines=12,
|
410 |
+
max_lines=15,
|
411 |
+
elem_classes="wide-result-text",
|
412 |
+
elem_id="detection-details",
|
413 |
+
container=False,
|
414 |
+
scale=2,
|
415 |
+
min_width=600
|
416 |
+
)
|
417 |
+
|
418 |
+
with gr.Tab("Statistics"):
|
419 |
+
with gr.Row():
|
420 |
+
with gr.Column(scale=3, elem_classes="plot-column"):
|
421 |
+
gr.HTML('<div class="section-heading">Object Distribution</div>')
|
422 |
+
plot_output = gr.Plot(
|
423 |
+
label=None,
|
424 |
+
elem_classes="large-plot-container"
|
425 |
+
)
|
426 |
+
|
427 |
+
# 右側放 JSON 數據比較清晰
|
428 |
+
with gr.Column(scale=2, elem_classes="stats-column"):
|
429 |
+
gr.HTML('<div class="section-heading">Detection Statistics</div>')
|
430 |
+
stats_json = gr.JSON(
|
431 |
+
label=None, # remove label
|
432 |
+
elem_classes="enhanced-json-display"
|
433 |
+
)
|
434 |
+
|
435 |
+
detect_btn.click(
|
436 |
+
fn=lambda img, model, conf, classes: process_and_plot(img, model, conf, classes),
|
437 |
+
inputs=[image_input, current_model, confidence, class_filter],
|
438 |
+
outputs=[result_image, result_text, stats_json, plot_output]
|
439 |
+
)
|
440 |
+
|
441 |
+
# model option
|
442 |
+
model_dropdown.change(
|
443 |
+
fn=lambda model: (model, DetectionModel.get_model_description(model)),
|
444 |
+
inputs=[model_dropdown],
|
445 |
+
outputs=[current_model, model_info]
|
446 |
+
)
|
447 |
+
|
448 |
+
# each classes link
|
449 |
+
people_classes = [0] # 人
|
450 |
+
vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # 各種車輛
|
451 |
+
animals_classes = list(range(14, 24)) # COCO 中的動物
|
452 |
+
common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # 常見家居物品
|
453 |
+
|
454 |
+
# Linked the quik buttom
|
455 |
+
people_btn.click(
|
456 |
+
lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes],
|
457 |
+
outputs=class_filter
|
458 |
+
)
|
459 |
+
|
460 |
+
vehicles_btn.click(
|
461 |
+
lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes],
|
462 |
+
outputs=class_filter
|
463 |
+
)
|
464 |
+
|
465 |
+
animals_btn.click(
|
466 |
+
lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes],
|
467 |
+
outputs=class_filter
|
468 |
+
)
|
469 |
+
|
470 |
+
objects_btn.click(
|
471 |
+
lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects],
|
472 |
+
outputs=class_filter
|
473 |
+
)
|
474 |
+
|
475 |
+
example_images = [
|
476 |
+
"room_01.jpg",
|
477 |
+
"street_01.jpg",
|
478 |
+
"street_02.jpg",
|
479 |
+
"street_03.jpg"
|
480 |
+
]
|
481 |
+
|
482 |
+
# add example images
|
483 |
+
gr.Examples(
|
484 |
+
examples=example_images,
|
485 |
+
inputs=image_input,
|
486 |
+
outputs=None,
|
487 |
+
fn=None,
|
488 |
+
cache_examples=False,
|
489 |
+
)
|
490 |
+
|
491 |
+
# 頁腳部分
|
492 |
+
gr.HTML("""
|
493 |
+
<div class="footer">
|
494 |
+
<p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p>
|
495 |
+
<p>Model can detect 80 different classes of objects</p>
|
496 |
+
</div>
|
497 |
+
""")
|
498 |
+
|
499 |
+
return demo
|
500 |
+
|
501 |
+
@spaces.GPU
|
502 |
+
def process_and_plot(image, model_name, confidence_threshold, filter_classes=None):
|
503 |
+
"""
|
504 |
+
Process image and create plots for statistics with enhanced visualization
|
505 |
+
|
506 |
+
Args:
|
507 |
+
image: Input image
|
508 |
+
model_name: Name of the model to use
|
509 |
+
confidence_threshold: Confidence threshold for detection
|
510 |
+
filter_classes: Optional list of classes to filter results
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
Tuple of (result_image, result_text, formatted_stats, plot_figure)
|
514 |
+
"""
|
515 |
+
global model_instances
|
516 |
+
|
517 |
+
if model_name not in model_instances:
|
518 |
+
print(f"Creating new model instance for {model_name}")
|
519 |
+
model_instances[model_name] = DetectionModel(model_name=model_name, confidence=confidence_threshold, iou=0.45)
|
520 |
+
else:
|
521 |
+
print(f"Using existing model instance for {model_name}")
|
522 |
+
model_instances[model_name].confidence = confidence_threshold
|
523 |
+
|
524 |
+
class_ids = None
|
525 |
+
if filter_classes:
|
526 |
+
class_ids = []
|
527 |
+
for class_str in filter_classes:
|
528 |
+
try:
|
529 |
+
# Extract ID from format "id: name"
|
530 |
+
class_id = int(class_str.split(":")[0].strip())
|
531 |
+
class_ids.append(class_id)
|
532 |
+
except:
|
533 |
+
continue
|
534 |
+
|
535 |
+
# Execute detection
|
536 |
+
result_image, result_text, stats = process_image(
|
537 |
+
image,
|
538 |
+
model_instances[model_name],
|
539 |
+
confidence_threshold,
|
540 |
+
class_ids
|
541 |
+
)
|
542 |
+
|
543 |
+
# Format the statistics for better display
|
544 |
+
formatted_stats = format_json_for_display(stats)
|
545 |
+
|
546 |
+
if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
|
547 |
+
# Create the table
|
548 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
549 |
+
ax.text(0.5, 0.5, "No detection data available",
|
550 |
+
ha='center', va='center', fontsize=14, fontfamily='Arial')
|
551 |
+
ax.set_xlim(0, 1)
|
552 |
+
ax.set_ylim(0, 1)
|
553 |
+
ax.axis('off')
|
554 |
+
plot_figure = fig
|
555 |
+
else:
|
556 |
+
# prepare visualization data
|
557 |
+
viz_data = {
|
558 |
+
"total_objects": stats.get("total_objects", 0),
|
559 |
+
"average_confidence": stats.get("average_confidence", 0),
|
560 |
+
"class_data": []
|
561 |
+
}
|
562 |
+
|
563 |
+
# get the color map
|
564 |
+
color_mapper_instance = ColorMapper()
|
565 |
+
|
566 |
+
# class data
|
567 |
+
available_classes = dict(get_all_classes())
|
568 |
+
for cls_name, cls_stats in stats.get("class_statistics", {}).items():
|
569 |
+
# search class ID
|
570 |
+
class_id = -1
|
571 |
+
for id, name in available_classes.items():
|
572 |
+
if name == cls_name:
|
573 |
+
class_id = id
|
574 |
+
break
|
575 |
+
|
576 |
+
cls_data = {
|
577 |
+
"name": cls_name,
|
578 |
+
"class_id": class_id,
|
579 |
+
"count": cls_stats.get("count", 0),
|
580 |
+
"average_confidence": cls_stats.get("average_confidence", 0),
|
581 |
+
"color": color_mapper_instance.get_color(class_id if class_id >= 0 else cls_name)
|
582 |
+
}
|
583 |
+
|
584 |
+
viz_data["class_data"].append(cls_data)
|
585 |
+
|
586 |
+
# descending order
|
587 |
+
viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
|
588 |
+
|
589 |
+
plot_figure = EvaluationMetrics.create_enhanced_stats_plot(viz_data)
|
590 |
+
|
591 |
+
return result_image, result_text, formatted_stats, plot_figure
|
592 |
+
|
593 |
+
|
594 |
+
if __name__ == "__main__":
|
595 |
+
import time
|
596 |
+
|
597 |
+
demo = create_interface()
|
598 |
+
demo.launch()
|
color_mapper.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Dict, List, Tuple, Union, Any
|
3 |
+
|
4 |
+
class ColorMapper:
|
5 |
+
"""
|
6 |
+
A class for consistent color mapping of object detection classes
|
7 |
+
Provides color schemes for visualization in both RGB and hex formats
|
8 |
+
"""
|
9 |
+
|
10 |
+
# Class categories for better organization
|
11 |
+
CATEGORIES = {
|
12 |
+
"person": [0],
|
13 |
+
"vehicles": [1, 2, 3, 4, 5, 6, 7, 8],
|
14 |
+
"traffic": [9, 10, 11, 12],
|
15 |
+
"animals": [14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
|
16 |
+
"outdoor": [13, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33],
|
17 |
+
"sports": [34, 35, 36, 37, 38],
|
18 |
+
"kitchen": [39, 40, 41, 42, 43, 44, 45],
|
19 |
+
"food": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55],
|
20 |
+
"furniture": [56, 57, 58, 59, 60, 61],
|
21 |
+
"electronics": [62, 63, 64, 65, 66, 67, 68, 69, 70],
|
22 |
+
"household": [71, 72, 73, 74, 75, 76, 77, 78, 79]
|
23 |
+
}
|
24 |
+
|
25 |
+
# Base colors for each category (in HSV for easier variation)
|
26 |
+
CATEGORY_COLORS = {
|
27 |
+
"person": (0, 0.8, 0.9), # Red
|
28 |
+
"vehicles": (210, 0.8, 0.9), # Blue
|
29 |
+
"traffic": (45, 0.8, 0.9), # Orange
|
30 |
+
"animals": (120, 0.7, 0.8), # Green
|
31 |
+
"outdoor": (180, 0.7, 0.9), # Cyan
|
32 |
+
"sports": (270, 0.7, 0.8), # Purple
|
33 |
+
"kitchen": (30, 0.7, 0.9), # Light Orange
|
34 |
+
"food": (330, 0.7, 0.85), # Pink
|
35 |
+
"furniture": (150, 0.5, 0.85), # Light Green
|
36 |
+
"electronics": (240, 0.6, 0.9), # Light Blue
|
37 |
+
"household": (60, 0.6, 0.9) # Yellow
|
38 |
+
}
|
39 |
+
|
40 |
+
def __init__(self):
|
41 |
+
"""Initialize the ColorMapper with COCO class mappings"""
|
42 |
+
self.class_names = self._get_coco_classes()
|
43 |
+
self.color_map = self._generate_color_map()
|
44 |
+
|
45 |
+
def _get_coco_classes(self) -> Dict[int, str]:
|
46 |
+
"""Get the standard COCO class names with their IDs"""
|
47 |
+
return {
|
48 |
+
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
|
49 |
+
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
|
50 |
+
10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
|
51 |
+
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
|
52 |
+
20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
|
53 |
+
25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
|
54 |
+
30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
|
55 |
+
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket',
|
56 |
+
39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife',
|
57 |
+
44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich',
|
58 |
+
49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza',
|
59 |
+
54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant',
|
60 |
+
59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop',
|
61 |
+
64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave',
|
62 |
+
69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book',
|
63 |
+
74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier',
|
64 |
+
79: 'toothbrush'
|
65 |
+
}
|
66 |
+
|
67 |
+
def _hsv_to_rgb(self, h: float, s: float, v: float) -> Tuple[int, int, int]:
|
68 |
+
"""
|
69 |
+
Convert HSV color to RGB
|
70 |
+
|
71 |
+
Args:
|
72 |
+
h: Hue (0-360)
|
73 |
+
s: Saturation (0-1)
|
74 |
+
v: Value (0-1)
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Tuple of (R, G, B) values (0-255)
|
78 |
+
"""
|
79 |
+
h = h / 60
|
80 |
+
i = int(h)
|
81 |
+
f = h - i
|
82 |
+
p = v * (1 - s)
|
83 |
+
q = v * (1 - s * f)
|
84 |
+
t = v * (1 - s * (1 - f))
|
85 |
+
|
86 |
+
if i == 0:
|
87 |
+
r, g, b = v, t, p
|
88 |
+
elif i == 1:
|
89 |
+
r, g, b = q, v, p
|
90 |
+
elif i == 2:
|
91 |
+
r, g, b = p, v, t
|
92 |
+
elif i == 3:
|
93 |
+
r, g, b = p, q, v
|
94 |
+
elif i == 4:
|
95 |
+
r, g, b = t, p, v
|
96 |
+
else:
|
97 |
+
r, g, b = v, p, q
|
98 |
+
|
99 |
+
return (int(r * 255), int(g * 255), int(b * 255))
|
100 |
+
|
101 |
+
def _rgb_to_hex(self, rgb: Tuple[int, int, int]) -> str:
|
102 |
+
"""
|
103 |
+
Convert RGB color to hex color code
|
104 |
+
|
105 |
+
Args:
|
106 |
+
rgb: Tuple of (R, G, B) values (0-255)
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Hex color code (e.g. '#FF0000')
|
110 |
+
"""
|
111 |
+
return f'#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}'
|
112 |
+
|
113 |
+
def _find_category(self, class_id: int) -> str:
|
114 |
+
"""
|
115 |
+
Find the category for a given class ID
|
116 |
+
|
117 |
+
Args:
|
118 |
+
class_id: Class ID (0-79)
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Category name
|
122 |
+
"""
|
123 |
+
for category, ids in self.CATEGORIES.items():
|
124 |
+
if class_id in ids:
|
125 |
+
return category
|
126 |
+
return "other" # Fallback
|
127 |
+
|
128 |
+
def _generate_color_map(self) -> Dict:
|
129 |
+
"""
|
130 |
+
Generate a color map for all 80 COCO classes
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
Dictionary mapping class IDs and names to color values
|
134 |
+
"""
|
135 |
+
color_map = {
|
136 |
+
'by_id': {}, # Map class ID to RGB and hex
|
137 |
+
'by_name': {}, # Map class name to RGB and hex
|
138 |
+
'categories': {} # Map category to base color
|
139 |
+
}
|
140 |
+
|
141 |
+
# Generate colors for categories
|
142 |
+
for category, hsv in self.CATEGORY_COLORS.items():
|
143 |
+
rgb = self._hsv_to_rgb(hsv[0], hsv[1], hsv[2])
|
144 |
+
hex_color = self._rgb_to_hex(rgb)
|
145 |
+
color_map['categories'][category] = {
|
146 |
+
'rgb': rgb,
|
147 |
+
'hex': hex_color
|
148 |
+
}
|
149 |
+
|
150 |
+
# Generate variations for each class within a category
|
151 |
+
for class_id, class_name in self.class_names.items():
|
152 |
+
category = self._find_category(class_id)
|
153 |
+
base_hsv = self.CATEGORY_COLORS.get(category, (0, 0, 0.8)) # Default gray
|
154 |
+
|
155 |
+
# Slightly vary the hue and saturation within the category
|
156 |
+
ids_in_category = self.CATEGORIES.get(category, [])
|
157 |
+
if ids_in_category:
|
158 |
+
position = ids_in_category.index(class_id) if class_id in ids_in_category else 0
|
159 |
+
variation = position / max(1, len(ids_in_category) - 1) # 0 to 1
|
160 |
+
|
161 |
+
# Vary hue slightly (±15°) and saturation
|
162 |
+
h_offset = 30 * variation - 15 # -15 to +15
|
163 |
+
s_offset = 0.2 * variation # 0 to 0.2
|
164 |
+
|
165 |
+
h = (base_hsv[0] + h_offset) % 360
|
166 |
+
s = min(1.0, base_hsv[1] + s_offset)
|
167 |
+
v = base_hsv[2]
|
168 |
+
else:
|
169 |
+
h, s, v = base_hsv
|
170 |
+
|
171 |
+
rgb = self._hsv_to_rgb(h, s, v)
|
172 |
+
hex_color = self._rgb_to_hex(rgb)
|
173 |
+
|
174 |
+
# Store in both mappings
|
175 |
+
color_map['by_id'][class_id] = {
|
176 |
+
'rgb': rgb,
|
177 |
+
'hex': hex_color,
|
178 |
+
'category': category
|
179 |
+
}
|
180 |
+
|
181 |
+
color_map['by_name'][class_name] = {
|
182 |
+
'rgb': rgb,
|
183 |
+
'hex': hex_color,
|
184 |
+
'category': category
|
185 |
+
}
|
186 |
+
|
187 |
+
return color_map
|
188 |
+
|
189 |
+
def get_color(self, class_identifier: Union[int, str], format: str = 'hex') -> Any:
|
190 |
+
"""
|
191 |
+
Get color for a specific class
|
192 |
+
|
193 |
+
Args:
|
194 |
+
class_identifier: Class ID (int) or name (str)
|
195 |
+
format: Color format ('hex', 'rgb', or 'bgr')
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
Color in requested format
|
199 |
+
"""
|
200 |
+
# Determine if identifier is an ID or name
|
201 |
+
if isinstance(class_identifier, int):
|
202 |
+
color_info = self.color_map['by_id'].get(class_identifier)
|
203 |
+
else:
|
204 |
+
color_info = self.color_map['by_name'].get(class_identifier)
|
205 |
+
|
206 |
+
if not color_info:
|
207 |
+
# Fallback color if not found
|
208 |
+
return '#CCCCCC' if format == 'hex' else (204, 204, 204)
|
209 |
+
|
210 |
+
if format == 'hex':
|
211 |
+
return color_info['hex']
|
212 |
+
elif format == 'rgb':
|
213 |
+
return color_info['rgb']
|
214 |
+
elif format == 'bgr':
|
215 |
+
# Convert RGB to BGR for OpenCV
|
216 |
+
r, g, b = color_info['rgb']
|
217 |
+
return (b, g, r)
|
218 |
+
else:
|
219 |
+
return color_info['rgb']
|
220 |
+
|
221 |
+
def get_all_colors(self, format: str = 'hex') -> Dict:
|
222 |
+
"""
|
223 |
+
Get all colors in the specified format
|
224 |
+
|
225 |
+
Args:
|
226 |
+
format: Color format ('hex', 'rgb', or 'bgr')
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
Dictionary mapping class names to colors
|
230 |
+
"""
|
231 |
+
result = {}
|
232 |
+
for class_id, class_name in self.class_names.items():
|
233 |
+
result[class_name] = self.get_color(class_id, format)
|
234 |
+
return result
|
235 |
+
|
236 |
+
def get_category_colors(self, format: str = 'hex') -> Dict:
|
237 |
+
"""
|
238 |
+
Get base colors for each category
|
239 |
+
|
240 |
+
Args:
|
241 |
+
format: Color format ('hex', 'rgb', or 'bgr')
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
Dictionary mapping categories to colors
|
245 |
+
"""
|
246 |
+
result = {}
|
247 |
+
for category, color_info in self.color_map['categories'].items():
|
248 |
+
if format == 'hex':
|
249 |
+
result[category] = color_info['hex']
|
250 |
+
elif format == 'bgr':
|
251 |
+
r, g, b = color_info['rgb']
|
252 |
+
result[category] = (b, g, r)
|
253 |
+
else:
|
254 |
+
result[category] = color_info['rgb']
|
255 |
+
return result
|
256 |
+
|
257 |
+
def get_category_for_class(self, class_identifier: Union[int, str]) -> str:
|
258 |
+
"""
|
259 |
+
Get the category for a specific class
|
260 |
+
|
261 |
+
Args:
|
262 |
+
class_identifier: Class ID (int) or name (str)
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
Category name
|
266 |
+
"""
|
267 |
+
if isinstance(class_identifier, int):
|
268 |
+
return self.color_map['by_id'].get(class_identifier, {}).get('category', 'other')
|
269 |
+
else:
|
270 |
+
return self.color_map['by_name'].get(class_identifier, {}).get('category', 'other')
|
detection_model.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ultralytics import YOLO
|
2 |
+
from typing import Any, List, Dict, Optional
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
|
7 |
+
class DetectionModel:
|
8 |
+
"""Core detection model class for object detection using YOLOv8"""
|
9 |
+
|
10 |
+
# Model information dictionary
|
11 |
+
MODEL_INFO = {
|
12 |
+
"yolov8n.pt": {
|
13 |
+
"name": "YOLOv8n (Nano)",
|
14 |
+
"description": "Fastest model with smallest size (3.2M parameters). Best for speed-critical applications.",
|
15 |
+
"size_mb": 6,
|
16 |
+
"inference_speed": "Very Fast"
|
17 |
+
},
|
18 |
+
"yolov8m.pt": {
|
19 |
+
"name": "YOLOv8m (Medium)",
|
20 |
+
"description": "Balanced model with good accuracy-speed tradeoff (25.9M parameters). Recommended for general use.",
|
21 |
+
"size_mb": 25,
|
22 |
+
"inference_speed": "Medium"
|
23 |
+
},
|
24 |
+
"yolov8x.pt": {
|
25 |
+
"name": "YOLOv8x (XLarge)",
|
26 |
+
"description": "Most accurate but slower model (68.2M parameters). Best for accuracy-critical applications.",
|
27 |
+
"size_mb": 68,
|
28 |
+
"inference_speed": "Slower"
|
29 |
+
}
|
30 |
+
}
|
31 |
+
|
32 |
+
def __init__(self, model_name: str = 'yolov8m.pt', confidence: float = 0.25, iou: float = 0.45):
|
33 |
+
"""
|
34 |
+
Initialize the detection model
|
35 |
+
|
36 |
+
Args:
|
37 |
+
model_name: Model name or path, default is yolov8m.pt
|
38 |
+
confidence: Confidence threshold, default is 0.25
|
39 |
+
iou: IoU threshold for non-maximum suppression, default is 0.45
|
40 |
+
"""
|
41 |
+
self.model_name = model_name
|
42 |
+
self.confidence = confidence
|
43 |
+
self.iou = iou
|
44 |
+
self.model = None
|
45 |
+
self.class_names = {}
|
46 |
+
self.is_model_loaded = False
|
47 |
+
|
48 |
+
# Load model on initialization
|
49 |
+
self._load_model()
|
50 |
+
|
51 |
+
def _load_model(self):
|
52 |
+
"""Load the YOLO model"""
|
53 |
+
try:
|
54 |
+
print(f"Loading model: {self.model_name}")
|
55 |
+
self.model = YOLO(self.model_name)
|
56 |
+
self.class_names = self.model.names
|
57 |
+
self.is_model_loaded = True
|
58 |
+
print(f"Successfully loaded model: {self.model_name}")
|
59 |
+
print(f"Number of classes the model can recognize: {len(self.class_names)}")
|
60 |
+
except Exception as e:
|
61 |
+
print(f"Error occurred when loading the model: {e}")
|
62 |
+
self.is_model_loaded = False
|
63 |
+
|
64 |
+
def change_model(self, new_model_name: str) -> bool:
|
65 |
+
"""
|
66 |
+
Change the currently loaded model
|
67 |
+
|
68 |
+
Args:
|
69 |
+
new_model_name: Name of the new model to load
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
bool: True if model changed successfully, False otherwise
|
73 |
+
"""
|
74 |
+
if self.model_name == new_model_name and self.is_model_loaded:
|
75 |
+
print(f"Model {new_model_name} is already loaded")
|
76 |
+
return True
|
77 |
+
|
78 |
+
print(f"Changing model from {self.model_name} to {new_model_name}")
|
79 |
+
|
80 |
+
# Unload current model to free memory
|
81 |
+
if self.model is not None:
|
82 |
+
del self.model
|
83 |
+
self.model = None
|
84 |
+
|
85 |
+
# Clean GPU memory if available
|
86 |
+
if torch.cuda.is_available():
|
87 |
+
torch.cuda.empty_cache()
|
88 |
+
|
89 |
+
# Update model name and load new model
|
90 |
+
self.model_name = new_model_name
|
91 |
+
self._load_model()
|
92 |
+
|
93 |
+
return self.is_model_loaded
|
94 |
+
|
95 |
+
def reload_model(self):
|
96 |
+
"""Reload the model (useful for changing model or after error)"""
|
97 |
+
if self.model is not None:
|
98 |
+
del self.model
|
99 |
+
self.model = None
|
100 |
+
|
101 |
+
# Clean GPU memory if available
|
102 |
+
if torch.cuda.is_available():
|
103 |
+
torch.cuda.empty_cache()
|
104 |
+
|
105 |
+
self._load_model()
|
106 |
+
|
107 |
+
def detect(self, image_input: Any) -> Optional[Any]:
|
108 |
+
"""
|
109 |
+
Perform object detection on a single image
|
110 |
+
|
111 |
+
Args:
|
112 |
+
image_input: Image path (str), PIL Image, or numpy array
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Detection result object or None if error occurred
|
116 |
+
"""
|
117 |
+
if self.model is None or not self.is_model_loaded:
|
118 |
+
print("Model not found or not loaded. Attempting to reload...")
|
119 |
+
self._load_model()
|
120 |
+
if self.model is None or not self.is_model_loaded:
|
121 |
+
print("Failed to load model. Cannot perform detection.")
|
122 |
+
return None
|
123 |
+
|
124 |
+
try:
|
125 |
+
results = self.model(image_input, conf=self.confidence, iou=self.iou)
|
126 |
+
return results[0]
|
127 |
+
except Exception as e:
|
128 |
+
print(f"Error occurred during detection: {e}")
|
129 |
+
return None
|
130 |
+
|
131 |
+
def get_class_names(self, class_id: int) -> str:
|
132 |
+
"""Get class name for a given class ID"""
|
133 |
+
return self.class_names.get(class_id, "Unknown Class")
|
134 |
+
|
135 |
+
def get_supported_classes(self) -> Dict[int, str]:
|
136 |
+
"""Get all supported classes as a dictionary of {id: class_name}"""
|
137 |
+
return self.class_names
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def get_available_models(cls) -> List[Dict]:
|
141 |
+
"""
|
142 |
+
Get list of available models with their information
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
List of dictionaries containing model information
|
146 |
+
"""
|
147 |
+
models = []
|
148 |
+
for model_file, info in cls.MODEL_INFO.items():
|
149 |
+
models.append({
|
150 |
+
"model_file": model_file,
|
151 |
+
"name": info["name"],
|
152 |
+
"description": info["description"],
|
153 |
+
"size_mb": info["size_mb"],
|
154 |
+
"inference_speed": info["inference_speed"]
|
155 |
+
})
|
156 |
+
return models
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def get_model_description(cls, model_name: str) -> str:
|
160 |
+
"""Get description for a specific model"""
|
161 |
+
if model_name in cls.MODEL_INFO:
|
162 |
+
info = cls.MODEL_INFO[model_name]
|
163 |
+
return f"{info['name']}: {info['description']} (Size: ~{info['size_mb']}MB, Speed: {info['inference_speed']})"
|
164 |
+
return "Model information not available"
|
evaluation_metrics.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from typing import Dict, List, Any, Optional, Tuple
|
4 |
+
|
5 |
+
class EvaluationMetrics:
|
6 |
+
"""Class for computing detection metrics, generating statistics and visualization data"""
|
7 |
+
|
8 |
+
@staticmethod
|
9 |
+
def calculate_basic_stats(result: Any) -> Dict:
|
10 |
+
"""
|
11 |
+
Calculate basic statistics for a single detection result
|
12 |
+
|
13 |
+
Args:
|
14 |
+
result: Detection result object
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
Dictionary with basic statistics
|
18 |
+
"""
|
19 |
+
if result is None:
|
20 |
+
return {"error": "No detection result provided"}
|
21 |
+
|
22 |
+
# Get classes and confidences
|
23 |
+
classes = result.boxes.cls.cpu().numpy().astype(int)
|
24 |
+
confidences = result.boxes.conf.cpu().numpy()
|
25 |
+
names = result.names
|
26 |
+
|
27 |
+
# Count by class
|
28 |
+
class_counts = {}
|
29 |
+
for cls, conf in zip(classes, confidences):
|
30 |
+
cls_name = names[int(cls)]
|
31 |
+
if cls_name not in class_counts:
|
32 |
+
class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
|
33 |
+
|
34 |
+
class_counts[cls_name]["count"] += 1
|
35 |
+
class_counts[cls_name]["total_confidence"] += float(conf)
|
36 |
+
class_counts[cls_name]["confidences"].append(float(conf))
|
37 |
+
|
38 |
+
# Calculate average confidence
|
39 |
+
for cls_name, stats in class_counts.items():
|
40 |
+
if stats["count"] > 0:
|
41 |
+
stats["average_confidence"] = stats["total_confidence"] / stats["count"]
|
42 |
+
stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
|
43 |
+
stats.pop("total_confidence") # Remove intermediate calculation
|
44 |
+
|
45 |
+
# Prepare summary
|
46 |
+
stats = {
|
47 |
+
"total_objects": len(classes),
|
48 |
+
"class_statistics": class_counts,
|
49 |
+
"average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
|
50 |
+
}
|
51 |
+
|
52 |
+
return stats
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
|
56 |
+
"""
|
57 |
+
Generate structured data suitable for visualization
|
58 |
+
|
59 |
+
Args:
|
60 |
+
result: Detection result object
|
61 |
+
class_colors: Dictionary mapping class names to color codes (optional)
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dictionary with visualization-ready data
|
65 |
+
"""
|
66 |
+
if result is None:
|
67 |
+
return {"error": "No detection result provided"}
|
68 |
+
|
69 |
+
# Get basic stats first
|
70 |
+
stats = EvaluationMetrics.calculate_basic_stats(result)
|
71 |
+
|
72 |
+
# Create visualization-specific data structure
|
73 |
+
viz_data = {
|
74 |
+
"total_objects": stats["total_objects"],
|
75 |
+
"average_confidence": stats["average_confidence"],
|
76 |
+
"class_data": []
|
77 |
+
}
|
78 |
+
|
79 |
+
# Sort classes by count (descending)
|
80 |
+
sorted_classes = sorted(
|
81 |
+
stats["class_statistics"].items(),
|
82 |
+
key=lambda x: x[1]["count"],
|
83 |
+
reverse=True
|
84 |
+
)
|
85 |
+
|
86 |
+
# Create class-specific visualization data
|
87 |
+
for cls_name, cls_stats in sorted_classes:
|
88 |
+
class_id = -1
|
89 |
+
# Find the class ID based on the name
|
90 |
+
for idx, name in result.names.items():
|
91 |
+
if name == cls_name:
|
92 |
+
class_id = idx
|
93 |
+
break
|
94 |
+
|
95 |
+
cls_data = {
|
96 |
+
"name": cls_name,
|
97 |
+
"class_id": class_id,
|
98 |
+
"count": cls_stats["count"],
|
99 |
+
"average_confidence": cls_stats.get("average_confidence", 0),
|
100 |
+
"confidence_std": cls_stats.get("confidence_std", 0),
|
101 |
+
"color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
|
102 |
+
}
|
103 |
+
|
104 |
+
viz_data["class_data"].append(cls_data)
|
105 |
+
|
106 |
+
return viz_data
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
|
110 |
+
"""
|
111 |
+
Create a horizontal bar chart showing detection statistics
|
112 |
+
|
113 |
+
Args:
|
114 |
+
viz_data: Visualization data generated by generate_visualization_data
|
115 |
+
figsize: Figure size (width, height) in inches
|
116 |
+
max_classes: Maximum number of classes to display
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
Matplotlib figure object
|
120 |
+
"""
|
121 |
+
# Use the enhanced version
|
122 |
+
return EvaluationMetrics.create_enhanced_stats_plot(viz_data, figsize, max_classes)
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
|
126 |
+
"""
|
127 |
+
Create an enhanced horizontal bar chart with larger fonts and better styling
|
128 |
+
|
129 |
+
Args:
|
130 |
+
viz_data: Visualization data dictionary
|
131 |
+
figsize: Figure size (width, height) in inches
|
132 |
+
max_classes: Maximum number of classes to display
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
Matplotlib figure with enhanced styling
|
136 |
+
"""
|
137 |
+
if "error" in viz_data:
|
138 |
+
# Create empty plot if error
|
139 |
+
fig, ax = plt.subplots(figsize=figsize)
|
140 |
+
ax.text(0.5, 0.5, viz_data["error"],
|
141 |
+
ha='center', va='center', fontsize=14, fontfamily='Arial')
|
142 |
+
ax.set_xlim(0, 1)
|
143 |
+
ax.set_ylim(0, 1)
|
144 |
+
ax.axis('off')
|
145 |
+
return fig
|
146 |
+
|
147 |
+
if "class_data" not in viz_data or not viz_data["class_data"]:
|
148 |
+
# Create empty plot if no data
|
149 |
+
fig, ax = plt.subplots(figsize=figsize)
|
150 |
+
ax.text(0.5, 0.5, "No detection data available",
|
151 |
+
ha='center', va='center', fontsize=14, fontfamily='Arial')
|
152 |
+
ax.set_xlim(0, 1)
|
153 |
+
ax.set_ylim(0, 1)
|
154 |
+
ax.axis('off')
|
155 |
+
return fig
|
156 |
+
|
157 |
+
# Limit to max_classes
|
158 |
+
class_data = viz_data["class_data"][:max_classes]
|
159 |
+
|
160 |
+
# Extract data for plotting
|
161 |
+
class_names = [item["name"] for item in class_data]
|
162 |
+
counts = [item["count"] for item in class_data]
|
163 |
+
colors = [item["color"] for item in class_data]
|
164 |
+
|
165 |
+
# Create figure and horizontal bar chart with improved styling
|
166 |
+
plt.rcParams['font.family'] = 'Arial'
|
167 |
+
fig, ax = plt.subplots(figsize=figsize)
|
168 |
+
|
169 |
+
# Set background color to white
|
170 |
+
fig.patch.set_facecolor('white')
|
171 |
+
ax.set_facecolor('white')
|
172 |
+
|
173 |
+
y_pos = np.arange(len(class_names))
|
174 |
+
|
175 |
+
# Create horizontal bars with class-specific colors
|
176 |
+
bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6)
|
177 |
+
|
178 |
+
# Add count values at end of each bar with larger font
|
179 |
+
for i, bar in enumerate(bars):
|
180 |
+
width = bar.get_width()
|
181 |
+
conf = class_data[i]["average_confidence"]
|
182 |
+
ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
|
183 |
+
f"{width:.0f} (conf: {conf:.2f})",
|
184 |
+
va='center', fontsize=12, fontfamily='Arial')
|
185 |
+
|
186 |
+
# Customize axis and labels with larger fonts
|
187 |
+
ax.set_yticks(y_pos)
|
188 |
+
ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial')
|
189 |
+
ax.invert_yaxis() # Labels read top-to-bottom
|
190 |
+
ax.set_xlabel('Count', fontsize=14, fontfamily='Arial')
|
191 |
+
ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total',
|
192 |
+
fontsize=16, fontfamily='Arial', fontweight='bold')
|
193 |
+
|
194 |
+
# Add grid for better readability
|
195 |
+
ax.set_axisbelow(True)
|
196 |
+
ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB')
|
197 |
+
|
198 |
+
# Increase tick label font size
|
199 |
+
ax.tick_params(axis='both', which='major', labelsize=12)
|
200 |
+
|
201 |
+
# Add detection summary as a text box with improved styling
|
202 |
+
summary_text = (
|
203 |
+
f"Total Objects: {viz_data['total_objects']}\n"
|
204 |
+
f"Average Confidence: {viz_data['average_confidence']:.2f}\n"
|
205 |
+
f"Unique Classes: {len(viz_data['class_data'])}"
|
206 |
+
)
|
207 |
+
plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial',
|
208 |
+
bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5',
|
209 |
+
edgecolor='#E5E7EB'))
|
210 |
+
|
211 |
+
plt.tight_layout()
|
212 |
+
return fig
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def format_detection_summary(viz_data: Dict) -> str:
|
216 |
+
"""
|
217 |
+
Format detection results as a readable text summary with improved spacing
|
218 |
+
|
219 |
+
Args:
|
220 |
+
viz_data: Visualization data generated by generate_visualization_data
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Formatted text with proper spacing
|
224 |
+
"""
|
225 |
+
if "error" in viz_data:
|
226 |
+
return viz_data["error"]
|
227 |
+
|
228 |
+
if "total_objects" not in viz_data:
|
229 |
+
return "No detection data available."
|
230 |
+
|
231 |
+
# 獲取基本統計信息
|
232 |
+
total_objects = viz_data["total_objects"]
|
233 |
+
avg_confidence = viz_data["average_confidence"]
|
234 |
+
|
235 |
+
# 創建標題,使用更多空白行增加可讀性
|
236 |
+
lines = [
|
237 |
+
f"Detected {total_objects} objects.",
|
238 |
+
f"Average confidence: {avg_confidence:.2f}",
|
239 |
+
"\n", # 添加額外的空行
|
240 |
+
"Objects by class:",
|
241 |
+
]
|
242 |
+
|
243 |
+
# 添加類別詳情,每個類別使用更多空間
|
244 |
+
if "class_data" in viz_data and viz_data["class_data"]:
|
245 |
+
for item in viz_data["class_data"]:
|
246 |
+
count = item['count']
|
247 |
+
# 使用正確的單複數形式
|
248 |
+
item_text = "item" if count == 1 else "items"
|
249 |
+
|
250 |
+
# 每個項目前添加空行,並使用縮進格式化
|
251 |
+
lines.append("\n") # 每個項目前添加空白行
|
252 |
+
lines.append(f"• {item['name']}: {count} {item_text}")
|
253 |
+
lines.append(f" Confidence: {item['average_confidence']:.2f}")
|
254 |
+
else:
|
255 |
+
lines.append("\nNo class information available.")
|
256 |
+
|
257 |
+
return "\n".join(lines)
|
258 |
+
|
259 |
+
@staticmethod
|
260 |
+
def calculate_distance_metrics(result: Any) -> Dict:
|
261 |
+
"""
|
262 |
+
Calculate distance-related metrics for detected objects
|
263 |
+
|
264 |
+
Args:
|
265 |
+
result: Detection result object
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
Dictionary with distance metrics
|
269 |
+
"""
|
270 |
+
if result is None:
|
271 |
+
return {"error": "No detection result provided"}
|
272 |
+
|
273 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
274 |
+
classes = result.boxes.cls.cpu().numpy().astype(int)
|
275 |
+
names = result.names
|
276 |
+
|
277 |
+
# Initialize metrics
|
278 |
+
metrics = {
|
279 |
+
"proximity": {}, # Classes that appear close to each other
|
280 |
+
"spatial_distribution": {}, # Distribution across the image
|
281 |
+
"size_distribution": {} # Size distribution of objects
|
282 |
+
}
|
283 |
+
|
284 |
+
# Calculate image dimensions (assuming normalized coordinates or extract from result)
|
285 |
+
img_width, img_height = 1, 1
|
286 |
+
if hasattr(result, "orig_shape"):
|
287 |
+
img_height, img_width = result.orig_shape[:2]
|
288 |
+
|
289 |
+
# Calculate bounding box areas and centers
|
290 |
+
areas = []
|
291 |
+
centers = []
|
292 |
+
class_names = []
|
293 |
+
|
294 |
+
for box, cls in zip(boxes, classes):
|
295 |
+
x1, y1, x2, y2 = box
|
296 |
+
width, height = x2 - x1, y2 - y1
|
297 |
+
area = width * height
|
298 |
+
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
299 |
+
|
300 |
+
areas.append(area)
|
301 |
+
centers.append((center_x, center_y))
|
302 |
+
class_names.append(names[int(cls)])
|
303 |
+
|
304 |
+
# Calculate spatial distribution
|
305 |
+
if centers:
|
306 |
+
x_coords = [c[0] for c in centers]
|
307 |
+
y_coords = [c[1] for c in centers]
|
308 |
+
|
309 |
+
metrics["spatial_distribution"] = {
|
310 |
+
"x_mean": float(np.mean(x_coords)) / img_width,
|
311 |
+
"y_mean": float(np.mean(y_coords)) / img_height,
|
312 |
+
"x_std": float(np.std(x_coords)) / img_width,
|
313 |
+
"y_std": float(np.std(y_coords)) / img_height
|
314 |
+
}
|
315 |
+
|
316 |
+
# Calculate size distribution
|
317 |
+
if areas:
|
318 |
+
metrics["size_distribution"] = {
|
319 |
+
"mean_area": float(np.mean(areas)) / (img_width * img_height),
|
320 |
+
"std_area": float(np.std(areas)) / (img_width * img_height),
|
321 |
+
"min_area": float(np.min(areas)) / (img_width * img_height),
|
322 |
+
"max_area": float(np.max(areas)) / (img_width * img_height)
|
323 |
+
}
|
324 |
+
|
325 |
+
# Calculate proximity between different classes
|
326 |
+
class_centers = {}
|
327 |
+
for cls_name, center in zip(class_names, centers):
|
328 |
+
if cls_name not in class_centers:
|
329 |
+
class_centers[cls_name] = []
|
330 |
+
class_centers[cls_name].append(center)
|
331 |
+
|
332 |
+
# Find classes that appear close to each other
|
333 |
+
proximity_pairs = []
|
334 |
+
for i, cls1 in enumerate(class_centers.keys()):
|
335 |
+
for j, cls2 in enumerate(class_centers.keys()):
|
336 |
+
if i >= j: # Avoid duplicate pairs and self-comparison
|
337 |
+
continue
|
338 |
+
|
339 |
+
# Calculate minimum distance between any two objects of these classes
|
340 |
+
min_distance = float('inf')
|
341 |
+
for center1 in class_centers[cls1]:
|
342 |
+
for center2 in class_centers[cls2]:
|
343 |
+
dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
|
344 |
+
min_distance = min(min_distance, dist)
|
345 |
+
|
346 |
+
# Normalize by image diagonal
|
347 |
+
img_diagonal = np.sqrt(img_width**2 + img_height**2)
|
348 |
+
norm_distance = min_distance / img_diagonal
|
349 |
+
|
350 |
+
proximity_pairs.append({
|
351 |
+
"class1": cls1,
|
352 |
+
"class2": cls2,
|
353 |
+
"distance": float(norm_distance)
|
354 |
+
})
|
355 |
+
|
356 |
+
# Sort by distance and keep the closest pairs
|
357 |
+
proximity_pairs.sort(key=lambda x: x["distance"])
|
358 |
+
metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
|
359 |
+
|
360 |
+
return metrics
|
style.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Style:
|
2 |
+
@staticmethod
|
3 |
+
def get_css():
|
4 |
+
"""Return the application's CSS styles with improved aesthetics"""
|
5 |
+
css = """
|
6 |
+
/* Base styles and typography */
|
7 |
+
body {
|
8 |
+
font-family: Arial, sans-serif;
|
9 |
+
background: linear-gradient(135deg, #f0f9ff, #e1f5fe);
|
10 |
+
margin: 0;
|
11 |
+
padding: 0;
|
12 |
+
display: flex;
|
13 |
+
justify-content: center;
|
14 |
+
min-height: 100vh;
|
15 |
+
}
|
16 |
+
|
17 |
+
/* Typography improvements */
|
18 |
+
h1, h2, h3, h4, h5, h6, p, span, div, label, button {
|
19 |
+
font-family: Arial, sans-serif;
|
20 |
+
}
|
21 |
+
|
22 |
+
/* Container styling */
|
23 |
+
.gradio-container {
|
24 |
+
max-width: 1200px !important;
|
25 |
+
margin: 0 auto;
|
26 |
+
padding: 1rem;
|
27 |
+
width: 100%;
|
28 |
+
}
|
29 |
+
|
30 |
+
/* Header area styling with gradient background */
|
31 |
+
.app-header {
|
32 |
+
text-align: center;
|
33 |
+
margin-bottom: 2rem;
|
34 |
+
background: linear-gradient(135deg, #f8f9fa, #e9ecef);
|
35 |
+
padding: 1.5rem;
|
36 |
+
border-radius: 10px;
|
37 |
+
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.05);
|
38 |
+
width: 100%;
|
39 |
+
}
|
40 |
+
|
41 |
+
.app-title {
|
42 |
+
color: #2D3748;
|
43 |
+
font-size: 2.5rem;
|
44 |
+
margin-bottom: 0.5rem;
|
45 |
+
background: linear-gradient(90deg, #38b2ac, #4299e1);
|
46 |
+
-webkit-background-clip: text;
|
47 |
+
-webkit-text-fill-color: transparent;
|
48 |
+
font-weight: bold;
|
49 |
+
}
|
50 |
+
|
51 |
+
.app-subtitle {
|
52 |
+
color: #4A5568;
|
53 |
+
font-size: 1.2rem;
|
54 |
+
font-weight: normal;
|
55 |
+
margin-top: 0.25rem;
|
56 |
+
}
|
57 |
+
|
58 |
+
.app-divider {
|
59 |
+
width: 80px;
|
60 |
+
height: 3px;
|
61 |
+
background: linear-gradient(90deg, #38b2ac, #4299e1);
|
62 |
+
margin: 1rem auto;
|
63 |
+
}
|
64 |
+
|
65 |
+
/* Panel styling - gradient background */
|
66 |
+
.input-panel, .output-panel {
|
67 |
+
background: white;
|
68 |
+
border-radius: 10px;
|
69 |
+
padding: 1.5rem;
|
70 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
|
71 |
+
margin: 0 auto 1rem auto;
|
72 |
+
}
|
73 |
+
|
74 |
+
/* Section heading styling with gradient background */
|
75 |
+
.section-heading {
|
76 |
+
font-size: 1.25rem;
|
77 |
+
font-weight: 600;
|
78 |
+
color: #2D3748;
|
79 |
+
margin-bottom: 1rem;
|
80 |
+
margin-top: 0.5rem;
|
81 |
+
text-align: center;
|
82 |
+
padding: 0.8rem;
|
83 |
+
background: linear-gradient(to right, #e6f3fc, #f0f9ff);
|
84 |
+
border-radius: 8px;
|
85 |
+
}
|
86 |
+
|
87 |
+
/* How-to-use section with gradient background */
|
88 |
+
.how-to-use {
|
89 |
+
background: linear-gradient(135deg, #f8fafc, #e8f4fd);
|
90 |
+
border-radius: 10px;
|
91 |
+
padding: 1.5rem;
|
92 |
+
margin-top: 1rem;
|
93 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
|
94 |
+
color: #2d3748;
|
95 |
+
}
|
96 |
+
|
97 |
+
/* Detection button styling */
|
98 |
+
.detect-btn {
|
99 |
+
background: linear-gradient(90deg, #38b2ac, #4299e1) !important;
|
100 |
+
color: white !important;
|
101 |
+
border: none !important;
|
102 |
+
border-radius: 8px !important;
|
103 |
+
transition: transform 0.3s, box-shadow 0.3s !important;
|
104 |
+
font-weight: bold !important;
|
105 |
+
letter-spacing: 0.5px !important;
|
106 |
+
padding: 0.75rem 1.5rem !important;
|
107 |
+
width: 100%;
|
108 |
+
margin: 1rem auto !important;
|
109 |
+
font-family: Arial, sans-serif !important;
|
110 |
+
}
|
111 |
+
|
112 |
+
.detect-btn:hover {
|
113 |
+
transform: translateY(-2px) !important;
|
114 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2) !important;
|
115 |
+
}
|
116 |
+
|
117 |
+
.detect-btn:active {
|
118 |
+
transform: translateY(1px) !important;
|
119 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2) !important;
|
120 |
+
}
|
121 |
+
|
122 |
+
/* JSON display improvements */
|
123 |
+
.json-display pre {
|
124 |
+
background: #f8fafc;
|
125 |
+
border-radius: 6px;
|
126 |
+
padding: 1rem;
|
127 |
+
font-family: 'Consolas', 'Monaco', monospace;
|
128 |
+
white-space: pre-wrap;
|
129 |
+
max-height: 500px;
|
130 |
+
overflow-y: auto;
|
131 |
+
box-shadow: inset 0 0 4px rgba(0, 0, 0, 0.1);
|
132 |
+
}
|
133 |
+
|
134 |
+
.json-key {
|
135 |
+
color: #e53e3e;
|
136 |
+
}
|
137 |
+
|
138 |
+
.json-value {
|
139 |
+
color: #2b6cb0;
|
140 |
+
}
|
141 |
+
|
142 |
+
.json-string {
|
143 |
+
color: #38a169;
|
144 |
+
}
|
145 |
+
|
146 |
+
/* Chart/plot styling improvements */
|
147 |
+
.plot-container {
|
148 |
+
background: white;
|
149 |
+
border-radius: 8px;
|
150 |
+
padding: 0.5rem;
|
151 |
+
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
|
152 |
+
}
|
153 |
+
|
154 |
+
/* Larger font for plots */
|
155 |
+
.plot-container text {
|
156 |
+
font-family: Arial, sans-serif !important;
|
157 |
+
font-size: 14px !important;
|
158 |
+
}
|
159 |
+
|
160 |
+
/* Title styling for charts */
|
161 |
+
.plot-title {
|
162 |
+
font-family: Arial, sans-serif !important;
|
163 |
+
font-size: 16px !important;
|
164 |
+
font-weight: bold !important;
|
165 |
+
}
|
166 |
+
|
167 |
+
/* Tab styling with subtle gradient */
|
168 |
+
.tabs {
|
169 |
+
width: 100%;
|
170 |
+
display: flex;
|
171 |
+
justify-content: center;
|
172 |
+
}
|
173 |
+
|
174 |
+
.tabs > div:first-child {
|
175 |
+
background: linear-gradient(to right, #f8fafc, #e8f4fd) !important;
|
176 |
+
border-radius: 8px 8px 0 0;
|
177 |
+
}
|
178 |
+
|
179 |
+
/* Footer styling with gradient background */
|
180 |
+
.footer {
|
181 |
+
text-align: center;
|
182 |
+
margin-top: 2rem;
|
183 |
+
font-size: 0.9rem;
|
184 |
+
color: #4A5568;
|
185 |
+
padding: 1rem;
|
186 |
+
background: linear-gradient(135deg, #f8f9fa, #e1effe);
|
187 |
+
border-radius: 10px;
|
188 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
|
189 |
+
width: 100%;
|
190 |
+
}
|
191 |
+
|
192 |
+
/* Ensure centering works for all elements */
|
193 |
+
.container, .gr-container, .gr-row, .gr-col {
|
194 |
+
display: flex;
|
195 |
+
flex-direction: column;
|
196 |
+
align-items: center;
|
197 |
+
justify-content: center;
|
198 |
+
width: 100%;
|
199 |
+
}
|
200 |
+
|
201 |
+
/* 結果文本框的改進樣式 */
|
202 |
+
#detection-details, .wide-result-text {
|
203 |
+
width: 100% !important;
|
204 |
+
max-width: 100% !important;
|
205 |
+
box-sizing: border-box !important;
|
206 |
+
}
|
207 |
+
|
208 |
+
.wide-result-text textarea {
|
209 |
+
width: 100% !important;
|
210 |
+
min-width: 600px !important;
|
211 |
+
font-family: 'Arial', sans-serif !important;
|
212 |
+
font-size: 14px !important;
|
213 |
+
line-height: 1.5 !important; /* 減少行間距 */
|
214 |
+
padding: 16px !important;
|
215 |
+
white-space: pre-wrap !important;
|
216 |
+
background-color: #f8f9fa !important;
|
217 |
+
border-radius: 8px !important;
|
218 |
+
min-height: 300px !important;
|
219 |
+
resize: none !important;
|
220 |
+
overflow-y: auto !important;
|
221 |
+
border: 1px solid #e2e8f0 !important;
|
222 |
+
display: block !important;
|
223 |
+
}
|
224 |
+
|
225 |
+
/* 結果詳情面板樣式 - 加入漸層背景 */
|
226 |
+
.result-details-box {
|
227 |
+
width: 100% !important;
|
228 |
+
margin-top: 1.5rem;
|
229 |
+
background: linear-gradient(135deg, #f8fafc, #e8f4fd);
|
230 |
+
border-radius: 10px;
|
231 |
+
padding: 1rem;
|
232 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
|
233 |
+
}
|
234 |
+
|
235 |
+
/* 確保結果詳情面板內的元素寬度可以適應面板 */
|
236 |
+
.result-details-box > * {
|
237 |
+
width: 100% !important;
|
238 |
+
max-width: 100% !important;
|
239 |
+
}
|
240 |
+
|
241 |
+
/* 確保文本區域不會被限制寬度 */
|
242 |
+
.result-details-box .gr-text-input {
|
243 |
+
width: 100% !important;
|
244 |
+
max-width: none !important;
|
245 |
+
}
|
246 |
+
|
247 |
+
/* 輸出面板內容的布局調整 */
|
248 |
+
.output-panel {
|
249 |
+
display: flex;
|
250 |
+
flex-direction: column;
|
251 |
+
width: 100%;
|
252 |
+
padding: 0 !important;
|
253 |
+
}
|
254 |
+
|
255 |
+
/* 確保結果面板內的元素寬度可以適應面板 */
|
256 |
+
.output-panel > * {
|
257 |
+
width: 100%;
|
258 |
+
}
|
259 |
+
|
260 |
+
/* 改善統計面板列佈局 */
|
261 |
+
.plot-column, .stats-column {
|
262 |
+
display: flex;
|
263 |
+
flex-direction: column;
|
264 |
+
padding: 1rem;
|
265 |
+
}
|
266 |
+
|
267 |
+
/* Responsive adjustments */
|
268 |
+
@media (max-width: 768px) {
|
269 |
+
.app-title {
|
270 |
+
font-size: 2rem;
|
271 |
+
}
|
272 |
+
|
273 |
+
.app-subtitle {
|
274 |
+
font-size: 1rem;
|
275 |
+
}
|
276 |
+
|
277 |
+
.gradio-container {
|
278 |
+
padding: 0.5rem;
|
279 |
+
}
|
280 |
+
}
|
281 |
+
"""
|
282 |
+
return css
|
visualization_helper.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from typing import Any, List, Dict, Tuple, Optional
|
5 |
+
import io
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
class VisualizationHelper:
|
9 |
+
"""Helper class for visualizing detection results"""
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
|
13 |
+
figsize: Tuple[int, int] = (12, 12),
|
14 |
+
return_pil: bool = False) -> Optional[Image.Image]:
|
15 |
+
"""
|
16 |
+
Visualize detection results on a single image
|
17 |
+
|
18 |
+
Args:
|
19 |
+
image: Image path or numpy array
|
20 |
+
result: Detection result object
|
21 |
+
color_mapper: ColorMapper instance for consistent colors
|
22 |
+
figsize: Figure size
|
23 |
+
return_pil: If True, returns a PIL Image object
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
PIL Image if return_pil is True, otherwise displays the plot
|
27 |
+
"""
|
28 |
+
if result is None:
|
29 |
+
print('No data for visualization')
|
30 |
+
return None
|
31 |
+
|
32 |
+
# Read image if path is provided
|
33 |
+
if isinstance(image, str):
|
34 |
+
img = cv2.imread(image)
|
35 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
36 |
+
else:
|
37 |
+
img = image
|
38 |
+
if len(img.shape) == 3 and img.shape[2] == 3:
|
39 |
+
# Check if BGR format (OpenCV) and convert to RGB if needed
|
40 |
+
if isinstance(img, np.ndarray):
|
41 |
+
# Assuming BGR format from OpenCV
|
42 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
43 |
+
|
44 |
+
# Create figure
|
45 |
+
fig, ax = plt.subplots(figsize=figsize)
|
46 |
+
ax.imshow(img)
|
47 |
+
|
48 |
+
# Get bounding boxes, classes and confidences
|
49 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
50 |
+
classes = result.boxes.cls.cpu().numpy()
|
51 |
+
confs = result.boxes.conf.cpu().numpy()
|
52 |
+
|
53 |
+
# Get class names
|
54 |
+
names = result.names
|
55 |
+
|
56 |
+
# Create a default color mapper if none is provided
|
57 |
+
if color_mapper is None:
|
58 |
+
# For backward compatibility, fallback to a simple color function
|
59 |
+
from matplotlib import colormaps
|
60 |
+
cmap = colormaps['tab10']
|
61 |
+
def get_color(class_id):
|
62 |
+
return cmap(class_id % 10)
|
63 |
+
else:
|
64 |
+
# Use the provided color mapper
|
65 |
+
def get_color(class_id):
|
66 |
+
hex_color = color_mapper.get_color(class_id)
|
67 |
+
# Convert hex to RGB float values for matplotlib
|
68 |
+
hex_color = hex_color.lstrip('#')
|
69 |
+
return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
|
70 |
+
|
71 |
+
# Draw detection results
|
72 |
+
for box, cls, conf in zip(boxes, classes, confs):
|
73 |
+
x1, y1, x2, y2 = box
|
74 |
+
cls_id = int(cls)
|
75 |
+
cls_name = names[cls_id]
|
76 |
+
|
77 |
+
# Get color for this class
|
78 |
+
box_color = get_color(cls_id)
|
79 |
+
|
80 |
+
# Add text label with colored background
|
81 |
+
ax.text(x1, y1 - 5, f'{cls_name}: {conf:.2f}',
|
82 |
+
color='white', fontsize=10,
|
83 |
+
bbox=dict(facecolor=box_color[:3], alpha=0.7))
|
84 |
+
|
85 |
+
# Add bounding box
|
86 |
+
ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
|
87 |
+
fill=False, edgecolor=box_color[:3], linewidth=2))
|
88 |
+
|
89 |
+
ax.axis('off')
|
90 |
+
# ax.set_title('Detection Result')
|
91 |
+
plt.tight_layout()
|
92 |
+
|
93 |
+
if return_pil:
|
94 |
+
# Convert plot to PIL Image
|
95 |
+
buf = io.BytesIO()
|
96 |
+
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
97 |
+
buf.seek(0)
|
98 |
+
pil_img = Image.open(buf)
|
99 |
+
plt.close(fig)
|
100 |
+
return pil_img
|
101 |
+
else:
|
102 |
+
plt.show()
|
103 |
+
return None
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def create_summary(result: Any) -> Dict:
|
107 |
+
"""
|
108 |
+
Create a summary of detection results
|
109 |
+
|
110 |
+
Args:
|
111 |
+
result: Detection result object
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Dictionary with detection summary statistics
|
115 |
+
"""
|
116 |
+
if result is None:
|
117 |
+
return {"error": "No detection result provided"}
|
118 |
+
|
119 |
+
# Get classes and confidences
|
120 |
+
classes = result.boxes.cls.cpu().numpy().astype(int)
|
121 |
+
confidences = result.boxes.conf.cpu().numpy()
|
122 |
+
names = result.names
|
123 |
+
|
124 |
+
# Count detections by class
|
125 |
+
class_counts = {}
|
126 |
+
for cls, conf in zip(classes, confidences):
|
127 |
+
cls_name = names[int(cls)]
|
128 |
+
if cls_name not in class_counts:
|
129 |
+
class_counts[cls_name] = {"count": 0, "confidences": []}
|
130 |
+
|
131 |
+
class_counts[cls_name]["count"] += 1
|
132 |
+
class_counts[cls_name]["confidences"].append(float(conf))
|
133 |
+
|
134 |
+
# Calculate average confidence for each class
|
135 |
+
for cls_name, stats in class_counts.items():
|
136 |
+
if stats["confidences"]:
|
137 |
+
stats["average_confidence"] = float(np.mean(stats["confidences"]))
|
138 |
+
stats.pop("confidences") # Remove detailed confidences list to keep summary concise
|
139 |
+
|
140 |
+
# Prepare summary
|
141 |
+
summary = {
|
142 |
+
"total_objects": len(classes),
|
143 |
+
"class_counts": class_counts,
|
144 |
+
"unique_classes": len(class_counts)
|
145 |
+
}
|
146 |
+
|
147 |
+
return summary
|