joeWabbit commited on
Commit
dfbbeb5
·
verified ·
1 Parent(s): f58540c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -79
app.py CHANGED
@@ -5,54 +5,32 @@ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
5
  from PIL import Image, ImageFilter
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm as cm
8
-
9
- # ---------------------------
10
- # Depth Estimation Utilities
11
- # ---------------------------
12
  def compute_depth_map(image: Image.Image, scale_factor: float) -> np.ndarray:
13
- """
14
- Loads the LiheYoung/depth-anything-large-hf model and computes a depth map.
15
- The depth map is normalized, inverted (so that near=0 and far=1),
16
- and multiplied by the given scale_factor.
17
- """
18
- # Load model and processor from pretrained weights
19
  image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-large-hf")
20
  model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-large-hf")
21
-
22
- # Prepare image for the model
23
  inputs = image_processor(images=image, return_tensors="pt")
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
  predicted_depth = outputs.predicted_depth
27
 
28
- # Interpolate predicted depth map to match image size
29
  prediction = torch.nn.functional.interpolate(
30
  predicted_depth.unsqueeze(1),
31
- size=image.size[::-1], # PIL image size is (width, height)
32
  mode="bicubic",
33
  align_corners=False,
34
  )
35
- # Normalize for visualization
36
  depth_min = prediction.min()
37
  depth_max = prediction.max()
38
  depth_vis = (prediction - depth_min) / (depth_max - depth_min + 1e-8)
39
  depth_map = depth_vis.squeeze().cpu().numpy()
40
- # Invert so that near=0 and far=1, then scale
41
  depth_map_inverted = 1.0 - depth_map
42
  depth_map_inverted *= scale_factor
43
  return depth_map_inverted
44
 
45
- # ---------------------------
46
- # Depth-Based Blur Functions
47
- # ---------------------------
48
  def layered_blur(image: Image.Image, depth_map: np.ndarray, num_layers: int, max_blur: float) -> Image.Image:
49
- """
50
- Creates multiple blurred versions of the image (using Gaussian blur with radii from 0 to max_blur)
51
- and composites them using masks generated from bins of the normalized depth map.
52
- """
53
  blur_radii = np.linspace(0, max_blur, num_layers)
54
  blur_versions = [image.filter(ImageFilter.GaussianBlur(radius)) for radius in blur_radii]
55
- # Use a fixed range (0 to 1) since the depth map is normalized
56
  thresholds = np.linspace(0, 1, num_layers + 1)
57
  final_image = blur_versions[-1]
58
  for i in range(num_layers - 1, -1, -1):
