kishoreb4 commited on
Commit
cb7125b
·
verified ·
1 Parent(s): 3deb70d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -236
app.py CHANGED
@@ -1,12 +1,4 @@
1
  import streamlit as st
2
- # THIS MUST BE THE FIRST STREAMLIT COMMAND
3
- st.set_page_config(
4
- page_title="Pet Segmentation with SegFormer",
5
- page_icon="🐶",
6
- layout="wide",
7
- initial_sidebar_state="expanded"
8
- )
9
-
10
  import tensorflow as tf
11
  from tensorflow.keras import backend
12
  import numpy as np
@@ -18,18 +10,12 @@ import io
18
  import gdown
19
  from transformers import TFSegformerForSemanticSegmentation
20
 
21
-
22
- try:
23
- # Limit GPU memory growth
24
- gpus = tf.config.experimental.list_physical_devices('GPU')
25
- if gpus:
26
- for gpu in gpus:
27
- tf.config.experimental.set_memory_growth(gpu, True)
28
- st.sidebar.success(f"GPU available: {len(gpus)} device(s)")
29
- else:
30
- st.sidebar.warning("No GPU detected, using CPU")
31
- except Exception as e:
32
- st.sidebar.error(f"GPU config error: {e}")
33
 
34
  # Constants for image preprocessing
35
  IMAGE_SIZE = 512
@@ -49,7 +35,7 @@ def download_model_from_drive():
49
  model_path = "models/tf_model.h5"
50
 
51
  if not os.path.exists(model_path):
52
- # Extract the file ID from the sharing URL
53
  url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
54
  try:
55
  gdown.download(url, model_path, quiet=False)
@@ -124,13 +110,12 @@ def normalize_image(input_image):
124
  input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
125
  return input_image
126
 
127
- def preprocess_image(image, from_dataset=False):
128
  """
129
- Preprocess image for model input with special handling for dataset images
130
 
131
  Args:
132
  image: PIL Image to preprocess
133
- from_dataset: Whether the image is from the original dataset
134
 
135
  Returns:
136
  Preprocessed image tensor, original image
@@ -142,24 +127,12 @@ def preprocess_image(image, from_dataset=False):
142
  original_img = img_array.copy()
143
 
144
  # Resize to target size
145
- img_resized = tf.image.resize(
146
- img_array,
147
- (IMAGE_SIZE, IMAGE_SIZE),
148
- method='bilinear',
149
- preserve_aspect_ratio=False,
150
- antialias=True
151
- )
152
 
153
- # Special handling for dataset images
154
- if from_dataset:
155
- # The dataset already has specific dimensions, just normalize
156
- # Skip additional preprocessing that might have been applied
157
- img_normalized = normalize_image(img_resized)
158
- else:
159
- # Regular preprocessing for uploaded images
160
- img_normalized = normalize_image(img_resized)
161
 
162
- # Transpose from HWC to CHW (channels first)
163
  img_transposed = tf.transpose(img_normalized, (2, 0, 1))
164
 
165
  # Add batch dimension
@@ -177,21 +150,8 @@ def create_mask(pred_mask):
177
  Returns:
178
  Processed mask (2D array)
179
  """
180
- # Take argmax along the class dimension (axis=1 for batch data)
181
  pred_mask = tf.math.argmax(pred_mask, axis=1)
182
-
183
- # Remove batch dimension and convert to numpy
184
  pred_mask = tf.squeeze(pred_mask)
185
-
186
- # Resize to match original image size if needed
187
- if pred_mask.shape[0] != IMAGE_SIZE or pred_mask.shape[1] != IMAGE_SIZE:
188
- pred_mask = tf.image.resize(
189
- tf.expand_dims(pred_mask, axis=-1),
190
- (IMAGE_SIZE, IMAGE_SIZE),
191
- method='nearest'
192
- )
193
- pred_mask = tf.squeeze(pred_mask)
194
-
195
  return pred_mask.numpy()
196
 
197
  def colorize_mask(mask):
@@ -284,91 +244,6 @@ def create_overlay(image, mask, alpha=0.5):
284
 
285
  return overlay
286
 
