Spaces:
Build error
Build error
sashavor
commited on
Commit
·
a0a2e9f
1
Parent(s):
9e1d8e2
trying total score
Browse files
app.py
CHANGED
@@ -31,18 +31,19 @@ with open('image_labels.csv', 'r') as csv_file:
|
|
31 |
images= list(imagedict.keys())
|
32 |
labels = list(set(imagedict.values()))
|
33 |
|
34 |
-
def model_classify(im):
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
|
42 |
def random_image():
|
43 |
imname = random.choice(images)
|
44 |
im = Image.open('images/'+ imname +'.jpg')
|
45 |
label = str(imagedict[imname])
|
|
|
46 |
labels.remove(label)
|
47 |
options = sample(labels,3)
|
48 |
options.append(label)
|
@@ -50,14 +51,16 @@ def random_image():
|
|
50 |
options = [classes[int(i)] for i in options]
|
51 |
return im, label, gr.Radio.update(value=None, choices=options), None
|
52 |
|
53 |
-
def check_score(pred, truth, current_score):
|
54 |
if pred == classes[int(truth)]:
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
|
58 |
def compare_score(userclass, prediction):
|
59 |
-
print(userclass)
|
60 |
-
print(prediction)
|
61 |
if userclass == str(prediction).split(',')[0]:
|
62 |
return "Great! You and the model agree on the category"
|
63 |
return "You and the model disagree"
|
@@ -67,10 +70,11 @@ with gr.Blocks() as demo:
|
|
67 |
model_score = gr.State(0)
|
68 |
image_label = gr.State()
|
69 |
prediction = gr.State()
|
|
|
70 |
|
71 |
with gr.Row():
|
72 |
with gr.Column(min_width= 900):
|
73 |
-
image = gr.Image(shape=(
|
74 |
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
|
75 |
with gr.Column():
|
76 |
prediction = gr.Label(label="Model Prediction")
|
@@ -80,8 +84,8 @@ with gr.Blocks() as demo:
|
|
80 |
btn = gr.Button("Next image")
|
81 |
|
82 |
demo.load(random_image, None, [image, image_label, radio, prediction])
|
83 |
-
radio.change(model_classify, image, prediction)
|
84 |
-
radio.change(check_score, [radio, image_label, user_score], [user_score, score])
|
85 |
#radio.change(compare_score, [radio, prediction], message)
|
86 |
btn.click(random_image, None, [image, image_label, radio, prediction])
|
87 |
|
|
|
31 |
images= list(imagedict.keys())
|
32 |
labels = list(set(imagedict.values()))
|
33 |
|
34 |
+
def model_classify(radio, im):
|
35 |
+
if radio is not None:
|
36 |
+
inputs = feature_extractor(images=im, return_tensors="pt")
|
37 |
+
outputs = model(**inputs)
|
38 |
+
logits = outputs.logits
|
39 |
+
predicted_class_idx = logits.argmax(-1).item()
|
40 |
+
return model.config.id2label[predicted_class_idx]
|
41 |
|
42 |
def random_image():
|
43 |
imname = random.choice(images)
|
44 |
im = Image.open('images/'+ imname +'.jpg')
|
45 |
label = str(imagedict[imname])
|
46 |
+
print(label)
|
47 |
labels.remove(label)
|
48 |
options = sample(labels,3)
|
49 |
options.append(label)
|
|
|
51 |
options = [classes[int(i)] for i in options]
|
52 |
return im, label, gr.Radio.update(value=None, choices=options), None
|
53 |
|
54 |
+
def check_score(pred, truth, current_score, total_score):
|
55 |
if pred == classes[int(truth)]:
|
56 |
+
total_score +=1
|
57 |
+
return current_score + 1, f"Your score is {current_score+1} out of {total_score}"
|
58 |
+
else:
|
59 |
+
total_score +=1
|
60 |
+
return current_score, f"Your score is {current_score} out of {total_score}"
|
61 |
+
|
62 |
|
63 |
def compare_score(userclass, prediction):
|
|
|
|
|
64 |
if userclass == str(prediction).split(',')[0]:
|
65 |
return "Great! You and the model agree on the category"
|
66 |
return "You and the model disagree"
|
|
|
70 |
model_score = gr.State(0)
|
71 |
image_label = gr.State()
|
72 |
prediction = gr.State()
|
73 |
+
total_score = gr.State(0)
|
74 |
|
75 |
with gr.Row():
|
76 |
with gr.Column(min_width= 900):
|
77 |
+
image = gr.Image(shape=(600, 600))
|
78 |
radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
|
79 |
with gr.Column():
|
80 |
prediction = gr.Label(label="Model Prediction")
|
|
|
84 |
btn = gr.Button("Next image")
|
85 |
|
86 |
demo.load(random_image, None, [image, image_label, radio, prediction])
|
87 |
+
radio.change(model_classify, [radio, image], prediction)
|
88 |
+
radio.change(check_score, [radio, image_label, user_score, total_score], [user_score, score])
|
89 |
#radio.change(compare_score, [radio, prediction], message)
|
90 |
btn.click(random_image, None, [image, image_label, radio, prediction])
|
91 |
|