DawnC commited on
Commit
5888da9
·
verified ·
1 Parent(s): 0137ce3

Upload 2 files

Browse files
Files changed (2) hide show
  1. evaluation_metrics.py +72 -90
  2. visualization_helper.py +37 -28
evaluation_metrics.py CHANGED
@@ -4,85 +4,85 @@ from typing import Dict, List, Any, Optional, Tuple
4
 
5
  class EvaluationMetrics:
6
  """Class for computing detection metrics, generating statistics and visualization data"""
7
-
8
  @staticmethod
9
  def calculate_basic_stats(result: Any) -> Dict:
10
  """
11
  Calculate basic statistics for a single detection result
12
-
13
  Args:
14
  result: Detection result object
15
-
16
  Returns:
17
  Dictionary with basic statistics
18
  """
19
  if result is None:
20
  return {"error": "No detection result provided"}
21
-
22
  # Get classes and confidences
23
  classes = result.boxes.cls.cpu().numpy().astype(int)
24
  confidences = result.boxes.conf.cpu().numpy()
25
  names = result.names
26
-
27
  # Count by class
28
  class_counts = {}
29
  for cls, conf in zip(classes, confidences):
30
  cls_name = names[int(cls)]
31
  if cls_name not in class_counts:
32
  class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
33
-
34
  class_counts[cls_name]["count"] += 1
35
  class_counts[cls_name]["total_confidence"] += float(conf)
36
  class_counts[cls_name]["confidences"].append(float(conf))
37
-
38
  # Calculate average confidence
39
  for cls_name, stats in class_counts.items():
40
  if stats["count"] > 0:
41
  stats["average_confidence"] = stats["total_confidence"] / stats["count"]
42
  stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
43
  stats.pop("total_confidence") # Remove intermediate calculation
44
-
45
  # Prepare summary
46
  stats = {
47
  "total_objects": len(classes),
48
  "class_statistics": class_counts,
49
  "average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
50
  }
51
-
52
  return stats
53
-
54
  @staticmethod
55
  def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
56
  """
57
  Generate structured data suitable for visualization
58
-
59
  Args:
60
  result: Detection result object
61
  class_colors: Dictionary mapping class names to color codes (optional)
62
-
63
  Returns:
64
  Dictionary with visualization-ready data
65
  """
66
  if result is None:
67
  return {"error": "No detection result provided"}
68
-
69
  # Get basic stats first
70
  stats = EvaluationMetrics.calculate_basic_stats(result)
71
-
72
  # Create visualization-specific data structure
73
  viz_data = {
74
  "total_objects": stats["total_objects"],
75
  "average_confidence": stats["average_confidence"],
76
  "class_data": []
77
  }
78
-
79
  # Sort classes by count (descending)
80
  sorted_classes = sorted(
81
  stats["class_statistics"].items(),
82
  key=lambda x: x[1]["count"],
83
  reverse=True
84
  )
85
-
86
  # Create class-specific visualization data
87
  for cls_name, cls_stats in sorted_classes:
88
  class_id = -1
@@ -91,7 +91,7 @@ class EvaluationMetrics:
91
  if name == cls_name:
92
  class_id = idx
93
  break
94
-
95
  cls_data = {
96
  "name": cls_name,
97
  "class_id": class_id,
@@ -100,21 +100,21 @@ class EvaluationMetrics:
100
  "confidence_std": cls_stats.get("confidence_std", 0),
101
  "color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
102
  }
103
-
104
  viz_data["class_data"].append(cls_data)
105
-
106
  return viz_data
107
-
108
  @staticmethod
109
  def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
110
  """
111
  Create a horizontal bar chart showing detection statistics
112
-
113
  Args:
114
  viz_data: Visualization data generated by generate_visualization_data
115
  figsize: Figure size (width, height) in inches
116
  max_classes: Maximum number of classes to display
117
-
118
  Returns:
119
  Matplotlib figure object
120
  """
@@ -125,79 +125,79 @@ class EvaluationMetrics:
125
  def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
126
  """
127
  Create an enhanced horizontal bar chart with larger fonts and better styling
128
-
129
  Args:
130
  viz_data: Visualization data dictionary
131
  figsize: Figure size (width, height) in inches
132
  max_classes: Maximum number of classes to display
133
-
134
  Returns:
135
  Matplotlib figure with enhanced styling
136
  """
