File size: 4,214 Bytes
f37cfad
 
 
 
 
 
0490df4
27938aa
0490df4
27938aa
0490df4
f37cfad
 
 
 
 
27938aa
f37cfad
 
27938aa
f37cfad
 
 
 
 
27938aa
 
f37cfad
a0a2e9f
 
 
 
 
 
367e0cd
 
31e8d7f
b444b38
f37cfad
27938aa
 
 
 
 
 
 
 
 
dc0eec0
820d8a7
31e8d7f
 
b444b38
31e8d7f
b444b38
31e8d7f
 
 
2e4f4c8
a0a2e9f
2e4f4c8
 
a0a2e9f
f37cfad
088c61e
367e0cd
 
 
 
 
 
 
f37cfad
 
27938aa
 
 
b444b38
a0a2e9f
31e8d7f
27938aa
b444b38
367e0cd
 
 
f37cfad
b444b38
9e1d8e2
a0a2e9f
27938aa
 
b444b38
27938aa
367e0cd
27938aa
 
f37cfad
27938aa
b444b38
31e8d7f
088c61e
27938aa
31e8d7f
f37cfad
dc0eec0
f37cfad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from datasets import load_dataset
from PIL import Image
from collections import OrderedDict
from random import sample
import csv
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import random

feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")


classdict = OrderedDict()
for line in open('LOC_synset_mapping.txt', 'r').readlines():
    try:
        classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
    except:
        continue
classes = list(classdict.values())
imagedict={}
with open('image_labels.csv', 'r') as csv_file:
    reader = csv.DictReader(csv_file)
    for row in reader:
        imagedict[row['image_name']] = row['image_label']
images= list(imagedict.keys())
labels = list(set(imagedict.values()))

def model_classify(radio, im):
    if radio is not None:
        inputs = feature_extractor(images=im, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
        modelclass=model.config.id2label[predicted_class_idx]
        return  modelclass.split(',')[0], predicted_class_idx, True
    else:
        return None, None, False

def random_image():
    imname = random.choice(images)
    im = Image.open('images/'+ imname +'.jpg')
    label = str(imagedict[imname])
    labels.remove(label)
    options = sample(labels,3)
    options.append(label)
    random.shuffle(options)
    options = [classes[int(i)] for i in options]
    return im, label, gr.Radio.update(value=None, choices=options), None

def check_score(pred, truth, current_score, total_score, has_guessed):
    if not(has_guessed):
        if pred == classes[int(truth)]:
            total_score +=1
            return current_score + 1, f"Your score is {current_score+1} out of {total_score}!", total_score
        else:
            if pred is not None:
                total_score +=1
            return current_score, f"Your score is {current_score} out of {total_score}!", total_score
    else:
        return current_score, f"Your score is {current_score} out of {total_score}!", total_score



def compare_score(userclass, truth):
    if userclass is None:
        return"Try guessing a category!"
    else:
        if userclass == classes[int(truth)]:
            return "Great! You guessed it right"
        else:
            return "The right answer was " +str(classes[int(truth)])+ "! Try guessing the next image."

with gr.Blocks() as demo:
    user_score = gr.State(0)
    model_score = gr.State(0)
    image_label = gr.State()
    model_class = gr.State()
    total_score = gr.State(0)
    has_guessed = gr.State(False)

    gr.Markdown("# ImageNet Quiz")
    gr.Markdown("### ImageNet is one of the most popular datasets used for training and evaluating AI models.")
    gr.Markdown("### But many of its categories are hard to guess, even for humans.")
    gr.Markdown("#### Try your hand at guessing the category of each image displayed, from the options provided. Compare your answers to that of a neural network trained on the dataset, and see if you can do better!")
    with gr.Row():

        with gr.Column(min_width= 900):
            image = gr.Image(shape=(600, 600))
            radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
        with gr.Column():
            prediction = gr.Label(label="The AI model predicts:")
            score = gr.Label(label="Your Score")
            message = gr.Label(label="Did you guess it right?")

    btn = gr.Button("Next image")

    demo.load(random_image, None, [image, image_label, radio, prediction])
    radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
    radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
    radio.change(compare_score, [radio, image_label], message)
    btn.click(random_image, None, [image, image_label, radio, prediction])
    btn.click(lambda :False, None, has_guessed)


demo.launch()