DawnC commited on
Commit
e4aac34
·
verified ·
1 Parent(s): 4a77aaa

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +555 -0
  2. color_mapper.py +270 -0
  3. detection_model.py +164 -0
  4. evaluation_metrics.py +323 -0
  5. requirements.txt +8 -0
  6. visualization_helper.py +147 -0
app.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import gradio as gr
7
+ import io
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import spaces
10
+ from typing import Dict, List, Any, Optional, Tuple
11
+ from ultralytics import YOLO
12
+
13
+ from detection_model import DetectionModel
14
+ from color_mapper import ColorMapper
15
+ from visualization_helper import VisualizationHelper
16
+ from evaluation_metrics import EvaluationMetrics
17
+
18
+
19
+ color_mapper = ColorMapper()
20
+ model_instances = {}
21
+
22
+ @spaces.GPU
23
+ def process_image(image, model_instance, confidence_threshold, filter_classes=None):
24
+ """
25
+ Process an image for object detection
26
+
27
+ Args:
28
+ image: Input image (numpy array or PIL Image)
29
+ model_instance: DetectionModel instance to use
30
+ confidence_threshold: Confidence threshold for detection
31
+ filter_classes: Optional list of classes to filter results
32
+
33
+ Returns:
34
+ Tuple of (result_image, result_text, stats_data)
35
+ """
36
+ # initialize key variables
37
+ result = None
38
+ stats = {}
39
+ temp_path = None
40
+
41
+ try:
42
+ # update confidence threshold
43
+ model_instance.confidence = confidence_threshold
44
+
45
+ # processing input image
46
+ if isinstance(image, np.ndarray):
47
+ # Convert BGR to RGB if needed
48
+ if image.shape[2] == 3:
49
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
+ else:
51
+ image_rgb = image
52
+ pil_image = Image.fromarray(image_rgb)
53
+ elif image is None:
54
+ return None, "No image provided. Please upload an image.", {}
55
+ else:
56
+ pil_image = image
57
+
58
+ # store temp files
59
+ import uuid
60
+ import tempfile
61
+
62
+ temp_dir = tempfile.gettempdir() # use system temp directory
63
+ temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
64
+ temp_path = os.path.join(temp_dir, temp_filename)
65
+ pil_image.save(temp_path)
66
+
67
+ # object detection
68
+ result = model_instance.detect(temp_path)
69
+
70
+ if result is None:
71
+ return None, "Detection failed. Please try again with a different image.", {}
72
+
73
+ # calculate stats
74
+ stats = EvaluationMetrics.calculate_basic_stats(result)
75
+
76
+ # add space calculation
77
+ spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
78
+ stats["spatial_metrics"] = spatial_metrics
79
+
80
+ if filter_classes and len(filter_classes) > 0:
81
+ # get classes, boxes, confidence
82
+ classes = result.boxes.cls.cpu().numpy().astype(int)
83
+ confs = result.boxes.conf.cpu().numpy()
84
+ boxes = result.boxes.xyxy.cpu().numpy()
85
+
86
+ mask = np.zeros_like(classes, dtype=bool)
87
+ for cls_id in filter_classes:
88
+ mask = np.logical_or(mask, classes == cls_id)
89
+
90
+ filtered_stats = {
91
+ "total_objects": int(np.sum(mask)),
92
+ "class_statistics": {},
93
+ "average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
94
+ "spatial_metrics": stats["spatial_metrics"]
95
+ }
96
+
97
+ # update stats
98
+ names = result.names
99
+ for cls, conf in zip(classes[mask], confs[mask]):
100
+ cls_name = names[int(cls)]
101
+ if cls_name not in filtered_stats["class_statistics"]:
102
+ filtered_stats["class_statistics"][cls_name] = {
103
+ "count": 0,
104
+ "average_confidence": 0
105
+ }
106
+
107
+ filtered_stats["class_statistics"][cls_name]["count"] += 1
108
+ filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
109
+
110
+ stats = filtered_stats
111
+
112
+ viz_data = EvaluationMetrics.generate_visualization_data(
113
+ result,
114
+ color_mapper.get_all_colors()
115
+ )
116
+
117
+ result_image = VisualizationHelper.visualize_detection(
118
+ temp_path, result, color_mapper=color_mapper, figsize=(12, 12), return_pil=True
119
+ )
120
+
121
+ result_text = EvaluationMetrics.format_detection_summary(viz_data)
122
+
123
+ return result_image, result_text, stats
124
+
125
+ except Exception as e:
126
+ error_message = f"Error Occurs: {str(e)}"
127
+ import traceback
128
+ traceback.print_exc()
129
+ print(error_message)
130
+ return None, error_message, {}
131
+
132
+ finally:
133
+ if temp_path and os.path.exists(temp_path):
134
+ try:
135
+ os.remove(temp_path)
136
+ except Exception as e:
137
+ print(f"Cannot delete temp files {temp_path}: {str(e)}")
138
+
139
+ def format_result_text(stats):
140
+ """Format detection statistics into readable text"""
141
+ if not stats or "total_objects" not in stats:
142
+ return "No objects detected."
143
+
144
+ lines = [
145
+ f"Detected {stats['total_objects']} objects.",
146
+ f"Average confidence: {stats.get('average_confidence', 0):.2f}",
147
+ "",
148
+ "Objects by class:",
149
+ ]
150
+
151
+ if "class_statistics" in stats and stats["class_statistics"]:
152
+ # Sort classes by count
153
+ sorted_classes = sorted(
154
+ stats["class_statistics"].items(),
155
+ key=lambda x: x[1]["count"],
156
+ reverse=True
157
+ )
158
+
159
+ for cls_name, cls_stats in sorted_classes:
160
+ lines.append(f"• {cls_name}: {cls_stats['count']} (avg conf: {cls_stats.get('average_confidence', 0):.2f})")
161
+ else:
162
+ lines.append("No class information available.")
163
+
164
+ return "\n".join(lines)
165
+
166
+ def get_all_classes():
167
+ """Get all available COCO classes"""
168
+ try:
169
+ class_names = model.class_names
170
+ return [(idx, name) for idx, name in class_names.items()]
171
+ except:
172
+ # Fallback to standard COCO classes
173
+ return [
174
+ (0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'),
175
+ (5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'),
176
+ (10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'),
177
+ (14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'),
178
+ (20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'),
179
+ (25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'),
180
+ (30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'),
181
+ (35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'),
182
+ (39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'),
183
+ (44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'),
184
+ (49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'),
185
+ (54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'),
186
+ (59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'),
187
+ (64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'),
188
+ (69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'),
189
+ (74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'),
190
+ (79, 'toothbrush')
191
+ ]
192
+
193
+ def create_interface():
194
+ """Create the Gradio interface"""
195
+ # Get CSS styles
196
+ css = """
197
+ body {
198
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
199
+ background: linear-gradient(120deg, #e0f7fa, #b2ebf2);
200
+ margin: 0;
201
+ padding: 0;
202
+ }
203
+
204
+ .gradio-container {
205
+ max-width: 1200px !important;
206
+ }
207
+
208
+ .app-header {
209
+ text-align: center;
210
+ margin-bottom: 2rem;
211
+ background: rgba(255, 255, 255, 0.8);
212
+ padding: 1.5rem;
213
+ border-radius: 10px;
214
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
215
+ }
216
+
217
+ .app-title {
218
+ color: #2D3748;
219
+ font-size: 2.5rem;
220
+ margin-bottom: 0.5rem;
221
+ background: linear-gradient(90deg, #4299e1, #48bb78);
222
+ -webkit-background-clip: text;
223
+ -webkit-text-fill-color: transparent;
224
+ }
225
+
226
+ .app-subtitle {
227
+ color: #4A5568;
228
+ font-size: 1.2rem;
229
+ font-weight: normal;
230
+ margin-top: 0.25rem;
231
+ }
232
+
233
+ .app-divider {
234
+ width: 50px;
235
+ height: 3px;
236
+ background: linear-gradient(90deg, #4299e1, #48bb78);
237
+ margin: 1rem auto;
238
+ }
239
+
240
+ .input-panel, .output-panel {
241
+ background: white;
242
+ border-radius: 10px;
243
+ padding: 1rem;
244
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
245
+ }
246
+
247
+ .detect-btn {
248
+ background: linear-gradient(90deg, #4299e1, #48bb78) !important;
249
+ color: white !important;
250
+ border: none !important;
251
+ transition: transform 0.3s, box-shadow 0.3s !important;
252
+ }
253
+
254
+ .detect-btn:hover {
255
+ transform: translateY(-2px) !important;
256
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2) !important;
257
+ }
258
+
259
+ .detect-btn:active {
260
+ transform: translateY(1px) !important;
261
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2) !important;
262
+ }
263
+
264
+ .footer {
265
+ text-align: center;
266
+ margin-top: 2rem;
267
+ font-size: 0.9rem;
268
+ color: #4A5568;
269
+ }
270
+
271
+ /* Responsive adjustments */
272
+ @media (max-width: 768px) {
273
+ .app-title {
274
+ font-size: 2rem;
275
+ }
276
+
277
+ .app-subtitle {
278
+ font-size: 1rem;
279
+ }
280
+ }
281
+ """
282
+
283
+ # get the models info
284
+ available_models = DetectionModel.get_available_models()
285
+ model_choices = [model["model_file"] for model in available_models]
286
+ model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models]
287
+
288
+ # Available classes for filtering
289
+ available_classes = get_all_classes()
290
+ class_choices = [f"{id}: {name}" for id, name in available_classes]
291
+
292
+ # Create Gradio Blocks interface
293
+ with gr.Blocks(css=css) as demo:
294
+ # Header
295
+ with gr.Group(elem_classes="app-header"):
296
+ gr.HTML("""
297
+ <h1 class="app-title">VisionScout</h1>
298
+ <h2 class="app-subtitle">Detect and identify objects in your images</h2>
299
+ <div class="app-divider"></div>
300
+ """)
301
+
302
+ current_model = gr.State("yolov8m.pt") # use medium size as default
303
+
304
+ # Input and Output panels
305
+ with gr.Row():
306
+ # Left column - Input controls
307
+ with gr.Column(scale=4, elem_classes="input-panel"):
308
+ with gr.Group():
309
+ gr.Markdown("### Upload Image")
310
+ image_input = gr.Image(type="pil", label="Upload an image")
311
+
312
+ with gr.Accordion("Advanced Settings", open=False):
313
+ with gr.Row():
314
+ model_dropdown = gr.Dropdown(
315
+ choices=model_choices,
316
+ value="yolov8m.pt",
317
+ label="Select Model",
318
+ info="Choose different models based on your needs for speed vs. accuracy"
319
+ )
320
+
321
+ # display model info
322
+ model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt"))
323
+
324
+ confidence = gr.Slider(
325
+ minimum=0.1,
326
+ maximum=0.9,
327
+ value=0.25,
328
+ step=0.05,
329
+ label="Confidence Threshold",
330
+ info="Higher values show fewer but more confident detections"
331
+ )
332
+
333
+ with gr.Accordion("Filter Classes", open=False):
334
+ # Common object categories
335
+ with gr.Row():
336
+ people_btn = gr.Button("People")
337
+ vehicles_btn = gr.Button("Vehicles")
338
+ animals_btn = gr.Button("Animals")
339
+ objects_btn = gr.Button("Common Objects")
340
+
341
+ # Class selection
342
+ class_filter = gr.Dropdown(
343
+ choices=class_choices,
344
+ multiselect=True,
345
+ label="Select Classes to Display",
346
+ info="Leave empty to show all detected objects"
347
+ )
348
+
349
+ detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn")
350
+
351
+ with gr.Group():
352
+ gr.Markdown("### How to Use")
353
+ gr.Markdown("""
354
+ 1. Upload an image or use the camera
355
+ 2. Adjust confidence threshold if needed
356
+ 3. Optionally filter to specific object classes
357
+ 4. Click "Detect Objects" button
358
+
359
+ The model will identify objects in your image and display them with bounding boxes.
360
+
361
+ **Note:** Detection quality depends on image clarity and object visibility. The model can detect up to 80 different types of common objects.
362
+ """)
363
+
364
+ # Right column - Results display
365
+ with gr.Column(scale=6, elem_classes="output-panel"):
366
+ with gr.Tab("Detection Result"):
367
+ result_image = gr.Image(type="pil", label="Detection Result")
368
+ result_text = gr.Textbox(label="Detection Details", lines=10)
369
+
370
+ with gr.Tab("Statistics"):
371
+ with gr.Row():
372
+ with gr.Column(scale=1):
373
+ stats_json = gr.Json(label="Full Statistics")
374
+
375
+ with gr.Column(scale=1):
376
+ gr.Markdown("### Object Distribution")
377
+ plot_output = gr.Plot(label="Object Distribution")
378
+
379
+ # model option
380
+ model_dropdown.change(
381
+ fn=lambda model: (model, DetectionModel.get_model_description(model)),
382
+ inputs=[model_dropdown],
383
+ outputs=[current_model, model_info]
384
+ )
385
+
386
+ # change the buttom of different model
387
+ detect_btn.click(
388
+ fn=lambda img, model, conf, classes: process_and_plot(img, model, conf, classes),
389
+ inputs=[image_input, current_model, confidence, class_filter],
390
+ outputs=[result_image, result_text, stats_json, plot_output]
391
+ )
392
+
393
+ # Quick filter buttons
394
+ people_classes = [0] # Person
395
+ vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # Various vehicles
396
+ animals_classes = list(range(14, 24)) # Animals in COCO
397
+ common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # Common household items
398
+
399
+ people_btn.click(
400
+ lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes],
401
+ outputs=class_filter
402
+ )
403
+
404
+ vehicles_btn.click(
405
+ lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes],
406
+ outputs=class_filter
407
+ )
408
+
409
+ animals_btn.click(
410
+ lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes],
411
+ outputs=class_filter
412
+ )
413
+
414
+ objects_btn.click(
415
+ lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects],
416
+ outputs=class_filter
417
+ )
418
+
419
+ # Set up example images
420
+ example_images = [
421
+ "room_01.jpg",
422
+ "street_01.jpg",
423
+ "street_02.jpg",
424
+ "street_03.jpg"
425
+ ]
426
+
427
+
428
+ gr.Examples(
429
+ examples=example_images,
430
+ inputs=image_input,
431
+ outputs=None,
432
+ fn=None,
433
+ cache_examples=False,
434
+ )
435
+
436
+ # Footer
437
+ gr.HTML("""
438
+ <div class="footer">
439
+ <p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p>
440
+ <p>Model can detect 80 different classes of objects</p>
441
+ </div>
442
+ """)
443
+
444
+ return demo
445
+
446
+ @spaces.GPU
447
+ def process_and_plot(image, model_name, confidence_threshold, filter_classes=None):
448
+ """
449
+ Process image and create plots for statistics
450
+
451
+ Args:
452
+ image: Input image
453
+ model_name: Name of the model to use
454
+ confidence_threshold: Confidence threshold for detection
455
+ filter_classes: Optional list of classes to filter results
456
+
457
+ Returns:
458
+ Tuple of (result_image, result_text, stats_json, plot_figure)
459
+ """
460
+ global model_instances
461
+
462
+ if model_name not in model_instances:
463
+ print(f"Creating new model instance for {model_name}")
464
+ model_instances[model_name] = DetectionModel(model_name=model_name, confidence=confidence_threshold, iou=0.45)
465
+ else:
466
+ print(f"Using existing model instance for {model_name}")
467
+ model_instances[model_name].confidence = confidence_threshold
468
+
469
+ class_ids = None
470
+ if filter_classes:
471
+ class_ids = []
472
+ for class_str in filter_classes:
473
+ try:
474
+ # Extract ID from format "id: name"
475
+ class_id = int(class_str.split(":")[0].strip())
476
+ class_ids.append(class_id)
477
+ except:
478
+ continue
479
+
480
+ # execute detection
481
+ result_image, result_text, stats = process_image(
482
+ image,
483
+ model_instances[model_name],
484
+ confidence_threshold,
485
+ class_ids
486
+ )
487
+
488
+ # create stats table
489
+ plot_figure = create_stats_plot(stats)
490
+
491
+ return result_image, result_text, stats, plot_figure
492
+
493
+ def create_stats_plot(stats):
494
+ """
495
+ Create a visualization of statistics data
496
+
497
+ Args:
498
+ stats: Dictionary containing detection statistics
499
+
500
+ Returns:
501
+ Matplotlib figure with visualization
502
+ """
503
+ if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
504
+ # Create empty plot if no data
505
+ fig, ax = plt.subplots(figsize=(8, 6))
506
+ ax.text(0.5, 0.5, "No detection data available",
507
+ ha='center', va='center', fontsize=12)
508
+ ax.set_xlim(0, 1)
509
+ ax.set_ylim(0, 1)
510
+ ax.axis('off')
511
+ return fig
512
+
513
+ # preparing visualization data
514
+ viz_data = {
515
+ "total_objects": stats.get("total_objects", 0),
516
+ "average_confidence": stats.get("average_confidence", 0),
517
+ "class_data": []
518
+ }
519
+
520
+ # get current model classes
521
+ # This uses the get_all_classes function which should retrieve from the current model
522
+ available_classes = dict(get_all_classes())
523
+
524
+ # process class data
525
+ for cls_name, cls_stats in stats.get("class_statistics", {}).items():
526
+ # search for class ID
527
+ class_id = -1
528
+
529
+ # Try to find the class ID from class names
530
+ for id, name in available_classes.items():
531
+ if name == cls_name:
532
+ class_id = id
533
+ break
534
+
535
+ cls_data = {
536
+ "name": cls_name,
537
+ "class_id": class_id,
538
+ "count": cls_stats.get("count", 0),
539
+ "average_confidence": cls_stats.get("average_confidence", 0),
540
+ "color": color_mapper.get_color(class_id if class_id >= 0 else cls_name)
541
+ }
542
+
543
+ viz_data["class_data"].append(cls_data)
544
+
545
+ # Sort by count in descending order
546
+ viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
547
+
548
+ return EvaluationMetrics.create_stats_plot(viz_data)
549
+
550
+
551
+ if __name__ == "__main__":
552
+ import time
553
+
554
+ demo = create_interface()
555
+ demo.launch()
color_mapper.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict, List, Tuple, Union, Any
3
+
4
+ class ColorMapper:
5
+ """
6
+ A class for consistent color mapping of object detection classes
7
+ Provides color schemes for visualization in both RGB and hex formats
8
+ """
9
+
10
+ # Class categories for better organization
11
+ CATEGORIES = {
12
+ "person": [0],
13
+ "vehicles": [1, 2, 3, 4, 5, 6, 7, 8],
14
+ "traffic": [9, 10, 11, 12],
15
+ "animals": [14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
16
+ "outdoor": [13, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33],
17
+ "sports": [34, 35, 36, 37, 38],
18
+ "kitchen": [39, 40, 41, 42, 43, 44, 45],
19
+ "food": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55],
20
+ "furniture": [56, 57, 58, 59, 60, 61],
21
+ "electronics": [62, 63, 64, 65, 66, 67, 68, 69, 70],
22
+ "household": [71, 72, 73, 74, 75, 76, 77, 78, 79]
23
+ }
24
+
25
+ # Base colors for each category (in HSV for easier variation)
26
+ CATEGORY_COLORS = {
27
+ "person": (0, 0.8, 0.9), # Red
28
+ "vehicles": (210, 0.8, 0.9), # Blue
29
+ "traffic": (45, 0.8, 0.9), # Orange
30
+ "animals": (120, 0.7, 0.8), # Green
31
+ "outdoor": (180, 0.7, 0.9), # Cyan
32
+ "sports": (270, 0.7, 0.8), # Purple
33
+ "kitchen": (30, 0.7, 0.9), # Light Orange
34
+ "food": (330, 0.7, 0.85), # Pink
35
+ "furniture": (150, 0.5, 0.85), # Light Green
36
+ "electronics": (240, 0.6, 0.9), # Light Blue
37
+ "household": (60, 0.6, 0.9) # Yellow
38
+ }
39
+
40
+ def __init__(self):
41
+ """Initialize the ColorMapper with COCO class mappings"""
42
+ self.class_names = self._get_coco_classes()
43
+ self.color_map = self._generate_color_map()
44
+
45
+ def _get_coco_classes(self) -> Dict[int, str]:
46
+ """Get the standard COCO class names with their IDs"""
47
+ return {
48
+ 0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
49
+ 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
50
+ 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
51
+ 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
52
+ 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
53
+ 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
54
+ 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
55
+ 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket',
56
+ 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife',
57
+ 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich',
58
+ 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza',
59
+ 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant',
60
+ 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop',
61
+ 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave',
62
+ 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book',
63
+ 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier',
64
+ 79: 'toothbrush'
65
+ }
66
+
67
+ def _hsv_to_rgb(self, h: float, s: float, v: float) -> Tuple[int, int, int]:
68
+ """
69
+ Convert HSV color to RGB
70
+
71
+ Args:
72
+ h: Hue (0-360)
73
+ s: Saturation (0-1)
74
+ v: Value (0-1)
75
+
76
+ Returns:
77
+ Tuple of (R, G, B) values (0-255)
78
+ """
79
+ h = h / 60
80
+ i = int(h)
81
+ f = h - i
82
+ p = v * (1 - s)
83
+ q = v * (1 - s * f)
84
+ t = v * (1 - s * (1 - f))
85
+
86
+ if i == 0:
87
+ r, g, b = v, t, p
88
+ elif i == 1:
89
+ r, g, b = q, v, p
90
+ elif i == 2:
91
+ r, g, b = p, v, t
92
+ elif i == 3:
93
+ r, g, b = p, q, v
94
+ elif i == 4:
95
+ r, g, b = t, p, v
96
+ else:
97
+ r, g, b = v, p, q
98
+
99
+ return (int(r * 255), int(g * 255), int(b * 255))
100
+
101
+ def _rgb_to_hex(self, rgb: Tuple[int, int, int]) -> str:
102
+ """
103
+ Convert RGB color to hex color code
104
+
105
+ Args:
106
+ rgb: Tuple of (R, G, B) values (0-255)
107
+
108
+ Returns:
109
+ Hex color code (e.g. '#FF0000')
110
+ """
111
+ return f'#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}'
112
+
113
+ def _find_category(self, class_id: int) -> str:
114
+ """
115
+ Find the category for a given class ID
116
+
117
+ Args:
118
+ class_id: Class ID (0-79)
119
+
120
+ Returns:
121
+ Category name
122
+ """
123
+ for category, ids in self.CATEGORIES.items():
124
+ if class_id in ids:
125
+ return category
126
+ return "other" # Fallback
127
+
128
+ def _generate_color_map(self) -> Dict:
129
+ """
130
+ Generate a color map for all 80 COCO classes
131
+
132
+ Returns:
133
+ Dictionary mapping class IDs and names to color values
134
+ """
135
+ color_map = {
136
+ 'by_id': {}, # Map class ID to RGB and hex
137
+ 'by_name': {}, # Map class name to RGB and hex
138
+ 'categories': {} # Map category to base color
139
+ }
140
+
141
+ # Generate colors for categories
142
+ for category, hsv in self.CATEGORY_COLORS.items():
143
+ rgb = self._hsv_to_rgb(hsv[0], hsv[1], hsv[2])
144
+ hex_color = self._rgb_to_hex(rgb)
145
+ color_map['categories'][category] = {
146
+ 'rgb': rgb,
147
+ 'hex': hex_color
148
+ }
149
+
150
+ # Generate variations for each class within a category
151
+ for class_id, class_name in self.class_names.items():
152
+ category = self._find_category(class_id)
153
+ base_hsv = self.CATEGORY_COLORS.get(category, (0, 0, 0.8)) # Default gray
154
+
155
+ # Slightly vary the hue and saturation within the category
156
+ ids_in_category = self.CATEGORIES.get(category, [])
157
+ if ids_in_category:
158
+ position = ids_in_category.index(class_id) if class_id in ids_in_category else 0
159
+ variation = position / max(1, len(ids_in_category) - 1) # 0 to 1
160
+
161
+ # Vary hue slightly (±15°) and saturation
162
+ h_offset = 30 * variation - 15 # -15 to +15
163
+ s_offset = 0.2 * variation # 0 to 0.2
164
+
165
+ h = (base_hsv[0] + h_offset) % 360
166
+ s = min(1.0, base_hsv[1] + s_offset)
167
+ v = base_hsv[2]
168
+ else:
169
+ h, s, v = base_hsv
170
+
171
+ rgb = self._hsv_to_rgb(h, s, v)
172
+ hex_color = self._rgb_to_hex(rgb)
173
+
174
+ # Store in both mappings
175
+ color_map['by_id'][class_id] = {
176
+ 'rgb': rgb,
177
+ 'hex': hex_color,
178
+ 'category': category
179
+ }
180
+
181
+ color_map['by_name'][class_name] = {
182
+ 'rgb': rgb,
183
+ 'hex': hex_color,
184
+ 'category': category
185
+ }
186
+
187
+ return color_map
188
+
189
+ def get_color(self, class_identifier: Union[int, str], format: str = 'hex') -> Any:
190
+ """
191
+ Get color for a specific class
192
+
193
+ Args:
194
+ class_identifier: Class ID (int) or name (str)
195
+ format: Color format ('hex', 'rgb', or 'bgr')
196
+
197
+ Returns:
198
+ Color in requested format
199
+ """
200
+ # Determine if identifier is an ID or name
201
+ if isinstance(class_identifier, int):
202
+ color_info = self.color_map['by_id'].get(class_identifier)
203
+ else:
204
+ color_info = self.color_map['by_name'].get(class_identifier)
205
+
206
+ if not color_info:
207
+ # Fallback color if not found
208
+ return '#CCCCCC' if format == 'hex' else (204, 204, 204)
209
+
210
+ if format == 'hex':
211
+ return color_info['hex']
212
+ elif format == 'rgb':
213
+ return color_info['rgb']
214
+ elif format == 'bgr':
215
+ # Convert RGB to BGR for OpenCV
216
+ r, g, b = color_info['rgb']
217
+ return (b, g, r)
218
+ else:
219
+ return color_info['rgb']
220
+
221
+ def get_all_colors(self, format: str = 'hex') -> Dict:
222
+ """
223
+ Get all colors in the specified format
224
+
225
+ Args:
226
+ format: Color format ('hex', 'rgb', or 'bgr')
227
+
228
+ Returns:
229
+ Dictionary mapping class names to colors
230
+ """
231
+ result = {}
232
+ for class_id, class_name in self.class_names.items():
233
+ result[class_name] = self.get_color(class_id, format)
234
+ return result
235
+
236
+ def get_category_colors(self, format: str = 'hex') -> Dict:
237
+ """
238
+ Get base colors for each category
239
+
240
+ Args:
241
+ format: Color format ('hex', 'rgb', or 'bgr')
242
+
243
+ Returns:
244
+ Dictionary mapping categories to colors
245
+ """
246
+ result = {}
247
+ for category, color_info in self.color_map['categories'].items():
248
+ if format == 'hex':
249
+ result[category] = color_info['hex']
250
+ elif format == 'bgr':
251
+ r, g, b = color_info['rgb']
252
+ result[category] = (b, g, r)
253
+ else:
254
+ result[category] = color_info['rgb']
255
+ return result
256
+
257
+ def get_category_for_class(self, class_identifier: Union[int, str]) -> str:
258
+ """
259
+ Get the category for a specific class
260
+
261
+ Args:
262
+ class_identifier: Class ID (int) or name (str)
263
+
264
+ Returns:
265
+ Category name
266
+ """
267
+ if isinstance(class_identifier, int):
268
+ return self.color_map['by_id'].get(class_identifier, {}).get('category', 'other')
269
+ else:
270
+ return self.color_map['by_name'].get(class_identifier, {}).get('category', 'other')
detection_model.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ from typing import Any, List, Dict, Optional
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+
7
+ class DetectionModel:
8
+ """Core detection model class for object detection using YOLOv8"""
9
+
10
+ # Model information dictionary
11
+ MODEL_INFO = {
12
+ "yolov8n.pt": {
13
+ "name": "YOLOv8n (Nano)",
14
+ "description": "Fastest model with smallest size (3.2M parameters). Best for speed-critical applications.",
15
+ "size_mb": 6,
16
+ "inference_speed": "Very Fast"
17
+ },
18
+ "yolov8m.pt": {
19
+ "name": "YOLOv8m (Medium)",
20
+ "description": "Balanced model with good accuracy-speed tradeoff (25.9M parameters). Recommended for general use.",
21
+ "size_mb": 25,
22
+ "inference_speed": "Medium"
23
+ },
24
+ "yolov8x.pt": {
25
+ "name": "YOLOv8x (XLarge)",
26
+ "description": "Most accurate but slower model (68.2M parameters). Best for accuracy-critical applications.",
27
+ "size_mb": 68,
28
+ "inference_speed": "Slower"
29
+ }
30
+ }
31
+
32
+ def __init__(self, model_name: str = 'yolov8m.pt', confidence: float = 0.25, iou: float = 0.45):
33
+ """
34
+ Initialize the detection model
35
+
36
+ Args:
37
+ model_name: Model name or path, default is yolov8m.pt
38
+ confidence: Confidence threshold, default is 0.25
39
+ iou: IoU threshold for non-maximum suppression, default is 0.45
40
+ """
41
+ self.model_name = model_name
42
+ self.confidence = confidence
43
+ self.iou = iou
44
+ self.model = None
45
+ self.class_names = {}
46
+ self.is_model_loaded = False
47
+
48
+ # Load model on initialization
49
+ self._load_model()
50
+
51
+ def _load_model(self):
52
+ """Load the YOLO model"""
53
+ try:
54
+ print(f"Loading model: {self.model_name}")
55
+ self.model = YOLO(self.model_name)
56
+ self.class_names = self.model.names
57
+ self.is_model_loaded = True
58
+ print(f"Successfully loaded model: {self.model_name}")
59
+ print(f"Number of classes the model can recognize: {len(self.class_names)}")
60
+ except Exception as e:
61
+ print(f"Error occurred when loading the model: {e}")
62
+ self.is_model_loaded = False
63
+
64
+ def change_model(self, new_model_name: str) -> bool:
65
+ """
66
+ Change the currently loaded model
67
+
68
+ Args:
69
+ new_model_name: Name of the new model to load
70
+
71
+ Returns:
72
+ bool: True if model changed successfully, False otherwise
73
+ """
74
+ if self.model_name == new_model_name and self.is_model_loaded:
75
+ print(f"Model {new_model_name} is already loaded")
76
+ return True
77
+
78
+ print(f"Changing model from {self.model_name} to {new_model_name}")
79
+
80
+ # Unload current model to free memory
81
+ if self.model is not None:
82
+ del self.model
83
+ self.model = None
84
+
85
+ # Clean GPU memory if available
86
+ if torch.cuda.is_available():
87
+ torch.cuda.empty_cache()
88
+
89
+ # Update model name and load new model
90
+ self.model_name = new_model_name
91
+ self._load_model()
92
+
93
+ return self.is_model_loaded
94
+
95
+ def reload_model(self):
96
+ """Reload the model (useful for changing model or after error)"""
97
+ if self.model is not None:
98
+ del self.model
99
+ self.model = None
100
+
101
+ # Clean GPU memory if available
102
+ if torch.cuda.is_available():
103
+ torch.cuda.empty_cache()
104
+
105
+ self._load_model()
106
+
107
+ def detect(self, image_input: Any) -> Optional[Any]:
108
+ """
109
+ Perform object detection on a single image
110
+
111
+ Args:
112
+ image_input: Image path (str), PIL Image, or numpy array
113
+
114
+ Returns:
115
+ Detection result object or None if error occurred
116
+ """
117
+ if self.model is None or not self.is_model_loaded:
118
+ print("Model not found or not loaded. Attempting to reload...")
119
+ self._load_model()
120
+ if self.model is None or not self.is_model_loaded:
121
+ print("Failed to load model. Cannot perform detection.")
122
+ return None
123
+
124
+ try:
125
+ results = self.model(image_input, conf=self.confidence, iou=self.iou)
126
+ return results[0]
127
+ except Exception as e:
128
+ print(f"Error occurred during detection: {e}")
129
+ return None
130
+
131
+ def get_class_names(self, class_id: int) -> str:
132
+ """Get class name for a given class ID"""
133
+ return self.class_names.get(class_id, "Unknown Class")
134
+
135
+ def get_supported_classes(self) -> Dict[int, str]:
136
+ """Get all supported classes as a dictionary of {id: class_name}"""
137
+ return self.class_names
138
+
139
+ @classmethod
140
+ def get_available_models(cls) -> List[Dict]:
141
+ """
142
+ Get list of available models with their information
143
+
144
+ Returns:
145
+ List of dictionaries containing model information
146
+ """
147
+ models = []
148
+ for model_file, info in cls.MODEL_INFO.items():
149
+ models.append({
150
+ "model_file": model_file,
151
+ "name": info["name"],
152
+ "description": info["description"],
153
+ "size_mb": info["size_mb"],
154
+ "inference_speed": info["inference_speed"]
155
+ })
156
+ return models
157
+
158
+ @classmethod
159
+ def get_model_description(cls, model_name: str) -> str:
160
+ """Get description for a specific model"""
161
+ if model_name in cls.MODEL_INFO:
162
+ info = cls.MODEL_INFO[model_name]
163
+ return f"{info['name']}: {info['description']} (Size: ~{info['size_mb']}MB, Speed: {info['inference_speed']})"
164
+ return "Model information not available"
evaluation_metrics.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from typing import Dict, List, Any, Optional, Tuple
4
+
5
+ class EvaluationMetrics:
6
+ """Class for computing detection metrics, generating statistics and visualization data"""
7
+
8
+ @staticmethod
9
+ def calculate_basic_stats(result: Any) -> Dict:
10
+ """
11
+ Calculate basic statistics for a single detection result
12
+
13
+ Args:
14
+ result: Detection result object
15
+
16
+ Returns:
17
+ Dictionary with basic statistics
18
+ """
19
+ if result is None:
20
+ return {"error": "No detection result provided"}
21
+
22
+ # Get classes and confidences
23
+ classes = result.boxes.cls.cpu().numpy().astype(int)
24
+ confidences = result.boxes.conf.cpu().numpy()
25
+ names = result.names
26
+
27
+ # Count by class
28
+ class_counts = {}
29
+ for cls, conf in zip(classes, confidences):
30
+ cls_name = names[int(cls)]
31
+ if cls_name not in class_counts:
32
+ class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
33
+
34
+ class_counts[cls_name]["count"] += 1
35
+ class_counts[cls_name]["total_confidence"] += float(conf)
36
+ class_counts[cls_name]["confidences"].append(float(conf))
37
+
38
+ # Calculate average confidence
39
+ for cls_name, stats in class_counts.items():
40
+ if stats["count"] > 0:
41
+ stats["average_confidence"] = stats["total_confidence"] / stats["count"]
42
+ stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
43
+ stats.pop("total_confidence") # Remove intermediate calculation
44
+
45
+ # Prepare summary
46
+ stats = {
47
+ "total_objects": len(classes),
48
+ "class_statistics": class_counts,
49
+ "average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
50
+ }
51
+
52
+ return stats
53
+
54
+ @staticmethod
55
+ def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
56
+ """
57
+ Generate structured data suitable for visualization
58
+
59
+ Args:
60
+ result: Detection result object
61
+ class_colors: Dictionary mapping class names to color codes (optional)
62
+
63
+ Returns:
64
+ Dictionary with visualization-ready data
65
+ """
66
+ if result is None:
67
+ return {"error": "No detection result provided"}
68
+
69
+ # Get basic stats first
70
+ stats = EvaluationMetrics.calculate_basic_stats(result)
71
+
72
+ # Create visualization-specific data structure
73
+ viz_data = {
74
+ "total_objects": stats["total_objects"],
75
+ "average_confidence": stats["average_confidence"],
76
+ "class_data": []
77
+ }
78
+
79
+ # Sort classes by count (descending)
80
+ sorted_classes = sorted(
81
+ stats["class_statistics"].items(),
82
+ key=lambda x: x[1]["count"],
83
+ reverse=True
84
+ )
85
+
86
+ # Create class-specific visualization data
87
+ for cls_name, cls_stats in sorted_classes:
88
+ class_id = -1
89
+ # Find the class ID based on the name
90
+ for idx, name in result.names.items():
91
+ if name == cls_name:
92
+ class_id = idx
93
+ break
94
+
95
+ cls_data = {
96
+ "name": cls_name,
97
+ "class_id": class_id,
98
+ "count": cls_stats["count"],
99
+ "average_confidence": cls_stats.get("average_confidence", 0),
100
+ "confidence_std": cls_stats.get("confidence_std", 0),
101
+ "color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
102
+ }
103
+
104
+ viz_data["class_data"].append(cls_data)
105
+
106
+ return viz_data
107
+
108
+ @staticmethod
109
+ def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7),
110
+ max_classes: int = 30) -> plt.Figure:
111
+ """
112
+ Create a horizontal bar chart showing detection statistics
113
+
114
+ Args:
115
+ viz_data: Visualization data generated by generate_visualization_data
116
+ figsize: Figure size (width, height) in inches
117
+ max_classes: Maximum number of classes to display
118
+
119
+ Returns:
120
+ Matplotlib figure object
121
+ """
122
+ if "error" in viz_data:
123
+ # Create empty plot if error
124
+ fig, ax = plt.subplots(figsize=figsize)
125
+ ax.text(0.5, 0.5, viz_data["error"],
126
+ ha='center', va='center', fontsize=12)
127
+ ax.set_xlim(0, 1)
128
+ ax.set_ylim(0, 1)
129
+ ax.axis('off')
130
+ return fig
131
+
132
+ if "class_data" not in viz_data or not viz_data["class_data"]:
133
+ # Create empty plot if no data
134
+ fig, ax = plt.subplots(figsize=figsize)
135
+ ax.text(0.5, 0.5, "No detection data available",
136
+ ha='center', va='center', fontsize=12)
137
+ ax.set_xlim(0, 1)
138
+ ax.set_ylim(0, 1)
139
+ ax.axis('off')
140
+ return fig
141
+
142
+ # Limit to max_classes
143
+ class_data = viz_data["class_data"][:max_classes]
144
+
145
+ # Extract data for plotting
146
+ class_names = [item["name"] for item in class_data]
147
+ counts = [item["count"] for item in class_data]
148
+ colors = [item["color"] for item in class_data]
149
+
150
+ # Create figure and horizontal bar chart
151
+ fig, ax = plt.subplots(figsize=figsize)
152
+ y_pos = np.arange(len(class_names))
153
+
154
+ # Create horizontal bars with class-specific colors
155
+ bars = ax.barh(y_pos, counts, color=colors, alpha=0.8)
156
+
157
+ # Add count values at end of each bar
158
+ for i, bar in enumerate(bars):
159
+ width = bar.get_width()
160
+ conf = class_data[i]["average_confidence"]
161
+ ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
162
+ f"{width:.0f} (conf: {conf:.2f})",
163
+ va='center', fontsize=9)
164
+
165
+ # Customize axis and labels
166
+ ax.set_yticks(y_pos)
167
+ ax.set_yticklabels(class_names)
168
+ ax.invert_yaxis() # Labels read top-to-bottom
169
+ ax.set_xlabel('Count')
170
+ ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total')
171
+
172
+ # Add grid for better readability
173
+ ax.set_axisbelow(True)
174
+ ax.grid(axis='x', linestyle='--', alpha=0.7)
175
+
176
+ # Add detection summary as a text box
177
+ summary_text = (
178
+ f"Total Objects: {viz_data['total_objects']}\n"
179
+ f"Average Confidence: {viz_data['average_confidence']:.2f}\n"
180
+ f"Unique Classes: {len(viz_data['class_data'])}"
181
+ )
182
+ plt.figtext(0.02, 0.02, summary_text, fontsize=9,
183
+ bbox=dict(facecolor='white', alpha=0.8, boxstyle='round'))
184
+
185
+ plt.tight_layout()
186
+ return fig
187
+
188
+ @staticmethod
189
+ def format_detection_summary(viz_data: Dict) -> str:
190
+ """
191
+ Format detection results as a readable text summary
192
+ """
193
+ if "error" in viz_data:
194
+ return viz_data["error"]
195
+
196
+ if "total_objects" not in viz_data:
197
+ return "No detection data available."
198
+
199
+ # 移除時間顯示
200
+ total_objects = viz_data["total_objects"]
201
+ avg_confidence = viz_data["average_confidence"]
202
+
203
+ # 創建標題
204
+ lines = [
205
+ f"Detected {total_objects} objects.",
206
+ f"Average confidence: {avg_confidence:.2f}",
207
+ "",
208
+ "Objects by class:",
209
+ ]
210
+
211
+ # 添加類別詳情
212
+ if "class_data" in viz_data and viz_data["class_data"]:
213
+ for item in viz_data["class_data"]:
214
+ lines.append(
215
+ f"• {item['name']}: {item['count']} (avg conf: {item['average_confidence']:.2f})"
216
+ )
217
+ else:
218
+ lines.append("No class information available.")
219
+
220
+ return "\n".join(lines)
221
+
222
+ @staticmethod
223
+ def calculate_distance_metrics(result: Any) -> Dict:
224
+ """
225
+ Calculate distance-related metrics for detected objects
226
+
227
+ Args:
228
+ result: Detection result object
229
+
230
+ Returns:
231
+ Dictionary with distance metrics
232
+ """
233
+ if result is None:
234
+ return {"error": "No detection result provided"}
235
+
236
+ boxes = result.boxes.xyxy.cpu().numpy()
237
+ classes = result.boxes.cls.cpu().numpy().astype(int)
238
+ names = result.names
239
+
240
+ # Initialize metrics
241
+ metrics = {
242
+ "proximity": {}, # Classes that appear close to each other
243
+ "spatial_distribution": {}, # Distribution across the image
244
+ "size_distribution": {} # Size distribution of objects
245
+ }
246
+
247
+ # Calculate image dimensions (assuming normalized coordinates or extract from result)
248
+ img_width, img_height = 1, 1
249
+ if hasattr(result, "orig_shape"):
250
+ img_height, img_width = result.orig_shape[:2]
251
+
252
+ # Calculate bounding box areas and centers
253
+ areas = []
254
+ centers = []
255
+ class_names = []
256
+
257
+ for box, cls in zip(boxes, classes):
258
+ x1, y1, x2, y2 = box
259
+ width, height = x2 - x1, y2 - y1
260
+ area = width * height
261
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
262
+
263
+ areas.append(area)
264
+ centers.append((center_x, center_y))
265
+ class_names.append(names[int(cls)])
266
+
267
+ # Calculate spatial distribution
268
+ if centers:
269
+ x_coords = [c[0] for c in centers]
270
+ y_coords = [c[1] for c in centers]
271
+
272
+ metrics["spatial_distribution"] = {
273
+ "x_mean": float(np.mean(x_coords)) / img_width,
274
+ "y_mean": float(np.mean(y_coords)) / img_height,
275
+ "x_std": float(np.std(x_coords)) / img_width,
276
+ "y_std": float(np.std(y_coords)) / img_height
277
+ }
278
+
279
+ # Calculate size distribution
280
+ if areas:
281
+ metrics["size_distribution"] = {
282
+ "mean_area": float(np.mean(areas)) / (img_width * img_height),
283
+ "std_area": float(np.std(areas)) / (img_width * img_height),
284
+ "min_area": float(np.min(areas)) / (img_width * img_height),
285
+ "max_area": float(np.max(areas)) / (img_width * img_height)
286
+ }
287
+
288
+ # Calculate proximity between different classes
289
+ class_centers = {}
290
+ for cls_name, center in zip(class_names, centers):
291
+ if cls_name not in class_centers:
292
+ class_centers[cls_name] = []
293
+ class_centers[cls_name].append(center)
294
+
295
+ # Find classes that appear close to each other
296
+ proximity_pairs = []
297
+ for i, cls1 in enumerate(class_centers.keys()):
298
+ for j, cls2 in enumerate(class_centers.keys()):
299
+ if i >= j: # Avoid duplicate pairs and self-comparison
300
+ continue
301
+
302
+ # Calculate minimum distance between any two objects of these classes
303
+ min_distance = float('inf')
304
+ for center1 in class_centers[cls1]:
305
+ for center2 in class_centers[cls2]:
306
+ dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
307
+ min_distance = min(min_distance, dist)
308
+
309
+ # Normalize by image diagonal
310
+ img_diagonal = np.sqrt(img_width**2 + img_height**2)
311
+ norm_distance = min_distance / img_diagonal
312
+
313
+ proximity_pairs.append({
314
+ "class1": cls1,
315
+ "class2": cls2,
316
+ "distance": float(norm_distance)
317
+ })
318
+
319
+ # Sort by distance and keep the closest pairs
320
+ proximity_pairs.sort(key=lambda x: x["distance"])
321
+ metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
322
+
323
+ return metrics
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ ultralytics>=8.0.0
4
+ opencv-python>=4.7.0
5
+ pillow>=9.4.0
6
+ numpy>=1.23.5
7
+ matplotlib>=3.7.0
8
+ gradio>=3.32.0
visualization_helper.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from typing import Any, List, Dict, Tuple, Optional
5
+ import io
6
+ from PIL import Image
7
+
8
+ class VisualizationHelper:
9
+ """Helper class for visualizing detection results"""
10
+
11
+ @staticmethod
12
+ def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
13
+ figsize: Tuple[int, int] = (12, 12),
14
+ return_pil: bool = False) -> Optional[Image.Image]:
15
+ """
16
+ Visualize detection results on a single image
17
+
18
+ Args:
19
+ image: Image path or numpy array
20
+ result: Detection result object
21
+ color_mapper: ColorMapper instance for consistent colors
22
+ figsize: Figure size
23
+ return_pil: If True, returns a PIL Image object
24
+
25
+ Returns:
26
+ PIL Image if return_pil is True, otherwise displays the plot
27
+ """
28
+ if result is None:
29
+ print('No data for visualization')
30
+ return None
31
+
32
+ # Read image if path is provided
33
+ if isinstance(image, str):
34
+ img = cv2.imread(image)
35
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36
+ else:
37
+ img = image
38
+ if len(img.shape) == 3 and img.shape[2] == 3:
39
+ # Check if BGR format (OpenCV) and convert to RGB if needed
40
+ if isinstance(img, np.ndarray):
41
+ # Assuming BGR format from OpenCV
42
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
43
+
44
+ # Create figure
45
+ fig, ax = plt.subplots(figsize=figsize)
46
+ ax.imshow(img)
47
+
48
+ # Get bounding boxes, classes and confidences
49
+ boxes = result.boxes.xyxy.cpu().numpy()
50
+ classes = result.boxes.cls.cpu().numpy()
51
+ confs = result.boxes.conf.cpu().numpy()
52
+
53
+ # Get class names
54
+ names = result.names
55
+
56
+ # Create a default color mapper if none is provided
57
+ if color_mapper is None:
58
+ # For backward compatibility, fallback to a simple color function
59
+ from matplotlib import colormaps
60
+ cmap = colormaps['tab10']
61
+ def get_color(class_id):
62
+ return cmap(class_id % 10)
63
+ else:
64
+ # Use the provided color mapper
65
+ def get_color(class_id):
66
+ hex_color = color_mapper.get_color(class_id)
67
+ # Convert hex to RGB float values for matplotlib
68
+ hex_color = hex_color.lstrip('#')
69
+ return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
70
+
71
+ # Draw detection results
72
+ for box, cls, conf in zip(boxes, classes, confs):
73
+ x1, y1, x2, y2 = box
74
+ cls_id = int(cls)
75
+ cls_name = names[cls_id]
76
+
77
+ # Get color for this class
78
+ box_color = get_color(cls_id)
79
+
80
+ # Add text label with colored background
81
+ ax.text(x1, y1 - 5, f'{cls_name}: {conf:.2f}',
82
+ color='white', fontsize=10,
83
+ bbox=dict(facecolor=box_color[:3], alpha=0.7))
84
+
85
+ # Add bounding box
86
+ ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
87
+ fill=False, edgecolor=box_color[:3], linewidth=2))
88
+
89
+ ax.axis('off')
90
+ # ax.set_title('Detection Result')
91
+ plt.tight_layout()
92
+
93
+ if return_pil:
94
+ # Convert plot to PIL Image
95
+ buf = io.BytesIO()
96
+ fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
97
+ buf.seek(0)
98
+ pil_img = Image.open(buf)
99
+ plt.close(fig)
100
+ return pil_img
101
+ else:
102
+ plt.show()
103
+ return None
104
+
105
+ @staticmethod
106
+ def create_summary(result: Any) -> Dict:
107
+ """
108
+ Create a summary of detection results
109
+
110
+ Args:
111
+ result: Detection result object
112
+
113
+ Returns:
114
+ Dictionary with detection summary statistics
115
+ """
116
+ if result is None:
117
+ return {"error": "No detection result provided"}
118
+
119
+ # Get classes and confidences
120
+ classes = result.boxes.cls.cpu().numpy().astype(int)
121
+ confidences = result.boxes.conf.cpu().numpy()
122
+ names = result.names
123
+
124
+ # Count detections by class
125
+ class_counts = {}
126
+ for cls, conf in zip(classes, confidences):
127
+ cls_name = names[int(cls)]
128
+ if cls_name not in class_counts:
129
+ class_counts[cls_name] = {"count": 0, "confidences": []}
130
+
131
+ class_counts[cls_name]["count"] += 1
132
+ class_counts[cls_name]["confidences"].append(float(conf))
133
+
134
+ # Calculate average confidence for each class
135
+ for cls_name, stats in class_counts.items():
136
+ if stats["confidences"]:
137
+ stats["average_confidence"] = float(np.mean(stats["confidences"]))
138
+ stats.pop("confidences") # Remove detailed confidences list to keep summary concise
139
+
140
+ # Prepare summary
141
+ summary = {
142
+ "total_objects": len(classes),
143
+ "class_counts": class_counts,
144
+ "unique_classes": len(class_counts)
145
+ }
146
+
147
+ return summary