137
  if "error" in viz_data:
138
  # Create empty plot if error
139
  fig, ax = plt.subplots(figsize=figsize)
140
- ax.text(0.5, 0.5, viz_data["error"],
141
  ha='center', va='center', fontsize=14, fontfamily='Arial')
142
  ax.set_xlim(0, 1)
143
  ax.set_ylim(0, 1)
144
  ax.axis('off')
145
  return fig
146
-
147
  if "class_data" not in viz_data or not viz_data["class_data"]:
148
  # Create empty plot if no data
149
  fig, ax = plt.subplots(figsize=figsize)
150
- ax.text(0.5, 0.5, "No detection data available",
151
  ha='center', va='center', fontsize=14, fontfamily='Arial')
152
  ax.set_xlim(0, 1)
153
  ax.set_ylim(0, 1)
154
  ax.axis('off')
155
  return fig
156
-
157
  # Limit to max_classes
158
  class_data = viz_data["class_data"][:max_classes]
159
-
160
  # Extract data for plotting
161
  class_names = [item["name"] for item in class_data]
162
  counts = [item["count"] for item in class_data]
163
  colors = [item["color"] for item in class_data]
164
-
165
  # Create figure and horizontal bar chart with improved styling
166
  plt.rcParams['font.family'] = 'Arial'
167
  fig, ax = plt.subplots(figsize=figsize)
168
-
169
  # Set background color to white
170
  fig.patch.set_facecolor('white')
171
  ax.set_facecolor('white')
172
-
173
  y_pos = np.arange(len(class_names))
174
-
175
  # Create horizontal bars with class-specific colors
176
  bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6)
177
-
178
  # Add count values at end of each bar with larger font
179
  for i, bar in enumerate(bars):
180
  width = bar.get_width()
181
  conf = class_data[i]["average_confidence"]
182
- ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
183
- f"{width:.0f} (conf: {conf:.2f})",
184
  va='center', fontsize=12, fontfamily='Arial')
185
-
186
  # Customize axis and labels with larger fonts
187
  ax.set_yticks(y_pos)
188
  ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial')
189
  ax.invert_yaxis() # Labels read top-to-bottom
190
  ax.set_xlabel('Count', fontsize=14, fontfamily='Arial')
191
- ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total',
192
  fontsize=16, fontfamily='Arial', fontweight='bold')
193
-
194
  # Add grid for better readability
195
  ax.set_axisbelow(True)
196
  ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB')
197
-
198
  # Increase tick label font size
199
  ax.tick_params(axis='both', which='major', labelsize=12)
200
-
201
  # Add detection summary as a text box with improved styling
202
  summary_text = (
203
  f"Total Objects: {viz_data['total_objects']}\n"
@@ -205,114 +205,96 @@ class EvaluationMetrics:
205
  f"Unique Classes: {len(viz_data['class_data'])}"
206
  )
207
  plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial',
208
- bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5',
209
  edgecolor='#E5E7EB'))
210
-
211
  plt.tight_layout()
212
  return fig
213
-
214
  @staticmethod
215
  def format_detection_summary(viz_data: Dict) -> str:
216
- """
217
- Format detection results as a readable text summary with improved spacing
218
-
219
- Args:
220
- viz_data: Visualization data generated by generate_visualization_data
221
-
222
- Returns:
223
- Formatted text with proper spacing
224
- """
225
  if "error" in viz_data:
226
  return viz_data["error"]
227
-
228
  if "total_objects" not in viz_data:
229
  return "No detection data available."
230
-
231
- # 獲取基本統計信息
232
  total_objects = viz_data["total_objects"]
233
  avg_confidence = viz_data["average_confidence"]
234
-
235
- # 創建標題,使用更多空白行增加可讀性
236
  lines = [
237
- f"Detected {total_objects} objects.",
238
  f"Average confidence: {avg_confidence:.2f}",
239
- "\n", # 添加額外的空行
240
- "Objects by class:",
241
  ]
242
-
243
- # 添加類別詳情,每個類別使用更多空間
244
  if "class_data" in viz_data and viz_data["class_data"]:
245
  for item in viz_data["class_data"]:
246
  count = item['count']
247
- # 使用正確的單複數形式
248
  item_text = "item" if count == 1 else "items"
