Questaaaa commited on
Commit
90fb249
·
verified ·
1 Parent(s): a7b8099

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -9,14 +9,12 @@ model_name = "microsoft/beit-base-patch16-224"
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
- # ImageNet 1000 类别名称(从 Hugging Face 官方下载)
13
- imagenet_classes = [
14
- "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", # 0-4
15
- "electric ray", "stingray", "cock", "hen", "ostrich", # 5-9
16
- "brambling", "goldfinch", "house finch", "junco", "indigo bunting", # 10-14
17
- # 省略中间 900 多个类别...
18
- "sports car", "convertible", "minivan", "pickup", "SUV" # 817-821(汽车类)
19
- ]
20
 
21
  # 定义分类函数
22
  def classify_image(image):
@@ -34,8 +32,8 @@ def classify_image(image):
34
  predicted_class_idx = logits.argmax(-1).item()
35
 
36
  # 获取类别名称
37
- class_name = imagenet_classes[predicted_class_idx] if predicted_class_idx < len(imagenet_classes) else f"Unknown Class (ID: {predicted_class_idx})"
38
-
39
  return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
40
 
41
  # 创建 Gradio 界面
 
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
+ # 加载 ImageNet 1000 类别名称
13
+ imagenet_labels = {
14
+ idx: entry.strip() for idx, entry in enumerate(
15
+ open("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt").readlines()
16
+ )
17
+ }
 
 
18
 
19
  # 定义分类函数
20
  def classify_image(image):
 
32
  predicted_class_idx = logits.argmax(-1).item()
33
 
34
  # 获取类别名称
35
+ class_name = imagenet_labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})")
36
+
37
  return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
38
 
39
  # 创建 Gradio 界面