kishoreb4 commited on
Commit
f753b2d
·
1 Parent(s): 15d674e
Files changed (1) hide show
  1. app.py +96 -24
app.py CHANGED
@@ -32,7 +32,8 @@ NUM_CLASSES = len(ID2LABEL)
32
  def download_model_from_drive():
33
  model_path = "tf_model.h5"
34
  if not os.path.exists(model_path):
35
- url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
 
36
  try:
37
  gdown.download(url, model_path, quiet=False)
38
  st.success("Model downloaded successfully from Google Drive.")
@@ -55,7 +56,7 @@ def load_model():
55
  # Download the model first
56
  model_path = download_model_from_drive()
57
 
58
- if model_path is None:
59
  st.warning("Using default pretrained model since download failed")
60
  # Fall back to pretrained model
61
  model = TFSegformerForSemanticSegmentation.from_pretrained(
@@ -66,28 +67,30 @@ def load_model():
66
  ignore_mismatched_sizes=True
67
  )
68
  else:
69
- # Check if this is a Keras .h5 model or a HuggingFace model directory
70
- if model_path.endswith('.h5'):
71
- st.info("Loading Keras H5 model...")
72
- # For a Keras .h5 file, use tf.keras.models.load_model
73
- try:
74
- model = tf.keras.models.load_model(model_path)
75
- st.success("Keras model loaded successfully")
76
- except Exception as ke:
77
- st.error(f"Error loading Keras model: {str(ke)}")
78
- st.warning("Falling back to pretrained model")
79
- model = TFSegformerForSemanticSegmentation.from_pretrained(
80
- "nvidia/mit-b0",
81
- num_labels=NUM_CLASSES,
82
- id2label=ID2LABEL,
83
- label2id={label: id for id, label in ID2LABEL.items()},
84
- ignore_mismatched_sizes=True
85
- )
86
- else:
87
- # For a HuggingFace model directory
88
- st.info("Loading HuggingFace model...")
89
- model = TFSegformerForSemanticSegmentation.from_pretrained(model_path)
90
- st.success("HuggingFace model loaded successfully")
 
 
91
 
92
  return model
93
  except Exception as e:
@@ -192,6 +195,38 @@ def colorize_mask(mask):
192
 