249
-
250
- # 每個項目前添加空行,並使用縮進格式化
251
- lines.append("\n") # 每個項目前添加空白行
252
- lines.append(f"• {item['name']}: {count} {item_text}")
253
- lines.append(f" Confidence: {item['average_confidence']:.2f}")
254
  else:
255
- lines.append("\nNo class information available.")
256
-
257
  return "\n".join(lines)
258
-
259
  @staticmethod
260
  def calculate_distance_metrics(result: Any) -> Dict:
261
  """
262
  Calculate distance-related metrics for detected objects
263
-
264
  Args:
265
  result: Detection result object
266
-
267
  Returns:
268
  Dictionary with distance metrics
269
  """
270
  if result is None:
271
  return {"error": "No detection result provided"}
272
-
273
  boxes = result.boxes.xyxy.cpu().numpy()
274
  classes = result.boxes.cls.cpu().numpy().astype(int)
275
  names = result.names
276
-
277
  # Initialize metrics
278
  metrics = {
279
  "proximity": {}, # Classes that appear close to each other
280
  "spatial_distribution": {}, # Distribution across the image
281
  "size_distribution": {} # Size distribution of objects
282
  }
283
-
284
  # Calculate image dimensions (assuming normalized coordinates or extract from result)
285
  img_width, img_height = 1, 1
286
  if hasattr(result, "orig_shape"):
287
  img_height, img_width = result.orig_shape[:2]
288
-
289
  # Calculate bounding box areas and centers
290
  areas = []
291
  centers = []
292
  class_names = []
293
-
294
  for box, cls in zip(boxes, classes):
295
  x1, y1, x2, y2 = box
296
  width, height = x2 - x1, y2 - y1
297
  area = width * height
298
  center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
299
-
300
  areas.append(area)
301
  centers.append((center_x, center_y))
302
  class_names.append(names[int(cls)])
303
-
304
  # Calculate spatial distribution
305
  if centers:
306
  x_coords = [c[0] for c in centers]
307
  y_coords = [c[1] for c in centers]
308
-
309
  metrics["spatial_distribution"] = {
310
  "x_mean": float(np.mean(x_coords)) / img_width,
311
  "y_mean": float(np.mean(y_coords)) / img_height,
312
  "x_std": float(np.std(x_coords)) / img_width,
313
  "y_std": float(np.std(y_coords)) / img_height
314
  }
315
-
316
  # Calculate size distribution
317
  if areas:
318
  metrics["size_distribution"] = {
@@ -321,40 +303,40 @@ class EvaluationMetrics:
321
  "min_area": float(np.min(areas)) / (img_width * img_height),
322
  "max_area": float(np.max(areas)) / (img_width * img_height)
323
  }
324
-
325
  # Calculate proximity between different classes
326
  class_centers = {}
327
  for cls_name, center in zip(class_names, centers):
328
  if cls_name not in class_centers:
329
  class_centers[cls_name] = []
330
  class_centers[cls_name].append(center)
331
-
332
  # Find classes that appear close to each other
333
  proximity_pairs = []
334
  for i, cls1 in enumerate(class_centers.keys()):
335
  for j, cls2 in enumerate(class_centers.keys()):
336
  if i >= j: # Avoid duplicate pairs and self-comparison
337
  continue
338
-
339
  # Calculate minimum distance between any two objects of these classes
340
  min_distance = float('inf')
341
  for center1 in class_centers[cls1]:
342
  for center2 in class_centers[cls2]:
343
  dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
344
  min_distance = min(min_distance, dist)
345
-
346
  # Normalize by image diagonal
347
  img_diagonal = np.sqrt(img_width**2 + img_height**2)
348
  norm_distance = min_distance / img_diagonal
349
-
350
  proximity_pairs.append({
351
  "class1": cls1,
352
  "class2": cls2,
353
  "distance": float(norm_distance)
354
  })
355
-
356
  # Sort by distance and keep the closest pairs
357
  proximity_pairs.sort(key=lambda x: x["distance"])
358
  metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
359
-
360
  return metrics
 
4
 
5
  class EvaluationMetrics:
6
  """Class for computing detection metrics, generating statistics and visualization data"""
7
+
8
  @staticmethod
9
  def calculate_basic_stats(result: Any) -> Dict:
10
  """
11
  Calculate basic statistics for a single detection result
12
+
13
  Args:
14
  result: Detection result object
15
+
16
  Returns:
17
  Dictionary with basic statistics
18
  """
