brightlembo commited on
Commit
8a9662a
·
verified ·
1 Parent(s): 3803aa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py CHANGED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import io
5
+ import base64
6
+ import numpy as np
7
+ from flask import Flask, request, jsonify
8
+
9
+ class HuggingFaceClassifier:
10
+ def __init__(self, model_name="microsoft/resnet-50"):
11
+ """
12
+ Initialize Hugging Face model and feature extractor
13
+
14
+ Args:
15
+ model_name (str): Hugging Face model identifier
16
+ """
17
+ try:
18
+ # Load pre-trained model and feature extractor
19
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
20
+ self.model = AutoModelForImageClassification.from_pretrained(model_name)
21
+
22
+ # Move to GPU if available
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ self.model.to(self.device)
25
+ self.model.eval()
26
+
27
+ except Exception as e:
28
+ raise ValueError(f"Model loading error: {e}")
29
+
30
+ def preprocess_image(self, image):
31
+ """
32
+ Preprocess image for model input
33
+
34
+ Args:
35
+ image (PIL.Image): Input image
36
+
37
+ Returns:
38
+ torch.Tensor: Preprocessed image tensor
39
+ """
40
+ # Preprocess image using feature extractor
41
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
42
+ return inputs.pixel_values.to(self.device)
43
+
44
+ def predict(self, image):
45
+ """
46
+ Predict image classification
47
+
48
+ Args:
49
+ image (PIL.Image): Input image
50
+
51
+ Returns:
52
+ list: Top prediction results
53
+ """
54
+ try:
55
+ # Preprocess image
56
+ inputs = self.preprocess_image(image)
57
+
58
+ # Perform prediction
59
+ with torch.no_grad():
60
+ outputs = self.model(inputs)
61
+ logits = outputs.logits
62
+ probabilities = torch.softmax(logits, dim=-1)
63
+ top_k = torch.topk(probabilities, k=5)
64
+
65
+ # Process results
66
+ predicted_classes = [
67
+ {
68
+ "label": self.model.config.id2label[idx.item()],
69
+ "score": prob.item()
70
+ }
71
+ for idx, prob in zip(top_k.indices[0], top_k.values[0])
72
+ ]
73
+
74
+ return predicted_classes
75
+
76
+ except Exception as e:
77
+ raise RuntimeError(f"Prediction error: {e}")
78
+
79
+ # Flask API Setup
80
+ app = Flask(__name__)
81
+
82
+ # Initialize classifier (can be changed to any model)
83
+ classifier = HuggingFaceClassifier(
84
+ model_name="microsoft/resnet-50"
85
+ )
86
+
87
+ @app.route('/predict', methods=['POST'])
88
+ def predict_image():
89
+ """
90
+ Image classification endpoint
91
+ Supports base64 and file upload
92
+ """
93
+ try:
94
+ # Handle base64 encoded image
95
+ if 'image' in request.json:
96
+ image_data = base64.b64decode(request.json['image'])
97
+ image = Image.open(io.BytesIO(image_data))
98
+
99
+ # Handle file upload
100
+ elif 'file' in request.files:
101
+ image = Image.open(request.files['file'])
102
+
103
+ else:
104
+ return jsonify({
105
+ 'error': 'No image provided',
106
+ 'status': 'failed'
107
+ }), 400
108
+
109
+ # Perform prediction
110
+ predictions = classifier.predict(image)
111
+
112
+ return jsonify({
113
+ 'predictions': predictions,
114
+ 'status': 'success'
115
+ })
116
+
117
+ except Exception as e:
118
+ return jsonify({
119
+ 'error': str(e),
120
+ 'status': 'failed'
121
+ }), 500
122
+
123
+ @app.route('/models', methods=['GET'])
124
+ def available_models():
125
+ """
126
+ List available pre-trained models
127
+ """
128
+ models = [
129
+ "microsoft/resnet-50",
130
+ "google/vit-base-patch16-224",
131
+ "facebook/vit-mae-base",
132
+ "microsoft/beit-base-patch16-224"
133
+ ]
134
+
135
+ return jsonify({
136
+ 'models': models,
137
+ 'total_models': len(models)
138
+ })
139
+
140
+ @app.route('/health', methods=['GET'])
141
+ def health_check():
142
+ """
143
+ API health check endpoint
144
+ """
145
+ return jsonify({
146
+ 'status': 'healthy',
147
+ 'model': classifier.model.config.model_type,
148
+ 'device': str(classifier.device)
149
+ })
150
+
151
+ if __name__ == '__main__':
152
+ app.run(host='0.0.0.0', port=5000, debug=True)