File size: 3,101 Bytes
f37cfad
 
 
 
 
 
0490df4
27938aa
0490df4
27938aa
0490df4
f37cfad
 
 
 
 
 
 
 
 
 
27938aa
f37cfad
 
27938aa
f37cfad
 
 
 
 
27938aa
 
f37cfad
0490df4
 
 
 
 
27938aa
 
f37cfad
27938aa
 
 
 
 
 
 
 
 
 
f37cfad
27938aa
 
 
 
f37cfad
27938aa
 
 
 
 
 
f37cfad
 
27938aa
 
 
 
 
f37cfad
 
27938aa
 
 
 
 
 
 
 
f37cfad
27938aa
 
 
 
 
f37cfad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from datasets import load_dataset
from PIL import Image
from collections import OrderedDict
from random import sample
import csv
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import random

feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")


title="ImageNet Roulette"
description="Try guessing the category of each image displayed, from the options provided below.\
            After 10 guesses, we will show you your accuracy!\
            "

classdict = OrderedDict()
for line in open('LOC_synset_mapping.txt', 'r').readlines():
    try:
        classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
    except:
        continue
classes = list(classdict.values())
imagedict={}
with open('image_labels.csv', 'r') as csv_file:
    reader = csv.DictReader(csv_file)
    for row in reader:
        imagedict[row['image_name']] = row['image_label']
images= list(imagedict.keys())
labels = list(set(imagedict.values()))

def model_classify(im):
    inputs = feature_extractor(images=im, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return model.config.id2label[predicted_class_idx]


def random_image():
    imname = random.choice(images)
    im = Image.open('images/'+ imname +'.jpg')
    label = str(imagedict[imname])
    labels.remove(label)
    options = sample(labels,3)
    options.append(label)
    random.shuffle(options)
    options = [classes[int(i)] for i in options]
    return im, label, gr.Radio.update(choices=options), None

def check_score(pred, truth, current_score):
    if pred == classes[int(truth)]:
        return current_score + 1, f"Your score is {current_score}"
    return current_score, f"Your score is {current_score}"

def compare_score(userclass, prediction):
    print(userclass)
    print(prediction)
    if userclass == str(prediction).split(',')[0]:
        return "Great! You and the model agree on the category"
    return "You and the model disagree"

with gr.Blocks() as demo:
    user_score = gr.State(0)
    model_score = gr.State(0)
    image_label = gr.State()
    prediction = gr.State()

    with gr.Row():
        with gr.Column():
            image = gr.Image(shape=(448, 448))
            radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
        with gr.Column():
            prediction = gr.Label(label="Model Prediction")
            score = gr.Label(label="Your Score")
            message = gr.Text()

    btn = gr.Button("Next image")

    demo.load(random_image, None, [image, image_label, radio, prediction])
    radio.change(model_classify, image, prediction)
    radio.change(check_score, [radio, image_label, user_score], [user_score, score])
    radio.change(compare_score, [radio, prediction], message)
    btn.click(random_image, None, [image, image_label, radio, prediction])

demo.launch()