19
  if result is None:
20
  return {"error": "No detection result provided"}
21
+
22
  # Get classes and confidences
23
  classes = result.boxes.cls.cpu().numpy().astype(int)
24
  confidences = result.boxes.conf.cpu().numpy()
25
  names = result.names
26
+
27
  # Count by class
28
  class_counts = {}
29
  for cls, conf in zip(classes, confidences):
30
  cls_name = names[int(cls)]
31
  if cls_name not in class_counts:
32
  class_counts[cls_name] = {"count": 0, "total_confidence": 0, "confidences": []}
33
+
34
  class_counts[cls_name]["count"] += 1
35
  class_counts[cls_name]["total_confidence"] += float(conf)
36
  class_counts[cls_name]["confidences"].append(float(conf))
37
+
38
  # Calculate average confidence
39
  for cls_name, stats in class_counts.items():
40
  if stats["count"] > 0:
41
  stats["average_confidence"] = stats["total_confidence"] / stats["count"]
42
  stats["confidence_std"] = float(np.std(stats["confidences"])) if len(stats["confidences"]) > 1 else 0
43
  stats.pop("total_confidence") # Remove intermediate calculation
44
+
45
  # Prepare summary
46
  stats = {
47
  "total_objects": len(classes),
48
  "class_statistics": class_counts,
49
  "average_confidence": float(np.mean(confidences)) if len(confidences) > 0 else 0
50
  }
51
+
52
  return stats
53
+
54
  @staticmethod
55
  def generate_visualization_data(result: Any, class_colors: Dict = None) -> Dict:
56
  """
57
  Generate structured data suitable for visualization
58
+
59
  Args:
60
  result: Detection result object
61
  class_colors: Dictionary mapping class names to color codes (optional)
62
+
63
  Returns:
64
  Dictionary with visualization-ready data
65
  """
66
  if result is None:
67
  return {"error": "No detection result provided"}
68
+
69
  # Get basic stats first
70
  stats = EvaluationMetrics.calculate_basic_stats(result)
71
+
72
  # Create visualization-specific data structure
73
  viz_data = {
74
  "total_objects": stats["total_objects"],
75
  "average_confidence": stats["average_confidence"],
76
  "class_data": []
77
  }
78
+
79
  # Sort classes by count (descending)
80
  sorted_classes = sorted(
81
  stats["class_statistics"].items(),
82
  key=lambda x: x[1]["count"],
83
  reverse=True
84
  )
85
+
86
  # Create class-specific visualization data
87
  for cls_name, cls_stats in sorted_classes:
88
  class_id = -1
 
91
  if name == cls_name:
92
  class_id = idx
93
  break
94
+
95
  cls_data = {
96
  "name": cls_name,
97
  "class_id": class_id,
 
100
  "confidence_std": cls_stats.get("confidence_std", 0),
101
  "color": class_colors.get(cls_name, "#CCCCCC") if class_colors else "#CCCCCC"
102
  }
103
+
104
  viz_data["class_data"].append(cls_data)
105
+
106
  return viz_data
107
+
108
  @staticmethod
109
  def create_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
110
  """
111
  Create a horizontal bar chart showing detection statistics
112
+
113
  Args:
114
  viz_data: Visualization data generated by generate_visualization_data
115
  figsize: Figure size (width, height) in inches
116
  max_classes: Maximum number of classes to display
117
+
118
  Returns:
119
  Matplotlib figure object
120
  """
 
125
  def create_enhanced_stats_plot(viz_data: Dict, figsize: Tuple[int, int] = (10, 7), max_classes: int = 30) -> plt.Figure:
126
  """
127
  Create an enhanced horizontal bar chart with larger fonts and better styling
128
+
129
  Args:
130
  viz_data: Visualization data dictionary
131
  figsize: Figure size (width, height) in inches
132
  max_classes: Maximum number of classes to display
133
+
134
  Returns:
135
  Matplotlib figure with enhanced styling
136
  """
137
  if "error" in viz_data:
138
  # Create empty plot if error
139
  fig, ax = plt.subplots(figsize=figsize)
140
+ ax.text(0.5, 0.5, viz_data["error"],
141
  ha='center', va='center', fontsize=14, fontfamily='Arial')
142
  ax.set_xlim(0, 1)
143
  ax.set_ylim(0, 1)
