Spaces:
Build error
Build error
File size: 3,101 Bytes
f37cfad 0490df4 27938aa 0490df4 27938aa 0490df4 f37cfad 27938aa f37cfad 27938aa f37cfad 27938aa f37cfad 0490df4 27938aa f37cfad 27938aa f37cfad 27938aa f37cfad 27938aa f37cfad 27938aa f37cfad 27938aa f37cfad 27938aa f37cfad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from datasets import load_dataset
from PIL import Image
from collections import OrderedDict
from random import sample
import csv
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import random
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
title="ImageNet Roulette"
description="Try guessing the category of each image displayed, from the options provided below.\
After 10 guesses, we will show you your accuracy!\
"
classdict = OrderedDict()
for line in open('LOC_synset_mapping.txt', 'r').readlines():
try:
classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
except:
continue
classes = list(classdict.values())
imagedict={}
with open('image_labels.csv', 'r') as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
imagedict[row['image_name']] = row['image_label']
images= list(imagedict.keys())
labels = list(set(imagedict.values()))
def model_classify(im):
inputs = feature_extractor(images=im, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]
def random_image():
imname = random.choice(images)
im = Image.open('images/'+ imname +'.jpg')
label = str(imagedict[imname])
labels.remove(label)
options = sample(labels,3)
options.append(label)
random.shuffle(options)
options = [classes[int(i)] for i in options]
return im, label, gr.Radio.update(choices=options), None
def check_score(pred, truth, current_score):
if pred == classes[int(truth)]:
return current_score + 1, f"Your score is {current_score}"
return current_score, f"Your score is {current_score}"
def compare_score(userclass, prediction):
print(userclass)
print(prediction)
if userclass == str(prediction).split(',')[0]:
return "Great! You and the model agree on the category"
return "You and the model disagree"
with gr.Blocks() as demo:
user_score = gr.State(0)
model_score = gr.State(0)
image_label = gr.State()
prediction = gr.State()
with gr.Row():
with gr.Column():
image = gr.Image(shape=(448, 448))
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
with gr.Column():
prediction = gr.Label(label="Model Prediction")
score = gr.Label(label="Your Score")
message = gr.Text()
btn = gr.Button("Next image")
demo.load(random_image, None, [image, image_label, radio, prediction])
radio.change(model_classify, image, prediction)
radio.change(check_score, [radio, image_label, user_score], [user_score, score])
radio.change(compare_score, [radio, prediction], message)
btn.click(random_image, None, [image, image_label, radio, prediction])
demo.launch()
|