Spaces:
Build error
Build error
import gradio as gr | |
from transformers import pipeline | |
import torch | |
# Check if CUDA is available | |
device = 0 if torch.cuda.is_available() else -1 | |
model_names = [ | |
"apple/mobilevit-small", | |
"facebook/deit-base-patch16-224", | |
"facebook/convnext-base-224", | |
"google/vit-base-patch16-224", | |
"google/mobilenet_v2_1.4_224", | |
"microsoft/resnet-50", | |
"microsoft/swin-base-patch4-window7-224", | |
"microsoft/beit-base-patch16-224", | |
"nvidia/mit-b0", | |
"shi-labs/nat-base-in1k-224", | |
"shi-labs/dinat-base-in1k-224", | |
] | |
# Cache for pipelines to avoid reloading models | |
pipelines = {} | |
def process(image_file, top_k): | |
labels = [] | |
for m in model_names: | |
if m not in pipelines: | |
pipelines[m] = pipeline( | |
"image-classification", model=m, device=device | |
) | |
p = pipelines[m] | |
pred = p(image_file) | |
labels.append({x["label"]: x["score"] for x in pred[:top_k]}) | |
return labels | |
# Inputs | |
image = gr.Image(type="filepath", label="Upload an image") | |
top_k = gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Top k classes") | |
# Output | |
labels = [gr.Label(label=m) for m in model_names] | |
description = ( | |
"This Space compares popular image classifiers available on the " | |
"Hugging Face hub, including NAT and DINAT models. All models have " | |
"been fine-tuned on ImageNet-1k. The sample images were generated " | |
"with Stable Diffusion." | |
) | |
iface = gr.Interface( | |
theme="huggingface", | |
description=description, | |
layout="horizontal", | |
fn=process, | |
inputs=[image, top_k], | |
outputs=labels, | |
examples=[ | |
["bike.jpg", 5], | |
["car.jpg", 5], | |
["food.jpg", 5], | |
], | |
allow_flagging="never", | |
) | |
iface.launch() | |