144
  ax.axis('off')
145
  return fig
146
+
147
  if "class_data" not in viz_data or not viz_data["class_data"]:
148
  # Create empty plot if no data
149
  fig, ax = plt.subplots(figsize=figsize)
150
+ ax.text(0.5, 0.5, "No detection data available",
151
  ha='center', va='center', fontsize=14, fontfamily='Arial')
152
  ax.set_xlim(0, 1)
153
  ax.set_ylim(0, 1)
154
  ax.axis('off')
155
  return fig
156
+
157
  # Limit to max_classes
158
  class_data = viz_data["class_data"][:max_classes]
159
+
160
  # Extract data for plotting
161
  class_names = [item["name"] for item in class_data]
162
  counts = [item["count"] for item in class_data]
163
  colors = [item["color"] for item in class_data]
164
+
165
  # Create figure and horizontal bar chart with improved styling
166
  plt.rcParams['font.family'] = 'Arial'
167
  fig, ax = plt.subplots(figsize=figsize)
168
+
169
  # Set background color to white
170
  fig.patch.set_facecolor('white')
171
  ax.set_facecolor('white')
172
+
173
  y_pos = np.arange(len(class_names))
174
+
175
  # Create horizontal bars with class-specific colors
176
  bars = ax.barh(y_pos, counts, color=colors, alpha=0.8, height=0.6)
177
+
178
  # Add count values at end of each bar with larger font
179
  for i, bar in enumerate(bars):
180
  width = bar.get_width()
181
  conf = class_data[i]["average_confidence"]
182
+ ax.text(width + 0.3, bar.get_y() + bar.get_height()/2,
183
+ f"{width:.0f} (conf: {conf:.2f})",
184
  va='center', fontsize=12, fontfamily='Arial')
185
+
186
  # Customize axis and labels with larger fonts
187
  ax.set_yticks(y_pos)
188
  ax.set_yticklabels(class_names, fontsize=14, fontfamily='Arial')
189
  ax.invert_yaxis() # Labels read top-to-bottom
190
  ax.set_xlabel('Count', fontsize=14, fontfamily='Arial')
191
+ ax.set_title(f'Objects Detected: {viz_data["total_objects"]} Total',
192
  fontsize=16, fontfamily='Arial', fontweight='bold')
193
+
194
  # Add grid for better readability
195
  ax.set_axisbelow(True)
196
  ax.grid(axis='x', linestyle='--', alpha=0.7, color='#E5E7EB')
197
+
198
  # Increase tick label font size
199
  ax.tick_params(axis='both', which='major', labelsize=12)
200
+
201
  # Add detection summary as a text box with improved styling
202
  summary_text = (
203
  f"Total Objects: {viz_data['total_objects']}\n"
 
205
  f"Unique Classes: {len(viz_data['class_data'])}"
206
  )
207
  plt.figtext(0.02, 0.02, summary_text, fontsize=12, fontfamily='Arial',
208
+ bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.5',
209
  edgecolor='#E5E7EB'))
210
+
211
  plt.tight_layout()
212
  return fig
213
+
214
  @staticmethod
215
  def format_detection_summary(viz_data: Dict) -> str:
 
 
 
 
 
 
 
 
 
216
  if "error" in viz_data:
217
  return viz_data["error"]
218
+
219
  if "total_objects" not in viz_data:
220
  return "No detection data available."
221
+
 
222
  total_objects = viz_data["total_objects"]
223
  avg_confidence = viz_data["average_confidence"]
224
+
 
225
  lines = [
226
+ f"Detected {total_objects} objects.",
227
  f"Average confidence: {avg_confidence:.2f}",
228
+ "Objects by class:"
 
229
  ]
230
+
 
231
  if "class_data" in viz_data and viz_data["class_data"]:
232
  for item in viz_data["class_data"]:
233
  count = item['count']
 
234
  item_text = "item" if count == 1 else "items"
235
+ lines.append(f"• {item['name']}: {count} {item_text} (Confidence: {item['average_confidence']:.2f})")
 
 
 
 
236
  else:
237
+ lines.append("No class information available.")
238
+
239
  return "\n".join(lines)
240
+
241
  @staticmethod
242
  def calculate_distance_metrics(result: Any) -> Dict:
243
  """
244
  Calculate distance-related metrics for detected objects
245
+
246
  Args:
247
  result: Detection result object
248
+
249
  Returns:
250
  Dictionary with distance metrics
251
  """
