File size: 1,189 Bytes
0e4149c
3e61f53
 
 
d1e8a6c
0e4149c
3e61f53
 
 
 
0e4149c
a98b34c
 
d1e8a6c
0e4149c
 
d1e8a6c
 
 
 
 
 
 
 
3e61f53
d1e8a6c
3e61f53
d1e8a6c
 
 
a98b34c
 
d1e8a6c
0e4149c
 
93ca441
0e4149c
a98b34c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import numpy as np

# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# 获取模型内置的类别标签
labels = model.config.id2label

# 定义分类函数
def classify_image(image):
    # 转换 PIL Image 为 numpy 数组
    if isinstance(image, Image.Image):
        image = np.array(image)

    # 进行特征提取
    inputs = feature_extractor(images=image, return_tensors="pt")

    # 预测类别
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()

    # 获取类别名称
    class_name = labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})")

    return f"Predicted class: {class_name} (ID: {predicted_class_idx})"

# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()