Spaces:
Sleeping
Sleeping
add
Browse files
app.py
CHANGED
@@ -30,10 +30,13 @@ NUM_CLASSES = len(ID2LABEL)
|
|
30 |
|
31 |
@st.cache_resource
|
32 |
def download_model_from_drive():
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
37 |
try:
|
38 |
gdown.download(url, model_path, quiet=False)
|
39 |
st.success("Model downloaded successfully from Google Drive.")
|
@@ -44,6 +47,7 @@ def download_model_from_drive():
|
|
44 |
st.info("Model already exists locally.")
|
45 |
return model_path
|
46 |
|
|
|
47 |
@st.cache_resource
|
48 |
def load_model():
|
49 |
"""
|
@@ -53,58 +57,44 @@ def load_model():
|
|
53 |
Loaded model
|
54 |
"""
|
55 |
try:
|
56 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
model_path = download_model_from_drive()
|
58 |
|
59 |
-
if model_path is None
|
60 |
-
st.
|
61 |
-
# Fall back to pretrained model
|
62 |
-
model = TFSegformerForSemanticSegmentation.from_pretrained(
|
63 |
-
"nvidia/mit-b0",
|
64 |
-
num_labels=NUM_CLASSES,
|
65 |
-
id2label=ID2LABEL,
|
66 |
-
label2id={label: id for id, label in ID2LABEL.items()},
|
67 |
-
ignore_mismatched_sizes=True
|
68 |
-
)
|
69 |
-
else:
|
70 |
-
# For a HuggingFace model saved with SavedModel format
|
71 |
-
st.info("Loading SegFormer model...")
|
72 |
try:
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
id2label=ID2LABEL,
|
78 |
-
label2id={label: id for id, label in ID2LABEL.items()},
|
79 |
-
ignore_mismatched_sizes=True
|
80 |
-
)
|
81 |
-
# Then load weights from h5 file
|
82 |
-
model.load_weights(model_path)
|
83 |
-
st.success("Model weights loaded successfully")
|
84 |
except Exception as e:
|
85 |
-
st.error(f"Error loading
|
86 |
-
st.
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
ignore_mismatched_sizes=True
|
93 |
-
)
|
94 |
-
|
95 |
-
return model
|
96 |
except Exception as e:
|
97 |
-
st.error(f"Error
|
98 |
-
st.
|
99 |
# Fall back to pretrained model as a last resort
|
100 |
-
|
101 |
"nvidia/mit-b0",
|
102 |
num_labels=NUM_CLASSES,
|
103 |
id2label=ID2LABEL,
|
104 |
label2id={label: id for id, label in ID2LABEL.items()},
|
105 |
ignore_mismatched_sizes=True
|
106 |
)
|
107 |
-
return model
|
108 |
|
109 |
def normalize_image(input_image):
|
110 |
"""
|
@@ -161,7 +151,7 @@ def create_mask(pred_mask):
|
|
161 |
Processed mask (2D array)
|
162 |
"""
|
163 |
pred_mask = tf.math.argmax(pred_mask, axis=1)
|
164 |
-
pred_mask = tf.squeeze(pred_mask
|
165 |
return pred_mask.numpy()
|
166 |
|
167 |
def colorize_mask(mask):
|
@@ -176,7 +166,7 @@ def colorize_mask(mask):
|
|
176 |
"""
|
177 |
# Ensure the mask is 2D
|
178 |
if len(mask.shape) > 2:
|
179 |
-
mask = np.squeeze(mask
|
180 |
|
181 |
# Define colors for each class (RGB)
|
182 |
colors = [
|
@@ -298,152 +288,135 @@ def main():
|
|
298 |
else:
|
299 |
st.sidebar.success("Model loaded successfully!")
|
300 |
|
301 |
-
# Image upload
|
302 |
st.header("Upload an Image")
|
303 |
uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
|
304 |
uploaded_mask = st.file_uploader("Upload ground truth mask (optional):", type=["png", "jpg", "jpeg"])
|
305 |
|
306 |
-
# Sample images option
|
307 |
-
st.markdown("### Or use a sample image:")
|
308 |
-
sample_dir = "samples"
|
309 |
-
|
310 |
-
# Check if sample directory exists and contains images
|
311 |
-
sample_files = []
|
312 |
-
if os.path.exists(sample_dir):
|
313 |
-
sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
314 |
-
|
315 |
-
if sample_files:
|
316 |
-
selected_sample = st.selectbox("Select a sample image:", sample_files)
|
317 |
-
use_sample = st.button("Use this sample")
|
318 |
-
|
319 |
-
if use_sample:
|
320 |
-
with open(os.path.join(sample_dir, selected_sample), "rb") as file:
|
321 |
-
image_bytes = file.read()
|
322 |
-
uploaded_image = io.BytesIO(image_bytes)
|
323 |
-
st.success(f"Using sample image: {selected_sample}")
|
324 |
-
|
325 |
# Process uploaded image
|
326 |
if uploaded_image is not None:
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
st.subheader("Original Image")
|
334 |
-
st.image(image, caption="Uploaded Image", use_column_width=True)
|
335 |
-
|
336 |
-
# Preprocess and predict
|
337 |
-
with st.spinner("Generating segmentation mask..."):
|
338 |
-
# Preprocess the image
|
339 |
-
img_tensor, original_img = preprocess_image(image)
|
340 |
-
|
341 |
-
# Make prediction
|
342 |
-
prediction = model(pixel_values=img_tensor, training=False)
|
343 |
-
logits = prediction.logits
|
344 |
|
345 |
-
|
346 |
-
|
|
|
347 |
|
348 |
-
#
|
349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
-
#
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
with col2:
|
356 |
-
st.subheader("Segmentation Result")
|
357 |
-
st.image(overlay, caption="Segmentation Overlay", use_column_width=True)
|
358 |
-
|
359 |
-
# Display segmentation details
|
360 |
-
st.header("Segmentation Details")
|
361 |
-
col1, col2, col3 = st.columns(3)
|
362 |
-
|
363 |
-
with col1:
|
364 |
-
st.subheader("Background")
|
365 |
-
st.markdown("Areas surrounding the pet")
|
366 |
-
mask_bg = np.where(mask == 0, 255, 0).astype(np.uint8)
|
367 |
-
st.image(mask_bg, caption="Background", use_column_width=True)
|
368 |
|
369 |
-
|
370 |
-
st.
|
371 |
-
st.
|
372 |
-
mask_border = np.where(mask == 1, 255, 0).astype(np.uint8)
|
373 |
-
st.image(mask_border, caption="Border", use_column_width=True)
|
374 |
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
# Calculate IoU if ground truth is uploaded
|
382 |
-
if uploaded_mask is not None:
|
383 |
-
try:
|
384 |
-
# Read and process the mask file
|
385 |
-
mask_data = uploaded_mask.read()
|
386 |
-
st.write(f"Uploaded mask size: {len(mask_data)} bytes")
|
387 |
-
|
388 |
-
# Open the mask from bytes
|
389 |
-
mask_io = io.BytesIO(mask_data)
|
390 |
-
gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
|
391 |
-
|
392 |
-
# Handle different mask formats
|
393 |
-
if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
|
394 |
-
# Convert RGB to single channel if needed
|
395 |
-
gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
|
396 |
|
397 |
-
|
398 |
-
|
399 |
-
st.
|
|
|
|
|
400 |
|
401 |
-
|
402 |
-
st.
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
st.metric("Background IoU", f"{bg_iou:.4f}")
|
407 |
-
with col2:
|
408 |
-
border_iou = calculate_iou(gt_mask, mask, 1)
|
409 |
-
st.metric("Border IoU", f"{border_iou:.4f}")
|
410 |
-
with col3:
|
411 |
-
fg_iou = calculate_iou(gt_mask, mask, 2)
|
412 |
-
st.metric("Foreground IoU", f"{fg_iou:.4f}")
|
413 |
-
except Exception as e:
|
414 |
-
st.error(f"Error processing ground truth mask: {e}")
|
415 |
-
st.write("Please ensure the mask is valid and has the correct format.")
|
416 |
-
|
417 |
-
# Download buttons
|
418 |
-
col1, col2 = st.columns(2)
|
419 |
-
|
420 |
-
with col1:
|
421 |
-
# Convert mask to PNG for download
|
422 |
-
mask_colored = Image.fromarray(colorized_mask)
|
423 |
-
mask_bytes = io.BytesIO()
|
424 |
-
mask_colored.save(mask_bytes, format='PNG')
|
425 |
-
mask_bytes = mask_bytes.getvalue()
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
|
448 |
# Footer with additional information
|
449 |
st.markdown("---")
|
|
|
30 |
|
31 |
@st.cache_resource
|
32 |
def download_model_from_drive():
|
33 |
+
# Create a models directory
|
34 |
+
os.makedirs("models", exist_ok=True)
|
35 |
+
model_path = "models/best_model"
|
36 |
+
|
37 |
+
if not os.path.exists(model_path):
|
38 |
+
# Fixed Google Drive URL format for gdown
|
39 |
+
url = "https://drive.google.com/uc?id=1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
|
40 |
try:
|
41 |
gdown.download(url, model_path, quiet=False)
|
42 |
st.success("Model downloaded successfully from Google Drive.")
|
|
|
47 |
st.info("Model already exists locally.")
|
48 |
return model_path
|
49 |
|
50 |
+
|
51 |
@st.cache_resource
|
52 |
def load_model():
|
53 |
"""
|
|
|
57 |
Loaded model
|
58 |
"""
|
59 |
try:
|
60 |
+
# First create a base model with the correct architecture
|
61 |
+
base_model = TFSegformerForSemanticSegmentation.from_pretrained(
|
62 |
+
"nvidia/mit-b0",
|
63 |
+
num_labels=NUM_CLASSES,
|
64 |
+
id2label=ID2LABEL,
|
65 |
+
label2id={label: id for id, label in ID2LABEL.items()},
|
66 |
+
ignore_mismatched_sizes=True
|
67 |
+
)
|
68 |
+
|
69 |
+
# Download the trained weights
|
70 |
model_path = download_model_from_drive()
|
71 |
|
72 |
+
if model_path is not None and os.path.exists(model_path):
|
73 |
+
st.info(f"Loading weights from {model_path}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
try:
|
75 |
+
# Try to load the weights
|
76 |
+
base_model.load_weights(model_path)
|
77 |
+
st.success("Model weights loaded successfully!")
|
78 |
+
return base_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
except Exception as e:
|
80 |
+
st.error(f"Error loading weights: {e}")
|
81 |
+
st.info("Using base pretrained model instead")
|
82 |
+
return base_model
|
83 |
+
else:
|
84 |
+
st.warning("Using base pretrained model since download failed")
|
85 |
+
return base_model
|
86 |
+
|
|
|
|
|
|
|
|
|
87 |
except Exception as e:
|
88 |
+
st.error(f"Error in load_model: {e}")
|
89 |
+
st.warning("Using default pretrained model")
|
90 |
# Fall back to pretrained model as a last resort
|
91 |
+
return TFSegformerForSemanticSegmentation.from_pretrained(
|
92 |
"nvidia/mit-b0",
|
93 |
num_labels=NUM_CLASSES,
|
94 |
id2label=ID2LABEL,
|
95 |
label2id={label: id for id, label in ID2LABEL.items()},
|
96 |
ignore_mismatched_sizes=True
|
97 |
)
|
|
|
98 |
|
99 |
def normalize_image(input_image):
|
100 |
"""
|
|
|
151 |
Processed mask (2D array)
|
152 |
"""
|
153 |
pred_mask = tf.math.argmax(pred_mask, axis=1)
|
154 |
+
pred_mask = tf.squeeze(pred_mask)
|
155 |
return pred_mask.numpy()
|
156 |
|
157 |
def colorize_mask(mask):
|
|
|
166 |
"""
|
167 |
# Ensure the mask is 2D
|
168 |
if len(mask.shape) > 2:
|
169 |
+
mask = np.squeeze(mask)
|
170 |
|
171 |
# Define colors for each class (RGB)
|
172 |
colors = [
|
|
|
288 |
else:
|
289 |
st.sidebar.success("Model loaded successfully!")
|
290 |
|
291 |
+
# Image upload section
|
292 |
st.header("Upload an Image")
|
293 |
uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"])
|
294 |
uploaded_mask = st.file_uploader("Upload ground truth mask (optional):", type=["png", "jpg", "jpeg"])
|
295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
# Process uploaded image
|
297 |
if uploaded_image is not None:
|
298 |
+
try:
|
299 |
+
# Read the image
|
300 |
+
image_bytes = uploaded_image.read()
|
301 |
+
image = Image.open(io.BytesIO(image_bytes))
|
302 |
+
|
303 |
+
col1, col2 = st.columns(2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
+
with col1:
|
306 |
+
st.subheader("Original Image")
|
307 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
308 |
|
309 |
+
# Preprocess and predict
|
310 |
+
with st.spinner("Generating segmentation mask..."):
|
311 |
+
# Preprocess the image
|
312 |
+
img_tensor, original_img = preprocess_image(image)
|
313 |
+
|
314 |
+
# Make prediction
|
315 |
+
outputs = model(pixel_values=img_tensor, training=False)
|
316 |
+
logits = outputs.logits
|
317 |
+
|
318 |
+
# Create visualization mask
|
319 |
+
mask = create_mask(logits)
|
320 |
+
|
321 |
+
# Colorize the mask
|
322 |
+
colorized_mask = colorize_mask(mask)
|
323 |
+
|
324 |
+
# Create overlay
|
325 |
+
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
|
326 |
|
327 |
+
# Display results
|
328 |
+
with col2:
|
329 |
+
st.subheader("Segmentation Result")
|
330 |
+
st.image(overlay, caption="Segmentation Overlay", use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
+
# Display segmentation details
|
333 |
+
st.header("Segmentation Details")
|
334 |
+
col1, col2, col3 = st.columns(3)
|
|
|
|
|
335 |
|
336 |
+
with col1:
|
337 |
+
st.subheader("Background")
|
338 |
+
st.markdown("Areas surrounding the pet")
|
339 |
+
mask_bg = np.where(mask == 0, 255, 0).astype(np.uint8)
|
340 |
+
st.image(mask_bg, caption="Background", use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
+
with col2:
|
343 |
+
st.subheader("Border")
|
344 |
+
st.markdown("Boundary around the pet")
|
345 |
+
mask_border = np.where(mask == 1, 255, 0).astype(np.uint8)
|
346 |
+
st.image(mask_border, caption="Border", use_column_width=True)
|
347 |
|
348 |
+
with col3:
|
349 |
+
st.subheader("Foreground (Pet)")
|
350 |
+
st.markdown("The pet itself")
|
351 |
+
mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8)
|
352 |
+
st.image(mask_fg, caption="Foreground", use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
+
# Calculate IoU if ground truth is uploaded
|
355 |
+
if uploaded_mask is not None:
|
356 |
+
try:
|
357 |
+
# Read the mask file
|
358 |
+
mask_data = uploaded_mask.read()
|
359 |
+
mask_io = io.BytesIO(mask_data)
|
360 |
+
gt_mask = np.array(Image.open(mask_io).resize((OUTPUT_SIZE, OUTPUT_SIZE), Image.NEAREST))
|
361 |
+
|
362 |
+
# Handle different mask formats
|
363 |
+
if len(gt_mask.shape) == 3 and gt_mask.shape[2] == 3:
|
364 |
+
# Convert RGB to single channel if needed
|
365 |
+
gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_RGB2GRAY)
|
366 |
+
|
367 |
+
# Calculate and display IoU
|
368 |
+
resized_mask = cv2.resize(mask, (OUTPUT_SIZE, OUTPUT_SIZE), interpolation=cv2.INTER_NEAREST)
|
369 |
+
iou_score = calculate_iou(gt_mask, resized_mask)
|
370 |
+
st.success(f"Mean IoU: {iou_score:.4f}")
|
371 |
+
|
372 |
+
# Display specific class IoUs
|
373 |
+
st.markdown("### IoU by Class")
|
374 |
+
col1, col2, col3 = st.columns(3)
|
375 |
+
with col1:
|
376 |
+
bg_iou = calculate_iou(gt_mask, resized_mask, 0)
|
377 |
+
st.metric("Background IoU", f"{bg_iou:.4f}")
|
378 |
+
with col2:
|
379 |
+
border_iou = calculate_iou(gt_mask, resized_mask, 1)
|
380 |
+
st.metric("Border IoU", f"{border_iou:.4f}")
|
381 |
+
with col3:
|
382 |
+
fg_iou = calculate_iou(gt_mask, resized_mask, 2)
|
383 |
+
st.metric("Foreground IoU", f"{fg_iou:.4f}")
|
384 |
+
except Exception as e:
|
385 |
+
st.error(f"Error processing ground truth mask: {e}")
|
386 |
+
st.write("Please ensure the mask is valid and has the correct format.")
|
387 |
+
|
388 |
+
# Download buttons
|
389 |
+
col1, col2 = st.columns(2)
|
390 |
|
391 |
+
with col1:
|
392 |
+
# Convert mask to PNG for download
|
393 |
+
mask_colored = Image.fromarray(colorized_mask)
|
394 |
+
mask_bytes = io.BytesIO()
|
395 |
+
mask_colored.save(mask_bytes, format='PNG')
|
396 |
+
mask_bytes = mask_bytes.getvalue()
|
397 |
+
|
398 |
+
st.download_button(
|
399 |
+
label="Download Segmentation Mask",
|
400 |
+
data=mask_bytes,
|
401 |
+
file_name="pet_segmentation_mask.png",
|
402 |
+
mime="image/png"
|
403 |
+
)
|
404 |
+
|
405 |
+
with col2:
|
406 |
+
# Convert overlay to PNG for download
|
407 |
+
overlay_img = Image.fromarray(overlay)
|
408 |
+
overlay_bytes = io.BytesIO()
|
409 |
+
overlay_img.save(overlay_bytes, format='PNG')
|
410 |
+
overlay_bytes = overlay_bytes.getvalue()
|
411 |
+
|
412 |
+
st.download_button(
|
413 |
+
label="Download Overlay Image",
|
414 |
+
data=overlay_bytes,
|
415 |
+
file_name="pet_segmentation_overlay.png",
|
416 |
+
mime="image/png"
|
417 |
+
)
|
418 |
+
except Exception as e:
|
419 |
+
st.error(f"Error processing image: {e}")
|
420 |
|
421 |
# Footer with additional information
|
422 |
st.markdown("---")
|