ImageNetQuiz / app.py
sashavor
adding model
0490df4
raw
history blame
1.63 kB
import gradio as gr
from datasets import load_dataset
from PIL import Image
from collections import OrderedDict
from random import sample
import csv
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
title="ImageNet Roulette"
description="Try guessing the category of each image displayed, from the options provided below.\
After 10 guesses, we will show you your accuracy!\
"
classdict = OrderedDict()
for line in open('LOC_synset_mapping.txt', 'r').readlines():
try:
classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','')
except:
continue
imagedict={}
with open('image_labels.csv', 'r') as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
imagedict[row['image_name']] = row['image_label']
def model_classify(im):
inputs = feature_extractor(images=im, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return("Predicted class:", model.config.id2label[predicted_class_idx])
def check_answer(im):
return {'cat': 0.3, 'dog': 0.7}
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
im = Image.open('images/'+sample(imagedict.keys(),1)[0]+'.jpg')
image = gr.Image(im,shape=(224, 224))
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category")
demo.launch()