DawnC commited on
Commit
7542b9c
·
verified ·
1 Parent(s): e4aac34

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -555
app.py DELETED
@@ -1,555 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import cv2
5
- import matplotlib.pyplot as plt
6
- import gradio as gr
7
- import io
8
- from PIL import Image, ImageDraw, ImageFont
9
- import spaces
10
- from typing import Dict, List, Any, Optional, Tuple
11
- from ultralytics import YOLO
12
-
13
- from detection_model import DetectionModel
14
- from color_mapper import ColorMapper
15
- from visualization_helper import VisualizationHelper
16
- from evaluation_metrics import EvaluationMetrics
17
-
18
-
19
- color_mapper = ColorMapper()
20
- model_instances = {}
21
-
22
- @spaces.GPU
23
- def process_image(image, model_instance, confidence_threshold, filter_classes=None):
24
- """
25
- Process an image for object detection
26
-
27
- Args:
28
- image: Input image (numpy array or PIL Image)
29
- model_instance: DetectionModel instance to use
30
- confidence_threshold: Confidence threshold for detection
31
- filter_classes: Optional list of classes to filter results
32
-
33
- Returns:
34
- Tuple of (result_image, result_text, stats_data)
35
- """
36
- # initialize key variables
37
- result = None
38
- stats = {}
39
- temp_path = None
40
-
41
- try:
42
- # update confidence threshold
43
- model_instance.confidence = confidence_threshold
44
-
45
- # processing input image
46
- if isinstance(image, np.ndarray):
47
- # Convert BGR to RGB if needed
48
- if image.shape[2] == 3:
49
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
- else:
51
- image_rgb = image
52
- pil_image = Image.fromarray(image_rgb)
53
- elif image is None:
54
- return None, "No image provided. Please upload an image.", {}
55
- else:
56
- pil_image = image
57
-
58
- # store temp files
59
- import uuid
60
- import tempfile
61
-
62
- temp_dir = tempfile.gettempdir() # use system temp directory
63
- temp_filename = f"temp_{uuid.uuid4().hex}.jpg"
64
- temp_path = os.path.join(temp_dir, temp_filename)
65
- pil_image.save(temp_path)
66
-
67
- # object detection
68
- result = model_instance.detect(temp_path)
69
-
70
- if result is None:
71
- return None, "Detection failed. Please try again with a different image.", {}
72
-
73
- # calculate stats
74
- stats = EvaluationMetrics.calculate_basic_stats(result)
75
-
76
- # add space calculation
77
- spatial_metrics = EvaluationMetrics.calculate_distance_metrics(result)
78
- stats["spatial_metrics"] = spatial_metrics
79
-
80
- if filter_classes and len(filter_classes) > 0:
81
- # get classes, boxes, confidence
82
- classes = result.boxes.cls.cpu().numpy().astype(int)
83
- confs = result.boxes.conf.cpu().numpy()
84
- boxes = result.boxes.xyxy.cpu().numpy()
85
-
86
- mask = np.zeros_like(classes, dtype=bool)
87
- for cls_id in filter_classes:
88
- mask = np.logical_or(mask, classes == cls_id)
89
-
90
- filtered_stats = {
91
- "total_objects": int(np.sum(mask)),
92
- "class_statistics": {},
93
- "average_confidence": float(np.mean(confs[mask])) if np.any(mask) else 0,
94
- "spatial_metrics": stats["spatial_metrics"]
95
- }
96
-
97
- # update stats
98
- names = result.names
99
- for cls, conf in zip(classes[mask], confs[mask]):
100
- cls_name = names[int(cls)]
101
- if cls_name not in filtered_stats["class_statistics"]:
102
- filtered_stats["class_statistics"][cls_name] = {
103
- "count": 0,
104
- "average_confidence": 0
105
- }
106
-
107
- filtered_stats["class_statistics"][cls_name]["count"] += 1
108
- filtered_stats["class_statistics"][cls_name]["average_confidence"] = conf
109
-
110
- stats = filtered_stats
111
-
112
- viz_data = EvaluationMetrics.generate_visualization_data(
113
- result,
114
- color_mapper.get_all_colors()
115
- )
116
-
117
- result_image = VisualizationHelper.visualize_detection(
118
- temp_path, result, color_mapper=color_mapper, figsize=(12, 12), return_pil=True
119
- )
120
-
121
- result_text = EvaluationMetrics.format_detection_summary(viz_data)
122
-
123
- return result_image, result_text, stats
124
-
125
- except Exception as e:
126
- error_message = f"Error Occurs: {str(e)}"
127
- import traceback
128
- traceback.print_exc()
129
- print(error_message)
130
- return None, error_message, {}
131
-
132
- finally:
133
- if temp_path and os.path.exists(temp_path):
134
- try:
135
- os.remove(temp_path)
136
- except Exception as e:
137
- print(f"Cannot delete temp files {temp_path}: {str(e)}")
138
-
139
- def format_result_text(stats):
140
- """Format detection statistics into readable text"""
141
- if not stats or "total_objects" not in stats:
142
- return "No objects detected."
143
-
144
- lines = [
145
- f"Detected {stats['total_objects']} objects.",
146
- f"Average confidence: {stats.get('average_confidence', 0):.2f}",
147
- "",
148
- "Objects by class:",
149
- ]
150
-
151
- if "class_statistics" in stats and stats["class_statistics"]:
152
- # Sort classes by count
153
- sorted_classes = sorted(
154
- stats["class_statistics"].items(),
155
- key=lambda x: x[1]["count"],
156
- reverse=True
157
- )
158
-
159
- for cls_name, cls_stats in sorted_classes:
160
- lines.append(f"• {cls_name}: {cls_stats['count']} (avg conf: {cls_stats.get('average_confidence', 0):.2f})")
161
- else:
162
- lines.append("No class information available.")
163
-
164
- return "\n".join(lines)
165
-
166
- def get_all_classes():
167
- """Get all available COCO classes"""
168
- try:
169
- class_names = model.class_names
170
- return [(idx, name) for idx, name in class_names.items()]
171
- except:
172
- # Fallback to standard COCO classes
173
- return [
174
- (0, 'person'), (1, 'bicycle'), (2, 'car'), (3, 'motorcycle'), (4, 'airplane'),
175
- (5, 'bus'), (6, 'train'), (7, 'truck'), (8, 'boat'), (9, 'traffic light'),
176
- (10, 'fire hydrant'), (11, 'stop sign'), (12, 'parking meter'), (13, 'bench'),
177
- (14, 'bird'), (15, 'cat'), (16, 'dog'), (17, 'horse'), (18, 'sheep'), (19, 'cow'),
178
- (20, 'elephant'), (21, 'bear'), (22, 'zebra'), (23, 'giraffe'), (24, 'backpack'),
179
- (25, 'umbrella'), (26, 'handbag'), (27, 'tie'), (28, 'suitcase'), (29, 'frisbee'),
180
- (30, 'skis'), (31, 'snowboard'), (32, 'sports ball'), (33, 'kite'), (34, 'baseball bat'),
181
- (35, 'baseball glove'), (36, 'skateboard'), (37, 'surfboard'), (38, 'tennis racket'),
182
- (39, 'bottle'), (40, 'wine glass'), (41, 'cup'), (42, 'fork'), (43, 'knife'),
183
- (44, 'spoon'), (45, 'bowl'), (46, 'banana'), (47, 'apple'), (48, 'sandwich'),
184
- (49, 'orange'), (50, 'broccoli'), (51, 'carrot'), (52, 'hot dog'), (53, 'pizza'),
185
- (54, 'donut'), (55, 'cake'), (56, 'chair'), (57, 'couch'), (58, 'potted plant'),
186
- (59, 'bed'), (60, 'dining table'), (61, 'toilet'), (62, 'tv'), (63, 'laptop'),
187
- (64, 'mouse'), (65, 'remote'), (66, 'keyboard'), (67, 'cell phone'), (68, 'microwave'),
188
- (69, 'oven'), (70, 'toaster'), (71, 'sink'), (72, 'refrigerator'), (73, 'book'),
189
- (74, 'clock'), (75, 'vase'), (76, 'scissors'), (77, 'teddy bear'), (78, 'hair drier'),
190
- (79, 'toothbrush')
191
- ]
192
-
193
- def create_interface():
194
- """Create the Gradio interface"""
195
- # Get CSS styles
196
- css = """
197
- body {
198
- font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
199
- background: linear-gradient(120deg, #e0f7fa, #b2ebf2);
200
- margin: 0;
201
- padding: 0;
202
- }
203
-
204
- .gradio-container {
205
- max-width: 1200px !important;
206
- }
207
-
208
- .app-header {
209
- text-align: center;
210
- margin-bottom: 2rem;
211
- background: rgba(255, 255, 255, 0.8);
212
- padding: 1.5rem;
213
- border-radius: 10px;
214
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
215
- }
216
-
217
- .app-title {
218
- color: #2D3748;
219
- font-size: 2.5rem;
220
- margin-bottom: 0.5rem;
221
- background: linear-gradient(90deg, #4299e1, #48bb78);
222
- -webkit-background-clip: text;
223
- -webkit-text-fill-color: transparent;
224
- }
225
-
226
- .app-subtitle {
227
- color: #4A5568;
228
- font-size: 1.2rem;
229
- font-weight: normal;
230
- margin-top: 0.25rem;
231
- }
232
-
233
- .app-divider {
234
- width: 50px;
235
- height: 3px;
236
- background: linear-gradient(90deg, #4299e1, #48bb78);
237
- margin: 1rem auto;
238
- }
239
-
240
- .input-panel, .output-panel {
241
- background: white;
242
- border-radius: 10px;
243
- padding: 1rem;
244
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
245
- }
246
-
247
- .detect-btn {
248
- background: linear-gradient(90deg, #4299e1, #48bb78) !important;
249
- color: white !important;
250
- border: none !important;
251
- transition: transform 0.3s, box-shadow 0.3s !important;
252
- }
253
-
254
- .detect-btn:hover {
255
- transform: translateY(-2px) !important;
256
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2) !important;
257
- }
258
-
259
- .detect-btn:active {
260
- transform: translateY(1px) !important;
261
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2) !important;
262
- }
263
-
264
- .footer {
265
- text-align: center;
266
- margin-top: 2rem;
267
- font-size: 0.9rem;
268
- color: #4A5568;
269
- }
270
-
271
- /* Responsive adjustments */
272
- @media (max-width: 768px) {
273
- .app-title {
274
- font-size: 2rem;
275
- }
276
-
277
- .app-subtitle {
278
- font-size: 1rem;
279
- }
280
- }
281
- """
282
-
283
- # get the models info
284
- available_models = DetectionModel.get_available_models()
285
- model_choices = [model["model_file"] for model in available_models]
286
- model_labels = [f"{model['name']} - {model['inference_speed']}" for model in available_models]
287
-
288
- # Available classes for filtering
289
- available_classes = get_all_classes()
290
- class_choices = [f"{id}: {name}" for id, name in available_classes]
291
-
292
- # Create Gradio Blocks interface
293
- with gr.Blocks(css=css) as demo:
294
- # Header
295
- with gr.Group(elem_classes="app-header"):
296
- gr.HTML("""
297
- <h1 class="app-title">VisionScout</h1>
298
- <h2 class="app-subtitle">Detect and identify objects in your images</h2>
299
- <div class="app-divider"></div>
300
- """)
301
-
302
- current_model = gr.State("yolov8m.pt") # use medium size as default
303
-
304
- # Input and Output panels
305
- with gr.Row():
306
- # Left column - Input controls
307
- with gr.Column(scale=4, elem_classes="input-panel"):
308
- with gr.Group():
309
- gr.Markdown("### Upload Image")
310
- image_input = gr.Image(type="pil", label="Upload an image")
311
-
312
- with gr.Accordion("Advanced Settings", open=False):
313
- with gr.Row():
314
- model_dropdown = gr.Dropdown(
315
- choices=model_choices,
316
- value="yolov8m.pt",
317
- label="Select Model",
318
- info="Choose different models based on your needs for speed vs. accuracy"
319
- )
320
-
321
- # display model info
322
- model_info = gr.Markdown(DetectionModel.get_model_description("yolov8m.pt"))
323
-
324
- confidence = gr.Slider(
325
- minimum=0.1,
326
- maximum=0.9,
327
- value=0.25,
328
- step=0.05,
329
- label="Confidence Threshold",
330
- info="Higher values show fewer but more confident detections"
331
- )
332
-
333
- with gr.Accordion("Filter Classes", open=False):
334
- # Common object categories
335
- with gr.Row():
336
- people_btn = gr.Button("People")
337
- vehicles_btn = gr.Button("Vehicles")
338
- animals_btn = gr.Button("Animals")
339
- objects_btn = gr.Button("Common Objects")
340
-
341
- # Class selection
342
- class_filter = gr.Dropdown(
343
- choices=class_choices,
344
- multiselect=True,
345
- label="Select Classes to Display",
346
- info="Leave empty to show all detected objects"
347
- )
348
-
349
- detect_btn = gr.Button("Detect Objects", variant="primary", elem_classes="detect-btn")
350
-
351
- with gr.Group():
352
- gr.Markdown("### How to Use")
353
- gr.Markdown("""
354
- 1. Upload an image or use the camera
355
- 2. Adjust confidence threshold if needed
356
- 3. Optionally filter to specific object classes
357
- 4. Click "Detect Objects" button
358
-
359
- The model will identify objects in your image and display them with bounding boxes.
360
-
361
- **Note:** Detection quality depends on image clarity and object visibility. The model can detect up to 80 different types of common objects.
362
- """)
363
-
364
- # Right column - Results display
365
- with gr.Column(scale=6, elem_classes="output-panel"):
366
- with gr.Tab("Detection Result"):
367
- result_image = gr.Image(type="pil", label="Detection Result")
368
- result_text = gr.Textbox(label="Detection Details", lines=10)
369
-
370
- with gr.Tab("Statistics"):
371
- with gr.Row():
372
- with gr.Column(scale=1):
373
- stats_json = gr.Json(label="Full Statistics")
374
-
375
- with gr.Column(scale=1):
376
- gr.Markdown("### Object Distribution")
377
- plot_output = gr.Plot(label="Object Distribution")
378
-
379
- # model option
380
- model_dropdown.change(
381
- fn=lambda model: (model, DetectionModel.get_model_description(model)),
382
- inputs=[model_dropdown],
383
- outputs=[current_model, model_info]
384
- )
385
-
386
- # change the buttom of different model
387
- detect_btn.click(
388
- fn=lambda img, model, conf, classes: process_and_plot(img, model, conf, classes),
389
- inputs=[image_input, current_model, confidence, class_filter],
390
- outputs=[result_image, result_text, stats_json, plot_output]
391
- )
392
-
393
- # Quick filter buttons
394
- people_classes = [0] # Person
395
- vehicles_classes = [1, 2, 3, 4, 5, 6, 7, 8] # Various vehicles
396
- animals_classes = list(range(14, 24)) # Animals in COCO
397
- common_objects = [41, 42, 43, 44, 45, 67, 73, 74, 76] # Common household items
398
-
399
- people_btn.click(
400
- lambda: [f"{id}: {name}" for id, name in available_classes if id in people_classes],
401
- outputs=class_filter
402
- )
403
-
404
- vehicles_btn.click(
405
- lambda: [f"{id}: {name}" for id, name in available_classes if id in vehicles_classes],
406
- outputs=class_filter
407
- )
408
-
409
- animals_btn.click(
410
- lambda: [f"{id}: {name}" for id, name in available_classes if id in animals_classes],
411
- outputs=class_filter
412
- )
413
-
414
- objects_btn.click(
415
- lambda: [f"{id}: {name}" for id, name in available_classes if id in common_objects],
416
- outputs=class_filter
417
- )
418
-
419
- # Set up example images
420
- example_images = [
421
- "room_01.jpg",
422
- "street_01.jpg",
423
- "street_02.jpg",
424
- "street_03.jpg"
425
- ]
426
-
427
-
428
- gr.Examples(
429
- examples=example_images,
430
- inputs=image_input,
431
- outputs=None,
432
- fn=None,
433
- cache_examples=False,
434
- )
435
-
436
- # Footer
437
- gr.HTML("""
438
- <div class="footer">
439
- <p>Powered by YOLOv8 and Ultralytics • Created with Gradio</p>
440
- <p>Model can detect 80 different classes of objects</p>
441
- </div>
442
- """)
443
-
444
- return demo
445
-
446
- @spaces.GPU
447
- def process_and_plot(image, model_name, confidence_threshold, filter_classes=None):
448
- """
449
- Process image and create plots for statistics
450
-
451
- Args:
452
- image: Input image
453
- model_name: Name of the model to use
454
- confidence_threshold: Confidence threshold for detection
455
- filter_classes: Optional list of classes to filter results
456
-
457
- Returns:
458
- Tuple of (result_image, result_text, stats_json, plot_figure)
459
- """
460
- global model_instances
461
-
462
- if model_name not in model_instances:
463
- print(f"Creating new model instance for {model_name}")
464
- model_instances[model_name] = DetectionModel(model_name=model_name, confidence=confidence_threshold, iou=0.45)
465
- else:
466
- print(f"Using existing model instance for {model_name}")
467
- model_instances[model_name].confidence = confidence_threshold
468
-
469
- class_ids = None
470
- if filter_classes:
471
- class_ids = []
472
- for class_str in filter_classes:
473
- try:
474
- # Extract ID from format "id: name"
475
- class_id = int(class_str.split(":")[0].strip())
476
- class_ids.append(class_id)
477
- except:
478
- continue
479
-
480
- # execute detection
481
- result_image, result_text, stats = process_image(
482
- image,
483
- model_instances[model_name],
484
- confidence_threshold,
485
- class_ids
486
- )
487
-
488
- # create stats table
489
- plot_figure = create_stats_plot(stats)
490
-
491
- return result_image, result_text, stats, plot_figure
492
-
493
- def create_stats_plot(stats):
494
- """
495
- Create a visualization of statistics data
496
-
497
- Args:
498
- stats: Dictionary containing detection statistics
499
-
500
- Returns:
501
- Matplotlib figure with visualization
502
- """
503
- if not stats or "class_statistics" not in stats or not stats["class_statistics"]:
504
- # Create empty plot if no data
505
- fig, ax = plt.subplots(figsize=(8, 6))
506
- ax.text(0.5, 0.5, "No detection data available",
507
- ha='center', va='center', fontsize=12)
508
- ax.set_xlim(0, 1)
509
- ax.set_ylim(0, 1)
510
- ax.axis('off')
511
- return fig
512
-
513
- # preparing visualization data
514
- viz_data = {
515
- "total_objects": stats.get("total_objects", 0),
516
- "average_confidence": stats.get("average_confidence", 0),
517
- "class_data": []
518
- }
519
-
520
- # get current model classes
521
- # This uses the get_all_classes function which should retrieve from the current model
522
- available_classes = dict(get_all_classes())
523
-
524
- # process class data
525
- for cls_name, cls_stats in stats.get("class_statistics", {}).items():
526
- # search for class ID
527
- class_id = -1
528
-
529
- # Try to find the class ID from class names
530
- for id, name in available_classes.items():
531
- if name == cls_name:
532
- class_id = id
533
- break
534
-
535
- cls_data = {
536
- "name": cls_name,
537
- "class_id": class_id,
538
- "count": cls_stats.get("count", 0),
539
- "average_confidence": cls_stats.get("average_confidence", 0),
540
- "color": color_mapper.get_color(class_id if class_id >= 0 else cls_name)
541
- }
542
-
543
- viz_data["class_data"].append(cls_data)
544
-
545
- # Sort by count in descending order
546
- viz_data["class_data"].sort(key=lambda x: x["count"], reverse=True)
547
-
548
- return EvaluationMetrics.create_stats_plot(viz_data)
549
-
550
-
551
- if __name__ == "__main__":
552
- import time
553
-
554
- demo = create_interface()
555
- demo.launch()