Questaaaa commited on
Commit
3e61f53
·
verified ·
1 Parent(s): 93ca441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -1,13 +1,21 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
 
4
- # 选择 Hugging Face 预训练模型
5
- classifier = pipeline("feature-extraction", model="facebook/deit-base-distilled-patch16-224")
 
 
6
 
7
  # 定义分类函数
8
  def classify_image(image):
9
- predictions = classifier(image)
10
- return {"feature_vector": predictions[0]} # 返回特征向量
 
 
 
 
11
 
12
  # 创建 Gradio 界面
13
  demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
 
1
  import gradio as gr
2
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
+ import torch
4
+ from PIL import Image
5
 
6
+ # 加载模型和特征提取器
7
+ model_name = "microsoft/beit-base-patch16-224"
8
+ model = AutoModelForImageClassification.from_pretrained(model_name)
9
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
10
 
11
  # 定义分类函数
12
  def classify_image(image):
13
+ image = feature_extractor(images=image, return_tensors="pt")
14
+ with torch.no_grad():
15
+ outputs = model(**image)
16
+ logits = outputs.logits
17
+ predicted_class = logits.argmax(-1).item()
18
+ return f"Predicted class: {predicted_class}"
19
 
20
  # 创建 Gradio 界面
21
  demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")