@@ -65,51 +43,27 @@ def layered_blur(image: Image.Image, depth_map: np.ndarray, num_layers: int, max
65
  return final_image
66
 
67
  def process_depth_blur(uploaded_image, max_blur_value, scale_factor, num_layers):
68
- """
69
- Resizes the uploaded image to 512x512, computes its depth map,
70
- and applies layered blur based on the depth.
71
- """
72
  if not isinstance(uploaded_image, Image.Image):
73
  uploaded_image = Image.open(uploaded_image)
74
  image = uploaded_image.convert("RGB").resize((512, 512))
75
  depth_map = compute_depth_map(image, scale_factor)
76
  final_image = layered_blur(image, depth_map, int(num_layers), max_blur_value)
77
  return final_image
78
-
79
- # ---------------------------
80
- # Depth Heatmap Functions
81
- # ---------------------------
82
  def create_heatmap(depth_map: np.ndarray, intensity: float) -> Image.Image:
83
- """
84
- Applies a colormap to the normalized depth map.
85
- The 'intensity' slider multiplies the normalized depth values (clipped to [0,1])
86
- before applying the "inferno" colormap.
87
- """
88
- # Multiply depth map by intensity and clip to 0-1 range
89
  normalized = np.clip(depth_map * intensity, 0, 1)
90
  colormap = cm.get_cmap("inferno")
91
- colored = colormap(normalized) # Returns an RGBA image in [0, 1]
92
- heatmap = (colored[:, :, :3] * 255).astype(np.uint8) # drop alpha and convert to [0,255]
93
  return Image.fromarray(heatmap)
94
 
95
  def process_depth_heatmap(uploaded_image, intensity):
96
- """
97
- Resizes the uploaded image to 512x512, computes its depth map (with scale factor 1.0),
98
- and returns a heatmap visualization.
99
- """
100
  if not isinstance(uploaded_image, Image.Image):
101
  uploaded_image = Image.open(uploaded_image)
102
  image = uploaded_image.convert("RGB").resize((512, 512))
103
  depth_map = compute_depth_map(image, scale_factor=1.0)
104
  heatmap_img = create_heatmap(depth_map, intensity)
105
  return heatmap_img
106
-
107
- # --- Segmentation-Based Blur using BEN2 ---
108
  def load_segmentation_model():
109
- """
110
- Loads and caches the segmentation model from BEN2.
111
- Ensure you have ben2 installed and accessible in your path.
112
- """
113
  global seg_model, seg_device
114
  if "seg_model" not in globals():
115
  from ben2 import BEN_Base # Import BEN2
@@ -119,18 +73,11 @@ def load_segmentation_model():
119
  return seg_model, seg_device
120
 
121
  def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
122
- """
123
- Processes the image with segmentation-based blur.
124
- The image is resized to 512x512. A Gaussian blur with the specified radius is applied,
125
- then the segmentation mask is computed to composite the sharp foreground over the blurred background.
126
- """
127
  if not isinstance(uploaded_image, Image.Image):
128
  uploaded_image = Image.open(uploaded_image)
129
  image = uploaded_image.convert("RGB").resize((512, 512))
130
  seg_model, seg_device = load_segmentation_model()
131
  blurred_image = image.filter(ImageFilter.GaussianBlur(seg_blur_radius))
132
-
133
- # Generate segmentation mask (foreground)
134
  foreground = seg_model.inference(image, refine_foreground=False)
135
  foreground_rgba = foreground.convert("RGBA")
136
  _, _, _, alpha = foreground_rgba.split()
@@ -138,39 +85,29 @@ def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
138
  final_image = Image.composite(image, blurred_image, binary_mask)
139
  return final_image
140
 
141
- # --- Merged Gradio Interface ---
142
  with gr.Blocks() as demo:
143
- gr.Markdown("# Depth-Based vs Segmentation-Based Blur")
144
  with gr.Tabs():
145
- with gr.Tab("Depth Blur"):
 
 
 
 
 
 
146
  img_input = gr.Image(type="pil", label="Upload Image")
147
  blur_slider = gr.Slider(1, 50, value=6, label="Maximum Blur Radius")
148
  scale_slider = gr.Slider(0.1, 2.0, value=0.72, label="Depth Scale Factor")
149
  layers_slider = gr.Slider(2, 10, value=2.91, label="Number of Layers")
150
- blur_output = gr.Image(label="Depth Blur Result")
151
- blur_button = gr.Button("Process Depth Blur")
152
  blur_button.click(
153
  process_depth_blur,
154
  inputs=[img_input, blur_slider, scale_slider, layers_slider],
155
  outputs=blur_output
156
  )
157
- with gr.Tab("Depth Heatmap"):
158
- img_input2 = gr.Image(type="pil", label="Upload Image")
159
- intensity_slider = gr.Slider(0.5, 5.0, value=1.0, label="Heatmap Intensity")
160
- heatmap_output = gr.Image(label="Depth Heatmap")
161
- heatmap_button = gr.Button("Generate Depth Heatmap")
162
- heatmap_button.click(
163
- process_depth_heatmap,
164
- inputs=[img_input2, intensity_slider],
165
- outputs=heatmap_output
166
- )
167
- with gr.Tab("Segmentation-Based Blur (BEN2)"):
168
- seg_img = gr.Image(type="pil", label="Upload Image")
169
- seg_blur = gr.Slider(5, 30, value=15, step=1, label="Segmentation Blur Radius")
170
- seg_out = gr.Image(label="Segmentation-Based Blurred Image")
171
- seg_button = gr.Button("Process Segmentation Blur")
172
- seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out)
173
 
174
  if __name__ == "__main__":
175
- # Optionally, set share=True to generate a public link.
176
  demo.launch(share=True)
 
5
  from PIL import Image, ImageFilter
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm as cm
 
 
 
 
8
  def compute_depth_map(image: Image.Image, scale_factor: float) -> np.ndarray:
 
 
 
 
 
 
9
  image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-large-hf")
10
  model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-large-hf")
 
 
11
  inputs = image_processor(images=image, return_tensors="pt")
12
  with torch.no_grad():
13
  outputs = model(**inputs)
14
  predicted_depth = outputs.predicted_depth
15
 
 
16
  prediction = torch.nn.functional.interpolate(
17
  predicted_depth.unsqueeze(1),
18
+ size=image.size[::-1],
19
  mode="bicubic",
20
  align_corners=False,
21
  )
22
+
23
  depth_min = prediction.min()
24
  depth_max = prediction.max()
25
  depth_vis = (prediction - depth_min) / (depth_max - depth_min + 1e-8)
