import cv2 import numpy as np from PIL import Image, ImageFilter import torch import gradio as gr from torchvision import transforms from transformers import ( AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation, ) # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # Load Segmentation Model (RMBG-2.0 by briaai) # ----------------------------- seg_model = AutoModelForImageSegmentation.from_pretrained( "briaai/RMBG-2.0", trust_remote_code=True ) torch.set_float32_matmul_precision(["high", "highest"][0]) seg_model.to(device) seg_model.eval() # Define segmentation image size and transform seg_image_size = (1024, 1024) seg_transform = transforms.Compose([ transforms.Resize(seg_image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ----------------------------- # Load Depth Estimation Model (DepthPro by Apple) # ----------------------------- depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf") depth_model.to(device) depth_model.eval() # ----------------------------- # Define the Segmentation-Based Blur Effect # ----------------------------- def segmentation_blur_effect(input_image: Image.Image): imageResized = input_image.resize(seg_image_size) input_tensor = seg_transform(imageResized).unsqueeze(0).to(device) with torch.no_grad(): preds = seg_model(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(imageResized.size) mask_np = np.array(mask.convert("L")) _, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR) blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15) maskInv = cv2.bitwise_not(maskBinary) maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR) foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3)) background = cv2.bitwise_and(blurredBg, maskInv3) finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background) finalImg_pil = Image.fromarray(finalImg) return finalImg_pil, mask def lens_blur_effect(input_image: Image.Image, fg_threshold: float = 85, mg_threshold: float = 170): inputs = depth_processor(images=input_image, return_tensors="pt").to(device) with torch.no_grad(): outputs = depth_model(**inputs) post_processed_output = depth_processor.post_process_depth_estimation( outputs, target_sizes=[(input_image.height, input_image.width)] ) depth = post_processed_output[0]["predicted_depth"] depth = (depth - depth.min()) / (depth.max() - depth.min()) depth = depth * 255. depth = depth.detach().cpu().numpy() depth_map = depth.astype(np.uint8) depthImg = Image.fromarray(depth_map) img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) img_foreground = img.copy() # No blur for foreground img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7) img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15) print(depth_map) depth_map = depth_map.astype(np.float32) / depth_map.max() threshold1 = fg_threshold threshold2 = mg_threshold mask_fg = (depth_map < threshold1).astype(np.float32) mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32) mask_bg = (depth_map >= threshold2).astype(np.float32) mask_fg_3 = np.stack([mask_fg]*3, axis=-1) mask_mg_3 = np.stack([mask_mg]*3, axis=-1) mask_bg_3 = np.stack([mask_bg]*3, axis=-1) final_img = (img_foreground * mask_fg_3 + img_middleground * mask_mg_3 + img_background * mask_bg_3).astype(np.uint8) final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB) lensBlurImage = Image.fromarray(final_img_rgb) mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8)) mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8)) mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8)) return depthImg, lensBlurImage, mask_fg_img, mask_mg_img, mask_bg_img def process_image(input_image: Image.Image, fg_threshold: float, mg_threshold: float): seg_blur, seg_mask = segmentation_blur_effect(input_image) depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect( input_image, fg_threshold, mg_threshold ) return ( seg_blur, # seg_mask, depth_map_img, lens_blur_img, # mask_fg_img, # mask_mg_img, # mask_bg_img ) def update_preset(preset: str): presets = { "Preset 1": { "image_url": "https://i.ibb.co/fznz2b2b/hw3-q2.jpg", "fg_threshold": 0.33, "mg_threshold": 0.66 }, "Preset 2": { "image_url": "https://i.ibb.co/HLZGW7qH/q26.jpg", "fg_threshold": 0.2, "mg_threshold": 0.66 } } preset_info = presets[preset] response = requests.get(preset_info["image_url"]) image = Image.open(BytesIO(response.content)).convert("RGB") return image, preset_info["fg_threshold"], preset_info["mg_threshold"] title = "Blur Effects on Segmentation-Based Gaussian Blur & Depth-Based Lens Blur with Adjustable Depth Thresholds" demo = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil", label="Input Image", value="https://i.ibb.co/fznz2b2b/hw3-q2.jpg"), gr.Slider(minimum=0, maximum=1, step=0.01, value=0.33, label="Foreground Depth Threshold"), gr.Slider(minimum=0, maximum=1, step=0.01, value=0.66, label="Middleground Depth Threshold") ], outputs=[ gr.Image(type="pil", label="Segmentation-Based Blur"), gr.Image(type="pil", label="Depth Map"), gr.Image(type="pil", label="Depth-Based Lens Blur") ], title=title, allow_flagging="never" ) if __name__ == "__main__": demo.launch()