brightlembo commited on
Commit
b54a6c4
·
verified ·
1 Parent(s): 3f7bc0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -4
app.py CHANGED
@@ -3,16 +3,42 @@ from torchvision import transforms
3
  from PIL import Image
4
  import streamlit as st
5
  import json
 
 
6
 
7
 
8
  # Charger les noms des classes
9
  with open("class_names.json", "r") as f:
10
  class_names = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Charger le modèle
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- model = torch.load("efficientnet_b7_best.pth", map_location=device)
15
- model.eval() # Mode évaluation
16
 
17
  # Définir la taille de l'image
18
  image_size = (224, 224)
 
3
  from PIL import Image
4
  import streamlit as st
5
  import json
6
+ from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
7
+
8
 
9
 
10
  # Charger les noms des classes
11
  with open("class_names.json", "r") as f:
12
  class_names = json.load(f)
13
+
14
+ # Charger le modèle avec des poids pré-entraînés
15
+ weights = EfficientNet_B7_Weights.DEFAULT
16
+ base_model = efficientnet_b7(weights=weights)
17
+
18
+ # Adapter le modèle pour la classification
19
+ class CustomEfficientNet(nn.Module):
20
+ def __init__(self, base_model, num_classes):
21
+ super(CustomEfficientNet, self).__init__()
22
+ self.base = nn.Sequential(*list(base_model.children())[:-2])
23
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
24
+ self.fc1 = nn.Linear(2560, 512)
25
+ self.relu = nn.ReLU()
26
+ self.fc2 = nn.Linear(512, num_classes)
27
+
28
+ def forward(self, x):
29
+ x = self.base(x)
30
+ x = self.global_avg_pool(x)
31
+ x = x.view(x.size(0), -1)
32
+ x = self.relu(self.fc1(x))
33
+ x = self.fc2(x)
34
+ return x
35
+
36
+ # Définir le modèle final
37
+ num_classes = 2
38
+ model = CustomEfficientNet(base_model, num_classes).to("cuda" if torch.cuda.is_available() else "cpu")
39
+ model.load_state_dict(torch.load("efficientnet_b7_best.pth",weights_only=False))
40
+ model.eval() # Passer le modèle en mode évaluation
41
 
 
 
 
 
42
 
43
  # Définir la taille de l'image
44
  image_size = (224, 224)