joeWabbit commited on
Commit
4200d56
·
verified ·
1 Parent(s): 0055ad6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -15
app.py CHANGED
@@ -1,12 +1,44 @@
 
 
1
  import gradio as gr
2
  import torch
3
- from PIL import Image, ImageFilter
 
4
 
5
- def load_segmentation_model():
6
  """
7
- Loads and caches the segmentation model from BEN2.
8
- Ensure you have ben2 installed and accessible in your path.
9
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  global seg_model, seg_device
11
  if "seg_model" not in globals():
12
  from ben2 import BEN_Base # Import BEN2
@@ -16,11 +48,7 @@ def load_segmentation_model():
16
  return seg_model, seg_device
17
 
18
  def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
19
- """
20
- Processes the image with segmentation-based blur.
21
- The image is resized to 512x512. A Gaussian blur with the specified radius is applied,
22
- then the segmentation mask is computed to composite the sharp foreground over the blurred background.
23
- """
24
  if not isinstance(uploaded_image, Image.Image):
25
  uploaded_image = Image.open(uploaded_image)
26
  image = uploaded_image.convert("RGB").resize((512, 512))
@@ -35,13 +63,26 @@ def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
35
  final_image = Image.composite(image, blurred_image, binary_mask)
36
  return final_image
37
 
 
38
  with gr.Blocks() as demo:
39
- gr.Markdown("# Gaussian Blur using Image Segmentation BEN2 Model.")
40
- seg_img = gr.Image(type="pil", label="Upload Image")
41
- seg_blur = gr.Slider(5, 30, value=15, step=1, label="Gaussian Blur Radius")
42
- seg_out = gr.Image(label="Gaussian-Based Blurred Image")
43
- seg_button = gr.Button("Process Gaussian Blur")
44
- seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out)
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  if __name__ == "__main__":
47
  demo.launch()
 
1
+ from transformers import pipeline
2
+ from PIL import Image, ImageFilter
3
  import gradio as gr
4
  import torch
5
+ import numpy as np
6
+ depth_pipe = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
7
 
8
+ def compute_depth_map_pipeline(image: Image.Image, scale_factor: float) -> np.ndarray:
9
  """
10
+ Computes a depth map using the HF pipeline.
11
+ The returned depth is inverted (so near=0 and far=1) and scaled.
12
  """
13
+ result = depth_pipe(image)[0]
14
+ depth_map = np.array(result["depth"])
15
+ depth_map = 1.0 - depth_map
16
+ depth_map *= scale_factor
17
+ return depth_map
18
+
19
+ def layered_blur(image: Image.Image, depth_map: np.ndarray, num_layers: int, max_blur: float) -> Image.Image:
20
+ blur_radii = np.linspace(0, max_blur, num_layers)
21
+ blur_versions = [image.filter(ImageFilter.GaussianBlur(r)) for r in blur_radii]
22
+ upper_bound = depth_map.max()
23
+ thresholds = np.linspace(0, upper_bound, num_layers + 1)
24
+ final_image = blur_versions[-1]
25
+ for i in range(num_layers - 1, -1, -1):
26
+ mask_array = np.logical_and(depth_map >= thresholds[i],
27
+ depth_map < thresholds[i + 1]).astype(np.uint8) * 255
28
+ mask_image = Image.fromarray(mask_array, mode="L")
29
+ final_image = Image.composite(blur_versions[i], final_image, mask_image)
30
+ return final_image
31
+
32
+ def process_depth_blur_pipeline(uploaded_image, max_blur_value, scale_factor, num_layers):
33
+ if not isinstance(uploaded_image, Image.Image):
34
+ uploaded_image = Image.open(uploaded_image)
35
+ image = uploaded_image.convert("RGB").resize((512, 512))
36
+ depth_map = compute_depth_map_pipeline(image, scale_factor)
37
+ final_image = layered_blur(image, depth_map, int(num_layers), max_blur_value)
38
+ return final_image
39
+
40
+ # --- Segmentation-Based Blur using BEN2 ---
41
+ def load_segmentation_model():
42
  global seg_model, seg_device
43
  if "seg_model" not in globals():
44
  from ben2 import BEN_Base # Import BEN2
 
48
  return seg_model, seg_device
49
 
50
  def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
51
+
 
 
 
 
52
  if not isinstance(uploaded_image, Image.Image):
53
  uploaded_image = Image.open(uploaded_image)
54
  image = uploaded_image.convert("RGB").resize((512, 512))
 
63
  final_image = Image.composite(image, blurred_image, binary_mask)
64
  return final_image
65
 
66
+ # --- Merged Gradio Interface ---
67
  with gr.Blocks() as demo:
68
+ gr.Markdown("# Lens Blur & Gaussian Blur")
69
+ with gr.Tabs():
70
+ with gr.Tab("Lens Blur"):
71
+ depth_img = gr.Image(type="pil", label="Upload Image")
72
+ depth_max_blur = gr.Slider(1.0, 5.0, value=3.0, step=0.1, label="Maximum Blur Radius")
73
+ depth_scale = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Depth Scale Factor")
74
+ depth_layers = gr.Slider(2, 20, value=8, step=1, label="Number of Layers")
75
+ depth_out = gr.Image(label="Lens Blurred Image")
76
+ depth_button = gr.Button("Process Lens Blur")
77
+ depth_button.click(process_depth_blur_pipeline,
78
+ inputs=[depth_img, depth_max_blur, depth_scale, depth_layers],
79
+ outputs=depth_out)
80
+ with gr.Tab("Guassian Blur"):
81
+ seg_img = gr.Image(type="pil", label="Upload Image")
82
+ seg_blur = gr.Slider(5, 30, value=15, step=1, label="Segmentation Blur Radius")
83
+ seg_out = gr.Image(label="Gaussian Blurred Image")
84
+ seg_button = gr.Button("Gaussian Blur")
85
+ seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out)
86
 
87
  if __name__ == "__main__":
88
  demo.launch()