turtlegraphics commited on
Commit
4280a69
·
verified ·
1 Parent(s): dfd8178

switching to an image classifier

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -2,17 +2,19 @@
2
  # gradio demo
3
  #
4
  import gradio as gr
5
- from transformers import pipeline
6
 
7
- model = pipeline("text-generation", model="openai-community/gpt2")
 
8
 
9
- title = "openai-community/gpt2 prompt completion"
10
- description = "Basic demo from the NLP course"
11
 
12
- def predict(prompt):
13
- completion = model(prompt)[0]["generated_text"]
14
- return completion
15
-
16
- demo = gr.Interface(fn=predict, inputs="text", outputs="text")
 
17
 
18
  demo.launch()
 
2
  # gradio demo
3
  #
4
  import gradio as gr
5
+ from transformers import ViTFeatureExtractor, ViTModel
6
 
7
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
8
+ model = ViTModel.from_pretrained('akahana/vit-base-cats-vs-dogs')
9
 
10
+ title = "Sandbox"
11
+ description = "Place to try various models"
12
 
13
+ def classify(image):
14
+ inputs = feature_extractor(images=image, return_tensors="pt")
15
+ outputs = model(**inputs)
16
+ return "dog"
17
+
18
+ demo = gr.Interface(fn=classify, inputs="image", outputs="text")
19
 
20
  demo.launch()