26
  depth_map = depth_vis.squeeze().cpu().numpy()
 
27
  depth_map_inverted = 1.0 - depth_map
28
  depth_map_inverted *= scale_factor
29
  return depth_map_inverted
30
 
 
 
 
31
  def layered_blur(image: Image.Image, depth_map: np.ndarray, num_layers: int, max_blur: float) -> Image.Image:
 
 
 
 
32
  blur_radii = np.linspace(0, max_blur, num_layers)
33
  blur_versions = [image.filter(ImageFilter.GaussianBlur(radius)) for radius in blur_radii]
 
34
  thresholds = np.linspace(0, 1, num_layers + 1)
35
  final_image = blur_versions[-1]
36
  for i in range(num_layers - 1, -1, -1):
 
43
  return final_image
44
 
45
  def process_depth_blur(uploaded_image, max_blur_value, scale_factor, num_layers):
 
 
 
 
46
  if not isinstance(uploaded_image, Image.Image):
47
  uploaded_image = Image.open(uploaded_image)
48
  image = uploaded_image.convert("RGB").resize((512, 512))
49
  depth_map = compute_depth_map(image, scale_factor)
50
  final_image = layered_blur(image, depth_map, int(num_layers), max_blur_value)
51
  return final_image
 
 
 
 
52
  def create_heatmap(depth_map: np.ndarray, intensity: float) -> Image.Image:
 
 
 
 
 
 
53
  normalized = np.clip(depth_map * intensity, 0, 1)
54
  colormap = cm.get_cmap("inferno")
55
+ colored = colormap(normalized)
56
+ heatmap = (colored[:, :, :3] * 255).astype(np.uint8)
57
  return Image.fromarray(heatmap)
58
 
59
  def process_depth_heatmap(uploaded_image, intensity):
 
 
 
 
60
  if not isinstance(uploaded_image, Image.Image):
61
  uploaded_image = Image.open(uploaded_image)
62
  image = uploaded_image.convert("RGB").resize((512, 512))
63
  depth_map = compute_depth_map(image, scale_factor=1.0)
64
  heatmap_img = create_heatmap(depth_map, intensity)
65
  return heatmap_img
 
 
66
  def load_segmentation_model():
 
 
 
 
67
  global seg_model, seg_device
68
  if "seg_model" not in globals():
69
  from ben2 import BEN_Base # Import BEN2
 
73
  return seg_model, seg_device
74
 
75
  def process_segmentation_blur(uploaded_image, seg_blur_radius: float):
 
 
 
 
 
76
  if not isinstance(uploaded_image, Image.Image):
77
  uploaded_image = Image.open(uploaded_image)
78
  image = uploaded_image.convert("RGB").resize((512, 512))
79
  seg_model, seg_device = load_segmentation_model()
80
  blurred_image = image.filter(ImageFilter.GaussianBlur(seg_blur_radius))
 
 
81
  foreground = seg_model.inference(image, refine_foreground=False)
82
  foreground_rgba = foreground.convert("RGBA")
83
  _, _, _, alpha = foreground_rgba.split()
 
85
  final_image = Image.composite(image, blurred_image, binary_mask)
86
  return final_image
87
 
88
+
89
  with gr.Blocks() as demo:
90
+ gr.Markdown("#Gaussian Blur & Lens Blur Effect")
91
  with gr.Tabs():
92
+ with gr.Tab("Gaussian Blur"):
93
+ seg_img = gr.Image(type="pil", label="Upload Image")
94
+ seg_blur = gr.Slider(5, 30, value=15, step=1, label="Segmentation Blur Radius")
95
+ seg_out = gr.Image(label="Gaussian Blurred Image")
96
+ seg_button = gr.Button("Process Gaussian Blur")
97
+ seg_button.click(process_segmentation_blur, inputs=[seg_img, seg_blur], outputs=seg_out)
98
+ with gr.Tab("Lens Blur"):
99
  img_input = gr.Image(type="pil", label="Upload Image")
100
  blur_slider = gr.Slider(1, 50, value=6, label="Maximum Blur Radius")
101
  scale_slider = gr.Slider(0.1, 2.0, value=0.72, label="Depth Scale Factor")
102
  layers_slider = gr.Slider(2, 10, value=2.91, label="Number of Layers")
103
+ blur_output = gr.Image(label="Lens Blur Result")
104
+ blur_button = gr.Button("Process Blur")
105
  blur_button.click(
106
  process_depth_blur,
107
  inputs=[img_input, blur_slider, scale_slider, layers_slider],
108
  outputs=blur_output
109
  )
110
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  if __name__ == "__main__":
 
113
  demo.launch(share=True)