kishoreb4 commited on
Commit
46aa1b6
·
verified ·
1 Parent(s): 0caf545

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -121
app.py CHANGED
@@ -10,7 +10,7 @@ import io
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
13
- # Set page config at the very beginning of the app
14
  st.set_page_config(
15
  page_title="Pet Segmentation with SegFormer",
16
  page_icon="🐶",
@@ -18,13 +18,13 @@ st.set_page_config(
18
  initial_sidebar_state="expanded"
19
  )
20
 
21
- # Constants for image preprocessing - matching colab_code.py
22
  IMAGE_SIZE = 512
23
  OUTPUT_SIZE = 128
24
  MEAN = tf.constant([0.485, 0.456, 0.406])
25
  STD = tf.constant([0.229, 0.224, 0.225])
26
 
27
- # Class labels - DO NOT CHANGE
28
  ID2LABEL = {0: "background", 1: "border", 2: "foreground/pet"}
29
  NUM_CLASSES = len(ID2LABEL)
30
 
@@ -38,7 +38,8 @@ def download_model_from_drive():
38
 
39
  if not os.path.exists(model_path):
40
  # Correct format for gdown
41
- url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
 
42
  try:
43
  gdown.download(url, model_path, quiet=False)
44
  st.success("Model downloaded successfully from Google Drive.")
@@ -54,7 +55,7 @@ def download_model_from_drive():
54
  def load_model():
55
  """Load the SegFormer model"""
56
  try:
57
- # First create a base model with the correct architecture
58
  base_model = TFSegformerForSemanticSegmentation.from_pretrained(
59
  "nvidia/mit-b0",
60
  num_labels=NUM_CLASSES,
@@ -70,9 +71,8 @@ def load_model():
70
  base_model.load_weights(model_path)
71
  st.success("Model weights loaded successfully!")
72
  except Exception as e:
73
- st.success("Model weights loaded successfully!")
74
- # st.error(f"Error loading weights: {e}")
75
- # st.warning("Using base pretrained model instead.")
76
 
77
  return base_model
78
 
@@ -88,36 +88,27 @@ def normalize_image(input_image):
88
  return input_image
89
 
90
 
91
- def preprocess_image(image, is_dataset_image=False):
92
- """
93
- Preprocess image exactly like in colab_code.py
94
-
95
- Args:
96
- image: PIL Image to preprocess
97
- is_dataset_image: Whether the image is from the Oxford-IIIT Pet dataset
98
-
99
- Returns:
100
- Preprocessed image tensor, original image
101
- """
102
  # Convert PIL Image to numpy array
103
  img_array = np.array(image.convert('RGB'))
104
 
105
  # Store original image for display
106
  original_img = img_array.copy()
107
 
108
- # Resize to target size with preserve_aspect_ratio=False
109
  img_resized = tf.image.resize(
110
  img_array,
111
  (IMAGE_SIZE, IMAGE_SIZE),
112
  method='bilinear',
113
- preserve_aspect_ratio=False, # Ensure exact dimensions
114
  antialias=True
115
  )
116
 
117
  # Normalize
118
  img_normalized = normalize_image(img_resized)
119
 
120
- # Transpose from HWC to CHW (SegFormer expects channels first)
121
  img_transposed = tf.transpose(img_normalized, (2, 0, 1))
122
 
123
  # Add batch dimension
@@ -126,13 +117,12 @@ def preprocess_image(image, is_dataset_image=False):
126
  return img_batch, original_img
127
 
128
 
129
- def process_uploaded_mask(mask_array, from_dataset=True):
130
  """
131
- Process an uploaded mask from the dataset to match app's format
132
 
133
  Args:
134
  mask_array: Numpy array of the mask
135
- from_dataset: Whether the mask is from the original dataset
136
 
137
  Returns:
138
  Processed mask with values 0,1,2
@@ -145,50 +135,37 @@ def process_uploaded_mask(mask_array, from_dataset=True):
145
  if len(mask_array.shape) == 3 and mask_array.shape[2] >= 3:
146
  mask_array = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
147
 
148
- # For dataset masks, convert from original values (1,2,3) to app values (0,1,2)
149
- if from_dataset:
 
 
 
150
  processed_mask = np.zeros_like(mask_array)
151
-
152
- # Map dataset values to app values
153
  processed_mask[mask_array == 1] = 2 # Foreground/pet (1→2)
154
  processed_mask[mask_array == 2] = 1 # Border (2→1)
155
  processed_mask[mask_array == 3] = 0 # Background (3→0)
156
-
157
  return processed_mask
158
- else:
159
- # For non-dataset masks, assume they're already in the right format
 
160
  return mask_array
 
 
 
 
 
 
161
 
162
 
163
  def create_mask(pred_mask):
164
- """
165
- Convert model prediction to displayable mask
166
-
167
- Args:
168
- pred_mask: Prediction logits from the model
169
-
170
- Returns:
171
- Processed mask (2D array)
172
- """
173
- # Take argmax along the class dimension
174
  pred_mask = tf.math.argmax(pred_mask, axis=1)
175
-
176
- # Remove batch dimension and convert to numpy
177
  pred_mask = tf.squeeze(pred_mask)
178
-
179
  return pred_mask.numpy()
180
 
181
 
182
  def colorize_mask(mask):
183
- """
184
- Colorize a segmentation mask for visualization
185
-
186
- Args:
187
- mask: Segmentation mask (2D array with class indices)
188
-
189
- Returns:
190
- Colorized mask (3D array with RGB colors)
191
- """
192
  # Define colors for visualization
193
  colors = [
194
  [0, 0, 0], # Black for background (0)
@@ -208,17 +185,7 @@ def colorize_mask(mask):
208
 
209
 
210
  def create_overlay(image, mask, alpha=0.5):
211
- """
212
- Create an overlay of mask on original image
213
-
214
- Args:
215
- image: Original image
216
- mask: Colorized segmentation mask
217
- alpha: Transparency level (0-1)
218
-
219
- Returns:
220
- Overlay image
221
- """
222
  # Ensure mask shape matches image
223
  if image.shape[:2] != mask.shape[:2]:
224
  mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
@@ -236,17 +203,7 @@ def create_overlay(image, mask, alpha=0.5):
236
 
237
 
238
  def calculate_iou(y_true, y_pred, class_idx=None):
239
- """
240
- Calculate IoU (Intersection over Union)
241
-
242
- Args:
243
- y_true: Ground truth mask
244
- y_pred: Predicted mask
245
- class_idx: Class index to compute IoU for (if None, compute mean IoU)
246
-
247
- Returns:
248
- IoU score
249
- """
250
  if class_idx is not None:
251
  # Convert to binary masks for specific class
252
  y_true_class = (y_true == class_idx).astype(np.float32)
@@ -265,23 +222,11 @@ def calculate_iou(y_true, y_pred, class_idx=None):
265
  class_iou = calculate_iou(y_true, y_pred, idx)
266
  class_ious.append(class_iou)
267
 
268
- iou = np.mean(class_ious)
269
-
270
- return iou
271
 
272
 
273
  def calculate_dice(y_true, y_pred, class_idx=None):
274
- """
275
- Calculate Dice coefficient (F1 score)
276
-
277
- Args:
278
- y_true: Ground truth mask
279
- y_pred: Predicted mask
280
- class_idx: Class index to compute Dice for (if None, compute mean Dice)
281
-
282
- Returns:
283
- Dice score
284
- """
285
  if class_idx is not None:
286
  # Convert to binary masks for specific class
287
  y_true_class = (y_true == class_idx).astype(np.float32)
@@ -300,37 +245,18 @@ def calculate_dice(y_true, y_pred, class_idx=None):
300
  class_dice = calculate_dice(y_true, y_pred, idx)
301
  class_dices.append(class_dice)
302
 
303
- dice = np.mean(class_dices)
304
-
305
- return dice
306
 
307
 
308
  def calculate_pixel_accuracy(y_true, y_pred):
309
- """
310
- Calculate pixel accuracy
311
-
312
- Args:
313
- y_true: Ground truth mask
314
- y_pred: Predicted mask
315
-
316
- Returns:
317
- Pixel accuracy
318
- """
319
  correct = np.sum(y_true == y_pred)
320
  total = y_true.size
321
  return float(correct) / float(total)
322
 
323
 
324
  def display_side_by_side(original_img, gt_mask=None, pred_mask=None, overlay=None):
325
- """
326
- Display images side by side
327
-
328
- Args:
329
- original_img: Original input image
330
- gt_mask: Ground truth segmentation mask (optional)
331
- pred_mask: Predicted segmentation mask
332
- overlay: Overlay of mask on original image
333
- """
334
  # Determine number of columns based on available images
335
  columns = 1 # Start with original image
336
  if gt_mask is not None:
@@ -385,9 +311,6 @@ def main():
385
  # Debug mode toggle
386
  debug_mode = st.sidebar.checkbox("Debug Mode", value=False)
387
 
388
- # Dataset image toggle - important for processing Oxford-IIIT Pet masks
389
- dataset_image = st.sidebar.checkbox("Image is from Oxford-IIIT Pet dataset", value=True)
390
-
391
  # Overlay opacity control
392
  overlay_opacity = st.sidebar.slider(
393
  "Overlay Opacity",
@@ -418,11 +341,15 @@ def main():
418
  # Read the image
419
  image_bytes = uploaded_image.read()
420
  image = Image.open(io.BytesIO(image_bytes))
 
 
 
 
421
 
422
  # Preprocess and predict
423
  with st.spinner("Generating segmentation mask..."):
424
  # Preprocess the image
425
- img_tensor, original_img = preprocess_image(image, is_dataset_image=dataset_image)
426
 
427
  # Make prediction
428
  outputs = model(pixel_values=img_tensor, training=False)
@@ -437,7 +364,7 @@ def main():
437
  # Create overlay
438
  overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
439
 
440
- # Prepare for metrics calculation (if ground truth is provided)
441
  gt_mask = None
442
  gt_mask_colorized = None
443
  metrics_calculated = False
@@ -457,15 +384,15 @@ def main():
457
  st.write(f"Ground truth mask shape: {gt_mask_raw.shape}")
458
  st.write(f"Ground truth mask unique values: {np.unique(gt_mask_raw)}")
459
 
460
- # Process the mask based on source
461
- gt_mask = process_uploaded_mask(gt_mask_raw, from_dataset=dataset_image)
462
 
463
  # Colorize for display
464
  gt_mask_colorized = colorize_mask(gt_mask)
465
 
466
  # Resize for comparison
467
  gt_mask_resized = cv2.resize(gt_mask, (mask.shape[0], mask.shape[1]),
468
- interpolation=cv2.INTER_NEAREST)
469
 
470
  if debug_mode:
471
  st.write(f"Processed GT mask shape: {gt_mask_resized.shape}")
@@ -485,6 +412,7 @@ def main():
485
  st.code(traceback.format_exc())
486
 
487
  # Display results
 
488
  display_side_by_side(
489
  original_img,
490
  gt_mask_colorized,
@@ -577,7 +505,7 @@ def main():
577
  # Create CSV with metrics
578
  metrics_csv = f"Metric,Overall,Background,Border,Foreground\n"
579
  metrics_csv += f"IoU,{iou_score:.4f},{calculate_iou(gt_mask_resized, mask, 0):.4f},{calculate_iou(gt_mask_resized, mask, 1):.4f},{calculate_iou(gt_mask_resized, mask, 2):.4f}\n"
580
- metrics_csv += f"Dice,{dice_score:.4f},{calculate_dice(gt_mask_resized, mask, 0):.4f},{calculate_dice(gt_mask_resized, mask, 1):.4f},{calculate_dice(gt_mask_resized, mask, 2)::.4f}\n"
581
  metrics_csv += f"Accuracy,{accuracy:.4f},,,"
582
 
583
  st.download_button(
@@ -592,6 +520,9 @@ def main():
592
  if debug_mode:
593
  import traceback
594
  st.code(traceback.format_exc())
 
 
 
595
 
596
 
597
  if __name__ == "__main__":
 
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
13
+ # Set page config at the very beginning
14
  st.set_page_config(
15
  page_title="Pet Segmentation with SegFormer",
16
  page_icon="🐶",
 
18
  initial_sidebar_state="expanded"
19
  )
20
 
21
+ # Constants for image preprocessing
22
  IMAGE_SIZE = 512
23
  OUTPUT_SIZE = 128
24
  MEAN = tf.constant([0.485, 0.456, 0.406])
25
  STD = tf.constant([0.229, 0.224, 0.225])
26
 
27
+ # Class labels
28
  ID2LABEL = {0: "background", 1: "border", 2: "foreground/pet"}
29
  NUM_CLASSES = len(ID2LABEL)
30
 
 
38
 
39
  if not os.path.exists(model_path):
40
  # Correct format for gdown
41
+ file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
42
+ url = f"https://drive.google.com/uc?id={file_id}"
43
  try:
44
  gdown.download(url, model_path, quiet=False)
45
  st.success("Model downloaded successfully from Google Drive.")
 
55
  def load_model():
56
  """Load the SegFormer model"""
57
  try:
58
+ # Create a base model with the correct architecture
59
  base_model = TFSegformerForSemanticSegmentation.from_pretrained(
60
  "nvidia/mit-b0",
61
  num_labels=NUM_CLASSES,
 
71
  base_model.load_weights(model_path)
72
  st.success("Model weights loaded successfully!")
73
  except Exception as e:
74
+ st.error(f"Error loading weights: {e}")
75
+ st.warning("Using base pretrained model instead.")
 
76
 
77
  return base_model
78
 
 
88
  return input_image
89
 
90
 
91
+ def preprocess_image(image):
92
+ """Preprocess image exactly like in colab_code.py"""
 
 
 
 
 
 
 
 
 
93
  # Convert PIL Image to numpy array
94
  img_array = np.array(image.convert('RGB'))
95
 
96
  # Store original image for display
97
  original_img = img_array.copy()
98
 
99
+ # Resize to target size
100
  img_resized = tf.image.resize(
101
  img_array,
102
  (IMAGE_SIZE, IMAGE_SIZE),
103
  method='bilinear',
104
+ preserve_aspect_ratio=False,
105
  antialias=True
106
  )
107
 
108
  # Normalize
109
  img_normalized = normalize_image(img_resized)
110
 
111
+ # Transpose from HWC to CHW (channels first)
112
  img_transposed = tf.transpose(img_normalized, (2, 0, 1))
113
 
114
  # Add batch dimension
 
117
  return img_batch, original_img
118
 
119
 
120
+ def process_uploaded_mask(mask_array):
121
  """
122
+ Process an uploaded mask from save_image_and_mask_to_local function
123
 
124
  Args:
125
  mask_array: Numpy array of the mask
 
126
 
127
  Returns:
128
  Processed mask with values 0,1,2
 
135
  if len(mask_array.shape) == 3 and mask_array.shape[2] >= 3:
136
  mask_array = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
137
 
138
+ # Check the unique values in the mask to determine processing
139
+ unique_values = np.unique(mask_array)
140
+
141
+ # If mask has values 1,2,3 (from the dataset), convert to 0,1,2
142
+ if 3 in unique_values:
143
  processed_mask = np.zeros_like(mask_array)
 
 
144
  processed_mask[mask_array == 1] = 2 # Foreground/pet (1→2)
145
  processed_mask[mask_array == 2] = 1 # Border (2→1)
146
  processed_mask[mask_array == 3] = 0 # Background (3→0)
 
147
  return processed_mask
148
+
149
+ # If mask has values 0,1,2 already, just return it
150
+ elif 0 in unique_values and 2 in unique_values:
151
  return mask_array
152
+
153
+ # If we can't determine the format, use binary threshold as fallback
154
+ else:
155
+ # Use binary threshold to create a simple foreground/background mask
156
+ _, binary_mask = cv2.threshold(mask_array, 127, 2, cv2.THRESH_BINARY)
157
+ return binary_mask
158
 
159
 
160
  def create_mask(pred_mask):
161
+ """Convert model prediction to mask"""
 
 
 
 
 
 
 
 
 
162
  pred_mask = tf.math.argmax(pred_mask, axis=1)
 
 
163
  pred_mask = tf.squeeze(pred_mask)
 
164
  return pred_mask.numpy()
165
 
166
 
167
  def colorize_mask(mask):
168
+ """Colorize a segmentation mask for visualization"""
 
 
 
 
 
 
 
 
169
  # Define colors for visualization
170
  colors = [
171
  [0, 0, 0], # Black for background (0)
 
185
 
186
 
187
  def create_overlay(image, mask, alpha=0.5):
188
+ """Create an overlay of mask on original image"""
 
 
 
 
 
 
 
 
 
 
189
  # Ensure mask shape matches image
190
  if image.shape[:2] != mask.shape[:2]:
191
  mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
 
203
 
204
 
205
  def calculate_iou(y_true, y_pred, class_idx=None):
206
+ """Calculate IoU (Intersection over Union)"""
 
 
 
 
 
 
 
 
 
 
207
  if class_idx is not None:
208
  # Convert to binary masks for specific class
209
  y_true_class = (y_true == class_idx).astype(np.float32)
 
222
  class_iou = calculate_iou(y_true, y_pred, idx)
223
  class_ious.append(class_iou)
224
 
225
+ return np.mean(class_ious)
 
 
226
 
227
 
228
  def calculate_dice(y_true, y_pred, class_idx=None):
229
+ """Calculate Dice coefficient (F1 score)"""
 
 
 
 
 
 
 
 
 
 
230
  if class_idx is not None:
231
  # Convert to binary masks for specific class
232
  y_true_class = (y_true == class_idx).astype(np.float32)
 
245
  class_dice = calculate_dice(y_true, y_pred, idx)
246
  class_dices.append(class_dice)
247
 
248
+ return np.mean(class_dices)
 
 
249
 
250
 
251
  def calculate_pixel_accuracy(y_true, y_pred):
252
+ """Calculate pixel accuracy"""
 
 
 
 
 
 
 
 
 
253
  correct = np.sum(y_true == y_pred)
254
  total = y_true.size
255
  return float(correct) / float(total)
256
 
257
 
258
  def display_side_by_side(original_img, gt_mask=None, pred_mask=None, overlay=None):
259
+ """Display images side by side"""
 
 
 
 
 
 
 
 
260
  # Determine number of columns based on available images
261
  columns = 1 # Start with original image
262
  if gt_mask is not None:
 
311
  # Debug mode toggle
312
  debug_mode = st.sidebar.checkbox("Debug Mode", value=False)
313
 
 
 
 
314
  # Overlay opacity control
315
  overlay_opacity = st.sidebar.slider(
316
  "Overlay Opacity",
 
341
  # Read the image
342
  image_bytes = uploaded_image.read()
343
  image = Image.open(io.BytesIO(image_bytes))
344
+
345
+ # Display the original image first
346
+ st.subheader("Original Image")
347
+ st.image(image, caption="Uploaded Image", use_column_width=True)
348
 
349
  # Preprocess and predict
350
  with st.spinner("Generating segmentation mask..."):
351
  # Preprocess the image
352
+ img_tensor, original_img = preprocess_image(image)
353
 
354
  # Make prediction
355
  outputs = model(pixel_values=img_tensor, training=False)
 
364
  # Create overlay
365
  overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
366
 
367
+ # Prepare for metrics calculation
368
  gt_mask = None
369
  gt_mask_colorized = None
370
  metrics_calculated = False
 
384
  st.write(f"Ground truth mask shape: {gt_mask_raw.shape}")
385
  st.write(f"Ground truth mask unique values: {np.unique(gt_mask_raw)}")
386
 
387
+ # Process the mask
388
+ gt_mask = process_uploaded_mask(gt_mask_raw)
389
 
390
  # Colorize for display
391
  gt_mask_colorized = colorize_mask(gt_mask)
392
 
393
  # Resize for comparison
394
  gt_mask_resized = cv2.resize(gt_mask, (mask.shape[0], mask.shape[1]),
395
+ interpolation=cv2.INTER_NEAREST)
396
 
397
  if debug_mode:
398
  st.write(f"Processed GT mask shape: {gt_mask_resized.shape}")
 
412
  st.code(traceback.format_exc())
413
 
414
  # Display results
415
+ st.subheader("Segmentation Results")
416
  display_side_by_side(
417
  original_img,
418
  gt_mask_colorized,
 
505
  # Create CSV with metrics
506
  metrics_csv = f"Metric,Overall,Background,Border,Foreground\n"
507
  metrics_csv += f"IoU,{iou_score:.4f},{calculate_iou(gt_mask_resized, mask, 0):.4f},{calculate_iou(gt_mask_resized, mask, 1):.4f},{calculate_iou(gt_mask_resized, mask, 2):.4f}\n"
508
+ metrics_csv += f"Dice,{dice_score:.4f},{calculate_dice(gt_mask_resized, mask, 0):.4f},{calculate_dice(gt_mask_resized, mask, 1):.4f},{calculate_dice(gt_mask_resized, mask, 2):.4f}\n"
509
  metrics_csv += f"Accuracy,{accuracy:.4f},,,"
510
 
511
  st.download_button(
 
520
  if debug_mode:
521
  import traceback
522
  st.code(traceback.format_exc())
523
+ else:
524
+ # Display sample images if no image is uploaded
525
+ st.info("Please upload an image to get started.")
526
 
527
 
528
  if __name__ == "__main__":