EEE515-HW3 / app.py
Ash2505's picture
Update app.py
e7b1b3f verified
import cv2
import numpy as np
from PIL import Image, ImageFilter
import torch
import gradio as gr
from torchvision import transforms
from transformers import (
AutoModelForImageSegmentation,
DepthProImageProcessorFast,
DepthProForDepthEstimation,
)
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Load Segmentation Model (RMBG-2.0 by briaai)
# -----------------------------
seg_model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-2.0", trust_remote_code=True
)
torch.set_float32_matmul_precision(["high", "highest"][0])
seg_model.to(device)
seg_model.eval()
# Define segmentation image size and transform
seg_image_size = (1024, 1024)
seg_transform = transforms.Compose([
transforms.Resize(seg_image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# -----------------------------
# Load Depth Estimation Model (DepthPro by Apple)
# -----------------------------
depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf")
depth_model.to(device)
depth_model.eval()
# -----------------------------
# Define the Segmentation-Based Blur Effect
# -----------------------------
def segmentation_blur_effect(input_image: Image.Image):
imageResized = input_image.resize(seg_image_size)
input_tensor = seg_transform(imageResized).unsqueeze(0).to(device)
with torch.no_grad():
preds = seg_model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(imageResized.size)
mask_np = np.array(mask.convert("L"))
_, maskBinary = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
img = cv2.cvtColor(np.array(imageResized), cv2.COLOR_RGB2BGR)
blurredBg = cv2.GaussianBlur(np.array(imageResized), (0, 0), sigmaX=15, sigmaY=15)
maskInv = cv2.bitwise_not(maskBinary)
maskInv3 = cv2.cvtColor(maskInv, cv2.COLOR_GRAY2BGR)
foreground = cv2.bitwise_and(img, cv2.bitwise_not(maskInv3))
background = cv2.bitwise_and(blurredBg, maskInv3)
finalImg = cv2.add(cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), background)
finalImg_pil = Image.fromarray(finalImg)
return finalImg_pil, mask
def lens_blur_effect(input_image: Image.Image, fg_threshold: float = 85, mg_threshold: float = 170):
inputs = depth_processor(images=input_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
post_processed_output = depth_processor.post_process_depth_estimation(
outputs, target_sizes=[(input_image.height, input_image.width)]
)
depth = post_processed_output[0]["predicted_depth"]
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth = depth * 255.
depth = depth.detach().cpu().numpy()
depth_map = depth.astype(np.uint8)
depthImg = Image.fromarray(depth_map)
img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
img_foreground = img.copy() # No blur for foreground
img_middleground = cv2.GaussianBlur(img, (0, 0), sigmaX=7, sigmaY=7)
img_background = cv2.GaussianBlur(img, (0, 0), sigmaX=15, sigmaY=15)
print(depth_map)
depth_map = depth_map.astype(np.float32) / depth_map.max()
threshold1 = fg_threshold
threshold2 = mg_threshold
mask_fg = (depth_map < threshold1).astype(np.float32)
mask_mg = ((depth_map >= threshold1) & (depth_map < threshold2)).astype(np.float32)
mask_bg = (depth_map >= threshold2).astype(np.float32)
mask_fg_3 = np.stack([mask_fg]*3, axis=-1)
mask_mg_3 = np.stack([mask_mg]*3, axis=-1)
mask_bg_3 = np.stack([mask_bg]*3, axis=-1)
final_img = (img_foreground * mask_fg_3 +
img_middleground * mask_mg_3 +
img_background * mask_bg_3).astype(np.uint8)
final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
lensBlurImage = Image.fromarray(final_img_rgb)
mask_fg_img = Image.fromarray((mask_fg * 255).astype(np.uint8))
mask_mg_img = Image.fromarray((mask_mg * 255).astype(np.uint8))
mask_bg_img = Image.fromarray((mask_bg * 255).astype(np.uint8))
return depthImg, lensBlurImage, mask_fg_img, mask_mg_img, mask_bg_img
def process_image(input_image: Image.Image, fg_threshold: float, mg_threshold: float):
seg_blur, seg_mask = segmentation_blur_effect(input_image)
depth_map_img, lens_blur_img, mask_fg_img, mask_mg_img, mask_bg_img = lens_blur_effect(
input_image, fg_threshold, mg_threshold
)
return (
seg_blur,
# seg_mask,
depth_map_img,
lens_blur_img,
# mask_fg_img,
# mask_mg_img,
# mask_bg_img
)
def update_preset(preset: str):
presets = {
"Preset 1": {
"image_url": "https://i.ibb.co/fznz2b2b/hw3-q2.jpg",
"fg_threshold": 0.33,
"mg_threshold": 0.66
},
"Preset 2": {
"image_url": "https://i.ibb.co/HLZGW7qH/q26.jpg",
"fg_threshold": 0.2,
"mg_threshold": 0.66
}
}
preset_info = presets[preset]
response = requests.get(preset_info["image_url"])
image = Image.open(BytesIO(response.content)).convert("RGB")
return image, preset_info["fg_threshold"], preset_info["mg_threshold"]
title = "Blur Effects on Segmentation-Based Gaussian Blur & Depth-Based Lens Blur with Adjustable Depth Thresholds"
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image", value="https://i.ibb.co/fznz2b2b/hw3-q2.jpg"),
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.33, label="Foreground Depth Threshold"),
gr.Slider(minimum=0, maximum=1, step=0.01, value=0.66, label="Middleground Depth Threshold")
],
outputs=[
gr.Image(type="pil", label="Segmentation-Based Blur"),
gr.Image(type="pil", label="Depth Map"),
gr.Image(type="pil", label="Depth-Based Lens Blur")
],
title=title,
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()