brightlembo commited on
Commit
7a3cf1c
·
verified ·
1 Parent(s): 4ba5d40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -143
app.py CHANGED
@@ -1,152 +1,52 @@
1
  import torch
2
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
  from PIL import Image
4
- import io
5
- import base64
6
- import numpy as np
7
- from flask import Flask, request, jsonify
8
 
9
- class HuggingFaceClassifier:
10
- def __init__(self, model_name="microsoft/resnet-50"):
11
- """
12
- Initialize Hugging Face model and feature extractor
13
-
14
- Args:
15
- model_name (str): Hugging Face model identifier
16
- """
17
- try:
18
- # Load pre-trained model and feature extractor
19
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
20
- self.model = AutoModelForImageClassification.from_pretrained(model_name)
21
-
22
- # Move to GPU if available
23
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- self.model.to(self.device)
25
- self.model.eval()
26
-
27
- except Exception as e:
28
- raise ValueError(f"Model loading error: {e}")
29
-
30
- def preprocess_image(self, image):
31
- """
32
- Preprocess image for model input
33
-
34
- Args:
35
- image (PIL.Image): Input image
36
-
37
- Returns:
38
- torch.Tensor: Preprocessed image tensor
39
- """
40
- # Preprocess image using feature extractor
41
- inputs = self.feature_extractor(images=image, return_tensors="pt")
42
- return inputs.pixel_values.to(self.device)
43
-
44
- def predict(self, image):
45
- """
46
- Predict image classification
47
-
48
- Args:
49
- image (PIL.Image): Input image
50
-
51
- Returns:
52
- list: Top prediction results
53
- """
54
- try:
55
- # Preprocess image
56
- inputs = self.preprocess_image(image)
57
-
58
- # Perform prediction
59
- with torch.no_grad():
60
- outputs = self.model(inputs)
61
- logits = outputs.logits
62
- probabilities = torch.softmax(logits, dim=-1)
63
- top_k = torch.topk(probabilities, k=5)
64
-
65
- # Process results
66
- predicted_classes = [
67
- {
68
- "label": self.model.config.id2label[idx.item()],
69
- "score": prob.item()
70
- }
71
- for idx, prob in zip(top_k.indices[0], top_k.values[0])
72
- ]
73
-
74
- return predicted_classes
75
-
76
- except Exception as e:
77
- raise RuntimeError(f"Prediction error: {e}")
78
 
79
- # Flask API Setup
80
- app = Flask(__name__)
 
 
81
 
82
- # Initialize classifier (can be changed to any model)
83
- classifier = HuggingFaceClassifier(
84
- model_name="microsoft/resnet-50"
85
- )
86
 
87
- @app.route('/predict', methods=['POST'])
88
- def predict_image():
89
- """
90
- Image classification endpoint
91
- Supports base64 and file upload
92
- """
93
- try:
94
- # Handle base64 encoded image
95
- if 'image' in request.json:
96
- image_data = base64.b64decode(request.json['image'])
97
- image = Image.open(io.BytesIO(image_data))
98
-
99
- # Handle file upload
100
- elif 'file' in request.files:
101
- image = Image.open(request.files['file'])
102
-
103
- else:
104
- return jsonify({
105
- 'error': 'No image provided',
106
- 'status': 'failed'
107
- }), 400
108
-
109
- # Perform prediction
110
- predictions = classifier.predict(image)
111
-
112
- return jsonify({
113
- 'predictions': predictions,
114
- 'status': 'success'
115
- })
116
-
117
- except Exception as e:
118
- return jsonify({
119
- 'error': str(e),
120
- 'status': 'failed'
121
- }), 500
122
 
123
- @app.route('/models', methods=['GET'])
124
- def available_models():
125
- """
126
- List available pre-trained models
127
- """
128
- models = [
129
- "microsoft/resnet-50",
130
- "google/vit-base-patch16-224",
131
- "facebook/vit-mae-base",
132
- "microsoft/beit-base-patch16-224"
133
- ]
134
-
135
- return jsonify({
136
- 'models': models,
137
- 'total_models': len(models)
138
- })
139
 
140
- @app.route('/health', methods=['GET'])
141
- def health_check():
142
- """
143
- API health check endpoint
144
- """
145
- return jsonify({
146
- 'status': 'healthy',
147
- 'model': classifier.model.config.model_type,
148
- 'device': str(classifier.device)
149
- })
150
 
151
- if __name__ == '__main__':
152
- app.run(host='0.0.0.0', port=5000, debug=True)
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from torchvision import transforms
3
  from PIL import Image
4
+ import json
5
+ import streamlit as st
 
 
6
 
7
+ # Charger les noms des classes
8
+ with open("class_names.json", "r") as f:
9
+ class_names = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Charger le modèle
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model = torch.load("efficientnet_b7_best.pth", map_location=device)
14
+ model.eval() # Mode évaluation
15
 
16
+ # Définir la taille de l'image
17
+ image_size = (224, 224)
 
 
18
 
19
+ # Transformation pour l'image
20
+ class GrayscaleToRGB:
21
+ def __call__(self, img):
22
+ return img.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ valid_test_transforms = transforms.Compose([
25
+ transforms.Grayscale(num_output_channels=1),
26
+ transforms.Resize(image_size),
27
+ GrayscaleToRGB(), # Conversion en RGB
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
30
+ ])
 
 
 
 
 
 
 
 
 
31
 
32
+ # Fonction de prédiction
33
+ def predict_image(image):
34
+ image_tensor = valid_test_transforms(image).unsqueeze(0).to(device)
35
+ with torch.no_grad():
36
+ outputs = model(image_tensor)
37
+ _, predicted_class = torch.max(outputs, 1)
38
+ predicted_label = class_names[predicted_class.item()]
39
+ return predicted_label
 
 
40
 
41
+ # Interface Streamlit
42
+ st.title("Prédiction d'images avec PyTorch")
43
+ st.write("Chargez une image pour obtenir une prédiction de classe.")
44
+
45
+ uploaded_image = st.file_uploader("Téléchargez une image", type=["jpg", "jpeg", "png"])
46
+
47
+ if uploaded_image is not None:
48
+ image = Image.open(uploaded_image)
49
+ st.image(image, caption="Image téléchargée", use_column_width=True)
50
+
51
+ predicted_label = predict_image(image)
52
+ st.write(f"Prédiction de la classe : {predicted_label}")