import gradio as gr from datasets import load_dataset from PIL import Image from collections import OrderedDict from random import sample import csv from transformers import AutoFeatureExtractor, AutoModelForImageClassification extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") title="ImageNet Roulette" description="Try guessing the category of each image displayed, from the options provided below.\ After 10 guesses, we will show you your accuracy!\ " classdict = OrderedDict() for line in open('LOC_synset_mapping.txt', 'r').readlines(): try: classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','') except: continue imagedict={} with open('image_labels.csv', 'r') as csv_file: reader = csv.DictReader(csv_file) for row in reader: imagedict[row['image_name']] = row['image_label'] def model_classify(im): inputs = feature_extractor(images=im, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return("Predicted class:", model.config.id2label[predicted_class_idx]) def check_answer(im): return {'cat': 0.3, 'dog': 0.7} with gr.Blocks() as demo: with gr.Row(): with gr.Column(): im = Image.open('images/'+sample(imagedict.keys(),1)[0]+'.jpg') image = gr.Image(im,shape=(224, 224)) radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category") demo.launch()