sashavor commited on
Commit
0490df4
·
1 Parent(s): 877066a

adding model

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -4,6 +4,10 @@ from PIL import Image
4
  from collections import OrderedDict
5
  from random import sample
6
  import csv
 
 
 
 
7
 
8
 
9
  title="ImageNet Roulette"
@@ -24,7 +28,12 @@ with open('image_labels.csv', 'r') as csv_file:
24
  for row in reader:
25
  imagedict[row['image_name']] = row['image_label']
26
 
27
-
 
 
 
 
 
28
 
29
  def check_answer(im):
30
 
 
4
  from collections import OrderedDict
5
  from random import sample
6
  import csv
7
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
8
+
9
+ extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
10
+ model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
11
 
12
 
13
  title="ImageNet Roulette"
 
28
  for row in reader:
29
  imagedict[row['image_name']] = row['image_label']
30
 
31
+ def model_classify(im):
32
+ inputs = feature_extractor(images=im, return_tensors="pt")
33
+ outputs = model(**inputs)
34
+ logits = outputs.logits
35
+ predicted_class_idx = logits.argmax(-1).item()
36
+ return("Predicted class:", model.config.id2label[predicted_class_idx])
37
 
38
  def check_answer(im):
39