kishoreb4 commited on
Commit
6e791bb
·
verified ·
1 Parent(s): abec730

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -30
app.py CHANGED
@@ -124,12 +124,13 @@ def normalize_image(input_image):
124
  input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
125
  return input_image
126
 
127
- def preprocess_image(image):
128
  """
129
- Preprocess image for model input
130
 
131
  Args:
132
  image: PIL Image to preprocess
 
133
 
134
  Returns:
135
  Preprocessed image tensor, original image
@@ -140,23 +141,25 @@ def preprocess_image(image):
140
  # Store original image for display
141
  original_img = img_array.copy()
142
 
143
- # Resize to target size with preserve_aspect_ratio=False
144
  img_resized = tf.image.resize(
145
  img_array,
146
  (IMAGE_SIZE, IMAGE_SIZE),
147
  method='bilinear',
148
- preserve_aspect_ratio=False, # Ensure exact dimensions
149
  antialias=True
150
  )
151
 
152
- # Verify dimensions with assertion
153
- tf.debugging.assert_equal(tf.shape(img_resized)[0:2], [IMAGE_SIZE, IMAGE_SIZE],
154
- message="Image dimensions don't match expected size")
155
-
156
- # Normalize
157
- img_normalized = normalize_image(img_resized)
 
 
158
 
159
- # Transpose from HWC to CHW (SegFormer expects channels first)
160
  img_transposed = tf.transpose(img_normalized, (2, 0, 1))
161
 
162
  # Add batch dimension
@@ -329,6 +332,43 @@ def display_results_side_by_side(original_image, ground_truth_mask=None, predict
329
 
330
  st.image(pred_display, use_column_width=True)
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  def main():
333
  st.title("🐶 Pet Segmentation with SegFormer")
334
  st.markdown("""
@@ -364,6 +404,9 @@ def main():
364
  step=0.1
365
  )
366
 
 
 
 
367
  # Load model
368
  with st.spinner("Loading SegFormer model..."):
369
  model = load_model()
@@ -395,7 +438,7 @@ def main():
395
  with st.spinner("Generating segmentation mask..."):
396
  try:
397
  # Preprocess the image
398
- img_tensor, original_img = preprocess_image(image)
399
 
400
  # Print shape to debug
401
  st.write(f"DEBUG - Input tensor shape: {img_tensor.shape}")
@@ -470,37 +513,48 @@ def main():
470
  # Read the mask file
471
  mask_data = uploaded_mask.read()
472
  mask_io = io.BytesIO(mask_data)
473
- gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
474
 
475
- # Handle different mask formats
476
- if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
477
- # Convert RGB to single channel if needed
478
- gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
  # Calculate and display IoU
481
- resized_mask = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE), interpolation=cv2.INTER_NEAREST)
482
- iou_score = calculate_iou(gt_mask, resized_mask)
483
  st.success(f"Mean IoU: {iou_score:.4f}")
484
 
485
  # Display specific class IoUs
486
  st.markdown("### IoU by Class")
487
  col1, col2, col3 = st.columns(3)
488
  with col1:
489
- bg_iou = calculate_iou(gt_mask, resized_mask, 0)
490
  st.metric("Background IoU", f"{bg_iou:.4f}")
491
  with col2:
492
- border_iou = calculate_iou(gt_mask, resized_mask, 1)
493
  st.metric("Border IoU", f"{border_iou:.4f}")
494
  with col3:
495
- fg_iou = calculate_iou(gt_mask, resized_mask, 2)
496
  st.metric("Foreground IoU", f"{fg_iou:.4f}")
497
-
498
- # For display (original size)
499
- # Reset the file pointer again
500
- uploaded_mask.seek(0)
501
- mask_data = uploaded_mask.read()
502
- mask_io = io.BytesIO(mask_data)
503
- gt_mask_for_display = np.array(Image.open(mask_io))
504
 
505
  # Side-by-side display
506
  display_results_side_by_side(
@@ -508,10 +562,11 @@ def main():
508
  ground_truth_mask=gt_mask_for_display,
509
  predicted_mask=colorized_mask
510
  )
511
-
512
  except Exception as e:
513
  st.error(f"Error processing ground truth mask: {e}")
514
  st.write("Please ensure the mask is valid and has the correct format.")
 
 
515
 
516
  # Even with an error, try to display results without the ground truth
