Questaaaa commited on
Commit
f9746a2
·
verified ·
1 Parent(s): 90fb249

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -9,12 +9,14 @@ model_name = "microsoft/beit-base-patch16-224"
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
- # 加载 ImageNet 1000 类别名称
13
- imagenet_labels = {
14
- idx: entry.strip() for idx, entry in enumerate(
15
- open("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt").readlines()
16
- )
17
- }
 
 
18
 
19
  # 定义分类函数
20
  def classify_image(image):
@@ -32,8 +34,11 @@ def classify_image(image):
32
  predicted_class_idx = logits.argmax(-1).item()
33
 
34
  # 获取类别名称
35
- class_name = imagenet_labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})")
36
-
 
 
 
37
  return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
38
 
39
  # 创建 Gradio 界面
 
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
+ # ImageNet 1000 类别名称(手动加载)
13
+ imagenet_labels = [
14
+ "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark",
15
+ "electric ray", "stingray", "cock", "hen", "ostrich",
16
+ "brambling", "goldfinch", "house finch", "junco", "indigo bunting",
17
+ # ...... (省略 900 多个类别)
18
+ "sports car", "convertible", "minivan", "pickup", "SUV"
19
+ ]
20
 
21
  # 定义分类函数
22
  def classify_image(image):
 
34
  predicted_class_idx = logits.argmax(-1).item()
35
 
36
  # 获取类别名称
37
+ if predicted_class_idx < len(imagenet_labels):
38
+ class_name = imagenet_labels[predicted_class_idx]
39
+ else:
40
+ class_name = f"Unknown Class (ID: {predicted_class_idx})"
41
+
42
  return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
43
 
44
  # 创建 Gradio 界面