navjotk commited on
Commit
b63bd55
·
verified ·
1 Parent(s): 23e2414

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,23 +1,21 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw, ImageFont
 
4
 
5
  # Load YOLOS object detection model
6
  detector = pipeline("object-detection", model="hustvl/yolos-small")
7
 
8
- # Confidence threshold
9
- CONFIDENCE_THRESHOLD = 0.5
10
-
11
- # Color palette
12
  COLORS = ["red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"]
13
 
 
14
  def get_color_for_label(label):
15
  return COLORS[hash(label) % len(COLORS)]
16
 
17
- def detect_and_draw(image):
 
18
  results = detector(image)
19
-
20
- # Convert to RGB for drawing
21
  image = image.convert("RGB")
22
  draw = ImageDraw.Draw(image)
23
 
@@ -30,21 +28,19 @@ def detect_and_draw(image):
30
 
31
  for obj in results:
32
  score = obj["score"]
33
- if score < CONFIDENCE_THRESHOLD:
34
  continue
35
 
36
  label = f"{obj['label']} ({score:.2f})"
37
  box = obj["box"]
38
  color = get_color_for_label(obj["label"])
39
 
40
- # Draw box
41
  draw.rectangle(
42
  [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])],
43
  outline=color,
44
  width=3,
45
  )
46
 
47
- # Draw label
48
  draw.text(
49
  (box["xmin"] + 5, box["ymin"] + 5),
50
  label,
@@ -52,19 +48,28 @@ def detect_and_draw(image):
52
  font=font
53
  )
54
 
55
- # AnnotatedImage expects (box_tuple, label)
56
  box_coords = (box["xmin"], box["ymin"], box["xmax"], box["ymax"])
57
  annotations.append((box_coords, label))
58
 
59
- return image, annotations
 
 
 
 
60
 
61
- # Gradio interface
62
  demo = gr.Interface(
63
  fn=detect_and_draw,
64
- inputs=gr.Image(type="pil"),
65
- outputs=gr.AnnotatedImage(),
66
- title="YOLOS Object Detection",
67
- description=f"Upload an image to detect objects using the YOLOS model. Only objects with confidence > {CONFIDENCE_THRESHOLD} are shown.",
 
 
 
 
 
 
68
  )
69
 
70
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw, ImageFont
4
+ import tempfile
5
 
6
  # Load YOLOS object detection model
7
  detector = pipeline("object-detection", model="hustvl/yolos-small")
8
 
9
+ # Colors for labels
 
 
 
10
  COLORS = ["red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"]
11
 
12
+ # Pick color based on label
13
  def get_color_for_label(label):
14
  return COLORS[hash(label) % len(COLORS)]
15
 
16
+ # Main detection + drawing function
17
+ def detect_and_draw(image, threshold):
18
  results = detector(image)
 
 
19
  image = image.convert("RGB")
20
  draw = ImageDraw.Draw(image)
21
 
 
28
 
29
  for obj in results:
30
  score = obj["score"]
31
+ if score < threshold:
32
  continue
33
 
34
  label = f"{obj['label']} ({score:.2f})"
35
  box = obj["box"]
36
  color = get_color_for_label(obj["label"])
37
 
 
38
  draw.rectangle(
39
  [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])],
40
  outline=color,
41
  width=3,
42
  )
43
 
 
44
  draw.text(
45
  (box["xmin"] + 5, box["ymin"] + 5),
46
  label,
 
48
  font=font
49
  )
50
 
 
51
  box_coords = (box["xmin"], box["ymin"], box["xmax"], box["ymax"])
52
  annotations.append((box_coords, label))
53
 
54
+ # Save image for download
55
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
56
+ image.save(temp_file.name)
57
+
58
+ return image, annotations, temp_file.name
59
 
60
+ # Gradio UI
61
  demo = gr.Interface(
62
  fn=detect_and_draw,
63
+ inputs=[
64
+ gr.Image(type="pil", label="Upload Image"),
65
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold"),
66
+ ],
67
+ outputs=[
68
+ gr.AnnotatedImage(label="Detected Image"),
69
+ gr.File(label="Download Processed Image"),
70
+ ],
71
+ title="YOLOS Object Detection (CPU Friendly)",
72
+ description="Upload an image to detect objects using the YOLOS-small model. Adjust the slider to set the confidence threshold.",
73
  )
74
 
75
  demo.launch()