287
- def display_results_side_by_side(original_image, ground_truth_mask=None, predicted_mask=None):
288
- """
289
- Display results in a side-by-side format similar to colab_code.py
290
-
291
- Args:
292
- original_image: Original input image
293
- ground_truth_mask: Optional ground truth segmentation mask
294
- predicted_mask: Predicted segmentation mask
295
- """
296
- # Determine how many images to display
297
- cols = 1 + (ground_truth_mask is not None) + (predicted_mask is not None)
298
-
299
- # Create a figure with multiple columns
300
- st.write("### Segmentation Results Comparison")
301
-
302
- col_list = st.columns(cols)
303
-
304
- # Display original image
305
- with col_list[0]:
306
- st.markdown("**Original Image**")
307
- st.image(original_image, use_column_width=True)
308
-
309
- # Display ground truth if available
310
- if ground_truth_mask is not None:
311
- with col_list[1]:
312
- st.markdown("**Ground Truth Mask**")
313
-
314
- # Colorize ground truth if needed
315
- if len(ground_truth_mask.shape) == 2:
316
- gt_display = colorize_mask(ground_truth_mask)
317
- else:
318
- gt_display = ground_truth_mask
319
-
320
- st.image(gt_display, use_column_width=True)
321
-
322
- # Display prediction
323
- if predicted_mask is not None:
324
- with col_list[2 if ground_truth_mask is not None else 1]:
325
- st.markdown("**Predicted Mask**")
326
-
327
- # Colorize prediction if needed
328
- if len(predicted_mask.shape) == 2:
329
- pred_display = colorize_mask(predicted_mask)
330
- else:
331
- pred_display = predicted_mask
332
-
333
- st.image(pred_display, use_column_width=True)
334
-
335
- def process_uploaded_mask(mask_array, from_dataset=False):
336
- """
337
- Process an uploaded mask to ensure it has the correct format
338
-
339
- Args:
340
- mask_array: Numpy array of the mask
341
- from_dataset: Whether the mask is from the original dataset
342
-
343
- Returns:
344
- Processed mask with values 0,1,2
345
- """
346
- # Check for RGBA format and convert to RGB if needed
347
- if len(mask_array.shape) == 3 and mask_array.shape[2] == 4:
348
- # Convert RGBA to RGB (discard alpha channel)
349
- mask_array = mask_array[:,:,:3]
350
-
351
- # Convert RGB to grayscale if needed
352
- if len(mask_array.shape) == 3 and mask_array.shape[2] >= 3:
353
- # Convert RGB to grayscale
354
- mask_array = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
355
-
356
- if from_dataset:
357
- # For dataset masks (saved from your colab code):
358
- # Create an empty mask with the same shape
359
- processed_mask = np.zeros_like(mask_array)
360
-
361
- # Map the values correctly:
362
- # Original dataset uses 1,2,3 but your app expects 0,1,2
363
- processed_mask[mask_array == 1] = 2 # Foreground/pet (1→2)
364
- processed_mask[mask_array == 2] = 1 # Border (2→1)
365
- processed_mask[mask_array == 3] = 0 # Background (3→0)
366
-
367
- return processed_mask
368
- else:
369
- # For non-dataset masks, we assume they have correct class values
370
- return mask_array
371
-
372
  def main():
373
  st.title("🐶 Pet Segmentation with SegFormer")