193
  return rgb_mask
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def create_overlay(image, mask, alpha=0.5):
196
  """
197
  Create an overlay of mask on original image
@@ -266,6 +301,7 @@ def main():
266
  # Image upload
267
  st.header("Upload an Image")
268
  uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
 
269
 
270
  # Sample images option
271
  st.markdown("### Or use a sample image:")
@@ -342,6 +378,42 @@ def main():
342
  mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
343
  st.image(mask_fg, caption="Foreground", use_column_width=True)
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  # Download buttons
346
  col1, col2 = st.columns(2)
347
 
 
32
  def download_model_from_drive():
33
  model_path = "tf_model.h5"
34
  if not os.path.exists(model_path):
35
+ # Fix the Google Drive link format - this is why the download is failing
36
+ url = "https://drive.google.com/uc?id=1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
37
  try:
38
  gdown.download(url, model_path, quiet=False)
39
  st.success("Model downloaded successfully from Google Drive.")
 
56
  # Download the model first
57
  model_path = download_model_from_drive()
58
 
59
+ if model_path is None or not os.path.exists(model_path):
60
  st.warning("Using default pretrained model since download failed")
61
  # Fall back to pretrained model
62
  model = TFSegformerForSemanticSegmentation.from_pretrained(
 
67
  ignore_mismatched_sizes=True
68
  )
69
  else:
70
+ # For a HuggingFace model saved with SavedModel format
71
+ st.info("Loading SegFormer model...")
72
+ try:
73
+ # First try loading as HuggingFace model
74
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
75
+ "nvidia/mit-b0",
76
+ num_labels=NUM_CLASSES,
77
+ id2label=ID2LABEL,
78
+ label2id={label: id for id, label in ID2LABEL.items()},
79
+ ignore_mismatched_sizes=True
80
+ )
81
+ # Then load weights from h5 file
82
+ model.load_weights(model_path)
83
+ st.success("Model weights loaded successfully")
84
+ except Exception as e:
85
+ st.error(f"Error loading model weights: {str(e)}")
86
+ st.warning("Falling back to pretrained model")
87
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
88
+ "nvidia/mit-b0",
89
+ num_labels=NUM_CLASSES,
90
+ id2label=ID2LABEL,
91
+ label2id={label: id for id, label in ID2LABEL.items()},
92
+ ignore_mismatched_sizes=True
93
+ )
94
 
95
  return model
96
  except Exception as e:
 
195
 
196
  return rgb_mask
197
 
198
+ def calculate_iou(y_true, y_pred, class_idx=None):
199
+ """
200
+ Calculate IoU (Intersection over Union) for segmentation masks
201
+
202
+ Args:
203
+ y_true: Ground truth segmentation mask
204
+ y_pred: Predicted segmentation mask
205
+ class_idx: Index of the class to calculate IoU for (None for mean IoU)
206
+
207
+ Returns:
208
+ IoU score
209
+ """
210
+ if class_idx is not None:
211
+ # Binary IoU for specific class
212
+ y_true_class = (y_true == class_idx).astype(np.float32)
213
+ y_pred_class = (y_pred == class_idx).astype(np.float32)
214
+
215
+ intersection = np.sum(y_true_class * y_pred_class)
216
+ union = np.sum(y_true_class) + np.sum(y_pred_class) - intersection
217
+
218
+ iou = intersection / (union + 1e-6)
219
+ else:
220
+ # Mean IoU across all classes
221
+ class_ious = []
222
+ for idx in range(NUM_CLASSES):
223
+ class_iou = calculate_iou(y_true, y_pred, idx)
224
+ class_ious.append(class_iou)
225
+
226
+ iou = np.mean(class_ious)
227
+
228
+ return iou
229
+
230
  def create_overlay(image, mask, alpha=0.5):
231
  """
232
  Create an overlay of mask on original image
 
301
  # Image upload
302
  st.header("Upload an Image")
303
  uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
304
+ uploaded_mask = st.file_uploader("Upload ground truth mask (optional):", type=["png", "jpg", "jpeg"])
305
 
306
  # Sample images option
307
  st.markdown("### Or use a sample image:")
 
378
  mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
379
  st.image(mask_fg, caption="Foreground", use_column_width=True)
380
 
381
+ # Calculate IoU if ground truth is uploaded
382
+ if uploaded_mask is not None:
383
+ try:
384
+ # Read and process the mask file
385
+ mask_data = uploaded_mask.read()
386
+ st.write(f"Uploaded mask size: {len(mask_data)} bytes")
387
+
388
+ # Open the mask from bytes
389
+ mask_io = io.BytesIO(mask_data)
390
+ gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
391
+
392
+ # Handle different mask formats
393
+ if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
394
+ # Convert RGB to single channel if needed
395
+ gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
396
+
397
+ # Calculate and display IoU
398
+ iou_score = calculate_iou(gt_mask, mask)
399
+ st.success(f"Mean IoU: {iou_score:.4f}")
400
+
401
+ # Display specific class IoUs
402
+ st.markdown("### IoU by Class")
403
+ col1, col2, col3 = st.columns(3)
404
+ with col1:
405
+ bg_iou = calculate_iou(gt_mask, mask, 0)
406
+ st.metric("Background IoU", f"{bg_iou:.4f}")
407
+ with col2:
408
+ border_iou = calculate_iou(gt_mask, mask, 1)
409
+ st.metric("Border IoU", f"{border_iou:.4f}")
410
+ with col3:
411
+ fg_iou = calculate_iou(gt_mask, mask, 2)
412
+ st.metric("Foreground IoU", f"{fg_iou:.4f}")
413
+ except Exception as e:
414
+ st.error(f"Error processing ground truth mask: {e}")
415
+ st.write("Please ensure the mask is valid and has the correct format.")
416
+
417
  # Download buttons
418
  col1, col2 = st.columns(2)
419