252
  if result is None:
253
  return {"error": "No detection result provided"}
254
+
255
  boxes = result.boxes.xyxy.cpu().numpy()
256
  classes = result.boxes.cls.cpu().numpy().astype(int)
257
  names = result.names
258
+
259
  # Initialize metrics
260
  metrics = {
261
  "proximity": {}, # Classes that appear close to each other
262
  "spatial_distribution": {}, # Distribution across the image
263
  "size_distribution": {} # Size distribution of objects
264
  }
265
+
266
  # Calculate image dimensions (assuming normalized coordinates or extract from result)
267
  img_width, img_height = 1, 1
268
  if hasattr(result, "orig_shape"):
269
  img_height, img_width = result.orig_shape[:2]
270
+
271
  # Calculate bounding box areas and centers
272
  areas = []
273
  centers = []
274
  class_names = []
275
+
276
  for box, cls in zip(boxes, classes):
277
  x1, y1, x2, y2 = box
278
  width, height = x2 - x1, y2 - y1
279
  area = width * height
280
  center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
281
+
282
  areas.append(area)
283
  centers.append((center_x, center_y))
284
  class_names.append(names[int(cls)])
285
+
286
  # Calculate spatial distribution
287
  if centers:
288
  x_coords = [c[0] for c in centers]
289
  y_coords = [c[1] for c in centers]
290
+
291
  metrics["spatial_distribution"] = {
292
  "x_mean": float(np.mean(x_coords)) / img_width,
293
  "y_mean": float(np.mean(y_coords)) / img_height,
294
  "x_std": float(np.std(x_coords)) / img_width,
295
  "y_std": float(np.std(y_coords)) / img_height
296
  }
297
+
298
  # Calculate size distribution
299
  if areas:
300
  metrics["size_distribution"] = {
 
303
  "min_area": float(np.min(areas)) / (img_width * img_height),
304
  "max_area": float(np.max(areas)) / (img_width * img_height)
305
  }
306
+
307
  # Calculate proximity between different classes
308
  class_centers = {}
309
  for cls_name, center in zip(class_names, centers):
310
  if cls_name not in class_centers:
311
  class_centers[cls_name] = []
312
  class_centers[cls_name].append(center)
313
+
314
  # Find classes that appear close to each other
315
  proximity_pairs = []
316
  for i, cls1 in enumerate(class_centers.keys()):
317
  for j, cls2 in enumerate(class_centers.keys()):
318
  if i >= j: # Avoid duplicate pairs and self-comparison
319
  continue
320
+
321
  # Calculate minimum distance between any two objects of these classes
322
  min_distance = float('inf')
323
  for center1 in class_centers[cls1]:
324
  for center2 in class_centers[cls2]:
325
  dist = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
326
  min_distance = min(min_distance, dist)
327
+
328
  # Normalize by image diagonal
329
  img_diagonal = np.sqrt(img_width**2 + img_height**2)
330
  norm_distance = min_distance / img_diagonal
331
+
332
  proximity_pairs.append({
333
  "class1": cls1,
334
  "class2": cls2,
335
  "distance": float(norm_distance)
336
  })
337
+
338
  # Sort by distance and keep the closest pairs
339
  proximity_pairs.sort(key=lambda x: x["distance"])
340
  metrics["proximity"] = proximity_pairs[:5] # Keep top 5 closest pairs
341
+
342
  return metrics
visualization_helper.py CHANGED
@@ -1,34 +1,35 @@
1
  import cv2
2
  import numpy as np
3
  import matplotlib.pyplot as plt
 
4
  from typing import Any, List, Dict, Tuple, Optional
5
  import io
6
  from PIL import Image
7
 
8
  class VisualizationHelper:
9
  """Helper class for visualizing detection results"""
10
-
11
  @staticmethod
12
  def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
13
  figsize: Tuple[int, int] = (12, 12),
14
  return_pil: bool = False) -> Optional[Image.Image]:
15
  """
16
  Visualize detection results on a single image
17
-
18
  Args:
19
  image: Image path or numpy array
20
  result: Detection result object
21
  color_mapper: ColorMapper instance for consistent colors
22
  figsize: Figure size
23
  return_pil: If True, returns a PIL Image object
24
-
25
  Returns:
26
  PIL Image if return_pil is True, otherwise displays the plot
27
  """
