Spaces:
Sleeping
Sleeping
add
Browse files
app.py
CHANGED
@@ -32,7 +32,8 @@ NUM_CLASSES = len(ID2LABEL)
|
|
32 |
def download_model_from_drive():
|
33 |
model_path = "tf_model.h5"
|
34 |
if not os.path.exists(model_path):
|
35 |
-
|
|
|
36 |
try:
|
37 |
gdown.download(url, model_path, quiet=False)
|
38 |
st.success("Model downloaded successfully from Google Drive.")
|
@@ -55,7 +56,7 @@ def load_model():
|
|
55 |
# Download the model first
|
56 |
model_path = download_model_from_drive()
|
57 |
|
58 |
-
if model_path is None:
|
59 |
st.warning("Using default pretrained model since download failed")
|
60 |
# Fall back to pretrained model
|
61 |
model = TFSegformerForSemanticSegmentation.from_pretrained(
|
@@ -66,28 +67,30 @@ def load_model():
|
|
66 |
ignore_mismatched_sizes=True
|
67 |
)
|
68 |
else:
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
#
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
|
92 |
return model
|
93 |
except Exception as e:
|
@@ -192,6 +195,38 @@ def colorize_mask(mask):
|
|
192 |
|
193 |
return rgb_mask
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
def create_overlay(image, mask, alpha=0.5):
|
196 |
"""
|
197 |
Create an overlay of mask on original image
|
@@ -266,6 +301,7 @@ def main():
|
|
266 |
# Image upload
|
267 |
st.header("Upload an Image")
|
268 |
uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
|
|
|
269 |
|
270 |
# Sample images option
|
271 |
st.markdown("### Or use a sample image:")
|
@@ -342,6 +378,42 @@ def main():
|
|
342 |
mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
|
343 |
st.image(mask_fg, caption="Foreground", use_column_width=True)
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
# Download buttons
|
346 |
col1, col2 = st.columns(2)
|
347 |
|
|
|
32 |
def download_model_from_drive():
|
33 |
model_path = "tf_model.h5"
|
34 |
if not os.path.exists(model_path):
|
35 |
+
# Fix the Google Drive link format - this is why the download is failing
|
36 |
+
url = "https://drive.google.com/uc?id=1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
|
37 |
try:
|
38 |
gdown.download(url, model_path, quiet=False)
|
39 |
st.success("Model downloaded successfully from Google Drive.")
|
|
|
56 |
# Download the model first
|
57 |
model_path = download_model_from_drive()
|
58 |
|
59 |
+
if model_path is None or not os.path.exists(model_path):
|
60 |
st.warning("Using default pretrained model since download failed")
|
61 |
# Fall back to pretrained model
|
62 |
model = TFSegformerForSemanticSegmentation.from_pretrained(
|
|
|
67 |
ignore_mismatched_sizes=True
|
68 |
)
|
69 |
else:
|
70 |
+
# For a HuggingFace model saved with SavedModel format
|
71 |
+
st.info("Loading SegFormer model...")
|
72 |
+
try:
|
73 |
+
# First try loading as HuggingFace model
|
74 |
+
model = TFSegformerForSemanticSegmentation.from_pretrained(
|
75 |
+
"nvidia/mit-b0",
|
76 |
+
num_labels=NUM_CLASSES,
|
77 |
+
id2label=ID2LABEL,
|
78 |
+
label2id={label: id for id, label in ID2LABEL.items()},
|
79 |
+
ignore_mismatched_sizes=True
|
80 |
+
)
|
81 |
+
# Then load weights from h5 file
|
82 |
+
model.load_weights(model_path)
|
83 |
+
st.success("Model weights loaded successfully")
|
84 |
+
except Exception as e:
|
85 |
+
st.error(f"Error loading model weights: {str(e)}")
|
86 |
+
st.warning("Falling back to pretrained model")
|
87 |
+
model = TFSegformerForSemanticSegmentation.from_pretrained(
|
88 |
+
"nvidia/mit-b0",
|
89 |
+
num_labels=NUM_CLASSES,
|
90 |
+
id2label=ID2LABEL,
|
91 |
+
label2id={label: id for id, label in ID2LABEL.items()},
|
92 |
+
ignore_mismatched_sizes=True
|
93 |
+
)
|
94 |
|
95 |
return model
|
96 |
except Exception as e:
|
|
|
195 |
|
196 |
return rgb_mask
|
197 |
|
198 |
+
def calculate_iou(y_true, y_pred, class_idx=None):
|
199 |
+
"""
|
200 |
+
Calculate IoU (Intersection over Union) for segmentation masks
|
201 |
+
|
202 |
+
Args:
|
203 |
+
y_true: Ground truth segmentation mask
|
204 |
+
y_pred: Predicted segmentation mask
|
205 |
+
class_idx: Index of the class to calculate IoU for (None for mean IoU)
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
IoU score
|
209 |
+
"""
|
210 |
+
if class_idx is not None:
|
211 |
+
# Binary IoU for specific class
|
212 |
+
y_true_class = (y_true == class_idx).astype(np.float32)
|
213 |
+
y_pred_class = (y_pred == class_idx).astype(np.float32)
|
214 |
+
|
215 |
+
intersection = np.sum(y_true_class * y_pred_class)
|
216 |
+
union = np.sum(y_true_class) + np.sum(y_pred_class) - intersection
|
217 |
+
|
218 |
+
iou = intersection / (union + 1e-6)
|
219 |
+
else:
|
220 |
+
# Mean IoU across all classes
|
221 |
+
class_ious = []
|
222 |
+
for idx in range(NUM_CLASSES):
|
223 |
+
class_iou = calculate_iou(y_true, y_pred, idx)
|
224 |
+
class_ious.append(class_iou)
|
225 |
+
|
226 |
+
iou = np.mean(class_ious)
|
227 |
+
|
228 |
+
return iou
|
229 |
+
|
230 |
def create_overlay(image, mask, alpha=0.5):
|
231 |
"""
|
232 |
Create an overlay of mask on original image
|
|
|
301 |
# Image upload
|
302 |
st.header("Upload an Image")
|
303 |
uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
|
304 |
+
uploaded_mask = st.file_uploader("Upload ground truth mask (optional):", type=["png", "jpg", "jpeg"])
|
305 |
|
306 |
# Sample images option
|
307 |
st.markdown("### Or use a sample image:")
|
|
|
378 |
mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
|
379 |
st.image(mask_fg, caption="Foreground", use_column_width=True)
|
380 |
|
381 |
+
# Calculate IoU if ground truth is uploaded
|
382 |
+
if uploaded_mask is not None:
|
383 |
+
try:
|
384 |
+
# Read and process the mask file
|
385 |
+
mask_data = uploaded_mask.read()
|
386 |
+
st.write(f"Uploaded mask size: {len(mask_data)} bytes")
|
387 |
+
|
388 |
+
# Open the mask from bytes
|
389 |
+
mask_io = io.BytesIO(mask_data)
|
390 |
+
gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
|
391 |
+
|
392 |
+
# Handle different mask formats
|
393 |
+
if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
|
394 |
+
# Convert RGB to single channel if needed
|
395 |
+
gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
|
396 |
+
|
397 |
+
# Calculate and display IoU
|
398 |
+
iou_score = calculate_iou(gt_mask, mask)
|
399 |
+
st.success(f"Mean IoU: {iou_score:.4f}")
|
400 |
+
|
401 |
+
# Display specific class IoUs
|
402 |
+
st.markdown("### IoU by Class")
|
403 |
+
col1, col2, col3 = st.columns(3)
|
404 |
+
with col1:
|
405 |
+
bg_iou = calculate_iou(gt_mask, mask, 0)
|
406 |
+
st.metric("Background IoU", f"{bg_iou:.4f}")
|
407 |
+
with col2:
|
408 |
+
border_iou = calculate_iou(gt_mask, mask, 1)
|
409 |
+
st.metric("Border IoU", f"{border_iou:.4f}")
|
410 |
+
with col3:
|
411 |
+
fg_iou = calculate_iou(gt_mask, mask, 2)
|
412 |
+
st.metric("Foreground IoU", f"{fg_iou:.4f}")
|
413 |
+
except Exception as e:
|
414 |
+
st.error(f"Error processing ground truth mask: {e}")
|
415 |
+
st.write("Please ensure the mask is valid and has the correct format.")
|
416 |
+
|
417 |
# Download buttons
|
418 |
col1, col2 = st.columns(2)
|
419 |
|