Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
from PIL import Image, ImageFilter | |
import torch | |
import gradio as gr | |
from torchvision import transforms | |
from transformers import ( | |
AutoModelForImageSegmentation, | |
DepthProImageProcessorFast, | |
DepthProForDepthEstimation, | |
) | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# ----------------------------- | |
# Load Segmentation Model (RMBG-2.0 by briaai) | |
# ----------------------------- | |
seg_model = AutoModelForImageSegmentation.from_pretrained( | |
"briaai/RMBG-2.0", trust_remote_code=True | |
) | |
torch.set_float32_matmul_precision(["high", "highest"][0]) | |
seg_model.to(device) | |
seg_model.eval() | |
# Define segmentation image size and transform | |
seg_image_size = (1024, 1024) | |
seg_transform = transforms.Compose([ | |
transforms.Resize(seg_image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# ----------------------------- | |
# Load Depth Estimation Model (DepthPro by Apple) | |
# ----------------------------- | |
depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") | |
depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf") | |
depth_model.to(device) | |
depth_model.eval() | |
# ----------------------------- | |
# Define the Segmentation-Based Blur Effect | |
# ----------------------------- | |
def segmentation_blur_effect(input_image: Image.Image): | |
imageResized = input_image.resize(seg_image_size) | |
input_tensor = seg_transform(imageResized).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
preds = seg_model(input_tensor)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(imageResized.size) | |
mask_np = np.array(mask.convert("L")) | |
_, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) | |
img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR) | |
blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15) | |
maskInv = cv2.bitwise_not(maskBinary) | |
maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR) | |
foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3)) | |
background = cv2.bitwise_and(blurredBg, maskInv3) | |
finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background) | |
finalImg_pil = Image.fromarray(finalImg) | |
return finalImg_pil, mask | |
def lens_blur_effect(input_image: Image.Image, fg_threshold: float = 85, mg_threshold: float = 170): | |
inputs = depth_processor(images=input_image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = depth_model(**inputs) | |
post_processed_output = depth_processor.post_process_depth_estimation( | |
outputs, target_sizes=[(input_image.height, input_image.width)] | |
) | |
depth = post_processed_output[0]["predicted_depth"] | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) | |
depth = depth * 255. | |
depth = depth.detach().cpu().numpy() | |
depth_map = depth.astype(np.uint8) | |
depthImg = Image.fromarray(depth_map) | |
img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) | |
img_foreground = img.copy() # No blur for foreground | |
img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7) | |
img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15) | |
print(depth_map) | |
depth_map = depth_map.astype(np.float32) / depth_map.max() | |
threshold1 = fg_threshold | |
threshold2 = mg_threshold | |
mask_fg = (depth_map < threshold1).astype(np.float32) | |
mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32) | |
mask_bg = (depth_map >= threshold2).astype(np.float32) | |
mask_fg_3 = np.stack([mask_fg]*3, axis=-1) | |
mask_mg_3 = np.stack([mask_mg]*3, axis=-1) | |
mask_bg_3 = np.stack([mask_bg]*3, axis=-1) | |
final_img = (img_foreground * mask_fg_3 + | |
img_middleground * mask_mg_3 + | |
img_background * mask_bg_3).astype(np.uint8) | |
final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB) | |
lensBlurImage = Image.fromarray(final_img_rgb) | |
mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8)) | |
mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8)) | |
mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8)) | |
return depthImg, lensBlurImage, mask_fg_img, mask_mg_img, mask_bg_img | |
def process_image(input_image: Image.Image, fg_threshold: float, mg_threshold: float): | |
seg_blur, seg_mask = segmentation_blur_effect(input_image) | |
depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect( | |
input_image, fg_threshold, mg_threshold | |
) | |
return ( | |
seg_blur, | |
# seg_mask, | |
depth_map_img, | |
lens_blur_img, | |
# mask_fg_img, | |
# mask_mg_img, | |
# mask_bg_img | |
) | |
def update_preset(preset: str): | |
presets = { | |
"Preset 1": { | |
"image_url": "https://i.ibb.co/fznz2b2b/hw3-q2.jpg", | |
"fg_threshold": 0.33, | |
"mg_threshold": 0.66 | |
}, | |
"Preset 2": { | |
"image_url": "https://i.ibb.co/HLZGW7qH/q26.jpg", | |
"fg_threshold": 0.2, | |
"mg_threshold": 0.66 | |
} | |
} | |
preset_info = presets[preset] | |
response = requests.get(preset_info["image_url"]) | |
image = Image.open(BytesIO(response.content)).convert("RGB") | |
return image, preset_info["fg_threshold"], preset_info["mg_threshold"] | |
title = "Blur Effects on Segmentation-Based Gaussian Blur & Depth-Based Lens Blur with Adjustable Depth Thresholds" | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image", value="https://i.ibb.co/fznz2b2b/hw3-q2.jpg"), | |
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.33, label="Foreground Depth Threshold"), | |
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.66, label="Middleground Depth Threshold") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmentation-Based Blur"), | |
gr.Image(type="pil", label="Depth Map"), | |
gr.Image(type="pil", label="Depth-Based Lens Blur") | |
], | |
title=title, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |