kishoreb4 commited on
Commit
f94828c
·
1 Parent(s): 3176694
Files changed (1) hide show
  1. app.py +150 -177
app.py CHANGED
@@ -30,10 +30,13 @@ NUM_CLASSES = len(ID2LABEL)
30
 
31
  @st.cache_resource
32
  def download_model_from_drive():
33
- model_path = "tf_model.h5"
34
- if not os.path.exists(model_path):
35
- # Fix the Google Drive link format - this is why the download is failing
36
- url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
 
 
 
37
  try:
38
  gdown.download(url, model_path, quiet=False)
39
  st.success("Model downloaded successfully from Google Drive.")
@@ -44,6 +47,7 @@ def download_model_from_drive():
44
  st.info("Model already exists locally.")
45
  return model_path
46
 
 
47
  @st.cache_resource
48
  def load_model():
49
  """
@@ -53,58 +57,44 @@ def load_model():
53
  Loaded model
54
  """
55
  try:
56
- # Download the model first
 
 
 
 
 
 
 
 
 
57
  model_path = download_model_from_drive()
58
 
59
- if model_path is None or not os.path.exists(model_path):
60
- st.warning("Using default pretrained model since download failed")
61
- # Fall back to pretrained model
62
- model = TFSegformerForSemanticSegmentation.from_pretrained(
63
- "nvidia/mit-b0",
64
- num_labels=NUM_CLASSES,
65
- id2label=ID2LABEL,
66
- label2id={label: id for id, label in ID2LABEL.items()},
67
- ignore_mismatched_sizes=True
68
- )
69
- else:
70
- # For a HuggingFace model saved with SavedModel format
71
- st.info("Loading SegFormer model...")
72
  try:
73
- # First try loading as HuggingFace model
74
- model = TFSegformerForSemanticSegmentation.from_pretrained(
75
- "nvidia/mit-b0",
76
- num_labels=NUM_CLASSES,
77
- id2label=ID2LABEL,
78
- label2id={label: id for id, label in ID2LABEL.items()},
79
- ignore_mismatched_sizes=True
80
- )
81
- # Then load weights from h5 file
82
- model.load_weights(model_path)
83
- st.success("Model weights loaded successfully")
84
  except Exception as e:
85
- st.error(f"Error loading model weights: {str(e)}")
86
- st.warning("Falling back to pretrained model")
87
- model = TFSegformerForSemanticSegmentation.from_pretrained(
88
- "nvidia/mit-b0",
89
- num_labels=NUM_CLASSES,
90
- id2label=ID2LABEL,
91
- label2id={label: id for id, label in ID2LABEL.items()},
92
- ignore_mismatched_sizes=True
93
- )
94
-
95
- return model
96
  except Exception as e:
97
- st.error(f"Error loading model: {str(e)}")
98
- st.error("Falling back to pretrained model")
99
  # Fall back to pretrained model as a last resort
100
- model = TFSegformerForSemanticSegmentation.from_pretrained(
101
  "nvidia/mit-b0",
102
  num_labels=NUM_CLASSES,
103
  id2label=ID2LABEL,
104
  label2id={label: id for id, label in ID2LABEL.items()},
105
  ignore_mismatched_sizes=True
106
  )
107
- return model
108
 
109
  def normalize_image(input_image):
110
  """
@@ -161,7 +151,7 @@ def create_mask(pred_mask):
161
  Processed mask (2D array)
162
  """
163
  pred_mask = tf.math.argmax(pred_mask, axis=1)
164
- pred_mask = tf.squeeze(pred_mask, axis=0) # Remove batch dimension
165
  return pred_mask.numpy()
166
 
167
  def colorize_mask(mask):
