Spaces:
Build error
Build error
File size: 4,214 Bytes
f37cfad 0490df4 27938aa 0490df4 27938aa 0490df4 f37cfad 27938aa f37cfad 27938aa f37cfad 27938aa f37cfad a0a2e9f 367e0cd 31e8d7f b444b38 f37cfad 27938aa dc0eec0 820d8a7 31e8d7f b444b38 31e8d7f b444b38 31e8d7f 2e4f4c8 a0a2e9f 2e4f4c8 a0a2e9f f37cfad 088c61e 367e0cd f37cfad 27938aa b444b38 a0a2e9f 31e8d7f 27938aa b444b38 367e0cd f37cfad b444b38 9e1d8e2 a0a2e9f 27938aa b444b38 27938aa 367e0cd 27938aa f37cfad 27938aa b444b38 31e8d7f 088c61e 27938aa 31e8d7f f37cfad dc0eec0 f37cfad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
from datasets import load_dataset
from PIL import Image
from collections import OrderedDict
from random import sample
import csv
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import random
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
classdict = OrderedDict()
for line in open('LOC_synset_mapping.txt', 'r').readlines():
try:
classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
except:
continue
classes = list(classdict.values())
imagedict={}
with open('image_labels.csv', 'r') as csv_file:
reader = csv.DictReader(csv_file)
for row in reader:
imagedict[row['image_name']] = row['image_label']
images= list(imagedict.keys())
labels = list(set(imagedict.values()))
def model_classify(radio, im):
if radio is not None:
inputs = feature_extractor(images=im, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
modelclass=model.config.id2label[predicted_class_idx]
return modelclass.split(',')[0], predicted_class_idx, True
else:
return None, None, False
def random_image():
imname = random.choice(images)
im = Image.open('images/'+ imname +'.jpg')
label = str(imagedict[imname])
labels.remove(label)
options = sample(labels,3)
options.append(label)
random.shuffle(options)
options = [classes[int(i)] for i in options]
return im, label, gr.Radio.update(value=None, choices=options), None
def check_score(pred, truth, current_score, total_score, has_guessed):
if not(has_guessed):
if pred == classes[int(truth)]:
total_score +=1
return current_score + 1, f"Your score is {current_score+1} out of {total_score}!", total_score
else:
if pred is not None:
total_score +=1
return current_score, f"Your score is {current_score} out of {total_score}!", total_score
else:
return current_score, f"Your score is {current_score} out of {total_score}!", total_score
def compare_score(userclass, truth):
if userclass is None:
return"Try guessing a category!"
else:
if userclass == classes[int(truth)]:
return "Great! You guessed it right"
else:
return "The right answer was " +str(classes[int(truth)])+ "! Try guessing the next image."
with gr.Blocks() as demo:
user_score = gr.State(0)
model_score = gr.State(0)
image_label = gr.State()
model_class = gr.State()
total_score = gr.State(0)
has_guessed = gr.State(False)
gr.Markdown("# ImageNet Quiz")
gr.Markdown("### ImageNet is one of the most popular datasets used for training and evaluating AI models.")
gr.Markdown("### But many of its categories are hard to guess, even for humans.")
gr.Markdown("#### Try your hand at guessing the category of each image displayed, from the options provided. Compare your answers to that of a neural network trained on the dataset, and see if you can do better!")
with gr.Row():
with gr.Column(min_width= 900):
image = gr.Image(shape=(600, 600))
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
with gr.Column():
prediction = gr.Label(label="The AI model predicts:")
score = gr.Label(label="Your Score")
message = gr.Label(label="Did you guess it right?")
btn = gr.Button("Next image")
demo.load(random_image, None, [image, image_label, radio, prediction])
radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
radio.change(compare_score, [radio, image_label], message)
btn.click(random_image, None, [image, image_label, radio, prediction])
btn.click(lambda :False, None, has_guessed)
demo.launch()
|