File size: 1,746 Bytes
5a39643
 
e0b3a35
 
6ccbd18
e0b3a35
 
5a39643
 
f9ccbb5
5a39643
 
 
f9ccbb5
5a39643
 
 
 
 
82abe24
5a39643
 
e0b3a35
 
5a39643
6ccbd18
7539610
 
 
e0b3a35
6ccbd18
 
 
e0b3a35
7539610
 
 
5a39643
82abe24
5a39643
 
bc9d9eb
5a39643
 
7539610
5a39643
e0b3a35
6ccbd18
 
 
 
e0b3a35
5a39643
 
 
 
7539610
5a39643
7539610
 
5a39643
7539610
 
 
5a39643
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import pipeline
import torch


# Check if CUDA is available
device = 0 if torch.cuda.is_available() else -1

model_names = [
    "apple/mobilevit-small",
    "facebook/deit-base-patch16-224",
    "facebook/convnext-base-224",
    "google/vit-base-patch16-224",
    "google/mobilenet_v2_1.4_224",
    "microsoft/resnet-50",
    "microsoft/swin-base-patch4-window7-224",
    "microsoft/beit-base-patch16-224",
    "nvidia/mit-b0",
    "shi-labs/nat-base-in1k-224",
    "shi-labs/dinat-base-in1k-224",
]

# Cache for pipelines to avoid reloading models
pipelines = {}


def process(image_file, top_k):
    labels = []
    for m in model_names:
        if m not in pipelines:
            pipelines[m] = pipeline(
                "image-classification", model=m, device=device
            )
        p = pipelines[m]
        pred = p(image_file)
        labels.append({x["label"]: x["score"] for x in pred[:top_k]})
    return labels


# Inputs
image = gr.Image(type="filepath", label="Upload an image")
top_k = gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Top k classes")

# Output
labels = [gr.Label(label=m) for m in model_names]

description = (
    "This Space compares popular image classifiers available on the "
    "Hugging Face hub, including NAT and DINAT models. All models have "
    "been fine-tuned on ImageNet-1k. The sample images were generated "
    "with Stable Diffusion."
)

iface = gr.Interface(
    theme="huggingface",
    description=description,
    layout="horizontal",
    fn=process,
    inputs=[image, top_k],
    outputs=labels,
    examples=[
        ["bike.jpg", 5],
        ["car.jpg", 5],
        ["food.jpg", 5],
    ],
    allow_flagging="never",
)

iface.launch()