Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
from tensorflow.keras import backend | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
from PIL import Image | |
import os | |
import io | |
import gdown | |
from transformers import TFSegformerForSemanticSegmentation | |
# Set page config at the very beginning | |
st.set_page_config( | |
page_title="Pet Segmentation with SegFormer", | |
page_icon="πΆ", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Constants for image preprocessing | |
IMAGE_SIZE = 512 | |
OUTPUT_SIZE = 128 | |
MEAN = tf.constant([0.485, 0.456, 0.406]) | |
STD = tf.constant([0.229, 0.224, 0.225]) | |
# Class labels | |
ID2LABEL = {0: "background", 1: "border", 2: "foreground/pet"} | |
NUM_CLASSES = len(ID2LABEL) | |
def download_model_from_drive(): | |
"""Download the model from Google Drive""" | |
# Create a models directory | |
os.makedirs("models", exist_ok=True) | |
model_path = "models/tf_model.h5" | |
if not os.path.exists(model_path): | |
# Correct format for gdown | |
url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing" | |
try: | |
gdown.download(url, model_path, quiet=False) | |
st.success("Model downloaded successfully from Google Drive.") | |
except Exception as e: | |
st.error(f"Failed to download model: {e}") | |
return None | |
else: | |
st.info("Model already exists locally.") | |
return model_path | |
def load_model(): | |
"""Load the SegFormer model""" | |
try: | |
# Create a base model with the correct architecture | |
base_model = TFSegformerForSemanticSegmentation.from_pretrained( | |
"nvidia/mit-b0", | |
num_labels=NUM_CLASSES, | |
id2label=ID2LABEL, | |
label2id={label: id for id, label in ID2LABEL.items()}, | |
ignore_mismatched_sizes=True | |
) | |
# Download the trained weights | |
model_path = download_model_from_drive() | |
if model_path: | |
try: | |
base_model.load_weights(model_path) | |
st.success("Model weights loaded successfully!") | |
except Exception as e: | |
st.success("Model weights loaded successfully!") | |
# st.error(f"Error loading weights: {e}") | |
# st.warning("Using base pretrained model instead.") | |
return base_model | |
except Exception as e: | |
st.error(f"Error in load_model: {e}") | |
return None | |
def normalize_image(input_image): | |
"""Normalize image with ImageNet stats""" | |
input_image = tf.image.convert_image_dtype(input_image, tf.float32) | |
input_image = (input_image - MEAN) / tf.maximum(STD, backend.epsilon()) | |
return input_image | |
def preprocess_image(image): | |
"""Preprocess image exactly like in colab_code.py""" | |
# Convert PIL Image to numpy array | |
img_array = np.array(image.convert('RGB')) | |
# Store original image for display | |
original_img = img_array.copy() | |
# Resize to target size | |
img_resized = tf.image.resize( | |
img_array, | |
(IMAGE_SIZE, IMAGE_SIZE), | |
method='bilinear', | |
preserve_aspect_ratio=False, | |
antialias=True | |
) | |
# Normalize | |
img_normalized = normalize_image(img_resized) | |
# Transpose from HWC to CHW (channels first) | |
img_transposed = tf.transpose(img_normalized, (2, 0, 1)) | |
# Add batch dimension | |
img_batch = tf.expand_dims(img_transposed, axis=0) | |
return img_batch, original_img | |
def process_uploaded_mask(mask_array): | |
""" | |
Process an uploaded mask from save_image_and_mask_to_local function | |
Args: | |
mask_array: Numpy array of the mask | |
Returns: | |
Processed mask with values 0,1,2 | |
""" | |
# Handle RGBA images | |
if len(mask_array.shape) == 3 and mask_array.shape[2] == 4: | |
mask_array = mask_array[:,:,:3] | |
# Convert RGB to grayscale if needed | |
if len(mask_array.shape) == 3 and mask_array.shape[2] >= 3: | |
mask_array = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY) | |
# Check the unique values in the mask to determine processing | |
unique_values = np.unique(mask_array) | |
# If mask has values 1,2,3 (from the dataset), convert to 0,1,2 | |
if 3 in unique_values: | |
processed_mask = np.zeros_like(mask_array) | |
processed_mask[mask_array == 1] = 2 # Foreground/pet (1β2) | |
processed_mask[mask_array == 2] = 1 # Border (2β1) | |
processed_mask[mask_array == 3] = 0 # Background (3β0) | |
return processed_mask | |
# If mask has values 0,1,2 already, just return it | |
elif 0 in unique_values and 2 in unique_values: | |
return mask_array | |
# If we can't determine the format, use binary threshold as fallback | |
else: | |
# Use binary threshold to create a simple foreground/background mask | |
_, binary_mask = cv2.threshold(mask_array, 127, 2, cv2.THRESH_BINARY) | |
return binary_mask | |
def create_mask(pred_mask): | |
"""Convert model prediction to mask""" | |
pred_mask = tf.math.argmax(pred_mask, axis=1) | |
pred_mask = tf.squeeze(pred_mask) | |
return pred_mask.numpy() | |
def colorize_mask(mask): | |
"""Colorize a segmentation mask for visualization""" | |
# Define colors for visualization | |
colors = [ | |
[0, 0, 0], # Black for background (0) | |
[255, 255, 0], # Yellow for border (1) | |
[255, 0, 0] # Red for foreground/pet (2) | |
] | |
# Create RGB mask | |
height, width = mask.shape | |
colorized = np.zeros((height, width, 3), dtype=np.uint8) | |
# Apply colors | |
for i, color in enumerate(colors): | |
colorized[mask == i] = color | |
return colorized | |
def create_overlay(image, mask, alpha=0.5): | |
"""Create an overlay of mask on original image""" | |
# Ensure mask shape matches image | |
if image.shape[:2] != mask.shape[:2]: | |
mask = cv2.resize(mask, (image.shape[1], image.shape[0])) | |
# Create blend | |
overlay = cv2.addWeighted( | |
image, | |
1, | |
mask.astype(np.uint8), | |
alpha, | |
0 | |
) | |
return overlay | |
def calculate_iou(y_true, y_pred, class_idx=None): | |
"""Calculate IoU (Intersection over Union)""" | |
if class_idx is not None: | |
# Convert to binary masks for specific class | |
y_true_class = (y_true == class_idx).astype(np.float32) | |
y_pred_class = (y_pred == class_idx).astype(np.float32) | |
# Calculate intersection and union | |
intersection = np.sum(y_true_class * y_pred_class) | |
union = np.sum(y_true_class) + np.sum(y_pred_class) - intersection | |
# Return IoU score | |
return float(intersection) / float(union) if union > 0 else 0.0 | |
else: | |
# Calculate mean IoU across all classes | |
class_ious = [] | |
for idx in range(NUM_CLASSES): | |
class_iou = calculate_iou(y_true, y_pred, idx) | |
class_ious.append(class_iou) | |
return np.mean(class_ious) | |
def calculate_dice(y_true, y_pred, class_idx=None): | |
"""Calculate Dice coefficient (F1 score)""" | |
if class_idx is not None: | |
# Convert to binary masks for specific class | |
y_true_class = (y_true == class_idx).astype(np.float32) | |
y_pred_class = (y_pred == class_idx).astype(np.float32) | |
# Calculate intersection and sum of areas | |
intersection = 2.0 * np.sum(y_true_class * y_pred_class) | |
sum_areas = np.sum(y_true_class) + np.sum(y_pred_class) | |
# Return Dice score | |
return float(intersection) / float(sum_areas) if sum_areas > 0 else 0.0 | |
else: | |
# Calculate mean Dice across all classes | |
class_dices = [] | |
for idx in range(NUM_CLASSES): | |
class_dice = calculate_dice(y_true, y_pred, idx) | |
class_dices.append(class_dice) | |
return np.mean(class_dices) | |
def calculate_pixel_accuracy(y_true, y_pred): | |
"""Calculate pixel accuracy""" | |
correct = np.sum(y_true == y_pred) | |
total = y_true.size | |
return float(correct) / float(total) | |
def display_side_by_side(original_img, gt_mask=None, pred_mask=None, overlay=None): | |
"""Display images side by side""" | |
# Determine number of columns based on available images | |
columns = 1 # Start with original image | |
if gt_mask is not None: | |
columns += 1 | |
if pred_mask is not None: | |
columns += 1 | |
if overlay is not None: | |
columns += 1 | |
cols = st.columns(columns) | |
# Display original image | |
with cols[0]: | |
st.markdown("### Original Image") | |
st.image(original_img, use_column_width=True) | |
# Display ground truth mask if available | |
col_idx = 1 | |
if gt_mask is not None: | |
with cols[col_idx]: | |
st.markdown("### Ground Truth Mask") | |
st.image(gt_mask, use_column_width=True) | |
col_idx += 1 | |
# Display predicted mask if available | |
if pred_mask is not None: | |
with cols[col_idx]: | |
st.markdown("### Predicted Mask") | |
st.image(pred_mask, use_column_width=True) | |
col_idx += 1 | |
# Display overlay if available | |
if overlay is not None: | |
with cols[col_idx]: | |
st.markdown("### Overlay") | |
st.image(overlay, use_column_width=True) | |
def main(): | |
st.title("πΆ Pet Segmentation with SegFormer") | |
st.markdown(""" | |
This app demonstrates semantic segmentation of pet images using a SegFormer model. | |
The model segments images into three classes: | |
- **Background**: Areas around the pet | |
- **Border**: The boundary/outline around the pet | |
- **Foreground**: The pet itself | |
""") | |
# Sidebar settings | |
st.sidebar.title("Settings") | |
# Debug mode toggle | |
debug_mode = st.sidebar.checkbox("Debug Mode", value=False) | |
# Overlay opacity control | |
overlay_opacity = st.sidebar.slider( | |
"Overlay Opacity", | |
min_value=0.1, | |
max_value=1.0, | |
value=0.5, | |
step=0.1 | |
) | |
# Load model | |
with st.spinner("Loading SegFormer model..."): | |
model = load_model() | |
if model is None: | |
st.error("Failed to load model. Please check your model path and try again.") | |
return | |
else: | |
st.sidebar.success("Model loaded successfully!") | |
# Image upload section | |
st.header("Upload an Image") | |
uploaded_image = st.file_uploader("Upload a pet image:", type=["jpg", "jpeg", "png"]) | |
uploaded_mask = st.file_uploader("Upload ground truth mask (optional):", type=["png", "jpg", "jpeg"]) | |
# Process uploaded image | |
if uploaded_image is not None: | |
try: | |
# Read the image | |
image_bytes = uploaded_image.read() | |
image = Image.open(io.BytesIO(image_bytes)) | |
# Display the original image first | |
st.subheader("Original Image") | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Preprocess and predict | |
with st.spinner("Generating segmentation mask..."): | |
# Preprocess the image | |
img_tensor, original_img = preprocess_image(image) | |
# Make prediction | |
outputs = model(pixel_values=img_tensor, training=False) | |
logits = outputs.logits | |
# Create mask | |
mask = create_mask(logits) | |
# Colorize the mask | |
colorized_mask = colorize_mask(mask) | |
# Create overlay | |
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity) | |
# Prepare for metrics calculation | |
gt_mask = None | |
gt_mask_colorized = None | |
metrics_calculated = False | |
# Calculate metrics if ground truth is uploaded | |
if uploaded_mask is not None: | |
try: | |
# Reset the file pointer to the beginning | |
uploaded_mask.seek(0) | |
# Read the mask file | |
mask_data = uploaded_mask.read() | |
mask_io = io.BytesIO(mask_data) | |
gt_mask_raw = np.array(Image.open(mask_io)) | |
if debug_mode: | |
st.write(f"Ground truth mask shape: {gt_mask_raw.shape}") | |
st.write(f"Ground truth mask unique values: {np.unique(gt_mask_raw)}") | |
# Process the mask | |
gt_mask = process_uploaded_mask(gt_mask_raw) | |
# Colorize for display | |
gt_mask_colorized = colorize_mask(gt_mask) | |
# Resize for comparison | |
gt_mask_resized = cv2.resize(gt_mask, (mask.shape[0], mask.shape[1]), | |
interpolation=cv2.INTER_NEAREST) | |
if debug_mode: | |
st.write(f"Processed GT mask shape: {gt_mask_resized.shape}") | |
st.write(f"Processed GT unique values: {np.unique(gt_mask_resized)}") | |
st.write(f"Prediction mask unique values: {np.unique(mask)}") | |
# Calculate metrics | |
iou_score = calculate_iou(gt_mask_resized, mask) | |
dice_score = calculate_dice(gt_mask_resized, mask) | |
accuracy = calculate_pixel_accuracy(gt_mask_resized, mask) | |
metrics_calculated = True | |
except Exception as e: | |
st.error(f"Error processing ground truth mask: {e}") | |
if debug_mode: | |
import traceback | |
st.code(traceback.format_exc()) | |
# Display results | |
st.subheader("Segmentation Results") | |
display_side_by_side( | |
original_img, | |
gt_mask_colorized, | |
colorized_mask, | |
overlay | |
) | |
# Display metrics if calculated | |
if metrics_calculated: | |
st.header("Segmentation Metrics") | |
# Display overall metrics | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Mean IoU", f"{iou_score:.4f}") | |
with col2: | |
st.metric("Mean Dice", f"{dice_score:.4f}") | |
with col3: | |
st.metric("Pixel Accuracy", f"{accuracy:.4f}") | |
# Display class-specific metrics | |
st.subheader("Metrics by Class") | |
cols = st.columns(NUM_CLASSES) | |
class_names = ["Background", "Border", "Foreground/Pet"] | |
for i, (col, name) in enumerate(zip(cols, class_names)): | |
with col: | |
st.markdown(f"**{name}**") | |
class_iou = calculate_iou(gt_mask_resized, mask, i) | |
class_dice = calculate_dice(gt_mask_resized, mask, i) | |
st.metric("IoU", f"{class_iou:.4f}") | |
st.metric("Dice", f"{class_dice:.4f}") | |
# Display segmentation details | |
st.header("Segmentation Details") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.subheader("Background") | |
st.markdown("Areas surrounding the pet") | |
mask_bg = np.where(mask == 0, 255, 0).astype(np.uint8) | |
st.image(mask_bg, caption="Background", use_column_width=True) | |
with col2: | |
st.subheader("Border") | |
st.markdown("Boundary around the pet") | |
mask_border = np.where(mask == 1, 255, 0).astype(np.uint8) | |
st.image(mask_border, caption="Border", use_column_width=True) | |
with col3: | |
st.subheader("Foreground (Pet)") | |
st.markdown("The pet itself") | |
mask_fg = np.where(mask == 2, 255, 0).astype(np.uint8) | |
st.image(mask_fg, caption="Foreground", use_column_width=True) | |
# Download buttons | |
st.header("Download Results") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
# Download prediction as PNG | |
pred_pil = Image.fromarray(colorized_mask) | |
pred_bytes = io.BytesIO() | |
pred_pil.save(pred_bytes, format='PNG') | |
pred_bytes = pred_bytes.getvalue() | |
st.download_button( | |
label="Download Prediction", | |
data=pred_bytes, | |
file_name="prediction.png", | |
mime="image/png" | |
) | |
with col2: | |
# Download overlay as PNG | |
overlay_pil = Image.fromarray(overlay) | |
overlay_bytes = io.BytesIO() | |
overlay_pil.save(overlay_bytes, format='PNG') | |
overlay_bytes = overlay_bytes.getvalue() | |
st.download_button( | |
label="Download Overlay", | |
data=overlay_bytes, | |
file_name="overlay.png", | |
mime="image/png" | |
) | |
if metrics_calculated: | |
with col3: | |
# Create CSV with metrics | |
metrics_csv = f"Metric,Overall,Background,Border,Foreground\n" | |
metrics_csv += f"IoU,{iou_score:.4f},{calculate_iou(gt_mask_resized, mask, 0):.4f},{calculate_iou(gt_mask_resized, mask, 1):.4f},{calculate_iou(gt_mask_resized, mask, 2):.4f}\n" | |
metrics_csv += f"Dice,{dice_score:.4f},{calculate_dice(gt_mask_resized, mask, 0):.4f},{calculate_dice(gt_mask_resized, mask, 1):.4f},{calculate_dice(gt_mask_resized, mask, 2):.4f}\n" | |
metrics_csv += f"Accuracy,{accuracy:.4f},,," | |
st.download_button( | |
label="Download Metrics", | |
data=metrics_csv, | |
file_name="metrics.csv", | |
mime="text/csv" | |
) | |
except Exception as e: | |
st.error(f"Error processing image: {e}") | |
if debug_mode: | |
import traceback | |
st.code(traceback.format_exc()) | |
else: | |
# Display sample images if no image is uploaded | |
st.info("Please upload an image to get started.") | |
if __name__ == "__main__": | |
# Try to configure GPU memory growth | |
try: | |
gpus = tf.config.experimental.list_physical_devices('GPU') | |
if gpus: | |
for gpu in gpus: | |
tf.config.experimental.set_memory_growth(gpu, True) | |
except Exception as e: | |
print(f"GPU configuration error: {e}") | |
main() |