File size: 2,191 Bytes
90a129b
 
2054fb5
b63bd55
90a129b
b0c291d
90a129b
 
b0c291d
23e2414
2054fb5
b0c291d
2054fb5
 
 
b0c291d
b63bd55
2054fb5
 
 
23e2414
2054fb5
 
 
 
 
23e2414
 
90a129b
2054fb5
b63bd55
2054fb5
 
 
 
 
 
 
 
 
 
 
23e2414
b0c291d
2054fb5
23e2414
 
2054fb5
b63bd55
 
 
 
b0c291d
 
90a129b
b0c291d
90a129b
2054fb5
b63bd55
 
 
 
 
 
 
 
4e8c594
b0c291d
90a129b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from transformers import pipeline
from PIL import Image, ImageDraw, ImageFont
import tempfile

# Load the YOLOS object detection model
detector = pipeline("object-detection", model="hustvl/yolos-small")

# Define some colors to differentiate classes
COLORS = ["red", "blue", "green", "orange", "purple", "yellow", "cyan", "magenta"]

# Helper function to assign color per label
def get_color_for_label(label):
    return COLORS[hash(label) % len(COLORS)]

# Main function: detect, draw, and return outputs
def detect_and_draw(image, threshold):
    results = detector(image)
    image = image.convert("RGB")
    draw = ImageDraw.Draw(image)

    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except:
        font = ImageFont.load_default()

    annotations = []

    for obj in results:
        score = obj["score"]
        if score < threshold:
            continue

        label = f"{obj['label']} ({score:.2f})"
        box = obj["box"]
        color = get_color_for_label(obj["label"])

        draw.rectangle(
            [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])],
            outline=color,
            width=3,
        )

        draw.text((box["xmin"] + 5, box["ymin"] + 5), label, fill=color, font=font)

        box_coords = (box["xmin"], box["ymin"], box["xmax"], box["ymax"])
        annotations.append((box_coords, label))

    # Save image for download
    temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    image.save(temp_file.name)

    # ✅ Return the (image, annotations) tuple and the path to the saved image
    return (image, annotations), temp_file.name

# Gradio UI setup
demo = gr.Interface(
    fn=detect_and_draw,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold"),
    ],
    outputs=[
        gr.AnnotatedImage(label="Detected Image"),
        gr.File(label="Download Processed Image"),
    ],
    title="YOLOS Object Detection",
    description="Upload an image to detect objects using the YOLOS-small model. Adjust the confidence threshold using the slider.",
)

demo.launch()