Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,4 @@
|
|
1 |
import streamlit as st
|
2 |
-
# THIS MUST BE THE FIRST STREAMLIT COMMAND
|
3 |
-
st.set_page_config(
|
4 |
-
page_title="Pet Segmentation with SegFormer",
|
5 |
-
page_icon="🐶",
|
6 |
-
layout="wide",
|
7 |
-
initial_sidebar_state="expanded"
|
8 |
-
)
|
9 |
-
|
10 |
import tensorflow as tf
|
11 |
from tensorflow.keras import backend
|
12 |
import numpy as np
|
@@ -18,18 +10,12 @@ import io
|
|
18 |
import gdown
|
19 |
from transformers import TFSegformerForSemanticSegmentation
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
tf.config.experimental.set_memory_growth(gpu, True)
|
28 |
-
st.sidebar.success(f"GPU available: {len(gpus)} device(s)")
|
29 |
-
else:
|
30 |
-
st.sidebar.warning("No GPU detected, using CPU")
|
31 |
-
except Exception as e:
|
32 |
-
st.sidebar.error(f"GPU config error: {e}")
|
33 |
|
34 |
# Constants for image preprocessing
|
35 |
IMAGE_SIZE = 512
|
@@ -49,7 +35,7 @@ def download_model_from_drive():
|
|
49 |
model_path = "models/tf_model.h5"
|
50 |
|
51 |
if not os.path.exists(model_path):
|
52 |
-
#
|
53 |
url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
|
54 |
try:
|
55 |
gdown.download(url, model_path, quiet=False)
|
@@ -124,13 +110,12 @@ def normalize_image(input_image):
|
|
124 |
input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
|
125 |
return input_image
|
126 |
|
127 |
-
def preprocess_image(image
|
128 |
"""
|
129 |
-
Preprocess image for model input
|
130 |
|
131 |
Args:
|
132 |
image: PIL Image to preprocess
|
133 |
-
from_dataset: Whether the image is from the original dataset
|
134 |
|
135 |
Returns:
|
136 |
Preprocessed image tensor, original image
|
@@ -142,24 +127,12 @@ def preprocess_image(image, from_dataset=False):
|
|
142 |
original_img = img_array.copy()
|
143 |
|
144 |
# Resize to target size
|
145 |
-
img_resized = tf.image.resize(
|
146 |
-
img_array,
|
147 |
-
(IMAGE_SIZE, IMAGE_SIZE),
|
148 |
-
method='bilinear',
|
149 |
-
preserve_aspect_ratio=False,
|
150 |
-
antialias=True
|
151 |
-
)
|
152 |
|
153 |
-
#
|
154 |
-
|
155 |
-
# The dataset already has specific dimensions, just normalize
|
156 |
-
# Skip additional preprocessing that might have been applied
|
157 |
-
img_normalized = normalize_image(img_resized)
|
158 |
-
else:
|
159 |
-
# Regular preprocessing for uploaded images
|
160 |
-
img_normalized = normalize_image(img_resized)
|
161 |
|
162 |
-
# Transpose from HWC to CHW (channels first)
|
163 |
img_transposed = tf.transpose(img_normalized, (2, 0, 1))
|
164 |
|
165 |
# Add batch dimension
|
@@ -177,21 +150,8 @@ def create_mask(pred_mask):
|
|
177 |
Returns:
|
178 |
Processed mask (2D array)
|
179 |
"""
|
180 |
-
# Take argmax along the class dimension (axis=1 for batch data)
|
181 |
pred_mask = tf.math.argmax(pred_mask, axis=1)
|
182 |
-
|
183 |
-
# Remove batch dimension and convert to numpy
|
184 |
pred_mask = tf.squeeze(pred_mask)
|
185 |
-
|
186 |
-
# Resize to match original image size if needed
|
187 |
-
if pred_mask.shape[0] != IMAGE_SIZE or pred_mask.shape[1] != IMAGE_SIZE:
|
188 |
-
pred_mask = tf.image.resize(
|
189 |
-
tf.expand_dims(pred_mask, axis=-1),
|
190 |
-
(IMAGE_SIZE, IMAGE_SIZE),
|
191 |
-
method='nearest'
|
192 |
-
)
|
193 |
-
pred_mask = tf.squeeze(pred_mask)
|
194 |
-
|
195 |
return pred_mask.numpy()
|
196 |
|
197 |
def colorize_mask(mask):
|
@@ -284,91 +244,6 @@ def create_overlay(image, mask, alpha=0.5):
|
|
284 |
|
285 |
return overlay
|
286 |
|
287 |
-
def display_results_side_by_side(original_image, ground_truth_mask=None, predicted_mask=None):
|
288 |
-
"""
|
289 |
-
Display results in a side-by-side format similar to colab_code.py
|
290 |
-
|
291 |
-
Args:
|
292 |
-
original_image: Original input image
|
293 |
-
ground_truth_mask: Optional ground truth segmentation mask
|
294 |
-
predicted_mask: Predicted segmentation mask
|
295 |
-
"""
|
296 |
-
# Determine how many images to display
|
297 |
-
cols = 1 + (ground_truth_mask is not None) + (predicted_mask is not None)
|
298 |
-
|
299 |
-
# Create a figure with multiple columns
|
300 |
-
st.write("### Segmentation Results Comparison")
|
301 |
-
|
302 |
-
col_list = st.columns(cols)
|
303 |
-
|
304 |
-
# Display original image
|
305 |
-
with col_list[0]:
|
306 |
-
st.markdown("**Original Image**")
|
307 |
-
st.image(original_image, use_column_width=True)
|
308 |
-
|
309 |
-
# Display ground truth if available
|
310 |
-
if ground_truth_mask is not None:
|
311 |
-
with col_list[1]:
|
312 |
-
st.markdown("**Ground Truth Mask**")
|
313 |
-
|
314 |
-
# Colorize ground truth if needed
|
315 |
-
if len(ground_truth_mask.shape) == 2:
|
316 |
-
gt_display = colorize_mask(ground_truth_mask)
|
317 |
-
else:
|
318 |
-
gt_display = ground_truth_mask
|
319 |
-
|
320 |
-
st.image(gt_display, use_column_width=True)
|
321 |
-
|
322 |
-
# Display prediction
|
323 |
-
if predicted_mask is not None:
|
324 |
-
with col_list[2 if ground_truth_mask is not None else 1]:
|
325 |
-
st.markdown("**Predicted Mask**")
|
326 |
-
|
327 |
-
# Colorize prediction if needed
|
328 |
-
if len(predicted_mask.shape) == 2:
|
329 |
-
pred_display = colorize_mask(predicted_mask)
|
330 |
-
else:
|
331 |
-
pred_display = predicted_mask
|
332 |
-
|
333 |
-
st.image(pred_display, use_column_width=True)
|
334 |
-
|
335 |
-
def process_uploaded_mask(mask_array, from_dataset=False):
|
336 |
-
"""
|
337 |
-
Process an uploaded mask to ensure it has the correct format
|
338 |
-
|
339 |
-
Args:
|
340 |
-
mask_array: Numpy array of the mask
|
341 |
-
from_dataset: Whether the mask is from the original dataset
|
342 |
-
|
343 |
-
Returns:
|
344 |
-
Processed mask with values 0,1,2
|
345 |
-
"""
|
346 |
-
# Check for RGBA format and convert to RGB if needed
|
347 |
-
if len(mask_array.shape) == 3 and mask_array.shape[2] == 4:
|
348 |
-
# Convert RGBA to RGB (discard alpha channel)
|
349 |
-
mask_array = mask_array[:,:,:3]
|
350 |
-
|
351 |
-
# Convert RGB to grayscale if needed
|
352 |
-
if len(mask_array.shape) == 3 and mask_array.shape[2] >= 3:
|
353 |
-
# Convert RGB to grayscale
|
354 |
-
mask_array = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
|
355 |
-
|
356 |
-
if from_dataset:
|
357 |
-
# For dataset masks (saved from your colab code):
|
358 |
-
# Create an empty mask with the same shape
|
359 |
-
processed_mask = np.zeros_like(mask_array)
|
360 |
-
|
361 |
-
# Map the values correctly:
|
362 |
-
# Original dataset uses 1,2,3 but your app expects 0,1,2
|
363 |
-
processed_mask[mask_array == 1] = 2 # Foreground/pet (1→2)
|
364 |
-
processed_mask[mask_array == 2] = 1 # Border (2→1)
|
365 |
-
processed_mask[mask_array == 3] = 0 # Background (3→0)
|
366 |
-
|
367 |
-
return processed_mask
|
368 |
-
else:
|
369 |
-
# For non-dataset masks, we assume they have correct class values
|
370 |
-
return mask_array
|
371 |
-
|
372 |
def main():
|
373 |
st.title("🐶 Pet Segmentation with SegFormer")
|
374 |
st.markdown("""
|
@@ -404,9 +279,6 @@ def main():
|
|
404 |
step=0.1
|
405 |
)
|
406 |
|
407 |
-
# Add this checkbox to your app's UI
|
408 |
-
dataset_image = st.sidebar.checkbox("Image is from the Oxford-IIIT Pet dataset")
|
409 |
-
|
410 |
# Load model
|
411 |
with st.spinner("Loading SegFormer model..."):
|
412 |
model = load_model()
|
@@ -436,46 +308,21 @@ def main():
|
|
436 |
|
437 |
# Preprocess and predict
|
438 |
with st.spinner("Generating segmentation mask..."):
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
# Colorize the mask
|
455 |
-
colorized_mask = colorize_mask(mask)
|
456 |
-
|
457 |
-
# Create overlay
|
458 |
-
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
|
459 |
-
except Exception as inference_error:
|
460 |
-
st.error(f"Inference error: {inference_error}")
|
461 |
-
st.write("Trying alternative approach...")
|
462 |
-
|
463 |
-
# Alternative: resize to exactly 512x512 with crop_or_pad
|
464 |
-
img_resized = tf.image.resize_with_crop_or_pad(
|
465 |
-
original_img, IMAGE_SIZE, IMAGE_SIZE
|
466 |
-
)
|
467 |
-
img_normalized = normalize_image(img_resized)
|
468 |
-
img_transposed = tf.transpose(img_normalized, (2, 0, 1))
|
469 |
-
img_tensor = tf.expand_dims(img_transposed, axis=0)
|
470 |
-
|
471 |
-
outputs = model(pixel_values=img_tensor, training=False)
|
472 |
-
logits = outputs.logits
|
473 |
-
mask = create_mask(logits)
|
474 |
-
colorized_mask = colorize_mask(mask)
|
475 |
-
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
|
476 |
-
except Exception as e:
|
477 |
-
st.error(f"Failed to process image: {e}")
|
478 |
-
st.stop()
|
479 |
|
480 |
# Display results
|
481 |
with col2:
|
@@ -507,81 +354,37 @@ def main():
|
|
507 |
# Calculate IoU if ground truth is uploaded
|
508 |
if uploaded_mask is not None:
|
509 |
try:
|
510 |
-
# Reset the file pointer to the beginning
|
511 |
-
uploaded_mask.seek(0)
|
512 |
-
|
513 |
# Read the mask file
|
514 |
mask_data = uploaded_mask.read()
|
515 |
mask_io = io.BytesIO(mask_data)
|
|
|
516 |
|
517 |
-
#
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
st.write(f"Debug - Raw mask shape: {raw_mask.shape}")
|
522 |
-
st.write(f"Debug - Raw mask unique values: {np.unique(raw_mask)}")
|
523 |
-
|
524 |
-
# Process the mask based on source
|
525 |
-
processed_gt_mask = process_uploaded_mask(raw_mask, from_dataset=dataset_image)
|
526 |
-
|
527 |
-
# Resize for IoU calculation
|
528 |
-
gt_mask_resized = cv2.resize(processed_gt_mask, (OUTPUT_SIZE, OUTPUT_SIZE),
|
529 |
-
interpolation=cv2.INTER_NEAREST)
|
530 |
-
|
531 |
-
# Resize prediction for comparison
|
532 |
-
pred_mask_resized = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE),
|
533 |
-
interpolation=cv2.INTER_NEAREST)
|
534 |
-
|
535 |
-
# Show processed values
|
536 |
-
st.write(f"Debug - Processed GT mask unique values: {np.unique(gt_mask_resized)}")
|
537 |
-
st.write(f"Debug - Prediction mask unique values: {np.unique(pred_mask_resized)}")
|
538 |
|
539 |
# Calculate and display IoU
|
540 |
-
|
|
|
541 |
st.success(f"Mean IoU: {iou_score:.4f}")
|
542 |
|
543 |
# Display specific class IoUs
|
544 |
st.markdown("### IoU by Class")
|
545 |
col1, col2, col3 = st.columns(3)
|
546 |
with col1:
|
547 |
-
bg_iou = calculate_iou(
|
548 |
st.metric("Background IoU", f"{bg_iou:.4f}")
|
549 |
with col2:
|
550 |
-
border_iou = calculate_iou(
|
551 |
st.metric("Border IoU", f"{border_iou:.4f}")
|
552 |
with col3:
|
553 |
-
fg_iou = calculate_iou(
|
554 |
st.metric("Foreground IoU", f"{fg_iou:.4f}")
|
555 |
-
|
556 |
-
# For display, create a colorized version of the ground truth
|
557 |
-
gt_mask_for_display = colorize_mask(processed_gt_mask)
|
558 |
-
|
559 |
-
# Side-by-side display
|
560 |
-
display_results_side_by_side(
|
561 |
-
original_img,
|
562 |
-
ground_truth_mask=gt_mask_for_display,
|
563 |
-
predicted_mask=colorized_mask
|
564 |
-
)
|
565 |
except Exception as e:
|
566 |
st.error(f"Error processing ground truth mask: {e}")
|
567 |
st.write("Please ensure the mask is valid and has the correct format.")
|
568 |
-
|
569 |
-
st.code(traceback.format_exc()) # Show detailed error trace
|
570 |
-
|
571 |
-
# Even with an error, try to display results without the ground truth
|
572 |
-
display_results_side_by_side(
|
573 |
-
original_img,
|
574 |
-
ground_truth_mask=None,
|
575 |
-
predicted_mask=colorized_mask
|
576 |
-
)
|
577 |
-
else:
|
578 |
-
# No ground truth, just display original and prediction
|
579 |
-
display_results_side_by_side(
|
580 |
-
original_img,
|
581 |
-
ground_truth_mask=None,
|
582 |
-
predicted_mask=colorized_mask
|
583 |
-
)
|
584 |
-
|
585 |
# Download buttons
|
586 |
col1, col2 = st.columns(2)
|
587 |
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import tensorflow as tf
|
3 |
from tensorflow.keras import backend
|
4 |
import numpy as np
|
|
|
10 |
import gdown
|
11 |
from transformers import TFSegformerForSemanticSegmentation
|
12 |
|
13 |
+
st.set_page_config(
|
14 |
+
page_title="Pet Segmentation with SegFormer",
|
15 |
+
page_icon="🐶",
|
16 |
+
layout="wide",
|
17 |
+
initial_sidebar_state="expanded"
|
18 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
# Constants for image preprocessing
|
21 |
IMAGE_SIZE = 512
|
|
|
35 |
model_path = "models/tf_model.h5"
|
36 |
|
37 |
if not os.path.exists(model_path):
|
38 |
+
# Fixed Google Drive URL format for gdown
|
39 |
url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
|
40 |
try:
|
41 |
gdown.download(url, model_path, quiet=False)
|
|
|
110 |
input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon())
|
111 |
return input_image
|
112 |
|
113 |
+
def preprocess_image(image):
|
114 |
"""
|
115 |
+
Preprocess image for model input
|
116 |
|
117 |
Args:
|
118 |
image: PIL Image to preprocess
|
|
|
119 |
|
120 |
Returns:
|
121 |
Preprocessed image tensor, original image
|
|
|
127 |
original_img = img_array.copy()
|
128 |
|
129 |
# Resize to target size
|
130 |
+
img_resized = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
# Normalize
|
133 |
+
img_normalized = normalize_image(img_resized)
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
# Transpose from HWC to CHW (SegFormer expects channels first)
|
136 |
img_transposed = tf.transpose(img_normalized, (2, 0, 1))
|
137 |
|
138 |
# Add batch dimension
|
|
|
150 |
Returns:
|
151 |
Processed mask (2D array)
|
152 |
"""
|
|
|
153 |
pred_mask = tf.math.argmax(pred_mask, axis=1)
|
|
|
|
|
154 |
pred_mask = tf.squeeze(pred_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
return pred_mask.numpy()
|
156 |
|
157 |
def colorize_mask(mask):
|
|
|
244 |
|
245 |
return overlay
|
246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
def main():
|
248 |
st.title("🐶 Pet Segmentation with SegFormer")
|
249 |
st.markdown("""
|
|
|
279 |
step=0.1
|
280 |
)
|
281 |
|
|
|
|
|
|
|
282 |
# Load model
|
283 |
with st.spinner("Loading SegFormer model..."):
|
284 |
model = load_model()
|
|
|
308 |
|
309 |
# Preprocess and predict
|
310 |
with st.spinner("Generating segmentation mask..."):
|
311 |
+
# Preprocess the image
|
312 |
+
img_tensor, original_img = preprocess_image(image)
|
313 |
+
|
314 |
+
# Make prediction
|
315 |
+
outputs = model(pixel_values=img_tensor, training=False)
|
316 |
+
logits = outputs.logits
|
317 |
+
|
318 |
+
# Create visualization mask
|
319 |
+
mask = create_mask(logits)
|
320 |
+
|
321 |
+
# Colorize the mask
|
322 |
+
colorized_mask = colorize_mask(mask)
|
323 |
+
|
324 |
+
# Create overlay
|
325 |
+
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
# Display results
|
328 |
with col2:
|
|
|
354 |
# Calculate IoU if ground truth is uploaded
|
355 |
if uploaded_mask is not None:
|
356 |
try:
|
|
|
|
|
|
|
357 |
# Read the mask file
|
358 |
mask_data = uploaded_mask.read()
|
359 |
mask_io = io.BytesIO(mask_data)
|
360 |
+
gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
|
361 |
|
362 |
+
# Handle different mask formats
|
363 |
+
if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
|
364 |
+
# Convert RGB to single channel if needed
|
365 |
+
gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
# Calculate and display IoU
|
368 |
+
resized_mask = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE), interpolation=cv2.INTER_NEAREST)
|
369 |
+
iou_score = calculate_iou(gt_mask, resized_mask)
|
370 |
st.success(f"Mean IoU: {iou_score:.4f}")
|
371 |
|
372 |
# Display specific class IoUs
|
373 |
st.markdown("### IoU by Class")
|
374 |
col1, col2, col3 = st.columns(3)
|
375 |
with col1:
|
376 |
+
bg_iou = calculate_iou(gt_mask, resized_mask, 0)
|
377 |
st.metric("Background IoU", f"{bg_iou:.4f}")
|
378 |
with col2:
|
379 |
+
border_iou = calculate_iou(gt_mask, resized_mask, 1)
|
380 |
st.metric("Border IoU", f"{border_iou:.4f}")
|
381 |
with col3:
|
382 |
+
fg_iou = calculate_iou(gt_mask, resized_mask, 2)
|
383 |
st.metric("Foreground IoU", f"{fg_iou:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
except Exception as e:
|
385 |
st.error(f"Error processing ground truth mask: {e}")
|
386 |
st.write("Please ensure the mask is valid and has the correct format.")
|
387 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
# Download buttons
|
389 |
col1, col2 = st.columns(2)
|
390 |
|