374
  st.markdown("""
@@ -404,9 +279,6 @@ def main():
404
  step=0.1
405
  )
406
 
407
- # Add this checkbox to your app's UI
408
- dataset_image = st.sidebar.checkbox("Image is from the Oxford-IIIT Pet dataset")
409
-
410
  # Load model
411
  with st.spinner("Loading SegFormer model..."):
412
  model = load_model()
@@ -436,46 +308,21 @@ def main():
436
 
437
  # Preprocess and predict
438
  with st.spinner("Generating segmentation mask..."):
439
- try:
440
- # Preprocess the image
441
- img_tensor, original_img = preprocess_image(image, from_dataset=dataset_image)
442
-
443
- # Print shape to debug
444
- st.write(f"DEBUG - Input tensor shape: {img_tensor.shape}")
445
-
446
- # Make prediction with error handling
447
- try:
448
- outputs = model(pixel_values=img_tensor, training=False)
449
- logits = outputs.logits
450
-
451
- # Create visualization mask
452
- mask = create_mask(logits)
453
-
454
- # Colorize the mask
455
- colorized_mask = colorize_mask(mask)
456
-
457
- # Create overlay
458
- overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
459
- except Exception as inference_error:
460
- st.error(f"Inference error: {inference_error}")
461
- st.write("Trying alternative approach...")
462
-
463
- # Alternative: resize to exactly 512x512 with crop_or_pad
464
- img_resized = tf.image.resize_with_crop_or_pad(
465
- original_img, IMAGE_SIZE, IMAGE_SIZE
466
- )
467
- img_normalized = normalize_image(img_resized)
468
- img_transposed = tf.transpose(img_normalized, (2, 0, 1))
469
- img_tensor = tf.expand_dims(img_transposed, axis=0)
470
-
471
- outputs = model(pixel_values=img_tensor, training=False)
472
- logits = outputs.logits
473
- mask = create_mask(logits)
474
- colorized_mask = colorize_mask(mask)
475
- overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
476
- except Exception as e:
477
- st.error(f"Failed to process image: {e}")
478
- st.stop()
479
 
480
  # Display results
481
  with col2:
@@ -507,81 +354,37 @@ def main():
507
  # Calculate IoU if ground truth is uploaded
508
  if uploaded_mask is not None:
509
  try:
510
- # Reset the file pointer to the beginning
511
- uploaded_mask.seek(0)
512
-
513
  # Read the mask file
514
  mask_data = uploaded_mask.read()
515
  mask_io = io.BytesIO(mask_data)
 
516
 
517
- # Load the raw mask
518
- raw_mask = np.array(Image.open(mask_io))
519
-
520
- # Show debug info
521
- st.write(f"Debug - Raw mask shape: {raw_mask.shape}")
522
- st.write(f"Debug - Raw mask unique values: {np.unique(raw_mask)}")
523
-
524
- # Process the mask based on source
525
- processed_gt_mask = process_uploaded_mask(raw_mask, from_dataset=dataset_image)
526
-
527
- # Resize for IoU calculation
528
- gt_mask_resized = cv2.resize(processed_gt_mask, (OUTPUT_SIZE, OUTPUT_SIZE),
529
- interpolation=cv2.INTER_NEAREST)
530
-
531
- # Resize prediction for comparison
532
- pred_mask_resized = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE),
533
- interpolation=cv2.INTER_NEAREST)
534
-
535
- # Show processed values
536
- st.write(f"Debug - Processed GT mask unique values: {np.unique(gt_mask_resized)}")
537
- st.write(f"Debug - Prediction mask unique values: {np.unique(pred_mask_resized)}")
538
 
539
  # Calculate and display IoU
540
- iou_score = calculate_iou(gt_mask_resized, pred_mask_resized)
 
541
  st.success(f"Mean IoU: {iou_score:.4f}")
542
 
543
  # Display specific class IoUs
544
  st.markdown("### IoU by Class")
545
  col1, col2, col3 = st.columns(3)
546
  with col1:
547
- bg_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 0)
548
  st.metric("Background IoU", f"{bg_iou:.4f}")
549
  with col2:
550
- border_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 1)
551
  st.metric("Border IoU", f"{border_iou:.4f}")
552
  with col3:
553
- fg_iou = calculate_iou(gt_mask_resized, pred_mask_resized, 2)
554
  st.metric("Foreground IoU", f"{fg_iou:.4f}")
555
-
556
- # For display, create a colorized version of the ground truth
557
- gt_mask_for_display = colorize_mask(processed_gt_mask)
558
-
559
- # Side-by-side display
560
- display_results_side_by_side(
561
- original_img,
562
- ground_truth_mask=gt_mask_for_display,
563
- predicted_mask=colorized_mask
564
- )
565
  except Exception as e:
566
  st.error(f"Error processing ground truth mask: {e}")
567
  st.write("Please ensure the mask is valid and has the correct format.")
568
- import traceback
569
- st.code(traceback.format_exc()) # Show detailed error trace
570
-
571
- # Even with an error, try to display results without the ground truth
572
- display_results_side_by_side(
573
- original_img,
574
- ground_truth_mask=None,
575
- predicted_mask=colorized_mask
576
- )
577
- else:
578
- # No ground truth, just display original and prediction
579
- display_results_side_by_side(
580
- original_img,
581
- ground_truth_mask=None,
582
- predicted_mask=colorized_mask
583
- )
584
-
585
  # Download buttons
586
  col1, col2 = st.columns(2)
587
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
2
  import tensorflow as tf
3
  from tensorflow.keras import backend
4
  import numpy as np
 
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
13
+ st.set_page_config(
14
+ page_title="Pet Segmentation with SegFormer",
15
+ page_icon="🐶",
16
+ layout="wide",
17
+ initial_sidebar_state="expanded"
18
+ )
 
 
 
 
 
 
19
 
20
  # Constants for image preprocessing
21
  IMAGE_SIZE = 512
 
35
  model_path = "models/tf_model.h5"
36
 
37
  if not os.path.exists(model_path):
38
+ # Fixed Google Drive URL format for gdown
39
  url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
40
  try:
41
  gdown.download(url, model_path, quiet=False)
 
110
  input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
111
  return input_image
112
 
113
+ def preprocess_image(image):
114
  """
115
+ Preprocess image for model input
116
 
117
  Args:
118
  image: PIL Image to preprocess
 
119
 
120
  Returns:
121
  Preprocessed image tensor, original image
 
127
  original_img = img_array.copy()
128
 
129
  # Resize to target size
130
+ img_resized = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
 
 
 
 
 
 
131
 
132
+ # Normalize
133
+ img_normalized = normalize_image(img_resized)
 
 
 
 
 
 
134
 
135
+ # Transpose from HWC to CHW (SegFormer expects channels first)
136
  img_transposed = tf.transpose(img_normalized, (2, 0, 1))
137
 
138
  # Add batch dimension
 
150
  Returns:
151
  Processed mask (2D array)
152
  """
 
153
  pred_mask = tf.math.argmax(pred_mask, axis=1)
 
 
154
  pred_mask = tf.squeeze(pred_mask)
 
 
 
 
 
 
 
 
 
 
155
  return pred_mask.numpy()
156
 
157
  def colorize_mask(mask):
 
244
 
245
  return overlay
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def main():
248
  st.title("🐶 Pet Segmentation with SegFormer")
249
  st.markdown("""
 
279
  step=0.1
280
  )
