kishoreb4 commited on
Commit
821e207
·
1 Parent(s): 4dde97b

Add Streamlit app for segmentation

Browse files
Files changed (2) hide show
  1. app.py +396 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ from tensorflow.keras import backend
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import cv2
7
+ from PIL import Image
8
+ import os
9
+ import io
10
+ import gdown
11
+ from transformers import TFSegformerForSemanticSegmentation
12
+
13
+ # Set page configuration
14
+ st.set_page_config(
15
+ page_title="Pet Segmentation with SegFormer",
16
+ page_icon="🐶",
17
+ layout="wide",
18
+ initial_sidebar_state="expanded"
19
+ )
20
+
21
+ # Constants for image preprocessing
22
+ IMAGE_SIZE = 512
23
+ OUTPUT_SIZE = 128
24
+ MEAN = tf.constant([0.485, 0.456, 0.406])
25
+ STD = tf.constant([0.229, 0.224, 0.225])
26
+
27
+ # Class labels
28
+ ID2LABEL = {0: "background", 1: "border", 2: "foreground/pet"}
29
+ NUM_CLASSES = len(ID2LABEL)
30
+
31
+ @st.cache_resource
32
+ def download_model_from_drive():
33
+ """
34
+ Download model from Google Drive
35
+
36
+ Returns:
37
+ Path to downloaded model
38
+ """
39
+ # Define paths
40
+ model_dir = os.path.join("models", "saved_models")
41
+ os.makedirs(model_dir, exist_ok=True)
42
+ model_path = os.path.join(model_dir, "segformer_model")
43
+
44
+ # Check if model already exists
45
+ if not os.path.exists(model_path):
46
+ with st.spinner("Downloading model from Google Drive..."):
47
+ try:
48
+ # Google Drive file ID from the shared link
49
+ file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
50
+
51
+ # Download the model file
52
+ url = f"https://drive.google.com/uc?id={file_id}"
53
+ gdown.download(url, model_path, quiet=False)
54
+ st.success("Model downloaded successfully!")
55
+ except Exception as e:
56
+ st.error(f"Error downloading model: {str(e)}")
57
+ return None
58
+ else:
59
+ st.info("Model already exists locally.")
60
+
61
+ return model_path
62
+
63
+ @st.cache_resource
64
+ def load_model():
65
+ """
66
+ Load the SegFormer model
67
+
68
+ Returns:
69
+ Loaded model
70
+ """
71
+ try:
72
+ # Download the model first
73
+ model_path = download_model_from_drive()
74
+
75
+ if model_path is None:
76
+ st.warning("Using default pretrained model since download failed")
77
+ # Fall back to pretrained model
78
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
79
+ "nvidia/mit-b0",
80
+ num_labels=NUM_CLASSES,
81
+ id2label=ID2LABEL,
82
+ label2id={label: id for id, label in ID2LABEL.items()},
83
+ ignore_mismatched_sizes=True
84
+ )
85
+ else:
86
+ # Load downloaded model
87
+ model = TFSegformerForSemanticSegmentation.from_pretrained(model_path)
88
+
89
+ return model
90
+ except Exception as e:
91
+ st.error(f"Error loading model: {str(e)}")
92
+ st.error("Falling back to pretrained model")
93
+ # Fall back to pretrained model as a last resort
94
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
95
+ "nvidia/mit-b0",
96
+ num_labels=NUM_CLASSES,
97
+ id2label=ID2LABEL,
98
+ label2id={label: id for id, label in ID2LABEL.items()},
99
+ ignore_mismatched_sizes=True
100
+ )
101
+ return model
102
+
103
+ def normalize_image(input_image):
104
+ """
105
+ Normalize the input image
106
+
107
+ Args:
108
+ input_image: Image to normalize
109
+
110
+ Returns:
111
+ Normalized image
112
+ """
113
+ input_image = tf.image.convert_image_dtype(input_image, tf.float32)
114
+ input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
115
+ return input_image
116
+
117
+ def preprocess_image(image):
118
+ """
119
+ Preprocess image for model input
120
+
121
+ Args:
122
+ image: PIL Image to preprocess
123
+
124
+ Returns:
125
+ Preprocessed image tensor, original image
126
+ """
127
+ # Convert PIL Image to numpy array
128
+ img_array = np.array(image.convert('RGB'))
129
+
130
+ # Store original image for display
131
+ original_img = img_array.copy()
132
+
133
+ # Resize to target size
134
+ img_resized = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
135
+
136
+ # Normalize
137
+ img_normalized = normalize_image(img_resized)
138
+
139
+ # Transpose from HWC to CHW (SegFormer expects channels first)
140
+ img_transposed = tf.transpose(img_normalized, (2, 0, 1))
141
+
142
+ # Add batch dimension
143
+ img_batch = tf.expand_dims(img_transposed, axis=0)
144
+
145
+ return img_batch, original_img
146
+
147
+ def create_mask(pred_mask):
148
+ """
149
+ Convert model prediction to displayable mask
150
+
151
+ Args:
152
+ pred_mask: Prediction from model
153
+
154
+ Returns:
155
+ Processed mask for visualization
156
+ """
157
+ # Get the class with highest probability (argmax along class dimension)
158
+ pred_mask = tf.math.argmax(pred_mask, axis=1)
159
+
160
+ # Add channel dimension
161
+ pred_mask = tf.expand_dims(pred_mask, -1)
162
+
163
+ # Resize to original image size
164
+ pred_mask = tf.image.resize(
165
+ pred_mask,
166
+ (IMAGE_SIZE, IMAGE_SIZE),
167
+ method="nearest"
168
+ )
169
+
170
+ return pred_mask[0]
171
+
172
+ def colorize_mask(mask):
173
+ """
174
+ Apply colors to segmentation mask
175
+
176
+ Args:
177
+ mask: Segmentation mask
178
+
179
+ Returns:
180
+ Colorized mask
181
+ """
182
+ # Define colors for each class (RGB)
183
+ colors = [
184
+ [0, 0, 0], # Background (black)
185
+ [255, 0, 0], # Border (red)
186
+ [0, 0, 255] # Foreground/pet (blue)
187
+ ]
188
+
189
+ # Create RGB mask
190
+ rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
191
+
192
+ for i, color in enumerate(colors):
193
+ # Find pixels of this class and assign color
194
+ class_mask = np.where(mask == i, 1, 0).astype(np.uint8)
195
+ for c in range(3):
196
+ rgb_mask[:, :, c] += class_mask * color[c]
197
+
198
+ return rgb_mask
199
+
200
+ def create_overlay(image, mask, alpha=0.5):
201
+ """
202
+ Create an overlay of mask on original image
203
+
204
+ Args:
205
+ image: Original image
206
+ mask: Segmentation mask
207
+ alpha: Transparency level (0-1)
208
+
209
+ Returns:
210
+ Overlay image
211
+ """
212
+ # Ensure mask shape matches image
213
+ if image.shape[:2] != mask.shape[:2]:
214
+ mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
215
+
216
+ # Create blend
217
+ overlay = cv2.addWeighted(
218
+ image,
219
+ 1,
220
+ mask.astype(np.uint8),
221
+ alpha,
222
+ 0
223
+ )
224
+
225
+ return overlay
226
+
227
+ def main():
228
+ st.title("🐶 Pet Segmentation with SegFormer")
229
+ st.markdown("""
230
+ This app demonstrates semantic segmentation of pet images using a SegFormer model.
231
+ The model segments images into three classes:
232
+ - **Background**: Areas around the pet
233
+ - **Border**: The boundary/outline around the pet
234
+ - **Foreground**: The pet itself
235
+ """)
236
+
237
+ # Sidebar
238
+ st.sidebar.header("Model Information")
239
+ st.sidebar.markdown("""
240
+ **SegFormer** is a state-of-the-art semantic segmentation model based on transformers.
241
+
242
+ Key features:
243
+ - Hierarchical transformer encoder
244
+ - Lightweight MLP decoder
245
+ - Efficient mix of local and global attention
246
+
247
+ This implementation uses the MIT-B0 variant fine-tuned on the Oxford-IIIT Pet dataset.
248
+ """)
249
+
250
+ # Advanced settings in sidebar
251
+ st.sidebar.header("Settings")
252
+
253
+ # Overlay opacity
254
+ overlay_opacity = st.sidebar.slider(
255
+ "Overlay Opacity",
256
+ min_value=0.1,
257
+ max_value=1.0,
258
+ value=0.5,
259
+ step=0.1
260
+ )
261
+
262
+ # Load model
263
+ with st.spinner("Loading SegFormer model..."):
264
+ model = load_model()
265
+
266
+ if model is None:
267
+ st.error("Failed to load model. Using default pretrained model instead.")
268
+ else:
269
+ st.sidebar.success("Model loaded successfully!")
270
+
271
+ # Image upload
272
+ st.header("Upload an Image")
273
+ uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
274
+
275
+ # Sample images option
276
+ st.markdown("### Or use a sample image:")
277
+ sample_dir = "samples"
278
+
279
+ # Check if sample directory exists and contains images
280
+ sample_files = []
281
+ if os.path.exists(sample_dir):
282
+ sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
283
+
284
+ if sample_files:
285
+ selected_sample = st.selectbox("Select a sample image:", sample_files)
286
+ use_sample = st.button("Use this sample")
287
+
288
+ if use_sample:
289
+ with open(os.path.join(sample_dir, selected_sample), "rb") as file:
290
+ image_bytes = file.read()
291
+ uploaded_image = io.BytesIO(image_bytes)
292
+ st.success(f"Using sample image: {selected_sample}")
293
+
294
+ # Process uploaded image
295
+ if uploaded_image is not None:
296
+ # Display original image
297
+ image = Image.open(uploaded_image)
298
+
299
+ col1, col2 = st.columns(2)
300
+
301
+ with col1:
302
+ st.subheader("Original Image")
303
+ st.image(image, caption="Uploaded Image", use_column_width=True)
304
+
305
+ # Preprocess and predict
306
+ with st.spinner("Generating segmentation mask..."):
307
+ # Preprocess the image
308
+ img_tensor, original_img = preprocess_image(image)
309
+
310
+ # Make prediction
311
+ prediction = model(pixel_values=img_tensor, training=False)
312
+ logits = prediction.logits
313
+
314
+ # Create visualization mask
315
+ mask = create_mask(logits).numpy()
316
+
317
+ # Colorize the mask
318
+ colorized_mask = colorize_mask(mask)
319
+
320
+ # Create overlay
321
+ overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
322
+
323
+ # Display results
324
+ with col2:
325
+ st.subheader("Segmentation Result")
326
+ st.image(overlay, caption="Segmentation Overlay", use_column_width=True)
327
+
328
+ # Display segmentation details
329
+ st.header("Segmentation Details")
330
+ col1, col2, col3 = st.columns(3)
331
+
332
+ with col1:
333
+ st.subheader("Background")
334
+ st.markdown("Areas surrounding the pet")
335
+ mask_bg = np.where(mask == 0, 255, 0).astype(np.uint8)
336
+ st.image(mask_bg, caption="Background", use_column_width=True)
337
+
338
+ with col2:
339
+ st.subheader("Border")
340
+ st.markdown("Boundary around the pet")
341
+ mask_border = np.where(mask == 1, 255, 0).astype(np.uint8)
342
+ st.image(mask_border, caption="Border", use_column_width=True)
343
+
344
+ with col3:
345
+ st.subheader("Foreground (Pet)")
346
+ st.markdown("The pet itself")
347
+ mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
348
+ st.image(mask_fg, caption="Foreground", use_column_width=True)
349
+
350
+ # Download buttons
351
+ col1, col2 = st.columns(2)
352
+
353
+ with col1:
354
+ # Convert mask to PNG for download
355
+ mask_colored = Image.fromarray(colorized_mask)
356
+ mask_bytes = io.BytesIO()
357
+ mask_colored.save(mask_bytes, format='PNG')
358
+ mask_bytes = mask_bytes.getvalue()
359
+
360
+ st.download_button(
361
+ label="Download Segmentation Mask",
362
+ data=mask_bytes,
363
+ file_name="pet_segmentation_mask.png",
364
+ mime="image/png"
365
+ )
366
+
367
+ with col2:
368
+ # Convert overlay to PNG for download
369
+ overlay_img = Image.fromarray(overlay)
370
+ overlay_bytes = io.BytesIO()
371
+ overlay_img.save(overlay_bytes, format='PNG')
372
+ overlay_bytes = overlay_bytes.getvalue()
373
+
374
+ st.download_button(
375
+ label="Download Overlay Image",
376
+ data=overlay_bytes,
377
+ file_name="pet_segmentation_overlay.png",
378
+ mime="image/png"
379
+ )
380
+
381
+ # Footer with additional information
382
+ st.markdown("---")
383
+ st.markdown("### About the Model")
384
+ st.markdown("""
385
+ This segmentation model is based on the SegFormer architecture and was fine-tuned on the Oxford-IIIT Pet dataset.
386
+
387
+ **Key Performance Metrics:**
388
+ - Mean IoU (Intersection over Union): Measures overlap between predictions and ground truth
389
+ - Dice Coefficient: Similar to F1-score, balances precision and recall
390
+
391
+ The model segments pet images into three semantic classes (background, border, and pet/foreground),
392
+ making it useful for applications like pet image editing, background removal, and object detection.
393
+ """)
394
+
395
+ if __name__ == "__main__":
396
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.27.0
2
+ tensorflow==2.11.0
3
+ tf-keras
4
+ transformers==4.30.0
5
+ numpy>=1.22.0
6
+ matplotlib>=3.5.0
7
+ opencv-python-headless>=4.5.0
8
+ pillow>=9.0.0
9
+ gdown>=4.6.0
10
+ requests>=2.28.0