brightlembo commited on
Commit
81165fa
·
verified ·
1 Parent(s): 9107f06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py CHANGED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import requests
6
+ import json
7
+ import os
8
+
9
+ # URL du modèle hébergé sur Hugging Face
10
+ MODEL_URL = "https://huggingface.co/brightlembo/nao_sad_happy/blob/main/efficientnet_b7_best.pth"
11
+ MODEL_PATH = "efficientnet_b7_best.pth"
12
+ CLASS_NAMES_PATH = "class_names.json"
13
+
14
+ # Télécharger le modèle s'il n'existe pas localement
15
+ if not os.path.exists(MODEL_PATH):
16
+ st.info("Téléchargement du modèle depuis Hugging Face...")
17
+ response = requests.get(MODEL_URL, stream=True)
18
+ response.raise_for_status()
19
+ with open(MODEL_PATH, "wb") as f:
20
+ f.write(response.content)
21
+ st.success("Modèle téléchargé avec succès.")
22
+
23
+ # Charger les noms des classes
24
+ if not os.path.exists(CLASS_NAMES_PATH):
25
+ st.error(f"Le fichier {CLASS_NAMES_PATH} est introuvable. Veuillez le charger.")
26
+ st.stop()
27
+
28
+ with open(CLASS_NAMES_PATH, "r") as f:
29
+ class_names = json.load(f)
30
+
31
+ # Charger le modèle
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ try:
34
+ model = torch.load(MODEL_PATH, map_location=device)
35
+ model.eval()
36
+ except Exception as e:
37
+ st.error(f"Erreur lors du chargement du modèle : {e}")
38
+ st.stop()
39
+
40
+ # Transformation pour les images
41
+ image_size = (224, 224)
42
+
43
+ class GrayscaleToRGB:
44
+ def __call__(self, img):
45
+ return img.convert("RGB")
46
+
47
+ transform = transforms.Compose([
48
+ transforms.Grayscale(num_output_channels=1),
49
+ transforms.Resize(image_size),
50
+ GrayscaleToRGB(),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
53
+ ])
54
+
55
+ # Interface utilisateur Streamlit
56
+ st.title("Prédiction d'Images avec PyTorch")
57
+ st.write("Chargez une image pour obtenir une prédiction de classe.")
58
+
59
+ uploaded_file = st.file_uploader("Choisissez une image...", type=["jpg", "png", "jpeg"])
60
+
61
+ if uploaded_file is not None:
62
+ try:
63
+ # Charger et afficher l'image
64
+ image = Image.open(uploaded_file)
65
+ st.image(image, caption="Image chargée", use_column_width=True)
66
+
67
+ # Transformation et prédiction
68
+ image_tensor = transform(image).unsqueeze(0).to(device)
69
+ with torch.no_grad():
70
+ outputs = model(image_tensor)
71
+ _, predicted_class = torch.max(outputs, 1)
72
+
73
+ predicted_label = class_names[predicted_class.item()]
74
+ st.success(f"Classe prédite : {predicted_label}")
75
+ except Exception as e:
76
+ st.error(f"Erreur lors de la prédiction : {e}")