281
 
 
 
 
282
  # Load model
283
  with st.spinner("Loading SegFormer model..."):
284
  model = load_model()
 
308
 
309
  # Preprocess and predict
310
  with st.spinner("Generating segmentation mask..."):
311
+ # Preprocess the image
312
+ img_tensor, original_img = preprocess_image(image)
313
+
314
+ # Make prediction
315
+ outputs = model(pixel_values=img_tensor, training=False)
316
+ logits = outputs.logits
317
+
318
+ # Create visualization mask
319
+ mask = create_mask(logits)
320
+
321
+ # Colorize the mask
322
+ colorized_mask = colorize_mask(mask)
323
+
324
+ # Create overlay
325
+ overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  # Display results
328
  with col2:
 
354
  # Calculate IoU if ground truth is uploaded
355
  if uploaded_mask is not None:
356
  try:
 
 
 
357
  # Read the mask file
358
  mask_data = uploaded_mask.read()
359
  mask_io = io.BytesIO(mask_data)
360
+ gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
361
 
362
+ # Handle different mask formats
363
+ if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
364
+ # Convert RGB to single channel if needed
365
+ gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
  # Calculate and display IoU
368
+ resized_mask = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE), interpolation=cv2.INTER_NEAREST)
369
+ iou_score = calculate_iou(gt_mask, resized_mask)
370
  st.success(f"Mean IoU: {iou_score:.4f}")
371
 
372
  # Display specific class IoUs
373
  st.markdown("### IoU by Class")
374
  col1, col2, col3 = st.columns(3)
375
  with col1:
376
+ bg_iou = calculate_iou(gt_mask, resized_mask, 0)
377
  st.metric("Background IoU", f"{bg_iou:.4f}")
378
  with col2:
379
+ border_iou = calculate_iou(gt_mask, resized_mask, 1)
380
  st.metric("Border IoU", f"{border_iou:.4f}")
381
  with col3:
382
+ fg_iou = calculate_iou(gt_mask, resized_mask, 2)
383
  st.metric("Foreground IoU", f"{fg_iou:.4f}")
 
 
 
 
 
 
 
 
 
 
384
  except Exception as e:
385
  st.error(f"Error processing ground truth mask: {e}")
386
  st.write("Please ensure the mask is valid and has the correct format.")
387
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  # Download buttons
389
  col1, col2 = st.columns(2)
390