import gradio as gr from transformers import AutoModelForImageClassification, AutoFeatureExtractor import torch from PIL import Image import numpy as np # 加载模型和特征提取器 model_name = "microsoft/beit-base-patch16-224" model = AutoModelForImageClassification.from_pretrained(model_name) feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) # 获取模型内置的类别标签 labels = model.config.id2label # 定义分类函数 def classify_image(image): # 转换 PIL Image 为 numpy 数组 if isinstance(image, Image.Image): image = np.array(image) # 进行特征提取 inputs = feature_extractor(images=image, return_tensors="pt") # 预测类别 with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() # 获取类别名称 class_name = labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})") return f"Predicted class: {class_name} (ID: {predicted_class_idx})" # 创建 Gradio 界面 demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo") demo.launch()