Spaces:
Build error
Build error
import gradio as gr | |
from datasets import load_dataset | |
from PIL import Image | |
from collections import OrderedDict | |
from random import sample | |
import csv | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") | |
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") | |
title="ImageNet Roulette" | |
description="Try guessing the category of each image displayed, from the options provided below.\ | |
After 10 guesses, we will show you your accuracy!\ | |
" | |
classdict = OrderedDict() | |
for line in open('LOC_synset_mapping.txt', 'r').readlines(): | |
try: | |
classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','') | |
except: | |
continue | |
imagedict={} | |
with open('image_labels.csv', 'r') as csv_file: | |
reader = csv.DictReader(csv_file) | |
for row in reader: | |
imagedict[row['image_name']] = row['image_label'] | |
def model_classify(im): | |
inputs = feature_extractor(images=im, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
return("Predicted class:", model.config.id2label[predicted_class_idx]) | |
def check_answer(im): | |
return {'cat': 0.3, 'dog': 0.7} | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
im = Image.open('images/'+sample(imagedict.keys(),1)[0]+'.jpg') | |
image = gr.Image(im,shape=(224, 224)) | |
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category") | |
demo.launch() | |