File size: 1,189 Bytes
0e4149c 3e61f53 d1e8a6c 0e4149c 3e61f53 0e4149c a98b34c d1e8a6c 0e4149c d1e8a6c 3e61f53 d1e8a6c 3e61f53 d1e8a6c a98b34c d1e8a6c 0e4149c 93ca441 0e4149c a98b34c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import numpy as np
# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# 获取模型内置的类别标签
labels = model.config.id2label
# 定义分类函数
def classify_image(image):
# 转换 PIL Image 为 numpy 数组
if isinstance(image, Image.Image):
image = np.array(image)
# 进行特征提取
inputs = feature_extractor(images=image, return_tensors="pt")
# 预测类别
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# 获取类别名称
class_name = labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})")
return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()
|