Spaces:
Running
on
Zero
Running
on
Zero
Upload 6 files
Browse files- app.py +555 -0
- color_mapper.py +270 -0
- detection_model.py +164 -0
- evaluation_metrics.py +323 -0
- requirements.txt +8 -0
- visualization_helper.py +147 -0
app.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import cv2
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import gradio as gr
|
7 |
+
import io
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
import spaces
|
10 |
+
from typing import Dict, List, Any, Optional, Tuple
|
11 |
+
from ultralytics import YOLO
|
12 |
+
|
13 |
+
from detection_model import DetectionModel
|
14 |
+
from color_mapper import ColorMapper
|
15 |
+
from visualization_helper import VisualizationHelper
|
16 |
+
from evaluation_metrics import EvaluationMetrics
|
17 |
+
|
18 |
+
|
19 |
+
color_mapper = ColorMapper()
|
20 |
+
model_instances = {}
|
21 |
+
|
22 |
+
@spaces.GPU
|
23 |
+
def process_image(image, model_instance, confidence_threshold, filter_classes=None):
|
24 |
+
"""
|
25 |
+
Process an image for object detection
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image: Input image (numpy array or PIL Image)
|
29 |
+
model_instance: DetectionModel instance to use
|
30 |
+
confidence_threshold: Confidence threshold for detection
|
31 |
+
filter_classes: Optional list of classes to filter results
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Tuple of (result_image, result_text, stats_data)
|
35 |
+
"""
|
36 |
+
# initialize key variables
|
37 |
+
result = None
|
38 |
+
stats = {}
|
39 |
+
temp_path = None
|
40 |
+
|
41 |
+
try:
|
42 |
+
# update confidence threshold
|
43 |
+
model_instance.confidence = confidence_threshold
|
44 |
+
|
45 |
+
# processing input image
|
46 |
+
if isinstance(image, np.ndarray):
|
47 |
+
# Convert BGR to RGB if needed
|
48 |
+
if image.shape[2] == 3:
|
49 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
50 |
+
else:
|
51 |
+
image_rgb = image
|
52 |
+
pil_image = Image.fromarray(image_rgb)
|
53 |
+
elif image is None:
|
54 |
+
return None, "No image provided. Please upload an image.", {}
|
55 |
+
else:
|
56 |
+
pil_image = image
|
57 |
+
|
58 |
+
# store temp files
|
59 |
+
import uuid
|
60 |
+
import tempfile
|
61 |
+
|
62 |
+
temp_dir = tempfile.gettempdir() # use system temp directory
|
63 |
+
temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
|
64 |
+
temp_path = os.path.join(temp_dir, temp_filename)
|
65 |
+
pil_image.save(temp_path)
|
66 |
+
|
67 |
+
# object detection
|
68 |
+
result = model_instance.detect(temp_path)
|
69 |
+
|
70 |
+
if result is None:
|
71 |
+
return None, "Detection failed. Please try again with a different image.", {}
|
72 |
+
|
73 |
+
# calculate stats
|
74 |
+
stats = EvaluationMetrics.calculate_basic_stats(result)
|
75 |
+
|
76 |
+
# add space calculation
|
77 |
+
spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
|
78 |
+
stats["spatial_metrics"] = spatial_metrics
|
79 |
+
|
80 |
+
if filter_classes and len(filter_classes) > 0:
|
81 |
+
# get classes, boxes, confidence
|
82 |
+
classes = result.boxes.cls.cpu().numpy().astype(int)
|
83 |
+
confs = result.boxes.conf.cpu().numpy()
|
84 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
85 |
+
|
86 |
+
mask = np.zeros_like(classes, dtype=bool)
|
87 |
+
for cls_id in filter_classes:
|
88 |
+
mask = np.logical_or(mask, classes == cls_id)
|
89 |
+
|
90 |
+
filtered_stats = {
|
91 |
+
"total_objects": int(np.sum(mask)),
|
92 |
+
"class_statistics": {},
|
93 |
+
"average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
|
94 |
+
"spatial_metrics": stats["spatial_metrics"]
|
95 |
+
}
|
96 |
+
|
97 |
+
# update stats
|
98 |
+
names = result.names
|
99 |
+
for cls, conf in zip(classes[mask], confs[mask]):
|
100 |
+
cls_name = names[int(cls)]
|
101 |
+
if cls_name not in filtered_stats["class_statistics"]:
|
102 |
+
filtered_stats["class_statistics"][cls_name] = {
|
103 |
+
"count": 0,
|
104 |
+
"average_confidence": 0
|
105 |
+
}
|
106 |
+
|
107 |
+
filtered_stats["class_statistics"][cls_name]["count"] += 1
|
108 |
+
filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
|
109 |
+
|
110 |
+
stats = filtered_stats
|
111 |
+
|
112 |
+
viz_data = EvaluationMetrics.generate_visualization_data(
|
113 |
+
result,
|
114 |
+
color_mapper.get_all_colors()
|
115 |
+
)
|
116 |
+
|
117 |
+
result_image = VisualizationHelper.visualize_detection(
|
118 |
+
temp_path, result, color_mapper=color_mapper, figsize=(12, 12), return_pil=True
|
119 |
+
)
|
120 |
+
|
121 |
+
result_text = EvaluationMetrics.format_detection_summary(viz_data)
|
122 |
+
|
123 |
+
return result_image, result_text, stats
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
error_message = f"Error Occurs: {str(e)}"
|
127 |
+
import traceback
|
128 |
+
traceback.print_exc()
|
129 |
+
print(error_message)
|
130 |
+
return None, error_message, {}
|
131 |
+
|
132 |
+
finally:
|
133 |
+
if temp_path and os.path.exists(temp_path):
|
134 |
+
try:
|
135 |
+
os.remove(temp_path)
|
136 |
+
except Exception as e:
|
137 |
+
print(f"Cannot delete temp files {temp_path}: {str(e)}")
|
138 |
+
|
139 |
+
def format_result_text(stats):
|
140 |
+
"""Format detection statistics into readable text"""
|
141 |
+
if not stats or "total_objects" not in stats:
|
142 |
+
return "No objects detected."
|
143 |
+
|
144 |
+
lines = [
|
145 |
+
f"Detected {stats['total_objects']} objects.",
|
146 |
+
f"Average confidence: {stats.get('average_confidence', 0):.2f}",
|
147 |
+
"",
|
148 |
+
"Objects by class:",
|
149 |
+
]
|
150 |
+
|
151 |
+
if "class_statistics" in stats and stats["class_statistics"]:
|
152 |
+
# Sort classes by count
|
153 |
+
sorted_classes = sorted(
|
154 |
+
stats["class_statistics"].items(),
|
155 |
+
key=lambda x: x[1]["count"],
|
156 |
+
reverse=True
|
157 |
+
)
|
158 |
+
|
159 |
+
for cls_name, cls_stats in sorted_classes:
|
160 |
+
lines.append(f"• {cls_name}: {cls_stats['count']} (avg conf: {cls_stats.get('average_confidence', 0):.2f})")
|
161 |
+
else:
|
162 |
+
lines.append("No class information available.")
|
163 |
+
|
164 |
+
return "\n".join(lines)
|
165 |
+
|
166 |
+
def get_all_classes():
|
167 |
+
"""Get all available COCO classes"""
|
168 |
+
try:
|
169 |
+
class_names = model.class_names
|
170 |
+
return [(idx, name) for idx, name in class_names.items()]
|
171 |
+
except:
|
172 |
+
# Fallback to standard COCO classes
|
173 |
+
return [
|
174 |
+
(0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'),
|
175 |
+
(5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'),
|
176 |
+
(10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'),
|
177 |
+
(14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'),
|
178 |
+
(20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'),
|
179 |
+
(25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'),
|
180 |
+
(30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'),
|
181 |
+
(35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'),
|
182 |
+
(39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'),
|
183 |
+
(44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'),
|
184 |
+
(49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'),
|
185 |
+
(54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'),
|
186 |
+
(59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'),
|
187 |
+
(64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'),
|
188 |
+
(69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'),
|
189 |
+
(74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'),
|
190 |
+
(79, 'toothbrush')
|
191 |
+
]
|
192 |
+
|
193 |
+
def create_interface():
|
194 |
+
"""Create the Gradio interface"""
|
195 |
+
# Get CSS styles
|
196 |
+
css = """
|
197 |
+
body {
|
198 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
|
199 |
+
background: linear-gradient(120deg, #e0f7fa, #b2ebf2);
|
200 |
+
margin: 0;
|
201 |
+
padding: 0;
|
202 |
+
}
|
203 |
+
|
204 |
+
.gradio-container {
|
205 |
+
max-width: 1200px !important;
|
206 |
+
}
|
207 |
+
|
208 |
+
.app-header {
|
209 |
+
text-align: center;
|
210 |
+
margin-bottom: 2rem;
|
211 |
+
background: rgba(255, 255, 255, 0.8);
|
212 |
+
padding: 1.5rem;
|
213 |
+
border-radius: 10px;
|
214 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
215 |
+
}
|
216 |
+
|
217 |
+
.app-title {
|
218 |
+
color: #2D3748;
|
219 |
+
font-size: 2.5rem;
|
220 |
+
margin-bottom: 0.5rem;
|
221 |
+
background: linear-gradient(90deg, #4299e1, #48bb78);
|
222 |
+
-webkit-background-clip: text;
|
223 |
+
-webkit-text-fill-color: transparent;
|
224 |
+
}
|
225 |
+
|
226 |
+
.app-subtitle {
|
227 |
+
color: #4A5568;
|
228 |
+
font-size: 1.2rem;
|
229 |
+
font-weight: normal;
|
230 |
+
margin-top: 0.25rem;
|
231 |
+
}
|
232 |
+
|
233 |
+
.app-divider {
|
234 |
+
width: 50px;
|
235 |
+
height: 3px;
|
236 |
+
background: linear-gradient(90deg, #4299e1, #48bb78);
|
237 |
+
margin: 1rem auto;
|
238 |
+
}
|
239 |
+
|
240 |
+
.input-panel, .output-panel {
|
241 |
+
background: white;
|
242 |
+
border-radius: 10px;
|
243 |
+
padding: 1rem;
|
244 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
|
245 |
+
}
|
246 |
+
|
247 |
+
.detect-btn {
|
248 |
+
background: linear-gradient(90deg, #4299e1, #48bb78) !important;
|
249 |
+
color: white !important;
|
250 |
+
border: none !important;
|
251 |
+
transition: transform 0.3s, box-shadow 0.3s !important;
|
252 |
+
}
|
253 |
+
|
254 |
+
.detect-btn:hover {
|
255 |
+
transform: translateY(-2px) !important;
|
256 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2) !important;
|
257 |
+
}
|
258 |
+
|
259 |
+
.detect-btn:active {
|
260 |
+
transform: translateY(1px) !important;
|
261 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2) !important;
|
262 |
+
}
|
263 |
+
|
264 |
+
.footer {
|
265 |
+
text-align: center;
|
266 |
+
margin-top: 2rem;
|
267 |
+
font-size: 0.9rem;
|
268 |
+
color: #4A5568;
|
269 |
+
}
|
270 |
+
|
271 |
+
/* Responsive adjustments */
|
272 |
+
@media (max-width: 768px) {
|
273 |
+
.app-title {
|
274 |
+
font-size: 2rem;
|
275 |
+
}
|
276 |
+
|
277 |
+
.app-subtitle {
|
278 |
+
font-size: 1rem;
|
279 |
+
}
|
280 |
+
}
|
281 |
+
"""
|
282 |
+
|
283 |
+
# get the models info
|
284 |
+
available_models = DetectionModel.get_available_models()
|
285 |
+
model_choices = [model["model_file"] for model in available_models]
|
286 |
+
model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models]
|
287 |
+
|
288 |
+
# Available classes for filtering
|
289 |
+
available_classes = get_all_classes()
|
290 |
+
class_choices = [f"{id}: {name}" for id, name in available_classes]
|
291 |
+
|
292 |
+
# Create Gradio Blocks interface
|
293 |
+
with gr.Blocks(css=css) as demo:
|
294 |
+
# Header
|
295 |
+
with gr.Group(elem_classes="app-header"):
|
296 |
+
gr.HTML("""
|
297 |
+
<h1 class="app-title">VisionScout</h1>
|
298 |
+
<h2 class="app-subtitle">Detect and identify objects in your images</h2>
|
299 |
+
<div class="app-divider"></div>
|
300 |
+
""")
|
301 |
+
|
302 |
+
current_model = gr.State("yolov8m.pt") # use medium size as default
|
303 |
+
|
304 |
+
# Input and Output panels
|
305 |
+
with gr.Row():
|
306 |
+
# Left column - Input controls
|
307 |
+
with gr.Column(scale=4, elem_classes="input-panel"):
|
308 |
+
with gr.Group():
|
309 |
+
gr.Markdown("### Upload Image")
|
310 |
+
image_input = gr.Image(type="pil", label="Upload an image")
|
311 |
+
|
312 |
+
with gr.Accordion("Advanced Settings", open=False):
|
313 |
+
with gr.Row():
|
314 |
+
model_dropdown = gr.Dropdown(
|
315 |
+
choices=model_choices,
|
316 |
+
value="yolov8m.pt",
|
317 |
+
label="Select Model",
|
318 |
+
info="Choose different models based on your needs for speed vs. accuracy"
|
319 |
+
)
|
320 |
+
|
321 |
+
# display model info
|
322 |
+
model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt"))
|
323 |
+
|
324 |
+
confidence = gr.Slider(
|
325 |
+
minimum=0.1,
|
326 |
+
maximum=0.9,
|
327 |
+
value=0.25,
|
328 |
+
step=0.05,
|
329 |
+
label="Confidence Threshold",
|
330 |
+
info="Higher values show fewer but more confident detections"
|
331 |
+
)
|
332 |
+
|
333 |
+
with gr.Accordion("Filter Classes", open=False):
|
334 |
+
# Common object categories
|
335 |
+
with gr.Row():
|
336 |
+
people_btn = gr.Button("People")
|
337 |
+
vehicles_btn = gr.Button("Vehicles")
|
338 |
+
animals_btn = gr.Button("Animals")
|
339 |
+
objects_btn = gr.Button("Common Objects")
|
340 |
+
|
341 |
+
# Class selection
|
342 |
+
class_filter = gr.Dropdown(
|
343 |
+
choices=class_choices,
|
344 |
+
multiselect=True,
|
345 |
+
label="Select Classes to Display",
|
346 |
+
info="Leave empty to show all detected objects"
|
347 |
+
)
|
348 |
+
|
349 |
+
detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn")
|
350 |
+
|
351 |
+
with gr.Group():
|
352 |
+
gr.Markdown("### How to Use")
|
353 |
+
gr.Markdown("""
|
354 |
+
1. Upload an image or use the camera
|
355 |
+
2. Adjust confidence threshold if needed
|
356 |
+
3. Optionally filter to specific object classes
|
357 |
+
4. Click "Detect Objects" button
|
358 |
+
|
359 |
+
The model will identify objects in your image and display them with bounding boxes.
|
360 |
+
|
361 |
+
**Note:** Detection quality depends on image clarity and object visibility. The model can detect up to 80 different types of common objects.
|
362 |
+
""")
|
363 |
+
|
364 |
+
# Right column - Results display
|
365 |
+
with gr.Column(scale=6, elem_classes="output-panel"):
|
366 |
+
with gr.Tab("Detection Result"):
|
367 |
+
result_image = gr.Image(type="pil", label="Detection Result")
|
368 |
+
result_text = gr.Textbox(label="Detection Details", lines=10)
|
369 |
+
|
370 |
+
with gr.Tab("Statistics"):
|
371 |
+
with gr.Row():
|
372 |
+
with gr.Column(scale=1):
|
373 |
+
stats_json = gr.Json(label="Full Statistics")
|
374 |
+
|
375 |
+
with gr.Column(scale=1):
|
376 |
+
gr.Markdown("### Object Distribution")
|
377 |
+
plot_output = gr.Plot(label="Object Distribution")
|
378 |
+
|
379 |
+
# model option
|
380 |
+
model_dropdown.change(
|
381 |
+
fn=lambda model: (model, DetectionModel.get_model_description(model)),
|
382 |
+
inputs=[model_dropdown],
|
383 |
+
outputs=[current_model, model_info]
|
384 |
+
)
|
385 |
+
|
386 |
+
# change the buttom of different model
|
387 |
+
detect_btn.click(
|
388 |
+
fn=lambda img, model, conf, classes: process_and_plot(img, model, conf, classes),
|
389 |
+
inputs=[image_input, current_model, confidence, class_filter],
|
390 |
+
outputs=[result_image, result_text, stats_json, plot_output]
|
391 |
+
)
|
392 |
+
|
393 |
+
# Quick filter buttons
|
394 |
+
people_classes = [0] # Person
|
395 |
+
vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # Various vehicles
|
396 |
+
animals_classes = list(range(14, 24)) # Animals in COCO
|
397 |
+
common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # Common household items
|
398 |
+
|
399 |
+
people_btn.click(
|
400 |
+
lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes],
|
401 |
+
outputs=class_filter
|
402 |
+
)
|
403 |
+
|
404 |
+
vehicles_btn.click(
|
405 |
+
lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes],
|
406 |
+
outputs=class_filter
|
407 |
+
)
|
408 |
+
|
409 |
+
animals_btn.click(
|
410 |
+
lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes],
|
411 |
+
outputs=class_filter
|
412 |
+
)
|
413 |
+
|
414 |
+
objects_btn.click(
|
415 |
+
lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects],
|
416 |
+
outputs=class_filter
|
417 |
+
)
|
418 |
+
|
419 |
+
# Set up example images
|
420 |
+
example_images = [
|
421 |
+
"room_01.jpg",
|
422 |
+
"street_01.jpg",
|
423 |
+
"street_02.jpg",
|
424 |
+
"street_03.jpg"
|
425 |
+
]
|
426 |
+
|
427 |
+
|
428 |
+
gr.Examples(
|
429 |
+
examples=example_images,
|
430 |
+
inputs=image_input,
|
431 |
+
outputs=None,
|
432 |
+
fn=None,
|
433 |
+
cache_examples=False,
|
434 |
+
)
|
435 |
+
|
436 |
+
# Footer
|
437 |
+
gr.HTML("""
|
438 |
+
<div class="footer">
|
439 |
+
<p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p>
|
440 |
+
<p>Model can detect 80 different classes of objects</p>
|
441 |
+
</div>
|
442 |
+
""")
|
443 |
+
|
444 |
+
return demo
|
445 |
+
|
446 |
+
@spaces.GPU
|
447 |
+
def process_and_plot(image, model_name, confidence_threshold, filter_classes=None):
|
448 |
+
"""
|
449 |
+
Process image and create plots for statistics
|
450 |
+
|
451 |
+
Args:
|
452 |
+
image: Input image
|
453 |
+
model_name: Name of the model to use
|
454 |
+
confidence_threshold: Confidence threshold for detection
|
455 |
+
filter_classes: Optional list of classes to filter results
|
456 |
+
|
457 |
+
Returns:
|
458 |
+
Tuple of (result_image, result_text, stats_json, plot_figure)
|
459 |
+
"""
|
460 |
+
global model_instances
|
461 |
+
|
462 |
+
if model_name not in model_instances:
|
463 |
+
print(f"Creating new model instance for {model_name}")
|
464 |
+
model_instances[model_name] = DetectionModel(model_name=model_name, confidence=confidence_threshold, iou=0.45)
|
465 |
+
else:
|
466 |
+
print(f"Using existing model instance for {model_name}")
|
467 |
+
model_instances[model_name].confidence = confidence_threshold
|
468 |
+
|
469 |
+
class_ids = None
|
470 |
+
if filter_classes:
|
471 |
+
class_ids = []
|
472 |
+
for class_str in filter_classes:
|
473 |
+
try:
|
474 |
+
# Extract ID from format "id: name"
|
475 |
+
class_id = int(class_str.split(":")[0].strip())
|
476 |
+
class_ids.append(class_id)
|
477 |
+
except:
|
478 |
+
continue
|
479 |
+
|
480 |
+
# execute detection
|
481 |
+
result_image, result_text, stats = process_image(
|
482 |
+
image,
|
483 |
+
model_instances[model_name],
|
484 |
+
confidence_threshold,
|
485 |
+
class_ids
|
486 |
+
)
|
487 |
+
|
488 |
+
# create stats table
|
489 |
+
plot_figure = create_stats_plot(stats)
|
490 |
+
|
491 |
+
return result_image, result_text, stats, plot_figure
|
492 |
+
|
493 |
+
def create_stats_plot(stats):
|
494 |
+
"""
|
495 |
+
Create a visualization of statistics data
|
496 |
+
|
497 |
+
Args:
|
498 |
+
stats: Dictionary containing detection statistics
|
499 |
+
|
500 |
+
Returns:
|
501 |
+
Matplotlib figure with visualization
|
502 |
+
"""
|
503 |
+
if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
|
504 |
+
# Create empty plot if no data
|
505 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
506 |
+
ax.text(0.5, 0.5, "No detection data available",
|
507 |
+
ha='center', va='center', fontsize=12)
|
508 |
+
ax.set_xlim(0, 1)
|
509 |
+
ax.set_ylim(0, 1)
|
510 |
+
ax.axis('off')
|
511 |
+
return fig
|
512 |
+
|
513 |
+
# preparing visualization data
|
514 |
+
viz_data = {
|
515 |
+
"total_objects": stats.get("total_objects", 0),
|
516 |
+
"average_confidence": stats.get("average_confidence", 0),
|
517 |
+
"class_data": []
|
518 |
+
}
|
519 |
+
|
520 |
+
# get current model classes
|
521 |
+
# This uses the get_all_classes function which should retrieve from the current model
|
522 |
+
available_classes = dict(get_all_classes())
|
523 |
+
|
524 |
+
# process class data
|
525 |
+
for cls_name, cls_stats in stats.get("class_statistics", {}).items():
|
526 |
+
# search for class ID
|
527 |
+
class_id = -1
|
528 |
+
|
529 |
+
# Try to find the class ID from class names
|
530 |
+
for id, name in available_classes.items():
|
531 |
+
if name == cls_name:
|
532 |
+
class_id = id
|
533 |
+
break
|
534 |
+
|
535 |
+
cls_data = {
|
536 |
+
"name": cls_name,
|
537 |
+
"class_id": class_id,
|
538 |
+
"count": cls_stats.get("count", 0),
|
539 |
+
"average_confidence": cls_stats.get("average_confidence", 0),
|
540 |
+
"color": color_mapper.get_color(class_id if class_id >= 0 else cls_name)
|
541 |
+
}
|
542 |
+
|
543 |
+
viz_data["class_data"].append(cls_data)
|
544 |
+
|
545 |
+
# Sort by count in descending order
|
546 |
+
viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
|
547 |
+
|
548 |
+
return EvaluationMetrics.create_stats_plot(viz_data)
|
549 |
+
|
550 |
+
|
551 |
+
if __name__ == "__main__":
|
552 |
+
import time
|
553 |
+
|
554 |
+
demo = create_interface()
|
555 |
+
demo.launch()
|
color_mapper.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Dict, List, Tuple, Union, Any
|
3 |
+
|
4 |
+
class ColorMapper:
|
5 |
+
"""
|
6 |
+
A class for consistent color mapping of object detection classes
|
7 |
+
Provides color schemes for visualization in both RGB and hex formats
|
8 |
+
"""
|
9 |
+
|
10 |
+
# Class categories for better organization
|
11 |
+
CATEGORIES = {
|
12 |
+
"person": [0],
|
13 |
+
"vehicles": [1, 2, 3, 4, 5, 6, 7, 8],
|
14 |
+
"traffic": [9, 10, 11, 12],
|
15 |
+
"animals": [14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
|
16 |
+
"outdoor": [13, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33],
|
17 |
+
"sports": [34, 35, 36, 37, 38],
|
18 |
+
"kitchen": [39, 40, 41, 42, 43, 44, 45],
|
19 |
+
"food": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55],
|
20 |
+
"furniture": [56, 57, 58, 59, 60, 61],
|
21 |
+
"electronics": [62, 63, 64, 65, 66, 67, 68, 69, 70],
|
22 |
+
"household": [71, 72, 73, 74, 75, 76, 77, 78, 79]
|
23 |
+
}
|
24 |
+
|
25 |
+
# Base colors for each category (in HSV for easier variation)
|
26 |
+
CATEGORY_COLORS = {
|
27 |
+
"person": (0, 0.8, 0.9), # Red
|
28 |
+
"vehicles": (210, 0.8, 0.9), # Blue
|
29 |
+
"traffic": (45, 0.8, 0.9), # Orange
|
30 |
+
"animals": (120, 0.7, 0.8), # Green
|
31 |
+
"outdoor": (180, 0.7, 0.9), # Cyan
|
32 |
+
"sports": (270, 0.7, 0.8), # Purple
|
33 |
+
"kitchen": (30, 0.7, 0.9), # Light Orange
|
34 |
+
"food": (330, 0.7, 0.85), # Pink
|
35 |
+
"furniture": (150, 0.5, 0.85), # Light Green
|
36 |
+
"electronics": (240, 0.6, 0.9), # Light Blue
|
37 |
+
"household": (60, 0.6, 0.9) # Yellow
|
38 |
+
}
|
39 |
+
|
40 |
+
def __init__(self):
|
41 |
+
"""Initialize the ColorMapper with COCO class mappings"""
|
42 |
+
self.class_names = self._get_coco_classes()
|
43 |
+
self.color_map = self._generate_color_map()
|
44 |
+
|
45 |
+
def _get_coco_classes(self) -> Dict[int, str]:
|
46 |
+
"""Get the standard COCO class names with their IDs"""
|
47 |
+
return {
|
48 |
+
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
|
49 |
+
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
|
50 |
+
10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
|
51 |
+
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
|
52 |
+
20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
|
53 |
+
25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
|
54 |
+
30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
|
55 |
+
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket',
|
56 |
+
39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife',
|
57 |
+
44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich',
|
58 |
+
49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza',
|
59 |
+
54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant',
|
60 |
+
59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop',
|
61 |
+
64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave',
|
62 |
+
69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book',
|
63 |
+
74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier',
|
64 |
+
79: 'toothbrush'
|
65 |
+
}
|
66 |
+
|
67 |
+
def _hsv_to_rgb(self, h: float, s: float, v: float) -> Tuple[int, int, int]:
|
68 |
+
"""
|
69 |
+
Convert HSV color to RGB
|
70 |
+
|
71 |
+
Args:
|
72 |
+
h: Hue (0-360)
|
73 |
+
s: Saturation (0-1)
|
74 |
+
v: Value (0-1)
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Tuple of (R, G, B) values (0-255)
|
78 |
+
"""
|
79 |
+
h = h / 60
|
80 |
+
i = int(h)
|
81 |
+
f = h - i
|
82 |
+
p = v * (1 - s)
|
83 |
+
q = v * (1 - s * f)
|
84 |
+
t = v * (1 - s * (1 - f))
|
85 |
+
|
86 |
+
if i == 0:
|
87 |
+
r, g, b = v, t, p
|
88 |
+
elif i == 1:
|
89 |
+
r, g, b = q, v, p
|
90 |
+
elif i == 2:
|
91 |
+
r, g, b = p, v, t
|
92 |
+
elif i == 3:
|
93 |
+
r, g, b = p, q, v
|
94 |
+
elif i == 4:
|
95 |
+
r, g, b = t, p, v
|
96 |
+
else:
|
97 |
+
r, g, b = v, p, q
|
98 |
+
|
99 |
+
return (int(r * 255), int(g * 255), int(b * 255))
|
100 |
+
|
101 |
+
def _rgb_to_hex(self, rgb: Tuple[int, int, int]) -> str:
|
102 |
+
"""
|
103 |
+
Convert RGB color to hex color code
|
104 |
+
|
105 |
+
Args:
|
106 |
+
rgb: Tuple of (R, G, B) values (0-255)
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Hex color code (e.g. '#FF0000')
|
110 |
+
"""
|
111 |
+
return f'#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}'
|
112 |
+
|
113 |
+
def _find_category(self, class_id: int) -> str:
|
114 |
+
"""
|
115 |
+
Find the category for a given class ID
|
116 |
+
|
117 |
+
Args:
|
118 |
+
class_id: Class ID (0-79)
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Category name
|
122 |
+
"""
|
123 |
+
for category, ids in self.CATEGORIES.items():
|
124 |
+
if class_id in ids:
|
125 |
+
return category
|
126 |
+
return "other" # Fallback
|
127 |
+
|
128 |
+
def _generate_color_map(self) -> Dict:
|
129 |
+
"""
|
130 |
+
Generate a color map for all 80 COCO classes
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
Dictionary mapping class IDs and names to color values
|
134 |
+
"""
|
135 |
+
color_map = {
|
136 |
+
'by_id': {}, # Map class ID to RGB and hex
|
137 |
+
'by_name': {}, # Map class name to RGB and hex
|
138 |
+
'categories': {} # Map category to base color
|
139 |
+
}
|
140 |
+
|
141 |
+
# Generate colors for categories
|
142 |
+
for category, hsv in self.CATEGORY_COLORS.items():
|
143 |
+
rgb = self._hsv_to_rgb(hsv[0], hsv[1], hsv[2])
|
144 |
+
hex_color = self._rgb_to_hex(rgb)
|
145 |
+
color_map['categories'][category] = {
|
146 |
+
'rgb': rgb,
|
147 |
+
'hex': hex_color
|
148 |
+
}
|
149 |
+
|
150 |
+
# Generate variations for each class within a category
|
151 |
+
for class_id, class_name in self.class_names.items():
|
152 |
+
category = self._find_category(class_id)
|
153 |
+
base_hsv = self.CATEGORY_COLORS.get(category, (0, 0, 0.8)) # Default gray
|
154 |
+
|
155 |
+
# Slightly vary the hue and saturation within the category
|
156 |
+
ids_in_category = self.CATEGORIES.get(category, [])
|
157 |
+
if ids_in_category:
|
158 |
+
position = ids_in_category.index(class_id) if class_id in ids_in_category else 0
|
159 |
+
variation = position / max(1, len(ids_in_category) - 1) # 0 to 1
|
160 |
+
|
161 |
+
# Vary hue slightly (±15°) and saturation
|
162 |
+
h_offset = 30 * variation - 15 # -15 to +15
|
163 |
+
s_offset = 0.2 * variation # 0 to 0.2
|
164 |
+
|
165 |
+
h = (base_hsv[0] + h_offset) % 360
|
166 |
+
s = min(1.0, base_hsv[1] + s_offset)
|
167 |
+
v = base_hsv[2]
|
168 |
+
else:
|
169 |
+
h, s, v = base_hsv
|
170 |
+
|
171 |
+
rgb = self._hsv_to_rgb(h, s, v)
|
172 |
+
hex_color = self._rgb_to_hex(rgb)
|
173 |
+
|
174 |
+
# Store in both mappings
|
175 |
+
color_map['by_id'][class_id] = {
|
176 |
+
'rgb': rgb,
|
177 |
+
'hex': hex_color,
|
178 |
+
'category': category
|
179 |
+
}
|
180 |
+
|
181 |
+
color_map['by_name'][class_name] = {
|
182 |
+
'rgb': rgb,
|
183 |
+
'hex': hex_color,
|
184 |
+
'category': category
|
185 |
+
}
|
186 |
+
|
187 |
+
return color_map
|
188 |
+
|
189 |
+
def get_color(self, class_identifier: Union[int, str], format: str = 'hex') -> Any:
|
190 |
+
"""
|
191 |
+
Get color for a specific class
|
192 |
+
|
193 |
+
Args:
|
194 |
+
class_identifier: Class ID (int) or name (str)
|
195 |
+
format: Color format ('hex', 'rgb', or 'bgr')
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
Color in requested format
|
199 |
+
"""
|
200 |
+
# Determine if identifier is an ID or name
|
201 |
+
if isinstance(class_identifier, int):
|
202 |
+
color_info = self.color_map['by_id'].get(class_identifier)
|
203 |
+
else:
|
204 |
+
color_info = self.color_map['by_name'].get(class_identifier)
|
205 |
+
|
206 |
+
if not color_info:
|
207 |
+
# Fallback color if not found
|
208 |
+
return '#CCCCCC' if format == 'hex' else (204, 204, 204)
|
209 |
+
|
210 |
+
if format == 'hex':
|
211 |
+
return color_info['hex']
|
212 |
+
elif format == 'rgb':
|
213 |
+
return color_info['rgb']
|
214 |
+
elif format == 'bgr':
|
215 |
+
# Convert RGB to BGR for OpenCV
|
216 |
+
r, g, b = color_info['rgb']
|
217 |
+
return (b, g, r)
|
218 |
+
else:
|
219 |
+
return color_info['rgb']
|
220 |
+
|
221 |
+
def get_all_colors(self, format: str = 'hex') -> Dict:
|
222 |
+
"""
|
223 |
+
Get all colors in the specified format
|
224 |
+
|
225 |
+
Args:
|
226 |
+
format: Color format ('hex', 'rgb', or 'bgr')
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
Dictionary mapping class names to colors
|
230 |
+
"""
|
231 |
+
result = {}
|
232 |
+
for class_id, class_name in self.class_names.items():
|
233 |
+
result[class_name] = self.get_color(class_id, format)
|
234 |
+
return result
|
235 |
+
|
236 |
+
def get_category_colors(self, format: str = 'hex') -> Dict:
|
237 |
+
"""
|
238 |
+
Get base colors for each category
|
239 |
+
|
240 |
+
Args:
|
241 |
+
format: Color format ('hex', 'rgb', or 'bgr')
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
Dictionary mapping categories to colors
|
245 |
+
"""
|
246 |
+
result = {}
|
247 |
+
for category, color_info in self.color_map['categories'].items():
|
248 |
+
if format == 'hex':
|
249 |
+
result[category] = color_info['hex']
|
250 |
+
elif format == 'bgr':
|
251 |
+
r, g, b = color_info['rgb']
|
252 |
+
result[category] = (b, g, r)
|
253 |
+
else:
|
254 |
+
result[category] = color_info['rgb']
|
255 |
+
return result
|
256 |
+
|
257 |
+
def get_category_for_class(self, class_identifier: Union[int, str]) -> str:
|
258 |
+
"""
|
259 |
+
Get the category for a specific class
|
260 |
+
|
261 |
+
Args:
|
262 |
+
class_identifier: Class ID (int) or name (str)
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
Category name
|
266 |
+
"""
|
267 |
+
if isinstance(class_identifier, int):
|
268 |
+
return self.color_map['by_id'].get(class_identifier, {}).get('category', 'other')
|
269 |
+
else:
|
270 |
+
return self.color_map['by_name'].get(class_identifier, {}).get('category', 'other')
|
detection_model.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ultralytics import YOLO
|
2 |
+
from typing import Any, List, Dict, Optional
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
|
7 |
+
class DetectionModel:
|
8 |
+
"""Core detection model class for object detection using YOLOv8"""
|
9 |
+
|
10 |
+
# Model information dictionary
|
11 |
+
MODEL_INFO = {
|
12 |
+
"yolov8n.pt": {
|
13 |
+
"name": "YOLOv8n (Nano)",
|
14 |
+
"description": "Fastest model with smallest size (3.2M parameters). Best for speed-critical applications.",
|
15 |
+
"size_mb": 6,
|
16 |
+
"inference_speed": "Very Fast"
|
17 |
+
},
|
18 |
+
"yolov8m.pt": {
|
19 |
+
"name": "YOLOv8m (Medium)",
|
20 |
+
"description": "Balanced model with good accuracy-speed tradeoff (25.9M parameters). Recommended for general use.",
|
21 |
+
"size_mb": 25,
|
22 |
+
"inference_speed": "Medium"
|
23 |
+
},
|
24 |
+
"yolov8x.pt": {
|
25 |
+
"name": "YOLOv8x (XLarge)",
|
26 |
+
"description": "Most accurate but slower model (68.2M parameters). Best for accuracy-critical applications.",
|
27 |
+
"size_mb": 68,
|
28 |
+
"inference_speed": "Slower"
|
29 |
+
}
|
30 |
+
}
|
31 |
+
|
32 |
+
def __init__(self, model_name: str = 'yolov8m.pt', confidence: float = 0.25, iou: float = 0.45):
|
33 |
+
"""
|
34 |
+
Initialize the detection model
|
35 |
+
|
36 |
+
Args:
|
37 |
+
model_name: Model name or path, default is yolov8m.pt
|
38 |
+
confidence: Confidence threshold, default is 0.25
|
39 |
+
iou: IoU threshold for non-maximum suppression, default is 0.45
|
40 |
+
"""
|
41 |
+
self.model_name = model_name
|
42 |
+
self.confidence = confidence
|
43 |
+
self.iou = iou
|
44 |
+
self.model = None
|
45 |
+
self.class_names = {}
|
46 |
+
self.is_model_loaded = False
|
47 |
+
|
48 |
+
# Load model on initialization
|
49 |
+
self._load_model()
|
50 |
+
|
51 |
+
def _load_model(self):
|
52 |
+
"""Load the YOLO model"""
|
53 |
+
try:
|
54 |
+
print(f"Loading model: {self.model_name}")
|
55 |
+
self.model = YOLO(self.model_name)
|
56 |
+
self.class_names = self.model.names
|
57 |
+
self.is_model_loaded = True
|
58 |
+
print(f"Successfully loaded model: {self.model_name}")
|
59 |
+
print(f"Number of classes the model can recognize: {len(self.class_names)}")
|
60 |
+
except Exception as e:
|
61 |
+
print(f"Error occurred when loading the model: {e}")
|
62 |
+
self.is_model_loaded = False
|
63 |
+
|
64 |
+
def change_model(self, new_model_name: str) -> bool:
|
65 |
+
"""
|
66 |
+
Change the currently loaded model
|
67 |
+
|
68 |
+
Args:
|
69 |
+
new_model_name: Name of the new model to load
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
bool: True if model changed successfully, False otherwise
|
73 |
+
"""
|
74 |
+
if self.model_name == new_model_name and self.is_model_loaded:
|
75 |
+
print(f"Model {new_model_name} is already loaded")
|
76 |
+
return True
|
77 |
+
|
78 |
+
print(f"Changing model from {self.model_name} to {new_model_name}")
|
79 |
+
|
80 |
+
# Unload current model to free memory
|
81 |
+
if self.model is not None:
|
82 |
+
del self.model
|
83 |
+
self.model = None
|
84 |
+
|
85 |
+
# Clean GPU memory if available
|
86 |
+
if torch.cuda.is_available():
|
87 |
+
torch.cuda.empty_cache()
|
88 |
+
|
89 |
+
# Update model name and load new model
|
90 |
+
self.model_name = new_model_name
|
91 |
+
self._load_model()
|
92 |
+
|
93 |
+
return self.is_model_loaded
|
94 |
+
|
95 |
+
def reload_model(self):
|
96 |
+
"""Reload the model (useful for changing model or after error)"""
|
97 |
+
if self.model is not None:
|
98 |
+
del self.model
|
99 |
+
self.model = None
|
100 |
+
|
101 |
+
# Clean GPU memory if available
|
102 |
+
if torch.cuda.is_available():
|
103 |
+
torch.cuda.empty_cache()
|
104 |
+
|
105 |
+
self._load_model()
|
106 |
+
|
107 |
+
def detect(self, image_input: Any) -> Optional[Any]:
|
108 |
+
"""
|
109 |
+
Perform object detection on a single image
|
110 |
+
|
111 |
+
Args:
|
112 |
+
image_input: Image path (str), PIL Image, or numpy array
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Detection result object or None if error occurred
|
116 |
+
"""
|
117 |
+
if self.model is None or not self.is_model_loaded:
|
118 |
+
print("Model not found or not loaded. Attempting to reload...")
|
119 |
+
self._load_model()
|
120 |
+
if self.model is None or not self.is_model_loaded:
|
121 |
+
print("Failed to load model. Cannot perform detection.")
|
122 |
+
return None
|
123 |
+
|
124 |
+
try:
|
125 |
+
results = self.model(image_input, conf=self.confidence, iou=self.iou)
|
126 |
+
return results[0]
|
127 |
+
except Exception as e:
|
128 |
+
print(f"Error occurred during detection: {e}")
|
129 |
+
return None
|
130 |
+
|
131 |
+
def get_class_names(self, class_id: int) -> str:
|
132 |
+
"""Get class name for a given class ID"""
|
133 |
+
return self.class_names.get(class_id, "Unknown Class")
|
134 |
+
|
135 |
+
def get_supported_classes(self) -> Dict[int, str]:
|
136 |
+
"""Get all supported classes as a dictionary of {id: class_name}"""
|
137 |
+
return self.class_names
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def get_available_models(cls) -> List[Dict]:
|
141 |
+
"""
|
142 |
+
Get list of available models with their information
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
List of dictionaries containing model information
|
146 |
+
"""
|
147 |
+
models = []
|
148 |
+
for model_file, info in cls.MODEL_INFO.items():
|
149 |
+
models.append({
|
150 |
+
"model_file": model_file,
|
151 |
+
"name": info["name"],
|
152 |
+
"description": info["description"],
|
153 |
+
"size_mb": info["size_mb"],
|
154 |
+
"inference_speed": info["inference_speed"]
|
155 |
+
})
|
156 |
+
return models
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def get_model_description(cls, model_name: str) -> str:
|
160 |
+
"""Get description for a specific model"""
|
161 |
+
if model_name in cls.MODEL_INFO:
|
162 |
+
info = cls.MODEL_INFO[model_name]
|
163 |
+
return f"{info['name']}: {info['description']} (Size: ~{info['size_mb']}MB, Speed: {info['inference_speed']})"
|
164 |
+
return "Model information not available"
|
evaluation_metrics.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from typing import Dict, List, Any, Optional, Tuple
|
4 |
+
|
5 |
+
class EvaluationMetrics:
|
6 |
+
"""Class for computing detection metrics, generating statistics and visualization data"""
|
7 |
+
|
8 |
+
@staticmethod
|
9 |
+
def calculate_basic_stats(result: Any) -> Dict:
|
10 |
+
"""
|
11 |
+
Calculate basic statistics for a single detection result
|
12 |
+
|
13 |
+
Args:
|
14 |
+
result: Detection result object
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
Dictionary with basic statistics
|
18 |
+
"""
|
19 |
+
if result is None:
|
20 |
+
return {"error": "No detection result provided"}
|
21 |
+
|
22 |
+
# Get classes and confidences
|
23 |
+
classes = result.boxes.cls.cpu().numpy().astype(int)
|
24 |
+
confidences = result.boxes.conf.cpu().numpy()
|
25 |
+
names = result.names
|
26 |
+
|
27 |
+
# Count by class
|
28 |
+
class_counts = {}
|
29 |
+
for cls, conf in zip(classes, confidences):
|
30 |
+
cls_name = names[int(cls)]
|
31 |
+
if cls_name not in class_counts:
|
32 |
+
class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
|
33 |
+
|
34 |
+
class_counts[cls_name]["count"] += 1
|
35 |
+
class_counts[cls_name]["total_confidence"] += float(conf)
|
36 |
+
class_counts[cls_name]["confidences"].append(float(conf))
|
37 |
+
|
38 |
+
# Calculate average confidence
|
39 |
+
for cls_name, stats in class_counts.items():
|
40 |
+
if stats["count"] > 0:
|
41 |
+
stats["average_confidence"] = stats["total_confidence"] / stats["count"]
|
42 |
+
stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
|
43 |
+
stats.pop("total_confidence") # Remove intermediate calculation
|
44 |
+
|
45 |
+
# Prepare summary
|
46 |
+
stats = {
|
47 |
+
"total_objects": len(classes),
|
48 |
+
"class_statistics": class_counts,
|
49 |
+
"average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
|
50 |
+
}
|
51 |
+
|
52 |
+
return stats
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
|
56 |
+
"""
|
57 |
+
Generate structured data suitable for visualization
|
58 |
+
|
59 |
+
Args:
|
60 |
+
result: Detection result object
|
61 |
+
class_colors: Dictionary mapping class names to color codes (optional)
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dictionary with visualization-ready data
|
65 |
+
"""
|
66 |
+
if result is None:
|
67 |
+
return {"error": "No detection result provided"}
|
68 |
+
|
69 |
+
# Get basic stats first
|
70 |
+
stats = EvaluationMetrics.calculate_basic_stats(result)
|
71 |
+
|
72 |
+
# Create visualization-specific data structure
|
73 |
+
viz_data = {
|
74 |
+
"total_objects": stats["total_objects"],
|
75 |
+
"average_confidence": stats["average_confidence"],
|
76 |
+
"class_data": []
|
77 |
+
}
|
78 |
+
|
79 |
+
# Sort classes by count (descending)
|
80 |
+
sorted_classes = sorted(
|
81 |
+
stats["class_statistics"].items(),
|
82 |
+
key=lambda x: x[1]["count"],
|
83 |
+
reverse=True
|
84 |
+
)
|
85 |
+
|
86 |
+
# Create class-specific visualization data
|
87 |
+
for cls_name, cls_stats in sorted_classes:
|
88 |
+
class_id = -1
|
89 |
+
# Find the class ID based on the name
|
90 |
+
for idx, name in result.names.items():
|
91 |
+
if name == cls_name:
|
92 |
+
class_id = idx
|
93 |
+
break
|
94 |
+
|
95 |
+
cls_data = {
|
96 |
+
"name": cls_name,
|
97 |
+
"class_id": class_id,
|
98 |
+
"count": cls_stats["count"],
|
99 |
+
"average_confidence": cls_stats.get("average_confidence", 0),
|
100 |
+
"confidence_std": cls_stats.get("confidence_std", 0),
|
101 |
+
"color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
|
102 |
+
}
|
103 |
+
|
104 |
+
viz_data["class_data"].append(cls_data)
|
105 |
+
|
106 |
+
return viz_data
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7),
|
110 |
+
max_classes: int = 30) -> plt.Figure:
|
111 |
+
"""
|
112 |
+
Create a horizontal bar chart showing detection statistics
|
113 |
+
|
114 |
+
Args:
|
115 |
+
viz_data: Visualization data generated by generate_visualization_data
|
116 |
+
figsize: Figure size (width, height) in inches
|
117 |
+
max_classes: Maximum number of classes to display
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Matplotlib figure object
|
121 |
+
"""
|
122 |
+
if "error" in viz_data:
|
123 |
+
# Create empty plot if error
|
124 |
+
fig, ax = plt.subplots(figsize=figsize)
|
125 |
+
ax.text(0.5, 0.5, viz_data["error"],
|
126 |
+
ha='center', va='center', fontsize=12)
|
127 |
+
ax.set_xlim(0, 1)
|
128 |
+
ax.set_ylim(0, 1)
|
129 |
+
ax.axis('off')
|
130 |
+
return fig
|
131 |
+
|
132 |
+
if "class_data" not in viz_data or not viz_data["class_data"]:
|
133 |
+
# Create empty plot if no data
|
134 |
+
fig, ax = plt.subplots(figsize=figsize)
|
135 |
+
ax.text(0.5, 0.5, "No detection data available",
|
136 |
+
ha='center', va='center', fontsize=12)
|
137 |
+
ax.set_xlim(0, 1)
|
138 |
+
ax.set_ylim(0, 1)
|
139 |
+
ax.axis('off')
|
140 |
+
return fig
|
141 |
+
|
142 |
+
# Limit to max_classes
|
143 |
+
class_data = viz_data["class_data"][:max_classes]
|
144 |
+
|
145 |
+
# Extract data for plotting
|
146 |
+
class_names = [item["name"] for item in class_data]
|
147 |
+
counts = [item["count"] for item in class_data]
|
148 |
+
colors = [item["color"] for item in class_data]
|
149 |
+
|
150 |
+
# Create figure and horizontal bar chart
|
151 |
+
fig, ax = plt.subplots(figsize=figsize)
|
152 |
+
y_pos = np.arange(len(class_names))
|
153 |
+
|
154 |
+
# Create horizontal bars with class-specific colors
|
155 |
+
bars = ax.barh(y_pos, counts, color=colors, alpha=0.8)
|
156 |
+
|
157 |
+
# Add count values at end of each bar
|
158 |
+
for i, bar in enumerate(bars):
|
159 |
+
width = bar.get_width()
|
160 |
+
conf = class_data[i]["average_confidence"]
|
161 |
+
ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
|
162 |
+
f"{width:.0f} (conf: {conf:.2f})",
|
163 |
+
va='center', fontsize=9)
|
164 |
+
|
165 |
+
# Customize axis and labels
|
166 |
+
ax.set_yticks(y_pos)
|
167 |
+
ax.set_yticklabels(class_names)
|
168 |
+
ax.invert_yaxis() # Labels read top-to-bottom
|
169 |
+
ax.set_xlabel('Count')
|
170 |
+
ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total')
|
171 |
+
|
172 |
+
# Add grid for better readability
|
173 |
+
ax.set_axisbelow(True)
|
174 |
+
ax.grid(axis='x', linestyle='--', alpha=0.7)
|
175 |
+
|
176 |
+
# Add detection summary as a text box
|
177 |
+
summary_text = (
|
178 |
+
f"Total Objects: {viz_data['total_objects']}\n"
|
179 |
+
f"Average Confidence: {viz_data['average_confidence']:.2f}\n"
|
180 |
+
f"Unique Classes: {len(viz_data['class_data'])}"
|
181 |
+
)
|
182 |
+
plt.figtext(0.02, 0.02, summary_text, fontsize=9,
|
183 |
+
bbox=dict(facecolor='white', alpha=0.8, boxstyle='round'))
|
184 |
+
|
185 |
+
plt.tight_layout()
|
186 |
+
return fig
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def format_detection_summary(viz_data: Dict) -> str:
|
190 |
+
"""
|
191 |
+
Format detection results as a readable text summary
|
192 |
+
"""
|
193 |
+
if "error" in viz_data:
|
194 |
+
return viz_data["error"]
|
195 |
+
|
196 |
+
if "total_objects" not in viz_data:
|
197 |
+
return "No detection data available."
|
198 |
+
|
199 |
+
# 移除時間顯示
|
200 |
+
total_objects = viz_data["total_objects"]
|
201 |
+
avg_confidence = viz_data["average_confidence"]
|
202 |
+
|
203 |
+
# 創建標題
|
204 |
+
lines = [
|
205 |
+
f"Detected {total_objects} objects.",
|
206 |
+
f"Average confidence: {avg_confidence:.2f}",
|
207 |
+
"",
|
208 |
+
"Objects by class:",
|
209 |
+
]
|
210 |
+
|
211 |
+
# 添加類別詳情
|
212 |
+
if "class_data" in viz_data and viz_data["class_data"]:
|
213 |
+
for item in viz_data["class_data"]:
|
214 |
+
lines.append(
|
215 |
+
f"• {item['name']}: {item['count']} (avg conf: {item['average_confidence']:.2f})"
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
lines.append("No class information available.")
|
219 |
+
|
220 |
+
return "\n".join(lines)
|
221 |
+
|
222 |
+
@staticmethod
|
223 |
+
def calculate_distance_metrics(result: Any) -> Dict:
|
224 |
+
"""
|
225 |
+
Calculate distance-related metrics for detected objects
|
226 |
+
|
227 |
+
Args:
|
228 |
+
result: Detection result object
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Dictionary with distance metrics
|
232 |
+
"""
|
233 |
+
if result is None:
|
234 |
+
return {"error": "No detection result provided"}
|
235 |
+
|
236 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
237 |
+
classes = result.boxes.cls.cpu().numpy().astype(int)
|
238 |
+
names = result.names
|
239 |
+
|
240 |
+
# Initialize metrics
|
241 |
+
metrics = {
|
242 |
+
"proximity": {}, # Classes that appear close to each other
|
243 |
+
"spatial_distribution": {}, # Distribution across the image
|
244 |
+
"size_distribution": {} # Size distribution of objects
|
245 |
+
}
|
246 |
+
|
247 |
+
# Calculate image dimensions (assuming normalized coordinates or extract from result)
|
248 |
+
img_width, img_height = 1, 1
|
249 |
+
if hasattr(result, "orig_shape"):
|
250 |
+
img_height, img_width = result.orig_shape[:2]
|
251 |
+
|
252 |
+
# Calculate bounding box areas and centers
|
253 |
+
areas = []
|
254 |
+
centers = []
|
255 |
+
class_names = []
|
256 |
+
|
257 |
+
for box, cls in zip(boxes, classes):
|
258 |
+
x1, y1, x2, y2 = box
|
259 |
+
width, height = x2 - x1, y2 - y1
|
260 |
+
area = width * height
|
261 |
+
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
262 |
+
|
263 |
+
areas.append(area)
|
264 |
+
centers.append((center_x, center_y))
|
265 |
+
class_names.append(names[int(cls)])
|
266 |
+
|
267 |
+
# Calculate spatial distribution
|
268 |
+
if centers:
|
269 |
+
x_coords = [c[0] for c in centers]
|
270 |
+
y_coords = [c[1] for c in centers]
|
271 |
+
|
272 |
+
metrics["spatial_distribution"] = {
|
273 |
+
"x_mean": float(np.mean(x_coords)) / img_width,
|
274 |
+
"y_mean": float(np.mean(y_coords)) / img_height,
|
275 |
+
"x_std": float(np.std(x_coords)) / img_width,
|
276 |
+
"y_std": float(np.std(y_coords)) / img_height
|
277 |
+
}
|
278 |
+
|
279 |
+
# Calculate size distribution
|
280 |
+
if areas:
|
281 |
+
metrics["size_distribution"] = {
|
282 |
+
"mean_area": float(np.mean(areas)) / (img_width * img_height),
|
283 |
+
"std_area": float(np.std(areas)) / (img_width * img_height),
|
284 |
+
"min_area": float(np.min(areas)) / (img_width * img_height),
|
285 |
+
"max_area": float(np.max(areas)) / (img_width * img_height)
|
286 |
+
}
|
287 |
+
|
288 |
+
# Calculate proximity between different classes
|
289 |
+
class_centers = {}
|
290 |
+
for cls_name, center in zip(class_names, centers):
|
291 |
+
if cls_name not in class_centers:
|
292 |
+
class_centers[cls_name] = []
|
293 |
+
class_centers[cls_name].append(center)
|
294 |
+
|
295 |
+
# Find classes that appear close to each other
|
296 |
+
proximity_pairs = []
|
297 |
+
for i, cls1 in enumerate(class_centers.keys()):
|
298 |
+
for j, cls2 in enumerate(class_centers.keys()):
|
299 |
+
if i >= j: # Avoid duplicate pairs and self-comparison
|
300 |
+
continue
|
301 |
+
|
302 |
+
# Calculate minimum distance between any two objects of these classes
|
303 |
+
min_distance = float('inf')
|
304 |
+
for center1 in class_centers[cls1]:
|
305 |
+
for center2 in class_centers[cls2]:
|
306 |
+
dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
|
307 |
+
min_distance = min(min_distance, dist)
|
308 |
+
|
309 |
+
# Normalize by image diagonal
|
310 |
+
img_diagonal = np.sqrt(img_width**2 + img_height**2)
|
311 |
+
norm_distance = min_distance / img_diagonal
|
312 |
+
|
313 |
+
proximity_pairs.append({
|
314 |
+
"class1": cls1,
|
315 |
+
"class2": cls2,
|
316 |
+
"distance": float(norm_distance)
|
317 |
+
})
|
318 |
+
|
319 |
+
# Sort by distance and keep the closest pairs
|
320 |
+
proximity_pairs.sort(key=lambda x: x["distance"])
|
321 |
+
metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
|
322 |
+
|
323 |
+
return metrics
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision>=0.15.0
|
3 |
+
ultralytics>=8.0.0
|
4 |
+
opencv-python>=4.7.0
|
5 |
+
pillow>=9.4.0
|
6 |
+
numpy>=1.23.5
|
7 |
+
matplotlib>=3.7.0
|
8 |
+
gradio>=3.32.0
|
visualization_helper.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from typing import Any, List, Dict, Tuple, Optional
|
5 |
+
import io
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
class VisualizationHelper:
|
9 |
+
"""Helper class for visualizing detection results"""
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
|
13 |
+
figsize: Tuple[int, int] = (12, 12),
|
14 |
+
return_pil: bool = False) -> Optional[Image.Image]:
|
15 |
+
"""
|
16 |
+
Visualize detection results on a single image
|
17 |
+
|
18 |
+
Args:
|
19 |
+
image: Image path or numpy array
|
20 |
+
result: Detection result object
|
21 |
+
color_mapper: ColorMapper instance for consistent colors
|
22 |
+
figsize: Figure size
|
23 |
+
return_pil: If True, returns a PIL Image object
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
PIL Image if return_pil is True, otherwise displays the plot
|
27 |
+
"""
|
28 |
+
if result is None:
|
29 |
+
print('No data for visualization')
|
30 |
+
return None
|
31 |
+
|
32 |
+
# Read image if path is provided
|
33 |
+
if isinstance(image, str):
|
34 |
+
img = cv2.imread(image)
|
35 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
36 |
+
else:
|
37 |
+
img = image
|
38 |
+
if len(img.shape) == 3 and img.shape[2] == 3:
|
39 |
+
# Check if BGR format (OpenCV) and convert to RGB if needed
|
40 |
+
if isinstance(img, np.ndarray):
|
41 |
+
# Assuming BGR format from OpenCV
|
42 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
43 |
+
|
44 |
+
# Create figure
|
45 |
+
fig, ax = plt.subplots(figsize=figsize)
|
46 |
+
ax.imshow(img)
|
47 |
+
|
48 |
+
# Get bounding boxes, classes and confidences
|
49 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
50 |
+
classes = result.boxes.cls.cpu().numpy()
|
51 |
+
confs = result.boxes.conf.cpu().numpy()
|
52 |
+
|
53 |
+
# Get class names
|
54 |
+
names = result.names
|
55 |
+
|
56 |
+
# Create a default color mapper if none is provided
|
57 |
+
if color_mapper is None:
|
58 |
+
# For backward compatibility, fallback to a simple color function
|
59 |
+
from matplotlib import colormaps
|
60 |
+
cmap = colormaps['tab10']
|
61 |
+
def get_color(class_id):
|
62 |
+
return cmap(class_id % 10)
|
63 |
+
else:
|
64 |
+
# Use the provided color mapper
|
65 |
+
def get_color(class_id):
|
66 |
+
hex_color = color_mapper.get_color(class_id)
|
67 |
+
# Convert hex to RGB float values for matplotlib
|
68 |
+
hex_color = hex_color.lstrip('#')
|
69 |
+
return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
|
70 |
+
|
71 |
+
# Draw detection results
|
72 |
+
for box, cls, conf in zip(boxes, classes, confs):
|
73 |
+
x1, y1, x2, y2 = box
|
74 |
+
cls_id = int(cls)
|
75 |
+
cls_name = names[cls_id]
|
76 |
+
|
77 |
+
# Get color for this class
|
78 |
+
box_color = get_color(cls_id)
|
79 |
+
|
80 |
+
# Add text label with colored background
|
81 |
+
ax.text(x1, y1 - 5, f'{cls_name}: {conf:.2f}',
|
82 |
+
color='white', fontsize=10,
|
83 |
+
bbox=dict(facecolor=box_color[:3], alpha=0.7))
|
84 |
+
|
85 |
+
# Add bounding box
|
86 |
+
ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
|
87 |
+
fill=False, edgecolor=box_color[:3], linewidth=2))
|
88 |
+
|
89 |
+
ax.axis('off')
|
90 |
+
# ax.set_title('Detection Result')
|
91 |
+
plt.tight_layout()
|
92 |
+
|
93 |
+
if return_pil:
|
94 |
+
# Convert plot to PIL Image
|
95 |
+
buf = io.BytesIO()
|
96 |
+
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
97 |
+
buf.seek(0)
|
98 |
+
pil_img = Image.open(buf)
|
99 |
+
plt.close(fig)
|
100 |
+
return pil_img
|
101 |
+
else:
|
102 |
+
plt.show()
|
103 |
+
return None
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def create_summary(result: Any) -> Dict:
|
107 |
+
"""
|
108 |
+
Create a summary of detection results
|
109 |
+
|
110 |
+
Args:
|
111 |
+
result: Detection result object
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Dictionary with detection summary statistics
|
115 |
+
"""
|
116 |
+
if result is None:
|
117 |
+
return {"error": "No detection result provided"}
|
118 |
+
|
119 |
+
# Get classes and confidences
|
120 |
+
classes = result.boxes.cls.cpu().numpy().astype(int)
|
121 |
+
confidences = result.boxes.conf.cpu().numpy()
|
122 |
+
names = result.names
|
123 |
+
|
124 |
+
# Count detections by class
|
125 |
+
class_counts = {}
|
126 |
+
for cls, conf in zip(classes, confidences):
|
127 |
+
cls_name = names[int(cls)]
|
128 |
+
if cls_name not in class_counts:
|
129 |
+
class_counts[cls_name] = {"count": 0, "confidences": []}
|
130 |
+
|
131 |
+
class_counts[cls_name]["count"] += 1
|
132 |
+
class_counts[cls_name]["confidences"].append(float(conf))
|
133 |
+
|
134 |
+
# Calculate average confidence for each class
|
135 |
+
for cls_name, stats in class_counts.items():
|
136 |
+
if stats["confidences"]:
|
137 |
+
stats["average_confidence"] = float(np.mean(stats["confidences"]))
|
138 |
+
stats.pop("confidences") # Remove detailed confidences list to keep summary concise
|
139 |
+
|
140 |
+
# Prepare summary
|
141 |
+
summary = {
|
142 |
+
"total_objects": len(classes),
|
143 |
+
"class_counts": class_counts,
|
144 |
+
"unique_classes": len(class_counts)
|
145 |
+
}
|
146 |
+
|
147 |
+
return summary
|