28
  if result is None:
29
  print('No data for visualization')
30
  return None
31
-
32
  # Read image if path is provided
33
  if isinstance(image, str):
34
  img = cv2.imread(image)
@@ -40,19 +41,19 @@ class VisualizationHelper:
40
  if isinstance(img, np.ndarray):
41
  # Assuming BGR format from OpenCV
42
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
43
-
44
  # Create figure
45
  fig, ax = plt.subplots(figsize=figsize)
46
  ax.imshow(img)
47
-
48
  # Get bounding boxes, classes and confidences
49
  boxes = result.boxes.xyxy.cpu().numpy()
50
  classes = result.boxes.cls.cpu().numpy()
51
  confs = result.boxes.conf.cpu().numpy()
52
-
53
  # Get class names
54
  names = result.names
55
-
56
  # Create a default color mapper if none is provided
57
  if color_mapper is None:
58
  # For backward compatibility, fallback to a simple color function
@@ -67,29 +68,37 @@ class VisualizationHelper:
67
  # Convert hex to RGB float values for matplotlib
68
  hex_color = hex_color.lstrip('#')
69
  return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
70
-
71
  # Draw detection results
72
  for box, cls, conf in zip(boxes, classes, confs):
73
  x1, y1, x2, y2 = box
74
  cls_id = int(cls)
75
  cls_name = names[cls_id]
76
-
77
  # Get color for this class
78
  box_color = get_color(cls_id)
79
-
80
- # Add text label with colored background
81
- ax.text(x1, y1 - 5, f'{cls_name}: {conf:.2f}',
82
- color='white', fontsize=10,
83
- bbox=dict(facecolor=box_color[:3], alpha=0.7))
84
-
 
 
 
 
 
 
 
 
85
  # Add bounding box
86
- ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
87
  fill=False, edgecolor=box_color[:3], linewidth=2))
88
-
89
  ax.axis('off')
90
  # ax.set_title('Detection Result')
91
  plt.tight_layout()
92
-
93
  if return_pil:
94
  # Convert plot to PIL Image
95
  buf = io.BytesIO()
@@ -101,47 +110,47 @@ class VisualizationHelper:
101
  else:
102
  plt.show()
103
  return None
104
-
105
  @staticmethod
106
  def create_summary(result: Any) -> Dict:
107
  """
108
  Create a summary of detection results
109
-
110
  Args:
111
  result: Detection result object
112
-
113
  Returns:
114
  Dictionary with detection summary statistics
115
  """
116
  if result is None:
117
  return {"error": "No detection result provided"}
118
-
119
  # Get classes and confidences
120
  classes = result.boxes.cls.cpu().numpy().astype(int)
121
  confidences = result.boxes.conf.cpu().numpy()
122
  names = result.names
123
-
124
  # Count detections by class
125
  class_counts = {}
126
  for cls, conf in zip(classes, confidences):
127
  cls_name = names[int(cls)]
128
  if cls_name not in class_counts:
129
  class_counts[cls_name] = {"count": 0, "confidences": []}
130
-
131
  class_counts[cls_name]["count"] += 1
132
  class_counts[cls_name]["confidences"].append(float(conf))
133
-
134
  # Calculate average confidence for each class
135
  for cls_name, stats in class_counts.items():
136
  if stats["confidences"]:
137
  stats["average_confidence"] = float(np.mean(stats["confidences"]))
138
  stats.pop("confidences") # Remove detailed confidences list to keep summary concise
139
-
140
  # Prepare summary
141
  summary = {
142
  "total_objects": len(classes),
143
  "class_counts": class_counts,
144
  "unique_classes": len(class_counts)
145
  }
146
-
147
  return summary
 
1
  import cv2
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
+ import matplotlib.patheffects as path_effects
5
  from typing import Any, List, Dict, Tuple, Optional
6
  import io
7
  from PIL import Image
8
 
9
  class VisualizationHelper:
10
  """Helper class for visualizing detection results"""
11
+
12
  @staticmethod
13
  def visualize_detection(image: Any, result: Any, color_mapper: Optional[Any] = None,
14
  figsize: Tuple[int, int] = (12, 12),
15
  return_pil: bool = False) -> Optional[Image.Image]:
