Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation | |
from transformers import DPTImageProcessor, DPTForDepthEstimation | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Load segmentation model - using SegFormer which is compatible with AutoModelForSemanticSegmentation | |
seg_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
seg_model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
# Load depth estimation model | |
depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large") | |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large") | |
def safe_resize(image, target_size, interpolation=cv2.INTER_LINEAR): | |
"""Safely resize an image with validation checks.""" | |
if image is None: | |
return None | |
# Ensure image is a proper numpy array | |
if not isinstance(image, np.ndarray): | |
return None | |
# Check that dimensions are valid (non-zero) | |
h, w = target_size | |
if h <= 0 or w <= 0 or image.shape[0] <= 0 or image.shape[1] <= 0: | |
return image # Return original if target dimensions are invalid | |
# Handle grayscale images differently | |
if len(image.shape) == 2: | |
return cv2.resize(image, (w, h), interpolation=interpolation) | |
else: | |
return cv2.resize(image, (w, h), interpolation=interpolation) | |
def apply_gaussian_blur(image, mask, sigma=15): | |
"""Apply Gaussian blur to the background of an image based on a mask.""" | |
try: | |
# Convert mask to binary (0 and 255) | |
if mask.max() <= 1.0: | |
binary_mask = (mask * 255).astype(np.uint8) | |
else: | |
binary_mask = mask.astype(np.uint8) | |
# Create a blurred version of the entire image | |
blurred = cv2.GaussianBlur(image, (0, 0), sigma) | |
# Resize mask to match image dimensions if needed | |
if binary_mask.shape[:2] != image.shape[:2]: | |
binary_mask = safe_resize(binary_mask, (image.shape[0], image.shape[1])) | |
# Create a 3-channel mask if the input mask is single-channel | |
if len(binary_mask.shape) == 2: | |
mask_3ch = np.stack([binary_mask, binary_mask, binary_mask], axis=2) | |
else: | |
mask_3ch = binary_mask | |
# Normalize mask to range [0, 1] | |
mask_3ch = mask_3ch / 255.0 | |
# Combine original image (foreground) with blurred image (background) using the mask | |
result = image * mask_3ch + blurred * (1 - mask_3ch) | |
return result.astype(np.uint8) | |
except Exception as e: | |
print(f"Error in apply_gaussian_blur: {e}") | |
return image # Return original image if there's an error | |
def apply_depth_blur(image, depth_map, max_sigma=25): | |
"""Apply variable Gaussian blur based on depth map.""" | |
try: | |
# Normalize depth map to range [0, 1] | |
if depth_map.max() > 1.0: | |
depth_norm = depth_map / depth_map.max() | |
else: | |
depth_norm = depth_map | |
# Resize depth map to match image dimensions if needed | |
if depth_norm.shape[:2] != image.shape[:2]: | |
depth_norm = safe_resize(depth_norm, (image.shape[0], image.shape[1])) | |
# Create output image | |
result = np.zeros_like(image) | |
# Instead of many small blurs, use fewer blur levels for efficiency | |
blur_levels = 5 | |
step = max_sigma / blur_levels | |
for i in range(blur_levels): | |
sigma = (i + 1) * step | |
# Calculate depth range for this blur level | |
lower_bound = i / blur_levels | |
upper_bound = (i + 1) / blur_levels | |
# Create mask for pixels in this depth range | |
mask = np.logical_and(depth_norm >= lower_bound, depth_norm <= upper_bound).astype(np.float32) | |
# Skip if no pixels in this range | |
if not np.any(mask): | |
continue | |
# Apply blur for this level | |
blurred = cv2.GaussianBlur(image, (0, 0), sigma) | |
# Create 3-channel mask | |
mask_3ch = np.stack([mask, mask, mask], axis=2) if len(mask.shape) == 2 else mask | |
# Add to result | |
result += (blurred * mask_3ch).astype(np.uint8) | |
# Check if there are any pixels not covered and fill with original | |
total_mask = np.zeros_like(depth_norm) | |
for i in range(blur_levels): | |
lower_bound = i / blur_levels | |
upper_bound = (i + 1) / blur_levels | |
mask = np.logical_and(depth_norm >= lower_bound, depth_norm <= upper_bound).astype(np.float32) | |
total_mask += mask | |
missing_mask = (total_mask < 0.5).astype(np.float32) | |
if np.any(missing_mask): | |
missing_mask_3ch = np.stack([missing_mask, missing_mask, missing_mask], axis=2) | |
result += (image * missing_mask_3ch).astype(np.uint8) | |
return result | |
except Exception as e: | |
print(f"Error in apply_depth_blur: {e}") | |
return image # Return original image if there's an error | |
def get_segmentation_mask(image_pil): | |
"""Get segmentation mask for person/foreground from an image.""" | |
try: | |
# Process the image with the segmentation model | |
inputs = seg_processor(images=image_pil, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = seg_model(**inputs) | |
# Get the predicted segmentation mask | |
logits = outputs.logits | |
upsampled_logits = torch.nn.functional.interpolate( | |
logits, | |
size=image_pil.size[::-1], # Resize directly to original size | |
mode="bilinear", | |
align_corners=False, | |
) | |
# Get the predicted class for each pixel | |
predicted_mask = upsampled_logits.argmax(dim=1)[0] | |
# Convert the mask to a numpy array | |
mask_np = predicted_mask.cpu().numpy() | |
# Create a foreground mask - human and common foreground objects | |
# Classes based on ADE20K dataset | |
foreground_classes = [12] # Person class (you can add more classes as needed) | |
# Create a binary mask for foreground classes | |
foreground_mask = np.zeros_like(mask_np) | |
for cls in foreground_classes: | |
foreground_mask[mask_np == cls] = 1 | |
return foreground_mask | |
except Exception as e: | |
print(f"Error in get_segmentation_mask: {e}") | |
# Return a default mask (all ones) in case of error | |
return np.ones((image_pil.size[1], image_pil.size[0]), dtype=np.uint8) | |
def get_depth_map(image_pil): | |
"""Get depth map from an image.""" | |
try: | |
# Process the image with the depth estimation model | |
inputs = depth_processor(images=image_pil, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = depth_model(**inputs) | |
predicted_depth = outputs.predicted_depth | |
# Interpolate to original size | |
prediction = torch.nn.functional.interpolate( | |
predicted_depth.unsqueeze(1), | |
size=image_pil.size[::-1], | |
mode="bicubic", | |
align_corners=False, | |
) | |
# Convert to numpy array | |
depth_map = prediction.squeeze().cpu().numpy() | |
# Normalize depth map | |
depth_min = depth_map.min() | |
depth_max = depth_map.max() | |
if depth_max > depth_min: | |
depth_map = (depth_map - depth_min) / (depth_max - depth_min) | |
else: | |
depth_map = np.zeros_like(depth_map) | |
return depth_map | |
except Exception as e: | |
print(f"Error in get_depth_map: {e}") | |
# Return a default depth map (gradient from top to bottom) in case of error | |
h, w = image_pil.size[1], image_pil.size[0] | |
default_depth = np.zeros((h, w), dtype=np.float32) | |
for i in range(h): | |
default_depth[i, :] = i / h | |
return default_depth | |
def process_image(input_image, blur_sigma=15, depth_blur_sigma=25): | |
"""Main function to process the input image.""" | |
try: | |
# Input validation | |
if input_image is None: | |
print("No input image provided") | |
return [None, None, None, None, None] | |
# Convert to PIL Image if needed | |
if isinstance(input_image, np.ndarray): | |
# Make sure we have a valid image with at least 2 dimensions | |
if input_image.ndim < 2 or input_image.shape[0] <= 0 or input_image.shape[1] <= 0: | |
print("Invalid input image dimensions") | |
return [None, None, None, None, None] | |
pil_image = Image.fromarray(input_image) | |
else: | |
pil_image = input_image | |
input_image = np.array(pil_image) | |
# Get segmentation mask | |
print("Getting segmentation mask...") | |
seg_mask = get_segmentation_mask(pil_image) | |
# Get depth map | |
print("Getting depth map...") | |
depth_map = get_depth_map(pil_image) | |
# Apply gaussian blur to background | |
print("Applying gaussian blur...") | |
gaussian_result = apply_gaussian_blur(input_image, seg_mask, sigma=blur_sigma) | |
# Apply depth-based blur | |
print("Applying depth-based blur...") | |
depth_result = apply_depth_blur(input_image, depth_map, max_sigma=depth_blur_sigma) | |
# Display depth map as an image | |
depth_visualization = (depth_map * 255).astype(np.uint8) | |
depth_colored = cv2.applyColorMap(depth_visualization, cv2.COLORMAP_INFERNO) | |
# Display segmentation mask | |
seg_visualization = (seg_mask * 255).astype(np.uint8) | |
print("Processing complete!") | |
return [ | |
input_image, | |
seg_visualization, | |
gaussian_result, | |
depth_colored, | |
depth_result | |
] | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return [None, None, None, None, None] | |
# Create Gradio interface | |
with gr.Blocks(title="Image Blur Effects with Segmentation and Depth Estimation") as demo: | |
gr.Markdown("# Image Blur Effects App") | |
gr.Markdown("This app demonstrates two types of blur effects: background blur using segmentation and depth-based lens blur.") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload an image", type="numpy") | |
blur_sigma = gr.Slider(minimum=1, maximum=50, value=15, step=1, label="Background Blur Intensity") | |
depth_blur_sigma = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Depth Blur Max Intensity") | |
process_btn = gr.Button("Process Image") | |
with gr.Column(): | |
with gr.Tab("Original Image"): | |
output_original = gr.Image(label="Original Image") | |
with gr.Tab("Segmentation Mask"): | |
output_segmentation = gr.Image(label="Segmentation Mask") | |
with gr.Tab("Background Blur"): | |
output_gaussian = gr.Image(label="Background Blur Result") | |
with gr.Tab("Depth Map"): | |
output_depth = gr.Image(label="Depth Map") | |
with gr.Tab("Depth-based Lens Blur"): | |
output_depth_blur = gr.Image(label="Depth-based Lens Blur Result") | |
process_btn.click( | |
fn=process_image, | |
inputs=[input_image, blur_sigma, depth_blur_sigma], | |
outputs=[output_original, output_segmentation, output_gaussian, output_depth, output_depth_blur] | |
) | |
gr.Markdown(""" | |
## How it works | |
1. **Background Blur**: Uses a SegFormer model to identify foreground objects (like people) and blurs only the background | |
2. **Depth-based Lens Blur**: Uses a DPT depth estimation model to apply variable blur based on estimated distance | |
Try uploading a photo of a person against a background to see the effects! | |
""") | |
demo.launch() |