File size: 3,629 Bytes
1dce535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import torch
from PIL import Image
import cv2
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import gradio as gr

# Initialize the SegFormer model for segmentation
segformer_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
segformer_model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

# Function to segment the person in the image
def segment_person(image_input):
    # Convert input image (numpy array in RGB) to PIL Image
    image = Image.fromarray(image_input).convert("RGB")
    original_width, original_height = image.size
    
    # Resize image to 512x512 for the model
    model_input = image.resize((512, 512), Image.Resampling.LANCZOS)
    
    # Prepare the image for SegFormer
    inputs = segformer_processor(images=model_input, return_tensors="pt")
    
    # Perform inference
    with torch.no_grad():
        outputs = segformer_model(**inputs)
        logits = outputs.logits
    
    # Upsample logits to 512x512
    upsampled_logits = torch.nn.functional.interpolate(
        logits, size=(512, 512), mode="bilinear", align_corners=False
    )
    
    # Get the predicted segmentation mask (person class = 12 in ADE20K dataset)
    person_class_id = 12
    predicted_mask = upsampled_logits.argmax(dim=1)[0]  # Shape: (512, 512)
    binary_mask = (predicted_mask == person_class_id).cpu().numpy()  # Boolean mask
    
    # Post-process the mask
    mask_uint8 = (binary_mask * 255).astype(np.uint8)
    kernel = np.ones((5, 5), np.uint8)
    mask_cleaned = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel, iterations=2)
    mask_cleaned = cv2.morphologyEx(mask_cleaned, cv2.MORPH_OPEN, kernel, iterations=2)
    mask_smoothed = cv2.GaussianBlur(mask_cleaned, (7, 7), 0)
    _, mask_final = cv2.threshold(mask_smoothed, 127, 255, cv2.THRESH_BINARY)
    
    # Resize mask back to original dimensions
    mask_pil = Image.fromarray(mask_final)
    mask_resized = mask_pil.resize((original_width, original_height), Image.Resampling.LANCZOS)
    mask_array = np.array(mask_resized) > 0  # Boolean mask
    
    return mask_array

# Function to apply background blur
def blur_background(image_input, blur_strength):
    # Ensure image is in numpy array format (RGB)
    image_array = np.array(image_input)
    
    # Segment the person
    mask = segment_person(image_array)
    
    # Apply Gaussian blur to the entire image
    sigma = blur_strength
    blurred_image = cv2.GaussianBlur(image_array, (0, 0), sigmaX=sigma, sigmaY=sigma)
    
    # Composite the original foreground with the blurred background
    mask_3d = mask[:, :, np.newaxis]  # Add channel dimension for broadcasting
    result = np.where(mask_3d, image_array, blurred_image).astype(np.uint8)
    
    return result

# Gradio interface function
def gradio_interface(image, blur_strength):
    if image is None:
        raise ValueError("Please upload an image.")
    
    # Process the image
    output_image = blur_background(image, blur_strength)
    
    return output_image

# Create the Gradio app
app = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="numpy", label="Upload Image"),
        gr.Slider(minimum=1, maximum=25, value=10, step=1, label="Blur Strength (Sigma)")
    ],
    outputs=gr.Image(type="numpy", label="Output Image"),
    title="Person Segmentation and Background Blur",
    description="Upload an image to segment the person and blur the background. Adjust the blur strength using the slider."
)

# Launch the app
app.launch()