517
  display_results_side_by_side(
 
124
  input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
125
  return input_image
126
 
127
+ def preprocess_image(image, from_dataset=False):
128
  """
129
+ Preprocess image for model input with special handling for dataset images
130
 
131
  Args:
132
  image: PIL Image to preprocess
133
+ from_dataset: Whether the image is from the original dataset
134
 
135
  Returns:
136
  Preprocessed image tensor, original image
 
141
  # Store original image for display
142
  original_img = img_array.copy()
143
 
144
+ # Resize to target size
145
  img_resized = tf.image.resize(
146
  img_array,
147
  (IMAGE_SIZE, IMAGE_SIZE),
148
  method='bilinear',
149
+ preserve_aspect_ratio=False,
150
  antialias=True
151
  )
152
 
153
+ # Special handling for dataset images
154
+ if from_dataset:
155
+ # The dataset already has specific dimensions, just normalize
156
+ # Skip additional preprocessing that might have been applied
157
+ img_normalized = normalize_image(img_resized)
158
+ else:
159
+ # Regular preprocessing for uploaded images
160
+ img_normalized = normalize_image(img_resized)
161
 
162
+ # Transpose from HWC to CHW (channels first)
163
  img_transposed = tf.transpose(img_normalized, (2, 0, 1))
164
 
165
  # Add batch dimension
 
332
 
333
  st.image(pred_display, use_column_width=True)
334
 
335
+ def process_uploaded_mask(mask_array, from_dataset=False):
336
+ """
337
+ Process an uploaded mask to ensure it has the correct format
338
+
339
+ Args:
340
+ mask_array: Numpy array of the mask
341
+ from_dataset: Whether the mask is from the original dataset
342
+
343
+ Returns:
344
+ Processed mask with values 0,1,2
345
+ """
346
+ # Check for RGBA format and convert to RGB if needed
347
+ if len(mask_array.shape) == 3 and mask_array.shape[2] == 4:
348
+ # Convert RGBA to RGB (discard alpha channel)
349
+ mask_array = mask_array[:,:,:3]
350
+
351
+ # Convert RGB to grayscale if needed
352
+ if len(mask_array.shape) == 3 and mask_array.shape[2] >= 3:
353
+ # Convert RGB to grayscale
354
+ mask_array = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
355
+
356
+ if from_dataset:
357
+ # For dataset masks (saved from your colab code):
358
+ # Create an empty mask with the same shape
359
+ processed_mask = np.zeros_like(mask_array)
360
+
361
+ # Map the values correctly:
362
+ # Original dataset uses 1,2,3 but your app expects 0,1,2
363
+ processed_mask[mask_array == 1] = 2 # Foreground/pet (1→2)
364
+ processed_mask[mask_array == 2] = 1 # Border (2→1)
365
+ processed_mask[mask_array == 3] = 0 # Background (3→0)
366
+
367
+ return processed_mask
368
+ else:
369
+ # For non-dataset masks, we assume they have correct class values
370
+ return mask_array
371
+
372
  def main():
373
  st.title("🐶 Pet Segmentation with SegFormer")
374
  st.markdown("""
 
404
  step=0.1
405
  )
406
 
407
+ # Add this checkbox to your app's UI
408
+ dataset_image = st.sidebar.checkbox("Image is from the Oxford-IIIT Pet dataset")
409
+
410
  # Load model
411
  with st.spinner("Loading SegFormer model..."):
412
  model = load_model()
 
438
  with st.spinner("Generating segmentation mask..."):
439
  try:
440
  # Preprocess the image
441
+ img_tensor, original_img = preprocess_image(image, from_dataset=dataset_image)
442
 
443
  # Print shape to debug
444
  st.write(f"DEBUG - Input tensor shape: {img_tensor.shape}")
 
513
  # Read the mask file
514
  mask_data = uploaded_mask.read()
515
  mask_io = io.BytesIO(mask_data)
 
516
 
517
+ # Load the raw mask
518
+ raw_mask = np.array(Image.open(mask_io))
519
+
520
+ # Show debug info
521
+ st.write(f"Debug - Raw mask shape: {raw_mask.shape}")
522
+ st.write(f"Debug - Raw mask unique values: {np.unique(raw_mask)}")
523
+
524
+ # Process the mask based on source
525
+ processed_gt_mask = process_uploaded_mask(raw_mask, from_dataset=dataset_image)
526
+
527
+ # Resize for IoU calculation
528
+ gt_mask_resized = cv2.resize(processed_gt_mask, (OUTPUT_SIZE, OUTPUT_SIZE),
529
+ interpolation=cv2.INTER_NEAREST)
530
+
531
+ # Resize prediction for comparison
532
+ pred_mask_resized = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE),
533
+ interpolation=cv2.INTER_NEAREST)
534
+
535
+ # Show processed values
536
+ st.write(f"Debug - Processed GT mask unique values: {np.unique(gt_mask_resized)}")
537
+ st.write(f"Debug - Prediction mask unique values: {np.unique(pred_mask_resized)}")
538
 
539
  # Calculate and display IoU
540
+ iou_score = calculate_iou(gt_mask_resized, pred_mask_resized)
 
541
  st.success(f"Mean IoU: {iou_score:.4f}")
542
 
543
  # Display specific class IoUs
544
  st.markdown("### IoU by Class")
545
  col1, col2, col3 = st.columns(3)
546
  with col1:
547
+ bg_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 0)
548
  st.metric("Background IoU", f"{bg_iou:.4f}")
549
  with col2:
550
+ border_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 1)
551
  st.metric("Border IoU", f"{border_iou:.4f}")
552
  with col3:
553
+ fg_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 2)
554
  st.metric("Foreground IoU", f"{fg_iou:.4f}")
555
+
556
+ # For display, create a colorized version of the ground truth
557
+ gt_mask_for_display = colorize_mask(processed_gt_mask)
 
 
 
 
558
 
559
  # Side-by-side display
560
  display_results_side_by_side(
 
562
  ground_truth_mask=gt_mask_for_display,
563
  predicted_mask=colorized_mask
564
  )
 
565
  except Exception as e:
566
  st.error(f"Error processing ground truth mask: {e}")
567
  st.write("Please ensure the mask is valid and has the correct format.")
568
+ import traceback
569
+ st.code(traceback.format_exc()) # Show detailed error trace
570
 
571
  # Even with an error, try to display results without the ground truth
572
  display_results_side_by_side(