16
  """
17
  Visualize detection results on a single image
18
+
19
  Args:
20
  image: Image path or numpy array
21
  result: Detection result object
22
  color_mapper: ColorMapper instance for consistent colors
23
  figsize: Figure size
24
  return_pil: If True, returns a PIL Image object
25
+
26
  Returns:
27
  PIL Image if return_pil is True, otherwise displays the plot
28
  """
29
  if result is None:
30
  print('No data for visualization')
31
  return None
32
+
33
  # Read image if path is provided
34
  if isinstance(image, str):
35
  img = cv2.imread(image)
 
41
  if isinstance(img, np.ndarray):
42
  # Assuming BGR format from OpenCV
43
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
44
+
45
  # Create figure
46
  fig, ax = plt.subplots(figsize=figsize)
47
  ax.imshow(img)
48
+
49
  # Get bounding boxes, classes and confidences
50
  boxes = result.boxes.xyxy.cpu().numpy()
51
  classes = result.boxes.cls.cpu().numpy()
52
  confs = result.boxes.conf.cpu().numpy()
53
+
54
  # Get class names
55
  names = result.names
56
+
57
  # Create a default color mapper if none is provided
58
  if color_mapper is None:
59
  # For backward compatibility, fallback to a simple color function
 
68
  # Convert hex to RGB float values for matplotlib
69
  hex_color = hex_color.lstrip('#')
70
  return tuple(int(hex_color[i:i+2], 16) / 255 for i in (0, 2, 4)) + (1.0,)
71
+
72
  # Draw detection results
73
  for box, cls, conf in zip(boxes, classes, confs):
74
  x1, y1, x2, y2 = box
75
  cls_id = int(cls)
76
  cls_name = names[cls_id]
77
+
78
  # Get color for this class
79
  box_color = get_color(cls_id)
80
+
81
+ box_width = x2 - x1
82
+ box_height = y2 - y1
83
+ box_area = box_width * box_height
84
+
85
+ # 根據框大小調整字體大小,但有限制
86
+ adaptive_fontsize = max(10, min(14, int(10 + box_area / 10000)))
87
+
88
+
89
+ ax.text(x1, y1 - 8, f'{cls_name}: {conf:.2f}',
90
+ color='white', fontsize=adaptive_fontsize, fontweight="bold",
91
+ bbox=dict(facecolor=box_color[:3], alpha=0.85, pad=3, boxstyle="round,pad=0.3"),
92
+ path_effects=[path_effects.withStroke(linewidth=1.5, foreground="black")])
93
+
94
  # Add bounding box
95
+ ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1,
96
  fill=False, edgecolor=box_color[:3], linewidth=2))
97
+
98
  ax.axis('off')
99
  # ax.set_title('Detection Result')
100
  plt.tight_layout()
101
+
102
  if return_pil:
103
  # Convert plot to PIL Image
104
  buf = io.BytesIO()
 
110
  else:
111
  plt.show()
112
  return None
113
+
114
  @staticmethod
115
  def create_summary(result: Any) -> Dict:
116
  """
117
  Create a summary of detection results
118
+
119
  Args:
120
  result: Detection result object
121
+
122
  Returns:
123
  Dictionary with detection summary statistics
124
  """
125
  if result is None:
126
  return {"error": "No detection result provided"}
127
+
128
  # Get classes and confidences
129
  classes = result.boxes.cls.cpu().numpy().astype(int)
130
  confidences = result.boxes.conf.cpu().numpy()
131
  names = result.names
132
+
133
  # Count detections by class
134
  class_counts = {}
135
  for cls, conf in zip(classes, confidences):
136
  cls_name = names[int(cls)]
137
  if cls_name not in class_counts:
138
  class_counts[cls_name] = {"count": 0, "confidences": []}
139
+
140
  class_counts[cls_name]["count"] += 1
141
  class_counts[cls_name]["confidences"].append(float(conf))
142
+
143
  # Calculate average confidence for each class
144
  for cls_name, stats in class_counts.items():
145
  if stats["confidences"]:
146
  stats["average_confidence"] = float(np.mean(stats["confidences"]))
147
  stats.pop("confidences") # Remove detailed confidences list to keep summary concise
148
+
149
  # Prepare summary
150
  summary = {
151
  "total_objects": len(classes),
152
  "class_counts": class_counts,
153
  "unique_classes": len(class_counts)
154
  }
155
+
156
  return summary