lichih commited on
Commit
4cbcc7c
·
1 Parent(s): 74cb9a5

fix model/tensor device difference

Browse files
Files changed (1) hide show
  1. app.py +43 -53
app.py CHANGED
@@ -10,65 +10,55 @@ import torch
10
  import torch.nn.functional as F
11
 
12
 
13
- class Litton7Classifier:
14
- LABELS = [
15
- "Panoramic",
16
- "Feature",
17
- "Detail",
18
- "Enclosed",
19
- "Focal",
20
- "Ephemeral",
21
- "Canopied",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ]
23
-
24
- def __init__(
25
- self, model_path="Litton-7type-visual-landscape-model.pth", device=None
26
- ):
27
- if device is None:
28
- self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
29
- else:
30
- self.device = device
31
-
32
- self.model = torch.load(
33
- model_path, map_location=self.device, weights_only=False
34
- )
35
- if hasattr(self.model, "module"):
36
- self.model = self.model.module
37
- self.model.eval()
38
- self.preprocess = transforms.Compose(
39
- [
40
- transforms.Resize(256),
41
- transforms.CenterCrop(224),
42
- transforms.ToTensor(),
43
- transforms.Normalize(
44
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
45
- ),
46
- ]
47
- )
48
-
49
- @spaces.GPU(duration=60)
50
- def predict(self, image: Image.Image) -> Figure:
51
- image = image.convert("RGB")
52
- input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
53
-
54
- with torch.no_grad():
55
- logits = self.model(input_tensor)
56
- probs = F.softmax(logits[:, :7], dim=1).cpu()
57
-
58
- return draw_bar_chart(
59
- {
60
- "class": self.LABELS,
61
- "probs": probs[0] * 100,
62
- }
63
- )
64
 
65
 
66
  def draw_bar_chart(data: dict[str, list[str | float]]):
67
  classes = data["class"]
68
  probabilities = data["probs"]
69
 
70
- #fig = plt.figure()
71
- fig, ax = plt.subplots(figsize=(8, 6))
72
  ax.bar(classes, probabilities, color="skyblue")
73
 
74
  ax.set_xlabel("Class")
@@ -149,7 +139,7 @@ def get_layout():
149
  '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
150
  )
151
  start_button.click(
152
- fn=Litton7Classifier().predict,
153
  inputs=image_input,
154
  outputs=chart,
155
  )
 
10
  import torch.nn.functional as F
11
 
12
 
13
+ LABELS = [
14
+ "Panoramic",
15
+ "Feature",
16
+ "Detail",
17
+ "Enclosed",
18
+ "Focal",
19
+ "Ephemeral",
20
+ "Canopied",
21
+ ]
22
+
23
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
24
+
25
+ model = torch.load(
26
+ "Litton-7type-visual-landscape-model.pth", map_location=device, weights_only=False
27
+ ).module
28
+ model.eval()
29
+ preprocess = transforms.Compose(
30
+ [
31
+ transforms.Resize(256),
32
+ transforms.CenterCrop(224),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(
35
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
36
+ ),
37
  ]
38
+ )
39
+
40
+ @spaces.GPU
41
+ def predict(image: Image.Image) -> Figure:
42
+ image = image.convert("RGB")
43
+ input_tensor = preprocess(image).unsqueeze(0).to(device)
44
+
45
+ with torch.no_grad():
46
+ logits = model(input_tensor)
47
+ probs = F.softmax(logits[:, :7], dim=1).cpu()
48
+
49
+ return draw_bar_chart(
50
+ {
51
+ "class": LABELS,
52
+ "probs": probs[0] * 100,
53
+ }
54
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  def draw_bar_chart(data: dict[str, list[str | float]]):
58
  classes = data["class"]
59
  probabilities = data["probs"]
60
 
61
+ fig, ax = plt.subplots(figsize=(8, 6))
 
62
  ax.bar(classes, probabilities, color="skyblue")
63
 
64
  ax.set_xlabel("Class")
 
139
  '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
140
  )
141
  start_button.click(
142
+ fn=predict,
143
  inputs=image_input,
144
  outputs=chart,
145
  )