@@ -176,7 +166,7 @@ def colorize_mask(mask):
176
  """
177
  # Ensure the mask is 2D
178
  if len(mask.shape) > 2:
179
- mask = np.squeeze(mask, axis=-1)
180
 
181
  # Define colors for each class (RGB)
182
  colors = [
@@ -298,152 +288,135 @@ def main():
298
  else:
299
  st.sidebar.success("Model loaded successfully!")
300
 
301
- # Image upload
302
  st.header("Upload an Image")
303
  uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
304
  uploaded_mask = st.file_uploader("Upload ground truth mask (optional):", type=["png", "jpg", "jpeg"])
305
 
306
- # Sample images option
307
- st.markdown("### Or use a sample image:")
308
- sample_dir = "samples"
309
-
310
- # Check if sample directory exists and contains images
311
- sample_files = []
312
- if os.path.exists(sample_dir):
313
- sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
314
-
315
- if sample_files:
316
- selected_sample = st.selectbox("Select a sample image:", sample_files)
317
- use_sample = st.button("Use this sample")
318
-
319
- if use_sample:
320
- with open(os.path.join(sample_dir, selected_sample), "rb") as file:
321
- image_bytes = file.read()
322
- uploaded_image = io.BytesIO(image_bytes)
323
- st.success(f"Using sample image: {selected_sample}")
324
-
325
  # Process uploaded image
326
  if uploaded_image is not None:
327
- # Display original image
328
- image = Image.open(uploaded_image)
329
-
330
- col1, col2 = st.columns(2)
331
-
332
- with col1:
333
- st.subheader("Original Image")
334
- st.image(image, caption="Uploaded Image", use_column_width=True)
335
-
336
- # Preprocess and predict
337
- with st.spinner("Generating segmentation mask..."):
338
- # Preprocess the image
339
- img_tensor, original_img = preprocess_image(image)
340
-
341
- # Make prediction
342
- prediction = model(pixel_values=img_tensor, training=False)
343
- logits = prediction.logits
344
 
345
- # Create visualization mask
346
- mask = create_mask(logits).numpy()
 
347
 
348
- # Colorize the mask
349
- colorized_mask = colorize_mask(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
- # Create overlay
352
- overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
353
-
354
- # Display results
355
- with col2:
356
- st.subheader("Segmentation Result")
357
- st.image(overlay, caption="Segmentation Overlay", use_column_width=True)
358
-
359
- # Display segmentation details
360
- st.header("Segmentation Details")
361
- col1, col2, col3 = st.columns(3)
362
-
363
- with col1:
364
- st.subheader("Background")
365
- st.markdown("Areas surrounding the pet")
366
- mask_bg = np.where(mask == 0, 255, 0).astype(np.uint8)
367
- st.image(mask_bg, caption="Background", use_column_width=True)
368
 
369
- with col2:
370
- st.subheader("Border")
371
- st.markdown("Boundary around the pet")
372
- mask_border = np.where(mask == 1, 255, 0).astype(np.uint8)
373
- st.image(mask_border, caption="Border", use_column_width=True)
374
 
375
- with col3:
376
- st.subheader("Foreground (Pet)")
377
- st.markdown("The pet itself")
378
- mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
379
- st.image(mask_fg, caption="Foreground", use_column_width=True)
380
-
381
- # Calculate IoU if ground truth is uploaded
382
- if uploaded_mask is not None:
383
- try:
384
- # Read and process the mask file
385
- mask_data = uploaded_mask.read()
386
- st.write(f"Uploaded mask size: {len(mask_data)} bytes")
387
-
388
- # Open the mask from bytes
389
- mask_io = io.BytesIO(mask_data)
390
- gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
391
-
392
- # Handle different mask formats
393
- if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
394
- # Convert RGB to single channel if needed
395
- gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
396
 
397
- # Calculate and display IoU
398
- iou_score = calculate_iou(gt_mask, mask)
399
- st.success(f"Mean IoU: {iou_score:.4f}")
 
 
400
 
401
- # Display specific class IoUs
402
- st.markdown("### IoU by Class")
403
- col1, col2, col3 = st.columns(3)
404
- with col1:
405
- bg_iou = calculate_iou(gt_mask, mask, 0)
406
- st.metric("Background IoU", f"{bg_iou:.4f}")
407
- with col2:
408
- border_iou = calculate_iou(gt_mask, mask, 1)
409
- st.metric("Border IoU", f"{border_iou:.4f}")
410
- with col3:
411
- fg_iou = calculate_iou(gt_mask, mask, 2)
412
- st.metric("Foreground IoU", f"{fg_iou:.4f}")
413
- except Exception as e:
414
- st.error(f"Error processing ground truth mask: {e}")
415
- st.write("Please ensure the mask is valid and has the correct format.")
416
-
417
- # Download buttons
418
- col1, col2 = st.columns(2)
419
-
420
- with col1:
421
- # Convert mask to PNG for download
422
- mask_colored = Image.fromarray(colorized_mask)
423
- mask_bytes = io.BytesIO()
424
- mask_colored.save(mask_bytes, format='PNG')
425
- mask_bytes = mask_bytes.getvalue()
426
 
427
- st.download_button(
428
- label="Download Segmentation Mask",
429
- data=mask_bytes,
430
- file_name="pet_segmentation_mask.png",
431
- mime="image/png"
432
- )
433
-
434
- with col2:
435
- # Convert overlay to PNG for download
436
- overlay_img = Image.fromarray(overlay)
437
- overlay_bytes = io.BytesIO()
438
- overlay_img.save(overlay_bytes, format='PNG')
439
- overlay_bytes = overlay_bytes.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
 
441
- st.download_button(
442
- label="Download Overlay Image",
443
- data=overlay_bytes,
444
- file_name="pet_segmentation_overlay.png",
445
- mime="image/png"
446
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  # Footer with additional information
449
  st.markdown("---")
 
30
 
31
  @st.cache_resource
32
  def download_model_from_drive():
33
+ # Create a models directory
34
+ os.makedirs("models", exist_ok=True)
35
+ model_path = "models/best_model"
36
+
37
+ if not os.path.exists(model_path):
38
+ # Fixed Google Drive URL format for gdown
39
+ url = "https://drive.google.com/uc?id=1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
40
  try:
41
  gdown.download(url, model_path, quiet=False)
42
  st.success("Model downloaded successfully from Google Drive.")
 
47
  st.info("Model already exists locally.")
48
  return model_path
49
 
50
+
51
  @st.cache_resource
52
  def load_model():
53
  """
 
