fix model/tensor device difference
Browse files
app.py
CHANGED
@@ -10,65 +10,55 @@ import torch
|
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
]
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
transforms.Resize(256),
|
41 |
-
transforms.CenterCrop(224),
|
42 |
-
transforms.ToTensor(),
|
43 |
-
transforms.Normalize(
|
44 |
-
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
45 |
-
),
|
46 |
-
]
|
47 |
-
)
|
48 |
-
|
49 |
-
@spaces.GPU(duration=60)
|
50 |
-
def predict(self, image: Image.Image) -> Figure:
|
51 |
-
image = image.convert("RGB")
|
52 |
-
input_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
53 |
-
|
54 |
-
with torch.no_grad():
|
55 |
-
logits = self.model(input_tensor)
|
56 |
-
probs = F.softmax(logits[:, :7], dim=1).cpu()
|
57 |
-
|
58 |
-
return draw_bar_chart(
|
59 |
-
{
|
60 |
-
"class": self.LABELS,
|
61 |
-
"probs": probs[0] * 100,
|
62 |
-
}
|
63 |
-
)
|
64 |
|
65 |
|
66 |
def draw_bar_chart(data: dict[str, list[str | float]]):
|
67 |
classes = data["class"]
|
68 |
probabilities = data["probs"]
|
69 |
|
70 |
-
|
71 |
-
fig, ax = plt.subplots(figsize=(8, 6))
|
72 |
ax.bar(classes, probabilities, color="skyblue")
|
73 |
|
74 |
ax.set_xlabel("Class")
|
@@ -149,7 +139,7 @@ def get_layout():
|
|
149 |
'<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
|
150 |
)
|
151 |
start_button.click(
|
152 |
-
fn=
|
153 |
inputs=image_input,
|
154 |
outputs=chart,
|
155 |
)
|
|
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
|
13 |
+
LABELS = [
|
14 |
+
"Panoramic",
|
15 |
+
"Feature",
|
16 |
+
"Detail",
|
17 |
+
"Enclosed",
|
18 |
+
"Focal",
|
19 |
+
"Ephemeral",
|
20 |
+
"Canopied",
|
21 |
+
]
|
22 |
+
|
23 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
24 |
+
|
25 |
+
model = torch.load(
|
26 |
+
"Litton-7type-visual-landscape-model.pth", map_location=device, weights_only=False
|
27 |
+
).module
|
28 |
+
model.eval()
|
29 |
+
preprocess = transforms.Compose(
|
30 |
+
[
|
31 |
+
transforms.Resize(256),
|
32 |
+
transforms.CenterCrop(224),
|
33 |
+
transforms.ToTensor(),
|
34 |
+
transforms.Normalize(
|
35 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
36 |
+
),
|
37 |
]
|
38 |
+
)
|
39 |
+
|
40 |
+
@spaces.GPU
|
41 |
+
def predict(image: Image.Image) -> Figure:
|
42 |
+
image = image.convert("RGB")
|
43 |
+
input_tensor = preprocess(image).unsqueeze(0).to(device)
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
logits = model(input_tensor)
|
47 |
+
probs = F.softmax(logits[:, :7], dim=1).cpu()
|
48 |
+
|
49 |
+
return draw_bar_chart(
|
50 |
+
{
|
51 |
+
"class": LABELS,
|
52 |
+
"probs": probs[0] * 100,
|
53 |
+
}
|
54 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
def draw_bar_chart(data: dict[str, list[str | float]]):
|
58 |
classes = data["class"]
|
59 |
probabilities = data["probs"]
|
60 |
|
61 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
|
|
62 |
ax.bar(classes, probabilities, color="skyblue")
|
63 |
|
64 |
ax.set_xlabel("Class")
|
|
|
139 |
'<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>',
|
140 |
)
|
141 |
start_button.click(
|
142 |
+
fn=predict,
|
143 |
inputs=image_input,
|
144 |
outputs=chart,
|
145 |
)
|