kishoreb4 commited on
Commit
1758d8a
·
verified ·
1 Parent(s): 564c9e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +342 -172
app.py CHANGED
@@ -10,6 +10,7 @@ import io
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
 
13
  st.set_page_config(
14
  page_title="Pet Segmentation with SegFormer",
15
  page_icon="🐶",
@@ -17,25 +18,26 @@ st.set_page_config(
17
  initial_sidebar_state="expanded"
18
  )
19
 
20
- # Constants for image preprocessing
21
  IMAGE_SIZE = 512
22
  OUTPUT_SIZE = 128
23
  MEAN = tf.constant([0.485, 0.456, 0.406])
24
  STD = tf.constant([0.229, 0.224, 0.225])
25
 
26
- # Class labels
27
  ID2LABEL = {0: "background", 1: "border", 2: "foreground/pet"}
28
  NUM_CLASSES = len(ID2LABEL)
29
 
30
 
31
  @st.cache_resource
32
  def download_model_from_drive():
 
33
  # Create a models directory
34
  os.makedirs("models", exist_ok=True)
35
- model_path = "models/best_model"
36
 
37
  if not os.path.exists(model_path):
38
- # Fixed Google Drive URL format for gdown
39
  url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
40
  try:
41
  gdown.download(url, model_path, quiet=False)
@@ -50,12 +52,7 @@ def download_model_from_drive():
50
 
51
  @st.cache_resource
52
  def load_model():
53
- """
54
- Load the SegFormer model
55
-
56
- Returns:
57
- Loaded model
58
- """
59
  try:
60
  # First create a base model with the correct architecture
61
  base_model = TFSegformerForSemanticSegmentation.from_pretrained(
@@ -68,54 +65,35 @@ def load_model():
68
 
69
  # Download the trained weights
70
  model_path = download_model_from_drive()
71
-
72
- if model_path is not None and os.path.exists(model_path):
73
- st.info(f"Loading weights from {model_path}...")
74
  try:
75
- # Try to load the weights
76
  base_model.load_weights(model_path)
77
  st.success("Model weights loaded successfully!")
78
- return base_model
79
  except Exception as e:
80
  # st.error(f"Error loading weights: {e}")
81
- # st.info("Using base pretrained model instead")
82
- return base_model
83
- else:
84
- st.warning("Using base pretrained model since download failed")
85
- return base_model
86
 
87
  except Exception as e:
88
  st.error(f"Error in load_model: {e}")
89
- st.warning("Using default pretrained model")
90
- # Fall back to pretrained model as a last resort
91
- return TFSegformerForSemanticSegmentation.from_pretrained(
92
- "nvidia/mit-b0",
93
- num_labels=NUM_CLASSES,
94
- id2label=ID2LABEL,
95
- label2id={label: id for id, label in ID2LABEL.items()},
96
- ignore_mismatched_sizes=True
97
- )
98
 
99
  def normalize_image(input_image):
100
- """
101
- Normalize the input image
102
-
103
- Args:
104
- input_image: Image to normalize
105
-
106
- Returns:
107
- Normalized image
108
- """
109
  input_image = tf.image.convert_image_dtype(input_image, tf.float32)
110
  input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
111
  return input_image
112
 
113
- def preprocess_image(image):
 
114
  """
115
- Preprocess image for model input
116
 
117
  Args:
118
  image: PIL Image to preprocess
 
119
 
120
  Returns:
121
  Preprocessed image tensor, original image
@@ -126,8 +104,14 @@ def preprocess_image(image):
126
  # Store original image for display
127
  original_img = img_array.copy()
128
 
129
- # Resize to target size
130
- img_resized = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
 
 
 
 
 
 
131
 
132
  # Normalize
133
  img_normalized = normalize_image(img_resized)
@@ -140,6 +124,41 @@ def preprocess_image(image):
140
 
141
  return img_batch, original_img
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def create_mask(pred_mask):
144
  """
145
  Convert model prediction to displayable mask
@@ -150,64 +169,96 @@ def create_mask(pred_mask):
150
  Returns:
151
  Processed mask (2D array)
152
  """
 
153
  pred_mask = tf.math.argmax(pred_mask, axis=1)
 
 
154
  pred_mask = tf.squeeze(pred_mask)
 
155
  return pred_mask.numpy()
156
 
 
157
  def colorize_mask(mask):
158
  """
159
- Apply colors to segmentation mask
160
 
161
  Args:
162
- mask: Segmentation mask (2D array)
163
 
164
  Returns:
165
- Colorized mask (3D RGB array)
166
  """
167
- # Ensure the mask is 2D
168
- if len(mask.shape) > 2:
169
- mask = np.squeeze(mask)
170
-
171
- # Define colors for each class (RGB)
172
  colors = [
173
- [0, 0, 0], # Background (black)
174
- [255, 0, 0], # Border (red)
175
- [0, 0, 255] # Foreground/pet (blue)
176
  ]
177
 
178
  # Create RGB mask
179
- rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
 
180
 
 
181
  for i, color in enumerate(colors):
182
- class_mask = (mask == i).astype(np.uint8)
183
- for c in range(3):
184
- rgb_mask[:, :, c] += class_mask * color[c]
185
 
186
- return rgb_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  def calculate_iou(y_true, y_pred, class_idx=None):
189
  """
190
- Calculate IoU (Intersection over Union) for segmentation masks
191
 
192
  Args:
193
- y_true: Ground truth segmentation mask
194
- y_pred: Predicted segmentation mask
195
- class_idx: Index of the class to calculate IoU for (None for mean IoU)
196
 
197
  Returns:
198
  IoU score
199
  """
200
  if class_idx is not None:
201
- # Binary IoU for specific class
202
  y_true_class = (y_true == class_idx).astype(np.float32)
203
  y_pred_class = (y_pred == class_idx).astype(np.float32)
204
 
 
205
  intersection = np.sum(y_true_class * y_pred_class)
206
  union = np.sum(y_true_class) + np.sum(y_pred_class) - intersection
207
 
208
- iou = intersection / (union + 1e-6)
 
209
  else:
210
- # Mean IoU across all classes
211
  class_ious = []
212
  for idx in range(NUM_CLASSES):
213
  class_iou = calculate_iou(y_true, y_pred, idx)
@@ -217,32 +268,105 @@ def calculate_iou(y_true, y_pred, class_idx=None):
217
 
218
  return iou
219
 
220
- def create_overlay(image, mask, alpha=0.5):
 
221
  """
222
- Create an overlay of mask on original image
223
 
224
  Args:
225
- image: Original image
226
- mask: Segmentation mask
227
- alpha: Transparency level (0-1)
228
 
229
  Returns:
230
- Overlay image
231
  """
232
- # Ensure mask shape matches image
233
- if image.shape[:2] != mask.shape[:2]:
234
- mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- # Create blend
237
- overlay = cv2.addWeighted(
238
- image,
239
- 1,
240
- mask.astype(np.uint8),
241
- alpha,
242
- 0
243
- )
244
 
245
- return overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  def main():
248
  st.title("🐶 Pet Segmentation with SegFormer")
@@ -254,23 +378,16 @@ def main():
254
  - **Foreground**: The pet itself
255
  """)
256
 
257
- # Sidebar
258
- st.sidebar.header("Model Information")
259
- st.sidebar.markdown("""
260
- **SegFormer** is a state-of-the-art semantic segmentation model based on transformers.
261
-
262
- Key features:
263
- - Hierarchical transformer encoder
264
- - Lightweight MLP decoder
265
- - Efficient mix of local and global attention
266
-
267
- This implementation uses the MIT-B0 variant fine-tuned on the Oxford-IIIT Pet dataset.
268
- """)
269
 
270
- # Advanced settings in sidebar
271
- st.sidebar.header("Settings")
272
 
273
- # Overlay opacity
274
  overlay_opacity = st.sidebar.slider(
275
  "Overlay Opacity",
276
  min_value=0.1,
@@ -284,7 +401,8 @@ def main():
284
  model = load_model()
285
 
286
  if model is None:
287
- st.error("Failed to load model. Using default pretrained model instead.")
 
288
  else:
289
  st.sidebar.success("Model loaded successfully!")
290
 
@@ -300,22 +418,16 @@ def main():
300
  image_bytes = uploaded_image.read()
301
  image = Image.open(io.BytesIO(image_bytes))
302
 
303
- col1, col2 = st.columns(2)
304
-
305
- with col1:
306
- st.subheader("Original Image")
307
- st.image(image, caption="Uploaded Image", use_column_width=True)
308
-
309
  # Preprocess and predict
310
  with st.spinner("Generating segmentation mask..."):
311
  # Preprocess the image
312
- img_tensor, original_img = preprocess_image(image)
313
 
314
  # Make prediction
315
  outputs = model(pixel_values=img_tensor, training=False)
316
  logits = outputs.logits
317
 
318
- # Create visualization mask
319
  mask = create_mask(logits)
320
 
321
  # Colorize the mask
@@ -324,10 +436,86 @@ def main():
324
  # Create overlay
325
  overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  # Display results
328
- with col2:
329
- st.subheader("Segmentation Result")
330
- st.image(overlay, caption="Segmentation Overlay", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  # Display segmentation details
333
  st.header("Segmentation Details")
@@ -351,86 +539,68 @@ def main():
351
  mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
352
  st.image(mask_fg, caption="Foreground", use_column_width=True)
353
 
354
- # Calculate IoU if ground truth is uploaded
355
- if uploaded_mask is not None:
356
- try:
357
- # Read the mask file
358
- mask_data = uploaded_mask.read()
359
- mask_io = io.BytesIO(mask_data)
360
- gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
361
-
362
- # Handle different mask formats
363
- if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
364
- # Convert RGB to single channel if needed
365
- gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
366
-
367
- # Calculate and display IoU
368
- resized_mask = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE), interpolation=cv2.INTER_NEAREST)
369
- iou_score = calculate_iou(gt_mask, resized_mask)
370
- st.success(f"Mean IoU: {iou_score:.4f}")
371
-
372
- # Display specific class IoUs
373
- st.markdown("### IoU by Class")
374
- col1, col2, col3 = st.columns(3)
375
- with col1:
376
- bg_iou = calculate_iou(gt_mask, resized_mask, 0)
377
- st.metric("Background IoU", f"{bg_iou:.4f}")
378
- with col2:
379
- border_iou = calculate_iou(gt_mask, resized_mask, 1)
380
- st.metric("Border IoU", f"{border_iou:.4f}")
381
- with col3:
382
- fg_iou = calculate_iou(gt_mask, resized_mask, 2)
383
- st.metric("Foreground IoU", f"{fg_iou:.4f}")
384
- except Exception as e:
385
- st.error(f"Error processing ground truth mask: {e}")
386
- st.write("Please ensure the mask is valid and has the correct format.")
387
-
388
  # Download buttons
389
- col1, col2 = st.columns(2)
 
390
 
391
  with col1:
392
- # Convert mask to PNG for download
393
- mask_colored = Image.fromarray(colorized_mask)
394
- mask_bytes = io.BytesIO()
395
- mask_colored.save(mask_bytes, format='PNG')
396
- mask_bytes = mask_bytes.getvalue()
397
 
398
  st.download_button(
399
- label="Download Segmentation Mask",
400
- data=mask_bytes,
401
- file_name="pet_segmentation_mask.png",
402
  mime="image/png"
403
  )
404
 
405
  with col2:
406
- # Convert overlay to PNG for download
407
- overlay_img = Image.fromarray(overlay)
408
  overlay_bytes = io.BytesIO()
409
- overlay_img.save(overlay_bytes, format='PNG')
410
  overlay_bytes = overlay_bytes.getvalue()
411
 
412
  st.download_button(
413
- label="Download Overlay Image",
414
  data=overlay_bytes,
415
- file_name="pet_segmentation_overlay.png",
416
  mime="image/png"
417
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  except Exception as e:
419
  st.error(f"Error processing image: {e}")
420
-
421
- # Footer with additional information
422
- st.markdown("---")
423
- st.markdown("### About the Model")
424
- st.markdown("""
425
- This segmentation model is based on the SegFormer architecture and was fine-tuned on the Oxford-IIIT Pet dataset.
426
-
427
- **Key Performance Metrics:**
428
- - Mean IoU (Intersection over Union): Measures overlap between predictions and ground truth
429
- - Dice Coefficient: Similar to F1-score, balances precision and recall
430
-
431
- The model segments pet images into three semantic classes (background, border, and pet/foreground),
432
- making it useful for applications like pet image editing, background removal, and object detection.
433
- """)
434
 
435
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
436
  main()
 
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
13
+ # Set page config at the very beginning of the app
14
  st.set_page_config(
15
  page_title="Pet Segmentation with SegFormer",
16
  page_icon="🐶",
 
18
  initial_sidebar_state="expanded"
19
  )
20
 
21
+ # Constants for image preprocessing - matching colab_code.py
22
  IMAGE_SIZE = 512
23
  OUTPUT_SIZE = 128
24
  MEAN = tf.constant([0.485, 0.456, 0.406])
25
  STD = tf.constant([0.229, 0.224, 0.225])
26
 
27
+ # Class labels - DO NOT CHANGE
28
  ID2LABEL = {0: "background", 1: "border", 2: "foreground/pet"}
29
  NUM_CLASSES = len(ID2LABEL)
30
 
31
 
32
  @st.cache_resource
33
  def download_model_from_drive():
34
+ """Download the model from Google Drive"""
35
  # Create a models directory
36
  os.makedirs("models", exist_ok=True)
37
+ model_path = "models/tf_model.h5"
38
 
39
  if not os.path.exists(model_path):
40
+ # Correct format for gdown
41
  url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
42
  try:
43
  gdown.download(url, model_path, quiet=False)
 
52
 
53
  @st.cache_resource
54
  def load_model():
55
+ """Load the SegFormer model"""
 
 
 
 
 
56
  try:
57
  # First create a base model with the correct architecture
58
  base_model = TFSegformerForSemanticSegmentation.from_pretrained(
 
65
 
66
  # Download the trained weights
67
  model_path = download_model_from_drive()
68
+ if model_path:
 
 
69
  try:
 
70
  base_model.load_weights(model_path)
71
  st.success("Model weights loaded successfully!")
 
72
  except Exception as e:
73
  # st.error(f"Error loading weights: {e}")
74
+ # st.warning("Using base pretrained model instead.")
75
+
76
+ return base_model
 
 
77
 
78
  except Exception as e:
79
  st.error(f"Error in load_model: {e}")
80
+ return None
81
+
 
 
 
 
 
 
 
82
 
83
  def normalize_image(input_image):
84
+ """Normalize image with ImageNet stats"""
 
 
 
 
 
 
 
 
85
  input_image = tf.image.convert_image_dtype(input_image, tf.float32)
86
  input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
87
  return input_image
88
 
89
+
90
+ def preprocess_image(image, is_dataset_image=False):
91
  """
92
+ Preprocess image exactly like in colab_code.py
93
 
94
  Args:
95
  image: PIL Image to preprocess
96
+ is_dataset_image: Whether the image is from the Oxford-IIIT Pet dataset
97
 
98
  Returns:
99
  Preprocessed image tensor, original image
 
104
  # Store original image for display
105
  original_img = img_array.copy()
106
 
107
+ # Resize to target size with preserve_aspect_ratio=False
108
+ img_resized = tf.image.resize(
109
+ img_array,
110
+ (IMAGE_SIZE, IMAGE_SIZE),
111
+ method='bilinear',
112
+ preserve_aspect_ratio=False, # Ensure exact dimensions
113
+ antialias=True
114
+ )
115
 
116
  # Normalize
117
  img_normalized = normalize_image(img_resized)
 
124
 
125
  return img_batch, original_img
126
 
127
+
128
+ def process_uploaded_mask(mask_array, from_dataset=True):
129
+ """
130
+ Process an uploaded mask from the dataset to match app's format
131
+
132
+ Args:
133
+ mask_array: Numpy array of the mask
134
+ from_dataset: Whether the mask is from the original dataset
135
+
136
+ Returns:
137
+ Processed mask with values 0,1,2
138
+ """
139
+ # Handle RGBA images
140
+ if len(mask_array.shape) == 3 and mask_array.shape[2] == 4:
141
+ mask_array = mask_array[:,:,:3]
142
+
143
+ # Convert RGB to grayscale if needed
144
+ if len(mask_array.shape) == 3 and mask_array.shape[2] >= 3:
145
+ mask_array = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
146
+
147
+ # For dataset masks, convert from original values (1,2,3) to app values (0,1,2)
148
+ if from_dataset:
149
+ processed_mask = np.zeros_like(mask_array)
150
+
151
+ # Map dataset values to app values
152
+ processed_mask[mask_array == 1] = 2 # Foreground/pet (1→2)
153
+ processed_mask[mask_array == 2] = 1 # Border (2→1)
154
+ processed_mask[mask_array == 3] = 0 # Background (3→0)
155
+
156
+ return processed_mask
157
+ else:
158
+ # For non-dataset masks, assume they're already in the right format
159
+ return mask_array
160
+
161
+
162
  def create_mask(pred_mask):
163
  """
164
  Convert model prediction to displayable mask
 
169
  Returns:
170
  Processed mask (2D array)
171
  """
172
+ # Take argmax along the class dimension
173
  pred_mask = tf.math.argmax(pred_mask, axis=1)
174
+
175
+ # Remove batch dimension and convert to numpy
176
  pred_mask = tf.squeeze(pred_mask)
177
+
178
  return pred_mask.numpy()
179
 
180
+
181
  def colorize_mask(mask):
182
  """
183
+ Colorize a segmentation mask for visualization
184
 
185
  Args:
186
+ mask: Segmentation mask (2D array with class indices)
187
 
188
  Returns:
189
+ Colorized mask (3D array with RGB colors)
190
  """
191
+ # Define colors for visualization
 
 
 
 
192
  colors = [
193
+ [0, 0, 0], # Black for background (0)
194
+ [255, 255, 0], # Yellow for border (1)
195
+ [255, 0, 0] # Red for foreground/pet (2)
196
  ]
197
 
198
  # Create RGB mask
199
+ height, width = mask.shape
200
+ colorized = np.zeros((height, width, 3), dtype=np.uint8)
201
 
202
+ # Apply colors
203
  for i, color in enumerate(colors):
204
+ colorized[mask == i] = color
 
 
205
 
206
+ return colorized
207
+
208
+
209
+ def create_overlay(image, mask, alpha=0.5):
210
+ """
211
+ Create an overlay of mask on original image
212
+
213
+ Args:
214
+ image: Original image
215
+ mask: Colorized segmentation mask
216
+ alpha: Transparency level (0-1)
217
+
218
+ Returns:
219
+ Overlay image
220
+ """
221
+ # Ensure mask shape matches image
222
+ if image.shape[:2] != mask.shape[:2]:
223
+ mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
224
+
225
+ # Create blend
226
+ overlay = cv2.addWeighted(
227
+ image,
228
+ 1,
229
+ mask.astype(np.uint8),
230
+ alpha,
231
+ 0
232
+ )
233
+
234
+ return overlay
235
+
236
 
237
  def calculate_iou(y_true, y_pred, class_idx=None):
238
  """
239
+ Calculate IoU (Intersection over Union)
240
 
241
  Args:
242
+ y_true: Ground truth mask
243
+ y_pred: Predicted mask
244
+ class_idx: Class index to compute IoU for (if None, compute mean IoU)
245
 
246
  Returns:
247
  IoU score
248
  """
249
  if class_idx is not None:
250
+ # Convert to binary masks for specific class
251
  y_true_class = (y_true == class_idx).astype(np.float32)
252
  y_pred_class = (y_pred == class_idx).astype(np.float32)
253
 
254
+ # Calculate intersection and union
255
  intersection = np.sum(y_true_class * y_pred_class)
256
  union = np.sum(y_true_class) + np.sum(y_pred_class) - intersection
257
 
258
+ # Return IoU score
259
+ return float(intersection) / float(union) if union > 0 else 0.0
260
  else:
261
+ # Calculate mean IoU across all classes
262
  class_ious = []
263
  for idx in range(NUM_CLASSES):
264
  class_iou = calculate_iou(y_true, y_pred, idx)
 
268
 
269
  return iou
270
 
271
+
272
+ def calculate_dice(y_true, y_pred, class_idx=None):
273
  """
274
+ Calculate Dice coefficient (F1 score)
275
 
276
  Args:
277
+ y_true: Ground truth mask
278
+ y_pred: Predicted mask
279
+ class_idx: Class index to compute Dice for (if None, compute mean Dice)
280
 
281
  Returns:
282
+ Dice score
283
  """
284
+ if class_idx is not None:
285
+ # Convert to binary masks for specific class
286
+ y_true_class = (y_true == class_idx).astype(np.float32)
287
+ y_pred_class = (y_pred == class_idx).astype(np.float32)
288
+
289
+ # Calculate intersection and sum of areas
290
+ intersection = 2.0 * np.sum(y_true_class * y_pred_class)
291
+ sum_areas = np.sum(y_true_class) + np.sum(y_pred_class)
292
+
293
+ # Return Dice score
294
+ return float(intersection) / float(sum_areas) if sum_areas > 0 else 0.0
295
+ else:
296
+ # Calculate mean Dice across all classes
297
+ class_dices = []
298
+ for idx in range(NUM_CLASSES):
299
+ class_dice = calculate_dice(y_true, y_pred, idx)
300
+ class_dices.append(class_dice)
301
+
302
+ dice = np.mean(class_dices)
303
 
304
+ return dice
305
+
306
+
307
+ def calculate_pixel_accuracy(y_true, y_pred):
308
+ """
309
+ Calculate pixel accuracy
 
 
310
 
311
+ Args:
312
+ y_true: Ground truth mask
313
+ y_pred: Predicted mask
314
+
315
+ Returns:
316
+ Pixel accuracy
317
+ """
318
+ correct = np.sum(y_true == y_pred)
319
+ total = y_true.size
320
+ return float(correct) / float(total)
321
+
322
+
323
+ def display_side_by_side(original_img, gt_mask=None, pred_mask=None, overlay=None):
324
+ """
325
+ Display images side by side
326
+
327
+ Args:
328
+ original_img: Original input image
329
+ gt_mask: Ground truth segmentation mask (optional)
330
+ pred_mask: Predicted segmentation mask
331
+ overlay: Overlay of mask on original image
332
+ """
333
+ # Determine number of columns based on available images
334
+ columns = 1 # Start with original image
335
+ if gt_mask is not None:
336
+ columns += 1
337
+ if pred_mask is not None:
338
+ columns += 1
339
+ if overlay is not None:
340
+ columns += 1
341
+
342
+ cols = st.columns(columns)
343
+
344
+ # Display original image
345
+ with cols[0]:
346
+ st.markdown("### Original Image")
347
+ st.image(original_img, use_column_width=True)
348
+
349
+ # Display ground truth mask if available
350
+ col_idx = 1
351
+ if gt_mask is not None:
352
+ with cols[col_idx]:
353
+ st.markdown("### Ground Truth Mask")
354
+ st.image(gt_mask, use_column_width=True)
355
+ col_idx += 1
356
+
357
+ # Display predicted mask if available
358
+ if pred_mask is not None:
359
+ with cols[col_idx]:
360
+ st.markdown("### Predicted Mask")
361
+ st.image(pred_mask, use_column_width=True)
362
+ col_idx += 1
363
+
364
+ # Display overlay if available
365
+ if overlay is not None:
366
+ with cols[col_idx]:
367
+ st.markdown("### Overlay")
368
+ st.image(overlay, use_column_width=True)
369
+
370
 
371
  def main():
372
  st.title("🐶 Pet Segmentation with SegFormer")
 
378
  - **Foreground**: The pet itself
379
  """)
380
 
381
+ # Sidebar settings
382
+ st.sidebar.title("Settings")
383
+
384
+ # Debug mode toggle
385
+ debug_mode = st.sidebar.checkbox("Debug Mode", value=False)
 
 
 
 
 
 
 
386
 
387
+ # Dataset image toggle - important for processing Oxford-IIIT Pet masks
388
+ dataset_image = st.sidebar.checkbox("Image is from Oxford-IIIT Pet dataset", value=True)
389
 
390
+ # Overlay opacity control
391
  overlay_opacity = st.sidebar.slider(
392
  "Overlay Opacity",
393
  min_value=0.1,
 
401
  model = load_model()
402
 
403
  if model is None:
404
+ st.error("Failed to load model. Please check your model path and try again.")
405
+ return
406
  else:
407
  st.sidebar.success("Model loaded successfully!")
408
 
 
418
  image_bytes = uploaded_image.read()
419
  image = Image.open(io.BytesIO(image_bytes))
420
 
 
 
 
 
 
 
421
  # Preprocess and predict
422
  with st.spinner("Generating segmentation mask..."):
423
  # Preprocess the image
424
+ img_tensor, original_img = preprocess_image(image, is_dataset_image=dataset_image)
425
 
426
  # Make prediction
427
  outputs = model(pixel_values=img_tensor, training=False)
428
  logits = outputs.logits
429
 
430
+ # Create mask
431
  mask = create_mask(logits)
432
 
433
  # Colorize the mask
 
436
  # Create overlay
437
  overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
438
 
439
+ # Prepare for metrics calculation (if ground truth is provided)
440
+ gt_mask = None
441
+ gt_mask_colorized = None
442
+ metrics_calculated = False
443
+
444
+ # Calculate metrics if ground truth is uploaded
445
+ if uploaded_mask is not None:
446
+ try:
447
+ # Reset the file pointer to the beginning
448
+ uploaded_mask.seek(0)
449
+
450
+ # Read the mask file
451
+ mask_data = uploaded_mask.read()
452
+ mask_io = io.BytesIO(mask_data)
453
+ gt_mask_raw = np.array(Image.open(mask_io))
454
+
455
+ if debug_mode:
456
+ st.write(f"Ground truth mask shape: {gt_mask_raw.shape}")
457
+ st.write(f"Ground truth mask unique values: {np.unique(gt_mask_raw)}")
458
+
459
+ # Process the mask based on source
460
+ gt_mask = process_uploaded_mask(gt_mask_raw, from_dataset=dataset_image)
461
+
462
+ # Colorize for display
463
+ gt_mask_colorized = colorize_mask(gt_mask)
464
+
465
+ # Resize for comparison
466
+ gt_mask_resized = cv2.resize(gt_mask, (mask.shape[0], mask.shape[1]),
467
+ interpolation=cv2.INTER_NEAREST)
468
+
469
+ if debug_mode:
470
+ st.write(f"Processed GT mask shape: {gt_mask_resized.shape}")
471
+ st.write(f"Processed GT unique values: {np.unique(gt_mask_resized)}")
472
+ st.write(f"Prediction mask unique values: {np.unique(mask)}")
473
+
474
+ # Calculate metrics
475
+ iou_score = calculate_iou(gt_mask_resized, mask)
476
+ dice_score = calculate_dice(gt_mask_resized, mask)
477
+ accuracy = calculate_pixel_accuracy(gt_mask_resized, mask)
478
+
479
+ metrics_calculated = True
480
+ except Exception as e:
481
+ st.error(f"Error processing ground truth mask: {e}")
482
+ if debug_mode:
483
+ import traceback
484
+ st.code(traceback.format_exc())
485
+
486
  # Display results
487
+ display_side_by_side(
488
+ original_img,
489
+ gt_mask_colorized,
490
+ colorized_mask,
491
+ overlay
492
+ )
493
+
494
+ # Display metrics if calculated
495
+ if metrics_calculated:
496
+ st.header("Segmentation Metrics")
497
+
498
+ # Display overall metrics
499
+ col1, col2, col3 = st.columns(3)
500
+ with col1:
501
+ st.metric("Mean IoU", f"{iou_score:.4f}")
502
+ with col2:
503
+ st.metric("Mean Dice", f"{dice_score:.4f}")
504
+ with col3:
505
+ st.metric("Pixel Accuracy", f"{accuracy:.4f}")
506
+
507
+ # Display class-specific metrics
508
+ st.subheader("Metrics by Class")
509
+ cols = st.columns(NUM_CLASSES)
510
+ class_names = ["Background", "Border", "Foreground/Pet"]
511
+
512
+ for i, (col, name) in enumerate(zip(cols, class_names)):
513
+ with col:
514
+ st.markdown(f"**{name}**")
515
+ class_iou = calculate_iou(gt_mask_resized, mask, i)
516
+ class_dice = calculate_dice(gt_mask_resized, mask, i)
517
+ st.metric("IoU", f"{class_iou:.4f}")
518
+ st.metric("Dice", f"{class_dice:.4f}")
519
 
520
  # Display segmentation details
521
  st.header("Segmentation Details")
 
539
  mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
540
  st.image(mask_fg, caption="Foreground", use_column_width=True)
541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  # Download buttons
543
+ st.header("Download Results")
544
+ col1, col2, col3 = st.columns(3)
545
 
546
  with col1:
547
+ # Download prediction as PNG
548
+ pred_pil = Image.fromarray(colorized_mask)
549
+ pred_bytes = io.BytesIO()
550
+ pred_pil.save(pred_bytes, format='PNG')
551
+ pred_bytes = pred_bytes.getvalue()
552
 
553
  st.download_button(
554
+ label="Download Prediction",
555
+ data=pred_bytes,
556
+ file_name="prediction.png",
557
  mime="image/png"
558
  )
559
 
560
  with col2:
561
+ # Download overlay as PNG
562
+ overlay_pil = Image.fromarray(overlay)
563
  overlay_bytes = io.BytesIO()
564
+ overlay_pil.save(overlay_bytes, format='PNG')
565
  overlay_bytes = overlay_bytes.getvalue()
566
 
567
  st.download_button(
568
+ label="Download Overlay",
569
  data=overlay_bytes,
570
+ file_name="overlay.png",
571
  mime="image/png"
572
  )
573
+
574
+ if metrics_calculated:
575
+ with col3:
576
+ # Create CSV with metrics
577
+ metrics_csv = f"Metric,Overall,Background,Border,Foreground\n"
578
+ metrics_csv += f"IoU,{iou_score:.4f},{calculate_iou(gt_mask_resized, mask, 0):.4f},{calculate_iou(gt_mask_resized, mask, 1):.4f},{calculate_iou(gt_mask_resized, mask, 2):.4f}\n"
579
+ metrics_csv += f"Dice,{dice_score:.4f},{calculate_dice(gt_mask_resized, mask, 0):.4f},{calculate_dice(gt_mask_resized, mask, 1):.4f},{calculate_dice(gt_mask_resized, mask, 2)::.4f}\n"
580
+ metrics_csv += f"Accuracy,{accuracy:.4f},,,"
581
+
582
+ st.download_button(
583
+ label="Download Metrics",
584
+ data=metrics_csv,
585
+ file_name="metrics.csv",
586
+ mime="text/csv"
587
+ )
588
+
589
  except Exception as e:
590
  st.error(f"Error processing image: {e}")
591
+ if debug_mode:
592
+ import traceback
593
+ st.code(traceback.format_exc())
594
+
 
 
 
 
 
 
 
 
 
 
595
 
596
  if __name__ == "__main__":
597
+ # Try to configure GPU memory growth
598
+ try:
599
+ gpus = tf.config.experimental.list_physical_devices('GPU')
600
+ if gpus:
601
+ for gpu in gpus:
602
+ tf.config.experimental.set_memory_growth(gpu, True)
603
+ except Exception as e:
604
+ print(f"GPU configuration error: {e}")
605
+
606
  main()