57
  Loaded model
58
  """
59
  try:
60
+ # First create a base model with the correct architecture
61
+ base_model = TFSegformerForSemanticSegmentation.from_pretrained(
62
+ "nvidia/mit-b0",
63
+ num_labels=NUM_CLASSES,
64
+ id2label=ID2LABEL,
65
+ label2id={label: id for id, label in ID2LABEL.items()},
66
+ ignore_mismatched_sizes=True
67
+ )
68
+
69
+ # Download the trained weights
70
  model_path = download_model_from_drive()
71
 
72
+ if model_path is not None and os.path.exists(model_path):
73
+ st.info(f"Loading weights from {model_path}...")
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
+ # Try to load the weights
76
+ base_model.load_weights(model_path)
77
+ st.success("Model weights loaded successfully!")
78
+ return base_model
 
 
 
 
 
 
 
79
  except Exception as e:
80
+ st.error(f"Error loading weights: {e}")
81
+ st.info("Using base pretrained model instead")
82
+ return base_model
83
+ else:
84
+ st.warning("Using base pretrained model since download failed")
85
+ return base_model
86
+
 
 
 
 
87
  except Exception as e:
88
+ st.error(f"Error in load_model: {e}")
89
+ st.warning("Using default pretrained model")
90
  # Fall back to pretrained model as a last resort
91
+ return TFSegformerForSemanticSegmentation.from_pretrained(
92
  "nvidia/mit-b0",
93
  num_labels=NUM_CLASSES,
94
  id2label=ID2LABEL,
95
  label2id={label: id for id, label in ID2LABEL.items()},
96
  ignore_mismatched_sizes=True
97
  )
 
98
 
99
  def normalize_image(input_image):
100
  """
 
151
  Processed mask (2D array)
152
  """
153
  pred_mask = tf.math.argmax(pred_mask, axis=1)
154
+ pred_mask = tf.squeeze(pred_mask)
155
  return pred_mask.numpy()
156
 
157
  def colorize_mask(mask):
 
166
  """
167
  # Ensure the mask is 2D
168
  if len(mask.shape) > 2:
169
+ mask = np.squeeze(mask)
170
 
171
  # Define colors for each class (RGB)
