DawnC commited on
Commit
611206a
·
verified ·
1 Parent(s): c697df2

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +598 -0
  2. color_mapper.py +270 -0
  3. detection_model.py +164 -0
  4. evaluation_metrics.py +360 -0
  5. style.py +282 -0
  6. visualization_helper.py +147 -0
app.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import gradio as gr
7
+ import io
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import spaces
10
+ from typing import Dict, List, Any, Optional, Tuple
11
+ from ultralytics import YOLO
12
+
13
+ from detection_model import DetectionModel
14
+ from color_mapper import ColorMapper
15
+ from visualization_helper import VisualizationHelper
16
+ from evaluation_metrics import EvaluationMetrics
17
+ from style import Style
18
+
19
+
20
+ color_mapper = ColorMapper()
21
+ model_instances = {}
22
+
23
+ @spaces.GPU
24
+ def process_image(image, model_instance, confidence_threshold, filter_classes=None):
25
+ """
26
+ Process an image for object detection
27
+
28
+ Args:
29
+ image: Input image (numpy array or PIL Image)
30
+ model_instance: DetectionModel instance to use
31
+ confidence_threshold: Confidence threshold for detection
32
+ filter_classes: Optional list of classes to filter results
33
+
34
+ Returns:
35
+ Tuple of (result_image, result_text, stats_data)
36
+ """
37
+ # initialize key variables
38
+ result = None
39
+ stats = {}
40
+ temp_path = None
41
+
42
+ try:
43
+ # update confidence threshold
44
+ model_instance.confidence = confidence_threshold
45
+
46
+ # processing input image
47
+ if isinstance(image, np.ndarray):
48
+ # Convert BGR to RGB if needed
49
+ if image.shape[2] == 3:
50
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
+ else:
52
+ image_rgb = image
53
+ pil_image = Image.fromarray(image_rgb)
54
+ elif image is None:
55
+ return None, "No image provided. Please upload an image.", {}
56
+ else:
57
+ pil_image = image
58
+
59
+ # store temp files
60
+ import uuid
61
+ import tempfile
62
+
63
+ temp_dir = tempfile.gettempdir() # use system temp directory
64
+ temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
65
+ temp_path = os.path.join(temp_dir, temp_filename)
66
+ pil_image.save(temp_path)
67
+
68
+ # object detection
69
+ result = model_instance.detect(temp_path)
70
+
71
+ if result is None:
72
+ return None, "Detection failed. Please try again with a different image.", {}
73
+
74
+ # calculate stats
75
+ stats = EvaluationMetrics.calculate_basic_stats(result)
76
+
77
+ # add space calculation
78
+ spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
79
+ stats["spatial_metrics"] = spatial_metrics
80
+
81
+ if filter_classes and len(filter_classes) > 0:
82
+ # get classes, boxes, confidence
83
+ classes = result.boxes.cls.cpu().numpy().astype(int)
84
+ confs = result.boxes.conf.cpu().numpy()
85
+ boxes = result.boxes.xyxy.cpu().numpy()
86
+
87
+ mask = np.zeros_like(classes, dtype=bool)
88
+ for cls_id in filter_classes:
89
+ mask = np.logical_or(mask, classes == cls_id)
90
+
91
+ filtered_stats = {
92
+ "total_objects": int(np.sum(mask)),
93
+ "class_statistics": {},
94
+ "average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
95
+ "spatial_metrics": stats["spatial_metrics"]
96
+ }
97
+
98
+ # update stats
99
+ names = result.names
100
+ for cls, conf in zip(classes[mask], confs[mask]):
101
+ cls_name = names[int(cls)]
102
+ if cls_name not in filtered_stats["class_statistics"]:
103
+ filtered_stats["class_statistics"][cls_name] = {
104
+ "count": 0,
105
+ "average_confidence": 0
106
+ }
107
+
108
+ filtered_stats["class_statistics"][cls_name]["count"] += 1
109
+ filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
110
+
111
+ stats = filtered_stats
112
+
113
+ viz_data = EvaluationMetrics.generate_visualization_data(
114
+ result,
115
+ color_mapper.get_all_colors()
116
+ )
117
+
118
+ result_image = VisualizationHelper.visualize_detection(
119
+ temp_path, result, color_mapper=color_mapper, figsize=(12, 12), return_pil=True
120
+ )
121
+
122
+ result_text = EvaluationMetrics.format_detection_summary(viz_data)
123
+
124
+ return result_image, result_text, stats
125
+
126
+ except Exception as e:
127
+ error_message = f"Error Occurs: {str(e)}"
128
+ import traceback
129
+ traceback.print_exc()
130
+ print(error_message)
131
+ return None, error_message, {}
132
+
133
+ finally:
134
+ if temp_path and os.path.exists(temp_path):
135
+ try:
136
+ os.remove(temp_path)
137
+ except Exception as e:
138
+ print(f"Cannot delete temp files {temp_path}: {str(e)}")
139
+
140
+ def format_result_text(stats):
141
+ """
142
+ Format detection statistics into readable text with improved spacing
143
+
144
+ Args:
145
+ stats: Dictionary containing detection statistics
146
+
147
+ Returns:
148
+ Formatted text summary
149
+ """
150
+ if not stats or "total_objects" not in stats:
151
+ return "No objects detected."
152
+
153
+ # 減少不必要的空行
154
+ lines = [
155
+ f"Detected {stats['total_objects']} objects.",
156
+ f"Average confidence: {stats.get('average_confidence', 0):.2f}",
157
+ "Objects by class:"
158
+ ]
159
+
160
+ if "class_statistics" in stats and stats["class_statistics"]:
161
+ # 按計數排序類別
162
+ sorted_classes = sorted(
163
+ stats["class_statistics"].items(),
164
+ key=lambda x: x[1]["count"],
165
+ reverse=True
166
+ )
167
+
168
+ for cls_name, cls_stats in sorted_classes:
169
+ count = cls_stats["count"]
170
+ conf = cls_stats.get("average_confidence", 0)
171
+
172
+ item_text = "item" if count == 1 else "items"
173
+ lines.append(f"• {cls_name}: {count} {item_text} (avg conf: {conf:.2f})")
174
+ else:
175
+ lines.append("No class information available.")
176
+
177
+ # 添加空間信息
178
+ if "spatial_metrics" in stats and "spatial_distribution" in stats["spatial_metrics"]:
179
+ lines.append("Object Distribution:")
180
+
181
+ dist = stats["spatial_metrics"]["spatial_distribution"]
182
+ x_mean = dist.get("x_mean", 0)
183
+ y_mean = dist.get("y_mean", 0)
184
+
185
+ # 描述物體的大致位置
186
+ if x_mean < 0.33:
187
+ h_pos = "on the left side"
188
+ elif x_mean < 0.67:
189
+ h_pos = "in the center"
190
+ else:
191
+ h_pos = "on the right side"
192
+
193
+ if y_mean < 0.33:
194
+ v_pos = "in the upper part"
195
+ elif y_mean < 0.67:
196
+ v_pos = "in the middle"
197
+ else:
198
+ v_pos = "in the lower part"
199
+
200
+ lines.append(f"• Most objects appear {h_pos} {v_pos} of the image")
201
+
202
+ return "\n".join(lines)
203
+
204
+ def format_json_for_display(stats):
205
+ """
206
+ Format statistics JSON for better display
207
+
208
+ Args:
209
+ stats: Raw statistics dictionary
210
+
211
+ Returns:
212
+ Formatted statistics structure for display
213
+ """
214
+ # Create a cleaner copy of the stats for display
215
+ display_stats = {}
216
+
217
+ # Add summary section
218
+ display_stats["summary"] = {
219
+ "total_objects": stats.get("total_objects", 0),
220
+ "average_confidence": round(stats.get("average_confidence", 0), 3)
221
+ }
222
+
223
+ # Add class statistics in a more organized way
224
+ if "class_statistics" in stats and stats["class_statistics"]:
225
+ # Sort classes by count (descending)
226
+ sorted_classes = sorted(
227
+ stats["class_statistics"].items(),
228
+ key=lambda x: x[1].get("count", 0),
229
+ reverse=True
230
+ )
231
+
232
+ class_stats = {}
233
+ for cls_name, cls_data in sorted_classes:
234
+ class_stats[cls_name] = {
235
+ "count": cls_data.get("count", 0),
236
+ "average_confidence": round(cls_data.get("average_confidence", 0), 3)
237
+ }
238
+
239
+ display_stats["detected_objects"] = class_stats
240
+
241
+ # Simplify spatial metrics
242
+ if "spatial_metrics" in stats:
243
+ spatial = stats["spatial_metrics"]
244
+
245
+ # Simplify spatial distribution
246
+ if "spatial_distribution" in spatial:
247
+ dist = spatial["spatial_distribution"]
248
+ display_stats["spatial"] = {
249
+ "distribution": {
250
+ "x_mean": round(dist.get("x_mean", 0), 3),
251
+ "y_mean": round(dist.get("y_mean", 0), 3),
252
+ "x_std": round(dist.get("x_std", 0), 3),
253
+ "y_std": round(dist.get("y_std", 0), 3)
254
+ }
255
+ }
256
+
257
+ # Add simplified size information
258
+ if "size_distribution" in spatial:
259
+ size = spatial["size_distribution"]
260
+ display_stats["spatial"]["size"] = {
261
+ "mean_area": round(size.get("mean_area", 0), 3),
262
+ "min_area": round(size.get("min_area", 0), 3),
263
+ "max_area": round(size.get("max_area", 0), 3)
264
+ }
265
+
266
+ return display_stats
267
+
268
+ def get_all_classes():
269
+ """
270
+ Get all available COCO classes from the currently active model or fallback to standard COCO classes
271
+
272
+ Returns:
273
+ List of tuples (class_id, class_name)
274
+ """
275
+ global model_instances
276
+
277
+ # Try to get class names from any loaded model
278
+ for model_name, model_instance in model_instances.items():
279
+ if model_instance and model_instance.is_model_loaded:
280
+ try:
281
+ class_names = model_instance.class_names
282
+ return [(idx, name) for idx, name in class_names.items()]
283
+ except Exception:
284
+ pass
285
+
286
+ # Fallback to standard COCO classes
287
+ return [
288
+ (0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'),
289
+ (5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'),
290
+ (10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'),
291
+ (14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'),
292
+ (20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'),
293
+ (25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'),
294
+ (30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'),
295
+ (35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'),
296
+ (39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'),
297
+ (44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'),
298
+ (49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'),
299
+ (54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'),
300
+ (59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'),
301
+ (64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'),
302
+ (69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'),
303
+ (74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'),
304
+ (79, 'toothbrush')
305
+ ]
306
+
307
+ def create_interface():
308
+ """創建 Gradio 界面,包含美化的視覺效果"""
309
+ css = Style.get_css()
310
+
311
+ # 獲取可用模型信息
312
+ available_models = DetectionModel.get_available_models()
313
+ model_choices = [model["model_file"] for model in available_models]
314
+ model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models]
315
+
316
+ # 可用類別過濾選項
317
+ available_classes = get_all_classes()
318
+ class_choices = [f"{id}: {name}" for id, name in available_classes]
319
+
320
+ # 創建 Gradio Blocks 界面
321
+ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="teal", secondary_hue="blue")) as demo:
322
+ # 頁面頂部標題
323
+ with gr.Group(elem_classes="app-header"):
324
+ gr.HTML("""
325
+ <div style="text-align: center; width: 100%;">
326
+ <h1 class="app-title">VisionScout</h1>
327
+ <h2 class="app-subtitle">Detect and identify objects in your images</h2>
328
+ <div class="app-divider"></div>
329
+ </div>
330
+ """)
331
+
332
+ current_model = gr.State("yolov8m.pt") # use medium size model as defualt
333
+
334
+ # 主要內容區 - 輸入和輸出面板
335
+ with gr.Row(equal_height=True):
336
+ # 左側 - 輸入控制區(可上傳圖片)
337
+ with gr.Column(scale=4, elem_classes="input-panel"):
338
+ with gr.Group():
339
+ gr.HTML('<div class="section-heading">Upload Image</div>')
340
+ image_input = gr.Image(type="pil", label="Upload an image", elem_classes="upload-box")
341
+
342
+ with gr.Accordion("Advanced Settings", open=False):
343
+ with gr.Row():
344
+ model_dropdown = gr.Dropdown(
345
+ choices=model_choices,
346
+ value="yolov8m.pt",
347
+ label="Select Model",
348
+ info="Choose different models based on your needs for speed vs. accuracy"
349
+ )
350
+
351
+ # display model info
352
+ model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt"))
353
+
354
+ confidence = gr.Slider(
355
+ minimum=0.1,
356
+ maximum=0.9,
357
+ value=0.25,
358
+ step=0.05,
359
+ label="Confidence Threshold",
360
+ info="Higher values show fewer but more confident detections"
361
+ )
362
+
363
+ with gr.Accordion("Filter Classes", open=False):
364
+ # 常見物件類別快速選擇按鈕
365
+ gr.HTML('<div class="section-heading" style="font-size: 1rem;">Common Categories</div>')
366
+ with gr.Row():
367
+ people_btn = gr.Button("People", size="sm")
368
+ vehicles_btn = gr.Button("Vehicles", size="sm")
369
+ animals_btn = gr.Button("Animals", size="sm")
370
+ objects_btn = gr.Button("Common Objects", size="sm")
371
+
372
+ # 類別選擇下拉框
373
+ class_filter = gr.Dropdown(
374
+ choices=class_choices,
375
+ multiselect=True,
376
+ label="Select Classes to Display",
377
+ info="Leave empty to show all detected objects"
378
+ )
379
+
380
+ # detect buttom
381
+ detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn")
382
+
383
+ # 使用說明區
384
+ with gr.Group(elem_classes="how-to-use"):
385
+ gr.HTML('<div class="section-heading">How to Use</div>')
386
+ gr.Markdown("""
387
+ 1. Upload an image or use the camera
388
+ 2. Adjust confidence threshold if needed
389
+ 3. Optionally filter to specific object classes
390
+ 4. Click "Detect Objects" button
391
+
392
+ The model will identify objects in your image and display them with bounding boxes.
393
+
394
+ **Note:** Detection quality depends on image clarity and object visibility. The model can detect up to 80 different types of common objects.
395
+ """)
396
+
397
+ # 右側 - 結果顯示區
398
+ with gr.Column(scale=6, elem_classes="output-panel"):
399
+ with gr.Tabs(elem_classes="tabs"):
400
+ with gr.Tab("Detection Result"):
401
+ result_image = gr.Image(type="pil", label="Detection Result")
402
+
403
+ # 文本框的格式
404
+ with gr.Group(elem_classes="result-details-box"):
405
+ gr.HTML('<div class="section-heading">Detection Details</div>')
406
+ # 文本框設置,讓顯示會更寬
407
+ result_text = gr.Textbox(
408
+ label=None,
409
+ lines=12,
410
+ max_lines=15,
411
+ elem_classes="wide-result-text",
412
+ elem_id="detection-details",
413
+ container=False,
414
+ scale=2,
415
+ min_width=600
416
+ )
417
+
418
+ with gr.Tab("Statistics"):
419
+ with gr.Row():
420
+ with gr.Column(scale=3, elem_classes="plot-column"):
421
+ gr.HTML('<div class="section-heading">Object Distribution</div>')
422
+ plot_output = gr.Plot(
423
+ label=None,
424
+ elem_classes="large-plot-container"
425
+ )
426
+
427
+ # 右側放 JSON 數據比較清晰
428
+ with gr.Column(scale=2, elem_classes="stats-column"):
429
+ gr.HTML('<div class="section-heading">Detection Statistics</div>')
430
+ stats_json = gr.JSON(
431
+ label=None, # remove label
432
+ elem_classes="enhanced-json-display"
433
+ )
434
+
435
+ detect_btn.click(
436
+ fn=lambda img, model, conf, classes: process_and_plot(img, model, conf, classes),
437
+ inputs=[image_input, current_model, confidence, class_filter],
438
+ outputs=[result_image, result_text, stats_json, plot_output]
439
+ )
440
+
441
+ # model option
442
+ model_dropdown.change(
443
+ fn=lambda model: (model, DetectionModel.get_model_description(model)),
444
+ inputs=[model_dropdown],
445
+ outputs=[current_model, model_info]
446
+ )
447
+
448
+ # each classes link
449
+ people_classes = [0] # 人
450
+ vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # 各種車輛
451
+ animals_classes = list(range(14, 24)) # COCO 中的動物
452
+ common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # 常見家居物品
453
+
454
+ # Linked the quik buttom
455
+ people_btn.click(
456
+ lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes],
457
+ outputs=class_filter
458
+ )
459
+
460
+ vehicles_btn.click(
461
+ lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes],
462
+ outputs=class_filter
463
+ )
464
+
465
+ animals_btn.click(
466
+ lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes],
467
+ outputs=class_filter
468
+ )
469
+
470
+ objects_btn.click(
471
+ lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects],
472
+ outputs=class_filter
473
+ )
474
+
475
+ example_images = [
476
+ "room_01.jpg",
477
+ "street_01.jpg",
478
+ "street_02.jpg",
479
+ "street_03.jpg"
480
+ ]
481
+
482
+ # add example images
483
+ gr.Examples(
484
+ examples=example_images,
485
+ inputs=image_input,
486
+ outputs=None,
487
+ fn=None,
488
+ cache_examples=False,
489
+ )
490
+
491
+ # 頁腳部分
492
+ gr.HTML("""
493
+ <div class="footer">
494
+ <p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p>
495
+ <p>Model can detect 80 different classes of objects</p>
496
+ </div>
497
+ """)
498
+
499
+ return demo
500
+
501
+ @spaces.GPU
502
+ def process_and_plot(image, model_name, confidence_threshold, filter_classes=None):
503
+ """
504
+ Process image and create plots for statistics with enhanced visualization
505
+
506
+ Args:
507
+ image: Input image
508
+ model_name: Name of the model to use
509
+ confidence_threshold: Confidence threshold for detection
510
+ filter_classes: Optional list of classes to filter results
511
+
512
+ Returns:
513
+ Tuple of (result_image, result_text, formatted_stats, plot_figure)
514
+ """
515
+ global model_instances
516
+
517
+ if model_name not in model_instances:
518
+ print(f"Creating new model instance for {model_name}")
519
+ model_instances[model_name] = DetectionModel(model_name=model_name, confidence=confidence_threshold, iou=0.45)
520
+ else:
521
+ print(f"Using existing model instance for {model_name}")
522
+ model_instances[model_name].confidence = confidence_threshold
523
+
524
+ class_ids = None
525
+ if filter_classes:
526
+ class_ids = []
527
+ for class_str in filter_classes:
528
+ try:
529
+ # Extract ID from format "id: name"
530
+ class_id = int(class_str.split(":")[0].strip())
531
+ class_ids.append(class_id)
532
+ except:
533
+ continue
534
+
535
+ # Execute detection
536
+ result_image, result_text, stats = process_image(
537
+ image,
538
+ model_instances[model_name],
539
+ confidence_threshold,
540
+ class_ids
541
+ )
542
+
543
+ # Format the statistics for better display
544
+ formatted_stats = format_json_for_display(stats)
545
+
546
+ if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
547
+ # Create the table
548
+ fig, ax = plt.subplots(figsize=(8, 6))
549
+ ax.text(0.5, 0.5, "No detection data available",
550
+ ha='center', va='center', fontsize=14, fontfamily='Arial')
551
+ ax.set_xlim(0, 1)
552
+ ax.set_ylim(0, 1)
553
+ ax.axis('off')
554
+ plot_figure = fig
555
+ else:
556
+ # prepare visualization data
557
+ viz_data = {
558
+ "total_objects": stats.get("total_objects", 0),
559
+ "average_confidence": stats.get("average_confidence", 0),
560
+ "class_data": []
561
+ }
562
+
563
+ # get the color map
564
+ color_mapper_instance = ColorMapper()
565
+
566
+ # class data
567
+ available_classes = dict(get_all_classes())
568
+ for cls_name, cls_stats in stats.get("class_statistics", {}).items():
569
+ # search class ID
570
+ class_id = -1
571
+ for id, name in available_classes.items():
572
+ if name == cls_name:
573
+ class_id = id
574
+ break
575
+
576
+ cls_data = {
577
+ "name": cls_name,
578
+ "class_id": class_id,
579
+ "count": cls_stats.get("count", 0),
580
+ "average_confidence": cls_stats.get("average_confidence", 0),
581
+ "color": color_mapper_instance.get_color(class_id if class_id >= 0 else cls_name)
582
+ }
583
+
584
+ viz_data["class_data"].append(cls_data)
585
+
586
+ # descending order
587
+ viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
588
+
589
+ plot_figure = EvaluationMetrics.create_enhanced_stats_plot(viz_data)
590
+
591
+ return result_image, result_text, formatted_stats, plot_figure
592
+
593
+
594
+ if __name__ == "__main__":
595
+ import time
596
+
597
+ demo = create_interface()
598
+ demo.launch()
color_mapper.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict, List, Tuple, Union, Any
3
+
4
+ class ColorMapper:
5
+ """
6
+ A class for consistent color mapping of object detection classes
7
+ Provides color schemes for visualization in both RGB and hex formats
8
+ """
9
+
10
+ # Class categories for better organization
11
+ CATEGORIES = {
12
+ "person": [0],
13
+ "vehicles": [1, 2, 3, 4, 5, 6, 7, 8],
14
+ "traffic": [9, 10, 11, 12],
15
+ "animals": [14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
16
+ "outdoor": [13, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33],
17
+ "sports": [34, 35, 36, 37, 38],
18
+ "kitchen": [39, 40, 41, 42, 43, 44, 45],
19
+ "food": [46, 47, 48, 49, 50, 51, 52, 53, 54, 55],
20
+ "furniture": [56, 57, 58, 59, 60, 61],
21
+ "electronics": [62, 63, 64, 65, 66, 67, 68, 69, 70],
22
+ "household": [71, 72, 73, 74, 75, 76, 77, 78, 79]
23
+ }
24
+
25
+ # Base colors for each category (in HSV for easier variation)
26
+ CATEGORY_COLORS = {
27
+ "person": (0, 0.8, 0.9), # Red
28
+ "vehicles": (210, 0.8, 0.9), # Blue
29
+ "traffic": (45, 0.8, 0.9), # Orange
30
+ "animals": (120, 0.7, 0.8), # Green
31
+ "outdoor": (180, 0.7, 0.9), # Cyan
32
+ "sports": (270, 0.7, 0.8), # Purple
33
+ "kitchen": (30, 0.7, 0.9), # Light Orange
34
+ "food": (330, 0.7, 0.85), # Pink
35
+ "furniture": (150, 0.5, 0.85), # Light Green
36
+ "electronics": (240, 0.6, 0.9), # Light Blue
37
+ "household": (60, 0.6, 0.9) # Yellow
38
+ }
39
+
40
+ def __init__(self):
41
+ """Initialize the ColorMapper with COCO class mappings"""
42
+ self.class_names = self._get_coco_classes()
43
+ self.color_map = self._generate_color_map()
44
+
45
+ def _get_coco_classes(self) -> Dict[int, str]:
46
+ """Get the standard COCO class names with their IDs"""
47
+ return {
48
+ 0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
49
+ 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
50
+ 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
51
+ 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
52
+ 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
53
+ 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
54
+ 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
55
+ 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket',
56
+ 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife',
57
+ 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich',
58
+ 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza',
59
+ 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant',
60
+ 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop',
61
+ 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave',
62
+ 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book',
63
+ 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier',
64
+ 79: 'toothbrush'
65
+ }
66
+
67
+ def _hsv_to_rgb(self, h: float, s: float, v: float) -> Tuple[int, int, int]:
68
+ """
69
+ Convert HSV color to RGB
70
+
71
+ Args:
72
+ h: Hue (0-360)
73
+ s: Saturation (0-1)
74
+ v: Value (0-1)
75
+
76
+ Returns:
77
+ Tuple of (R, G, B) values (0-255)
78
+ """
79
+ h = h / 60
80
+ i = int(h)
81
+ f = h - i
82
+ p = v * (1 - s)
83
+ q = v * (1 - s * f)
84
+ t = v * (1 - s * (1 - f))
85
+
86
+ if i == 0:
87
+ r, g, b = v, t, p
88
+ elif i == 1:
89
+ r, g, b = q, v, p
90
+ elif i == 2:
91
+ r, g, b = p, v, t
92
+ elif i == 3:
93
+ r, g, b = p, q, v
94
+ elif i == 4:
95
+ r, g, b = t, p, v
96
+ else:
97
+ r, g, b = v, p, q
98
+
99
+ return (int(r * 255), int(g * 255), int(b * 255))
100
+
101
+ def _rgb_to_hex(self, rgb: Tuple[int, int, int]) -> str:
102
+ """
103
+ Convert RGB color to hex color code
104
+
105
+ Args:
106
+ rgb: Tuple of (R, G, B) values (0-255)
107
+
108
+ Returns:
109
+ Hex color code (e.g. '#FF0000')
110
+ """
111
+ return f'#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}'
112
+
113
+ def _find_category(self, class_id: int) -> str:
114
+ """
115
+ Find the category for a given class ID
116
+
117
+ Args:
118
+ class_id: Class ID (0-79)
119
+
120
+ Returns:
121
+ Category name
122
+ """
123
+ for category, ids in self.CATEGORIES.items():
124
+ if class_id in ids:
125
+ return category
126
+ return "other" # Fallback
127
+
128
+ def _generate_color_map(self) -> Dict:
129
+ """
130
+ Generate a color map for all 80 COCO classes
131
+
132
+ Returns:
133
+ Dictionary mapping class IDs and names to color values
134
+ """
135
+ color_map = {
136
+ 'by_id': {}, # Map class ID to RGB and hex
137
+ 'by_name': {}, # Map class name to RGB and hex
138
+ 'categories': {} # Map category to base color
139
+ }
140
+
141
+ # Generate colors for categories
142
+ for category, hsv in self.CATEGORY_COLORS.items():
143
+ rgb = self._hsv_to_rgb(hsv[0], hsv[1], hsv[2])
144
+ hex_color = self._rgb_to_hex(rgb)
145
+ color_map['categories'][category] = {
146
+ 'rgb': rgb,
147
+ 'hex': hex_color
148
+ }
149
+
150
+ # Generate variations for each class within a category
151
+ for class_id, class_name in self.class_names.items():
152
+ category = self._find_category(class_id)
153
+ base_hsv = self.CATEGORY_COLORS.get(category, (0, 0, 0.8)) # Default gray
154
+
155
+ # Slightly vary the hue and saturation within the category
156
+ ids_in_category = self.CATEGORIES.get(category, [])
157
+ if ids_in_category:
158
+ position = ids_in_category.index(class_id) if class_id in ids_in_category else 0
159
+ variation = position / max(1, len(ids_in_category) - 1) # 0 to 1
160
+
161
+ # Vary hue slightly (±15°) and saturation
162
+ h_offset = 30 * variation - 15 # -15 to +15
163
+ s_offset = 0.2 * variation # 0 to 0.2
164
+
165
+ h = (base_hsv[0] + h_offset) % 360
166
+ s = min(1.0, base_hsv[1] + s_offset)
167
+ v = base_hsv[2]
168
+ else:
169
+ h, s, v = base_hsv
170
+
171
+ rgb = self._hsv_to_rgb(h, s, v)
172
+ hex_color = self._rgb_to_hex(rgb)
173
+
174
+ # Store in both mappings
175
+ color_map['by_id'][class_id] = {
176
+ 'rgb': rgb,
177
+ 'hex': hex_color,
178
+ 'category': category
179
+ }
180
+
181
+ color_map['by_name'][class_name] = {
182
+ 'rgb': rgb,
183
+ 'hex': hex_color,
184
+ 'category': category
185
+ }
186
+
187
+ return color_map
188
+
189
+ def get_color(self, class_identifier: Union[int, str], format: str = 'hex') -> Any:
190
+ """
191
+ Get color for a specific class
192
+
193
+ Args:
194
+ class_identifier: Class ID (int) or name (str)
195
+ format: Color format ('hex', 'rgb', or 'bgr')
196
+
197
+ Returns:
198
+ Color in requested format
199
+ """
200
+ # Determine if identifier is an ID or name
201
+ if isinstance(class_identifier, int):
202
+ color_info = self.color_map['by_id'].get(class_identifier)
203
+ else:
204
+ color_info = self.color_map['by_name'].get(class_identifier)
205
+
206
+ if not color_info:
207
+ # Fallback color if not found
208
+ return '#CCCCCC' if format == 'hex' else (204, 204, 204)
209
+
210
+ if format == 'hex':
211
+ return color_info['hex']
212
+ elif format == 'rgb':
213
+ return color_info['rgb']
214
+ elif format == 'bgr':
215
+ # Convert RGB to BGR for OpenCV
216
+ r, g, b = color_info['rgb']
217
+ return (b, g, r)
218
+ else:
219
+ return color_info['rgb']
220
+
221
+ def get_all_colors(self, format: str = 'hex') -> Dict:
222
+ """
223
+ Get all colors in the specified format
224
+
225
+ Args:
226
+ format: Color format ('hex', 'rgb', or 'bgr')
227
+
228
+ Returns:
229
+ Dictionary mapping class names to colors
230
+ """
231
+ result = {}
232
+ for class_id, class_name in self.class_names.items():
233
+ result[class_name] = self.get_color(class_id, format)
234
+ return result
235
+
236
+ def get_category_colors(self, format: str = 'hex') -> Dict:
237
+ """
238
+ Get base colors for each category
239
+
240
+ Args:
241
+ format: Color format ('hex', 'rgb', or 'bgr')
242
+
243
+ Returns:
244
+ Dictionary mapping categories to colors
245
+ """
246
+ result = {}
247
+ for category, color_info in self.color_map['categories'].items():
248
+ if format == 'hex':
249
+ result[category] = color_info['hex']
250
+ elif format == 'bgr':
251
+ r, g, b = color_info['rgb']
252
+ result[category] = (b, g, r)
253
+ else:
254
+ result[category] = color_info['rgb']
255
+ return result
256
+
257
+ def get_category_for_class(self, class_identifier: Union[int, str]) -> str:
258
+ """
259
+ Get the category for a specific class
260
+
261
+ Args:
262
+ class_identifier: Class ID (int) or name (str)
263
+
264
+ Returns:
265
+ Category name
266
+ """
267
+ if isinstance(class_identifier, int):
268
+ return self.color_map['by_id'].get(class_identifier, {}).get('category', 'other')
269
+ else:
270
+ return self.color_map['by_name'].get(class_identifier, {}).get('category', 'other')
detection_model.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ from typing import Any, List, Dict, Optional
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+
7
+ class DetectionModel:
8
+ """Core detection model class for object detection using YOLOv8"""
9
+
10
+ # Model information dictionary
11
+ MODEL_INFO = {
12
+ "yolov8n.pt": {
13
+ "name": "YOLOv8n (Nano)",
14
+ "description": "Fastest model with smallest size (3.2M parameters). Best for speed-critical applications.",
15
+ "size_mb": 6,
16
+ "inference_speed": "Very Fast"
17
+ },
18
+ "yolov8m.pt": {
19
+ "name": "YOLOv8m (Medium)",
20
+ "description": "Balanced model with good accuracy-speed tradeoff (25.9M parameters). Recommended for general use.",
21
+ "size_mb": 25,
22
+ "inference_speed": "Medium"
23
+ },
24
+ "yolov8x.pt": {
25
+ "name": "YOLOv8x (XLarge)",
26
+ "description": "Most accurate but slower model (68.2M parameters). Best for accuracy-critical applications.",
27
+ "size_mb": 68,
28
+ "inference_speed": "Slower"
29
+ }
30
+ }
31
+
32
+ def __init__(self, model_name: str = 'yolov8m.pt', confidence: float = 0.25, iou: float = 0.45):
33
+ """
34
+ Initialize the detection model
35
+
36
+ Args:
37
+ model_name: Model name or path, default is yolov8m.pt
38
+ confidence: Confidence threshold, default is 0.25
39
+ iou: IoU threshold for non-maximum suppression, default is 0.45
40
+ """
41
+ self.model_name = model_name
42
+ self.confidence = confidence
43
+ self.iou = iou
44
+ self.model = None
45
+ self.class_names = {}
46
+ self.is_model_loaded = False
47
+
48
+ # Load model on initialization
49
+ self._load_model()
50
+
51
+ def _load_model(self):
52
+ """Load the YOLO model"""
53
+ try:
54
+ print(f"Loading model: {self.model_name}")
55
+ self.model = YOLO(self.model_name)
56
+ self.class_names = self.model.names
57
+ self.is_model_loaded = True
58
+ print(f"Successfully loaded model: {self.model_name}")
59
+ print(f"Number of classes the model can recognize: {len(self.class_names)}")
60
+ except Exception as e:
61
+ print(f"Error occurred when loading the model: {e}")
62
+ self.is_model_loaded = False
63
+
64
+ def change_model(self, new_model_name: str) -> bool:
65
+ """
66
+ Change the currently loaded model
67
+
68
+ Args:
69
+ new_model_name: Name of the new model to load
70
+
71
+ Returns:
72
+ bool: True if model changed successfully, False otherwise
73
+ """
74
+ if self.model_name == new_model_name and self.is_model_loaded:
75
+ print(f"Model {new_model_name} is already loaded")
76
+ return True
77
+
78
+ print(f"Changing model from {self.model_name} to {new_model_name}")
79
+
80
+ # Unload current model to free memory
81
+ if self.model is not None:
82
+ del self.model
83
+ self.model = None
84
+
85
+ # Clean GPU memory if available
86
+ if torch.cuda.is_available():
87
+ torch.cuda.empty_cache()
88
+
89
+ # Update model name and load new model
90
+ self.model_name = new_model_name
91
+ self._load_model()
92
+
93
+ return self.is_model_loaded
94
+
95
+ def reload_model(self):
96
+ """Reload the model (useful for changing model or after error)"""
97
+ if self.model is not None:
98
+ del self.model
99
+ self.model = None
100
+
101
+ # Clean GPU memory if available
102
+ if torch.cuda.is_available():
103
+ torch.cuda.empty_cache()
104
+
105
+ self._load_model()
106
+
107
+ def detect(self, image_input: Any) -> Optional[Any]:
108
+ """
109
+ Perform object detection on a single image
110
+
111
+ Args:
112
+ image_input: Image path (str), PIL Image, or numpy array
113
+
114
+ Returns:
115
+ Detection result object or None if error occurred
116
+ """
117
+ if self.model is None or not self.is_model_loaded:
118
+ print("Model not found or not loaded. Attempting to reload...")
119
+ self._load_model()
120
+ if self.model is None or not self.is_model_loaded:
121
+ print("Failed to load model. Cannot perform detection.")
122
+ return None
123
+
124
+ try:
125
+ results = self.model(image_input, conf=self.confidence, iou=self.iou)
126
+ return results[0]
127
+ except Exception as e:
128
+ print(f"Error occurred during detection: {e}")
129
+ return None
130
+
131
+ def get_class_names(self, class_id: int) -> str:
132
+ """Get class name for a given class ID"""
133
+ return self.class_names.get(class_id, "Unknown Class")
134
+
135
+ def get_supported_classes(self) -> Dict[int, str]:
136
+ """Get all supported classes as a dictionary of {id: class_name}"""
137
+ return self.class_names
138
+
139
+ @classmethod
140
+ def get_available_models(cls) -> List[Dict]:
141
+ """
142
+ Get list of available models with their information
143
+
144
+ Returns:
145
+ List of dictionaries containing model information
146
+ """
147
+ models = []
148
+ for model_file, info in cls.MODEL_INFO.items():
149
+ models.append({
150
+ "model_file": model_file,
151
+ "name": info["name"],
152
+ "description": info["description"],
153
+ "size_mb": info["size_mb"],
154
+ "inference_speed": info["inference_speed"]
155
+ })
156
+ return models
157
+
158
+ @classmethod
159
+ def get_model_description(cls, model_name: str) -> str:
160
+ """Get description for a specific model"""
161
+ if model_name in cls.MODEL_INFO:
162
+ info = cls.MODEL_INFO[model_name]
163
+ return f"{info['name']}: {info['description']} (Size: ~{info['size_mb']}MB, Speed: {info['inference_speed']})"
164
+ return "Model information not available"
evaluation_metrics.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from typing import Dict, List, Any, Optional, Tuple
4
+
5
+ class EvaluationMetrics:
6
+ """Class for computing detection metrics, generating statistics and visualization data"""
7
+
8
+ @staticmethod
9
+ def calculate_basic_stats(result: Any) -> Dict:
10
+ """
11
+ Calculate basic statistics for a single detection result
12
+
13
+ Args:
14
+ result: Detection result object
15
+
16
+ Returns:
17
+ Dictionary with basic statistics
18
+ """
19
+ if result is None:
20
+ return {"error": "No detection result provided"}
21
+
22
+ # Get classes and confidences
23
+ classes = result.boxes.cls.cpu().numpy().astype(int)
24
+ confidences = result.boxes.conf.cpu().numpy()
25
+ names = result.names
26
+
27
+ # Count by class
28
+ class_counts = {}
29
+ for cls, conf in zip(classes, confidences):
30
+ cls_name = names[int(cls)]
31
+ if cls_name not in class_counts:
32
+ class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
33
+
34
+ class_counts[cls_name]["count"] += 1
35
+ class_counts[cls_name]["total_confidence"] += float(conf)
36
+ class_counts[cls_name]["confidences"].append(float(conf))
37
+
38
+ # Calculate average confidence
39
+ for cls_name, stats in class_counts.items():
40
+ if stats["count"] > 0:
41
+ stats["average_confidence"] = stats["total_confidence"] / stats["count"]
42
+ stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
43
+ stats.pop("total_confidence") # Remove intermediate calculation
44
+
45
+ # Prepare summary
46
+ stats = {
47
+ "total_objects": len(classes),
48
+ "class_statistics": class_counts,
49
+ "average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
50
+ }
51
+
52
+ return stats
53
+
54
+ @staticmethod
55
+ def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
56
+ """
57
+ Generate structured data suitable for visualization
58
+
59
+ Args:
60
+ result: Detection result object
61
+ class_colors: Dictionary mapping class names to color codes (optional)
62
+
63
+ Returns:
64
+ Dictionary with visualization-ready data
65
+ """
66
+ if result is None:
67
+ return {"error": "No detection result provided"}
68
+
69
+ # Get basic stats first
70
+ stats = EvaluationMetrics.calculate_basic_stats(result)
71
+
72
+ # Create visualization-specific data structure
73
+ viz_data = {
74
+ "total_objects": stats["total_objects"],
75
+ "average_confidence": stats["average_confidence"],
76
+ "class_data": []
77
+ }
78
+
79
+ # Sort classes by count (descending)
80
+ sorted_classes = sorted(
81
+ stats["class_statistics"].items(),
82
+ key=lambda x: x[1]["count"],
83
+ reverse=True
84
+ )
85
+
86
+ # Create class-specific visualization data
87
+ for cls_name, cls_stats in sorted_classes:
88
+ class_id = -1
89
+ # Find the class ID based on the name
90
+ for idx, name in result.names.items():
91
+ if name == cls_name:
92
+ class_id = idx
93
+ break
94
+
95
+ cls_data = {
96
+ "name": cls_name,
97
+ "class_id": class_id,
98
+ "count": cls_stats["count"],
99
+ "average_confidence": cls_stats.get("average_confidence", 0),
100
+ "confidence_std": cls_stats.get("confidence_std", 0),
101
+ "color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
102
+ }
103
+
104
+ viz_data["class_data"].append(cls_data)
105
+
106
+ return viz_data
107
+
108
+ @staticmethod
109
+ def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
110
+ """
111
+ Create a horizontal bar chart showing detection statistics
112
+
113
+ Args:
114
+ viz_data: Visualization data generated by generate_visualization_data
115
+ figsize: Figure size (width, height) in inches
116
+ max_classes: Maximum number of classes to display
117
+
118
+ Returns:
119
+ Matplotlib figure object
120
+ """
121
+ # Use the enhanced version
122
+ return EvaluationMetrics.create_enhanced_stats_plot(viz_data, figsize, max_classes)
123
+
124
+ @staticmethod
125
+ def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
126
+ """
127
+ Create an enhanced horizontal bar chart with larger fonts and better styling
128
+
129
+ Args:
130
+ viz_data: Visualization data dictionary
131
+ figsize: Figure size (width, height) in inches
132
+ max_classes: Maximum number of classes to display
133
+
134
+ Returns:
135
+ Matplotlib figure with enhanced styling
136
+ """
137
+ if "error" in viz_data:
138
+ # Create empty plot if error
139
+ fig, ax = plt.subplots(figsize=figsize)
140
+ ax.text(0.5, 0.5, viz_data["error"],
141
+ ha='center', va='center', fontsize=14, fontfamily='Arial')
142
+ ax.set_xlim(0, 1)
143
+ ax.set_ylim(0, 1)
144
+ ax.axis('off')
145
+ return fig
146
+
147
+ if "class_data" not in viz_data or not viz_data["class_data"]:
148
+ # Create empty plot if no data
149
+ fig, ax = plt.subplots(figsize=figsize)
150
+ ax.text(0.5, 0.5, "No detection data available",
151
+ ha='center', va='center', fontsize=14, fontfamily='Arial')
152
+ ax.set_xlim(0, 1)
153
+ ax.set_ylim(0, 1)
154
+ ax.axis('off')
155
+ return fig
156
+
157
+ # Limit to max_classes
158
+ class_data = viz_data["class_data"][:max_classes]
159
+
160
+ # Extract data for plotting
161
+ class_names = [item["name"] for item in class_data]
162
+ counts = [item["count"] for item in class_data]
163
+ colors = [item["color"] for item in class_data]
164
+
165
+ # Create figure and horizontal bar chart with improved styling
166
+ plt.rcParams['font.family'] = 'Arial'
167
+ fig, ax = plt.subplots(figsize=figsize)
168
+
169
+ # Set background color to white
170
+ fig.patch.set_facecolor('white')
171
+ ax.set_facecolor('white')
172
+
173
+ y_pos = np.arange(len(class_names))
174
+
175
+ # Create horizontal bars with class-specific colors
176
+ bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6)
177
+
178
+ # Add count values at end of each bar with larger font
179
+ for i, bar in enumerate(bars):
180
+ width = bar.get_width()
181
+ conf = class_data[i]["average_confidence"]
182
+ ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
183
+ f"{width:.0f} (conf: {conf:.2f})",
184
+ va='center', fontsize=12, fontfamily='Arial')
185
+
186
+ # Customize axis and labels with larger fonts
187
+ ax.set_yticks(y_pos)
188
+ ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial')
189
+ ax.invert_yaxis() # Labels read top-to-bottom
190
+ ax.set_xlabel('Count', fontsize=14, fontfamily='Arial')
191
+ ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total',
192
+ fontsize=16, fontfamily='Arial', fontweight='bold')
193
+
194
+ # Add grid for better readability
195
+ ax.set_axisbelow(True)
196
+ ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB')
197
+
198
+ # Increase tick label font size
199
+ ax.tick_params(axis='both', which='major', labelsize=12)
200
+
201
+ # Add detection summary as a text box with improved styling
202
+ summary_text = (
203
+ f"Total Objects: {viz_data['total_objects']}\n"
204
+ f"Average Confidence: {viz_data['average_confidence']:.2f}\n"
205
+ f"Unique Classes: {len(viz_data['class_data'])}"
206
+ )
207
+ plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial',
208
+ bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5',
209
+ edgecolor='#E5E7EB'))
210
+
211
+ plt.tight_layout()
212
+ return fig
213
+
214
+ @staticmethod
215
+ def format_detection_summary(viz_data: Dict) -> str:
216
+ """
217
+ Format detection results as a readable text summary with improved spacing
218
+
219
+ Args:
220
+ viz_data: Visualization data generated by generate_visualization_data
221
+
222
+ Returns:
223
+ Formatted text with proper spacing
224
+ """
225
+ if "error" in viz_data:
226
+ return viz_data["error"]
227
+
228
+ if "total_objects" not in viz_data:
229
+ return "No detection data available."
230
+
231
+ # 獲取基本統計信息
232
+ total_objects = viz_data["total_objects"]
233
+ avg_confidence = viz_data["average_confidence"]
234
+
235
+ # 創建標題,使用更多空白行增加可讀性
236
+ lines = [
237
+ f"Detected {total_objects} objects.",
238
+ f"Average confidence: {avg_confidence:.2f}",
239
+ "\n", # 添加額外的空行
240
+ "Objects by class:",
241
+ ]
242
+
243
+ # 添加類別詳情,每個類別使用更多空間
244
+ if "class_data" in viz_data and viz_data["class_data"]:
245
+ for item in viz_data["class_data"]:
246
+ count = item['count']
247
+ # 使用正確的單複數形式
248
+ item_text = "item" if count == 1 else "items"
249
+
250
+ # 每個項目前添加空行,並使用縮進格式化
251
+ lines.append("\n") # 每個項目前添加空白行
252
+ lines.append(f"• {item['name']}: {count} {item_text}")
253
+ lines.append(f" Confidence: {item['average_confidence']:.2f}")
254
+ else:
255
+ lines.append("\nNo class information available.")
256
+
257
+ return "\n".join(lines)
258
+
259
+ @staticmethod
260
+ def calculate_distance_metrics(result: Any) -> Dict:
261
+ """
262
+ Calculate distance-related metrics for detected objects
263
+
264
+ Args:
265
+ result: Detection result object
266
+
267
+ Returns:
268
+ Dictionary with distance metrics
269
+ """
270
+ if result is None:
271
+ return {"error": "No detection result provided"}
272
+
273
+ boxes = result.boxes.xyxy.cpu().numpy()
274
+ classes = result.boxes.cls.cpu().numpy().astype(int)
275
+ names = result.names
276
+
277
+ # Initialize metrics
278
+ metrics = {
279
+ "proximity": {}, # Classes that appear close to each other
280
+ "spatial_distribution": {}, # Distribution across the image
281
+ "size_distribution": {} # Size distribution of objects
282
+ }
283
+
284
+ # Calculate image dimensions (assuming normalized coordinates or extract from result)
285
+ img_width, img_height = 1, 1
286
+ if hasattr(result, "orig_shape"):
287
+ img_height, img_width = result.orig_shape[:2]
288
+
289
+ # Calculate bounding box areas and centers
290
+ areas = []
291
+ centers = []
292
+ class_names = []
293
+
294
+ for box, cls in zip(boxes, classes):
295
+ x1, y1, x2, y2 = box
296
+ width, height = x2 - x1, y2 - y1
297
+ area = width * height
298
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
299
+
300
+ areas.append(area)
301
+ centers.append((center_x, center_y))
302
+ class_names.append(names[int(cls)])
303
+
304
+ # Calculate spatial distribution
305
+ if centers:
306
+ x_coords = [c[0] for c in centers]
307
+ y_coords = [c[1] for c in centers]
308
+
309
+ metrics["spatial_distribution"] = {
310
+ "x_mean": float(np.mean(x_coords)) / img_width,
311
+ "y_mean": float(np.mean(y_coords)) / img_height,
312
+ "x_std": float(np.std(x_coords)) / img_width,
313
+ "y_std": float(np.std(y_coords)) / img_height
314
+ }
315
+
316
+ # Calculate size distribution
317
+ if areas:
318
+ metrics["size_distribution"] = {
319
+ "mean_area": float(np.mean(areas)) / (img_width * img_height),
320
+ "std_area": float(np.std(areas)) / (img_width * img_height),
321
+ "min_area": float(np.min(areas)) / (img_width * img_height),
322
+ "max_area": float(np.max(areas)) / (img_width * img_height)
323
+ }
324
+
325
+ # Calculate proximity between different classes
326
+ class_centers = {}
327
+ for cls_name, center in zip(class_names, centers):
328
+ if cls_name not in class_centers:
329
+ class_centers[cls_name] = []
330
+ class_centers[cls_name].append(center)
331
+
332
+ # Find classes that appear close to each other
333
+ proximity_pairs = []
334
+ for i, cls1 in enumerate(class_centers.keys()):
335
+ for j, cls2 in enumerate(class_centers.keys()):
336
+ if i >= j: # Avoid duplicate pairs and self-comparison
337
+ continue
338
+
339
+ # Calculate minimum distance between any two objects of these classes
340
+ min_distance = float('inf')
341
+ for center1 in class_centers[cls1]:
342
+ for center2 in class_centers[cls2]:
343
+ dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
344
+ min_distance = min(min_distance, dist)
345
+
346
+ # Normalize by image diagonal
347
+ img_diagonal = np.sqrt(img_width**2 + img_height**2)
348
+ norm_distance = min_distance / img_diagonal
349
+
350
+ proximity_pairs.append({
351
+ "class1": cls1,
352
+ "class2": cls2,
353
+ "distance": float(norm_distance)
354
+ })
355
+
356
+ # Sort by distance and keep the closest pairs
357
+ proximity_pairs.sort(key=lambda x: x["distance"])
358
+ metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
359
+
360
+ return metrics
style.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Style:
2
+ @staticmethod
3
+ def get_css():
4
+ """Return the application's CSS styles with improved aesthetics"""
5
+ css = """
6
+ /* Base styles and typography */
7
+ body {
8
+ font-family: Arial, sans-serif;
9
+ background: linear-gradient(135deg, #f0f9ff, #e1f5fe);
10
+ margin: 0;
11
+ padding: 0;
12
+ display: flex;
13
+ justify-content: center;
14
+ min-height: 100vh;
15
+ }
16
+
17
+ /* Typography improvements */
18
+ h1, h2, h3, h4, h5, h6, p, span, div, label, button {
19
+ font-family: Arial, sans-serif;
20
+ }
21
+
22
+ /* Container styling */
23
+ .gradio-container {
24
+ max-width: 1200px !important;
25
+ margin: 0 auto;
26
+ padding: 1rem;
27
+ width: 100%;
28
+ }
29
+
30
+ /* Header area styling with gradient background */
31
+ .app-header {
32
+ text-align: center;
33
+ margin-bottom: 2rem;
34
+ background: linear-gradient(135deg, #f8f9fa, #e9ecef);
35
+ padding: 1.5rem;
36
+ border-radius: 10px;
37
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.05);
38
+ width: 100%;
39
+ }
40
+
41
+ .app-title {
42
+ color: #2D3748;
43
+ font-size: 2.5rem;
44
+ margin-bottom: 0.5rem;
45
+ background: linear-gradient(90deg, #38b2ac, #4299e1);
46
+ -webkit-background-clip: text;
47
+ -webkit-text-fill-color: transparent;
48
+ font-weight: bold;
49
+ }
50
+
51
+ .app-subtitle {
52
+ color: #4A5568;
53
+ font-size: 1.2rem;
54
+ font-weight: normal;
55
+ margin-top: 0.25rem;
56
+ }
57
+
58
+ .app-divider {
59
+ width: 80px;
60
+ height: 3px;
61
+ background: linear-gradient(90deg, #38b2ac, #4299e1);
62
+ margin: 1rem auto;
63
+ }
64
+
65
+ /* Panel styling - gradient background */
66
+ .input-panel, .output-panel {
67
+ background: white;
68
+ border-radius: 10px;
69
+ padding: 1.5rem;
70
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
71
+ margin: 0 auto 1rem auto;
72
+ }
73
+
74
+ /* Section heading styling with gradient background */
75
+ .section-heading {
76
+ font-size: 1.25rem;
77
+ font-weight: 600;
78
+ color: #2D3748;
79
+ margin-bottom: 1rem;
80
+ margin-top: 0.5rem;
81
+ text-align: center;
82
+ padding: 0.8rem;
83
+ background: linear-gradient(to right, #e6f3fc, #f0f9ff);
84
+ border-radius: 8px;
85
+ }
86
+
87
+ /* How-to-use section with gradient background */
88
+ .how-to-use {
89
+ background: linear-gradient(135deg, #f8fafc, #e8f4fd);
90
+ border-radius: 10px;
91
+ padding: 1.5rem;
92
+ margin-top: 1rem;
93
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
94
+ color: #2d3748;
95
+ }
96
+
97
+ /* Detection button styling */
98
+ .detect-btn {
99
+ background: linear-gradient(90deg, #38b2ac, #4299e1) !important;
100
+ color: white !important;
101
+ border: none !important;
102
+ border-radius: 8px !important;
103
+ transition: transform 0.3s, box-shadow 0.3s !important;
104
+ font-weight: bold !important;
105
+ letter-spacing: 0.5px !important;
106
+ padding: 0.75rem 1.5rem !important;
107
+ width: 100%;
108
+ margin: 1rem auto !important;
109
+ font-family: Arial, sans-serif !important;
110
+ }
111
+
112
+ .detect-btn:hover {
113
+ transform: translateY(-2px) !important;
114
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2) !important;
115
+ }
116
+
117
+ .detect-btn:active {
118
+ transform: translateY(1px) !important;
119
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2) !important;
120
+ }
121
+
122
+ /* JSON display improvements */
123
+ .json-display pre {
124
+ background: #f8fafc;
125
+ border-radius: 6px;
126
+ padding: 1rem;
127
+ font-family: 'Consolas', 'Monaco', monospace;
128
+ white-space: pre-wrap;
129
+ max-height: 500px;
130
+ overflow-y: auto;
131
+ box-shadow: inset 0 0 4px rgba(0, 0, 0, 0.1);
132
+ }
133
+
134
+ .json-key {
135
+ color: #e53e3e;
136
+ }
137
+
138
+ .json-value {
139
+ color: #2b6cb0;
140
+ }
141
+
142
+ .json-string {
143
+ color: #38a169;
144
+ }
145
+
146
+ /* Chart/plot styling improvements */
147
+ .plot-container {
148
+ background: white;
149
+ border-radius: 8px;
150
+ padding: 0.5rem;
151
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
152
+ }
153
+
154
+ /* Larger font for plots */
155
+ .plot-container text {
156
+ font-family: Arial, sans-serif !important;
157
+ font-size: 14px !important;
158
+ }
159
+
160
+ /* Title styling for charts */
161
+ .plot-title {
162
+ font-family: Arial, sans-serif !important;
163
+ font-size: 16px !important;
164
+ font-weight: bold !important;
165
+ }
166
+
167
+ /* Tab styling with subtle gradient */
168
+ .tabs {
169
+ width: 100%;
170
+ display: flex;
171
+ justify-content: center;
172
+ }
173
+
174
+ .tabs > div:first-child {
175
+ background: linear-gradient(to right, #f8fafc, #e8f4fd) !important;
176
+ border-radius: 8px 8px 0 0;
177
+ }
178
+
179
+ /* Footer styling with gradient background */
180
+ .footer {
181
+ text-align: center;
182
+ margin-top: 2rem;
183
+ font-size: 0.9rem;
184
+ color: #4A5568;
185
+ padding: 1rem;
186
+ background: linear-gradient(135deg, #f8f9fa, #e1effe);
187
+ border-radius: 10px;
188
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
189
+ width: 100%;
190
+ }
191
+
192
+ /* Ensure centering works for all elements */
193
+ .container, .gr-container, .gr-row, .gr-col {
194
+ display: flex;
195
+ flex-direction: column;
196
+ align-items: center;
197
+ justify-content: center;
198
+ width: 100%;
199
+ }
200
+
201
+ /* 結果文本框的改進樣式 */
202
+ #detection-details, .wide-result-text {
203
+ width: 100% !important;
204
+ max-width: 100% !important;
205
+ box-sizing: border-box !important;
206
+ }
207
+
208
+ .wide-result-text textarea {
209
+ width: 100% !important;
210
+ min-width: 600px !important;
211
+ font-family: 'Arial', sans-serif !important;
212
+ font-size: 14px !important;
213
+ line-height: 1.5 !important; /* 減少行間距 */
214
+ padding: 16px !important;
215
+ white-space: pre-wrap !important;
216
+ background-color: #f8f9fa !important;
217
+ border-radius: 8px !important;
218
+ min-height: 300px !important;
219
+ resize: none !important;
220
+ overflow-y: auto !important;
221
+ border: 1px solid #e2e8f0 !important;
222
+ display: block !important;
223
+ }
224
+
225
+ /* 結果詳情面板樣式 - 加入漸層背景 */
226
+ .result-details-box {
227
+ width: 100% !important;
228
+ margin-top: 1.5rem;
229
+ background: linear-gradient(135deg, #f8fafc, #e8f4fd);
230
+ border-radius: 10px;
231
+ padding: 1rem;
232
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
233
+ }
234
+
235
+ /* 確保結果詳情面板內的元素寬度可以適應面板 */
236
+ .result-details-box > * {
237
+ width: 100% !important;
238
+ max-width: 100% !important;
239
+ }
240
+
241
+ /* 確保文本區域不會被限制寬度 */
242
+ .result-details-box .gr-text-input {
243
+ width: 100% !important;
244
+ max-width: none !important;
245
+ }
246
+
247
+ /* 輸出面板內容的布局調整 */
248
+ .output-panel {
249
+ display: flex;
250
+ flex-direction: column;
251
+ width: 100%;
252
+ padding: 0 !important;
253
+ }
254
+
255
+ /* 確保結果面板內的元素寬度可以適應面板 */
256
+ .output-panel > * {
257
+ width: 100%;
258
+ }
259
+
260
+ /* 改善統計面板列佈局 */
261
+ .plot-column, .stats-column {
262
+ display: flex;
263
+ flex-direction: column;
264
+ padding: 1rem;
265
+ }
266
+
267
+ /* Responsive adjustments */
268
+ @media (max-width: 768px) {
269
+ .app-title {
270
+ font-size: 2rem;
271
+ }
272
+
273
+ .app-subtitle {
274
+ font-size: 1rem;
275
+ }
276
+
277
+ .gradio-container {
278
+ padding: 0.5rem;
279
+ }
280
+ }
281
+ """
282
+ return css
visualization_helper.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from typing import Any, List, Dict, Tuple, Optional
5
+ import io
6
+ from PIL import Image
7
+
8
+ class VisualizationHelper:
9
+ """Helper class for visualizing detection results"""
10
+
11
+ @staticmethod
12
+ def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
13
+ figsize: Tuple[int, int] = (12, 12),
14
+ return_pil: bool = False) -> Optional[Image.Image]:
15
+ """
16
+ Visualize detection results on a single image
17
+
18
+ Args:
19
+ image: Image path or numpy array
20
+ result: Detection result object
21
+ color_mapper: ColorMapper instance for consistent colors
22
+ figsize: Figure size
23
+ return_pil: If True, returns a PIL Image object
24
+
25
+ Returns:
26
+ PIL Image if return_pil is True, otherwise displays the plot
27
+ """
28
+ if result is None:
29
+ print('No data for visualization')
30
+ return None
31
+
32
+ # Read image if path is provided
33
+ if isinstance(image, str):
34
+ img = cv2.imread(image)
35
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36
+ else:
37
+ img = image
38
+ if len(img.shape) == 3 and img.shape[2] == 3:
39
+ # Check if BGR format (OpenCV) and convert to RGB if needed
40
+ if isinstance(img, np.ndarray):
41
+ # Assuming BGR format from OpenCV
42
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
43
+
44
+ # Create figure
45
+ fig, ax = plt.subplots(figsize=figsize)
46
+ ax.imshow(img)
47
+
48
+ # Get bounding boxes, classes and confidences
49
+ boxes = result.boxes.xyxy.cpu().numpy()
50
+ classes = result.boxes.cls.cpu().numpy()
51
+ confs = result.boxes.conf.cpu().numpy()
52
+
53
+ # Get class names
54
+ names = result.names
55
+
56
+ # Create a default color mapper if none is provided
57
+ if color_mapper is None:
58
+ # For backward compatibility, fallback to a simple color function
59
+ from matplotlib import colormaps
60
+ cmap = colormaps['tab10']
61
+ def get_color(class_id):
62
+ return cmap(class_id % 10)
63
+ else:
64
+ # Use the provided color mapper
65
+ def get_color(class_id):
66
+ hex_color = color_mapper.get_color(class_id)
67
+ # Convert hex to RGB float values for matplotlib
68
+ hex_color = hex_color.lstrip('#')
69
+ return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
70
+
71
+ # Draw detection results
72
+ for box, cls, conf in zip(boxes, classes, confs):
73
+ x1, y1, x2, y2 = box
74
+ cls_id = int(cls)
75
+ cls_name = names[cls_id]
76
+
77
+ # Get color for this class
78
+ box_color = get_color(cls_id)
79
+
80
+ # Add text label with colored background
81
+ ax.text(x1, y1 - 5, f'{cls_name}: {conf:.2f}',
82
+ color='white', fontsize=10,
83
+ bbox=dict(facecolor=box_color[:3], alpha=0.7))
84
+
85
+ # Add bounding box
86
+ ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
87
+ fill=False, edgecolor=box_color[:3], linewidth=2))
88
+
89
+ ax.axis('off')
90
+ # ax.set_title('Detection Result')
91
+ plt.tight_layout()
92
+
93
+ if return_pil:
94
+ # Convert plot to PIL Image
95
+ buf = io.BytesIO()
96
+ fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
97
+ buf.seek(0)
98
+ pil_img = Image.open(buf)
99
+ plt.close(fig)
100
+ return pil_img
101
+ else:
102
+ plt.show()
103
+ return None
104
+
105
+ @staticmethod
106
+ def create_summary(result: Any) -> Dict:
107
+ """
108
+ Create a summary of detection results
109
+
110
+ Args:
111
+ result: Detection result object
112
+
113
+ Returns:
114
+ Dictionary with detection summary statistics
115
+ """
116
+ if result is None:
117
+ return {"error": "No detection result provided"}
118
+
119
+ # Get classes and confidences
120
+ classes = result.boxes.cls.cpu().numpy().astype(int)
121
+ confidences = result.boxes.conf.cpu().numpy()
122
+ names = result.names
123
+
124
+ # Count detections by class
125
+ class_counts = {}
126
+ for cls, conf in zip(classes, confidences):
127
+ cls_name = names[int(cls)]
128
+ if cls_name not in class_counts:
129
+ class_counts[cls_name] = {"count": 0, "confidences": []}
130
+
131
+ class_counts[cls_name]["count"] += 1
132
+ class_counts[cls_name]["confidences"].append(float(conf))
133
+
134
+ # Calculate average confidence for each class
135
+ for cls_name, stats in class_counts.items():
136
+ if stats["confidences"]:
137
+ stats["average_confidence"] = float(np.mean(stats["confidences"]))
138
+ stats.pop("confidences") # Remove detailed confidences list to keep summary concise
139
+
140
+ # Prepare summary
141
+ summary = {
142
+ "total_objects": len(classes),
143
+ "class_counts": class_counts,
144
+ "unique_classes": len(class_counts)
145
+ }
146
+
147
+ return summary