172
  colors = [
 
288
  else:
289
  st.sidebar.success("Model loaded successfully!")
290
 
291
+ # Image upload section
292
  st.header("Upload an Image")
293
  uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
294
  uploaded_mask = st.file_uploader("Upload ground truth mask (optional):", type=["png", "jpg", "jpeg"])
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  # Process uploaded image
297
  if uploaded_image is not None:
298
+ try:
299
+ # Read the image
300
+ image_bytes = uploaded_image.read()
301
+ image = Image.open(io.BytesIO(image_bytes))
302
+
303
+ col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ with col1:
306
+ st.subheader("Original Image")
307
+ st.image(image, caption="Uploaded Image", use_column_width=True)
308
 
309
+ # Preprocess and predict
310
+ with st.spinner("Generating segmentation mask..."):
311
+ # Preprocess the image
312
+ img_tensor, original_img = preprocess_image(image)
313
+
314
+ # Make prediction
315
+ outputs = model(pixel_values=img_tensor, training=False)
316
+ logits = outputs.logits
317
+
318
+ # Create visualization mask
319
+ mask = create_mask(logits)
320
+
321
+ # Colorize the mask
322
+ colorized_mask = colorize_mask(mask)
323
+
324
+ # Create overlay
325
+ overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
326
 
327
+ # Display results
328
+ with col2:
329
+ st.subheader("Segmentation Result")
330
+ st.image(overlay, caption="Segmentation Overlay", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ # Display segmentation details
333
+ st.header("Segmentation Details")
334
+ col1, col2, col3 = st.columns(3)
 
 
335
 
336
+ with col1:
337
+ st.subheader("Background")
338
+ st.markdown("Areas surrounding the pet")
339
+ mask_bg = np.where(mask == 0, 255, 0).astype(np.uint8)
340
+ st.image(mask_bg, caption="Background", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
+ with col2:
343
+ st.subheader("Border")
344
+ st.markdown("Boundary around the pet")
345
+ mask_border = np.where(mask == 1, 255, 0).astype(np.uint8)
346
+ st.image(mask_border, caption="Border", use_column_width=True)
347
 
348
+ with col3:
349
+ st.subheader("Foreground (Pet)")
350
+ st.markdown("The pet itself")
351
+ mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
352
+ st.image(mask_fg, caption="Foreground", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
+ # Calculate IoU if ground truth is uploaded
355
+ if uploaded_mask is not None:
356
+ try:
357
+ # Read the mask file
358
+ mask_data = uploaded_mask.read()
359
+ mask_io = io.BytesIO(mask_data)
360
+ gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
361
+
362
+ # Handle different mask formats
363
+ if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
364
+ # Convert RGB to single channel if needed
365
+ gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
366
+
367
+ # Calculate and display IoU
368
+ resized_mask = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE), interpolation=cv2.INTER_NEAREST)
369
+ iou_score = calculate_iou(gt_mask, resized_mask)
370
+ st.success(f"Mean IoU: {iou_score:.4f}")
371
+
372
+ # Display specific class IoUs
373
+ st.markdown("### IoU by Class")
374
+ col1, col2, col3 = st.columns(3)
375
+ with col1:
376
+ bg_iou = calculate_iou(gt_mask, resized_mask, 0)
377
+ st.metric("Background IoU", f"{bg_iou:.4f}")
378
+ with col2:
379
+ border_iou = calculate_iou(gt_mask, resized_mask, 1)
380
+ st.metric("Border IoU", f"{border_iou:.4f}")
381
+ with col3:
382
+ fg_iou = calculate_iou(gt_mask, resized_mask, 2)
383
+ st.metric("Foreground IoU", f"{fg_iou:.4f}")
384
+ except Exception as e:
385
+ st.error(f"Error processing ground truth mask: {e}")
386
+ st.write("Please ensure the mask is valid and has the correct format.")
387
+
388
+ # Download buttons
389
+ col1, col2 = st.columns(2)
390
 
391
+ with col1:
392
+ # Convert mask to PNG for download
393
+ mask_colored = Image.fromarray(colorized_mask)
394
+ mask_bytes = io.BytesIO()
395
+ mask_colored.save(mask_bytes, format='PNG')
396
+ mask_bytes = mask_bytes.getvalue()
397
+
398
+ st.download_button(
399
+ label="Download Segmentation Mask",
400
+ data=mask_bytes,
401
+ file_name="pet_segmentation_mask.png",
402
+ mime="image/png"
403
+ )
404
+
405
+ with col2:
406
+ # Convert overlay to PNG for download
407
+ overlay_img = Image.fromarray(overlay)
408
+ overlay_bytes = io.BytesIO()
409
+ overlay_img.save(overlay_bytes, format='PNG')
410
+ overlay_bytes = overlay_bytes.getvalue()
411
+
412
+ st.download_button(
413
+ label="Download Overlay Image",
414
+ data=overlay_bytes,
415
+ file_name="pet_segmentation_overlay.png",
416
+ mime="image/png"
417
+ )
418
+ except Exception as e:
419
+ st.error(f"Error processing image: {e}")
420
 
421
  # Footer with additional information
422
  st.markdown("---")