File size: 15,405 Bytes
280f1e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef06d35
280f1e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# streamlit_app.py

import streamlit as st
import numpy as np
from PIL import Image
import cv2
import torch
from transformers import pipeline
import time
import os
from io import BytesIO # <-- IMPORT BytesIO

# --- Page Config (MUST BE FIRST st command) ---
# Set page config early
st.set_page_config(
    page_title="Depth Blur Studio",
    page_icon="๐Ÿ“ธ",
    layout="wide"
)

# --- Import Custom Class ---
# Assuming PortraitBlurrer.py is in a subfolder 'Portrait' relative to this script
try:
    # If PortraitBlurrer is in ./Portrait/Portrait.py
    from Portrait.Portrait import PortraitBlurrer
except ImportError:
    # Fallback if PortraitBlurrer is in ./PortraitBlurrer.py
    try:
        from PortraitBlurrer import PortraitBlurrer # type: ignore
        # st.warning("Assuming PortraitBlurrer class is in the root directory.") # Optional warning
    except ImportError:
        st.error("Fatal Error: Could not find the PortraitBlurrer class. Please check the file structure and import path.")
        st.stop() # Stop execution if class can't be found


# --- Model Loading (Cached) ---
@st.cache_resource # Use cache_resource for non-data objects like models/pipelines
def load_depth_pipeline():
    """Loads the depth estimation pipeline and caches it. Returns tuple (pipeline, device_id)."""
    t_device = 0 if torch.cuda.is_available() else -1
    print(f"Attempting to load model on device: {'GPU (CUDA)' if t_device == 0 else 'CPU'}")
    try:
        # Use default precision (float32)
        t_pipe = pipeline(task="depth-estimation",
                          model="depth-anything/Depth-Anything-V2-Large-hf",
                          device=t_device)
        print("Depth Anything V2 Large model loaded successfully.")
        return t_pipe, t_device # Return pipeline and device used
    except Exception as e:
        print(f"Error loading model: {e}")
        # Error will be displayed in the main app body after this function returns None
        return None, t_device # Return None for pipe on error

# Load the model via the cached function
pipe, device_used = load_depth_pipeline()

# --- Title and Model Status ---
# Display title and info AFTER attempting model load
st.title("Depth Blur Studio ๐Ÿ“ธ (Streamlit)")
st.markdown(
    "Upload a portrait image. The model will estimate depth and blur the background, keeping the subject sharp."
    "\n*Model: `depth-anything/Depth-Anything-V2-Large-hf`*"
)
st.caption(f"_(Using device: {'GPU (CUDA)' if device_used == 0 else 'CPU'})_") # Display device info

# Handle model loading failure AFTER potential UI elements like title
if pipe is None:
    st.error("Error loading depth estimation model. Application cannot proceed.")
    st.stop() # Stop if model loading failed


# --- Processing Function ---
def process_image_blur(pipeline_obj, input_image_pil, max_blur_ksize, depth_thresh, feather_ksize, sharpen_val):
    """
    Processes the image using the pipeline and PortraitBlurrer.
    Returns tuple: (blurred_pil, depth_pil, mask_pil) or (None, None, None) on failure.
    """
    print("Processing image...")
    processing_start_time = time.time()

    # 1. Convert PIL Image (RGB) to NumPy array (BGR for OpenCV)
    input_image_np_rgb = np.array(input_image_pil)
    original_bgr_np = cv2.cvtColor(input_image_np_rgb, cv2.COLOR_RGB2BGR)

    # 2. Perform depth estimation
    try:
        with torch.no_grad(): # Inference only
             depth_output = pipeline_obj(input_image_pil)
             # Ensure depth map is PIL Image
             if isinstance(depth_output, dict) and "depth" in depth_output:
                 depth_image_pil = depth_output["depth"]
                 if not isinstance(depth_image_pil, Image.Image):
                     # Attempt conversion if it's tensor/numpy (specifics might depend on pipeline output)
                     # This is a basic attempt; might need refinement based on actual output type
                     try:
                        depth_data = np.array(depth_image_pil)
                        # Normalize if needed (example: scale to 0-255)
                        depth_data = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
                        depth_image_pil = Image.fromarray(depth_data)
                     except Exception as conversion_e:
                         print(f"Could not convert depth output to PIL Image: {conversion_e}")
                         raise ValueError("Depth estimation did not return a usable PIL Image.")
             else:
                 # Handle cases where output might be directly the image or unexpected format
                 if isinstance(depth_output, Image.Image):
                     depth_image_pil = depth_output
                 else:
                      raise ValueError(f"Unexpected depth estimation output format: {type(depth_output)}")

        print("Depth map generated.")
    except Exception as e:
        print(f"Error during depth estimation: {e}")
        st.error(f"Depth estimation failed: {e}") # Show error in UI
        return None, None, None

    # 3. Initialize Blurrer and Process
    portrait_blurrer = PortraitBlurrer(
        max_blur=int(max_blur_ksize),
        depth_threshold=int(depth_thresh),
        feather_strength=int(feather_ksize),
        sharpen_strength=float(sharpen_val) # Use the passed sharpen value
    )

    try:
        # process_image returns blurred_bgr, depth_gray, mask_gray
        blurred_bgr_np, refined_depth_np, mask_np = portrait_blurrer.process_image(
            original_bgr_np, depth_image_pil
        )
    except Exception as e:
         print(f"Error during blurring/sharpening: {e}")
         st.error(f"Image processing (blur/sharpen) failed: {e}") # Show error in UI
         return None, None, None

    # 4. Convert results back to RGB PIL Images for Streamlit display
    blurred_pil = Image.fromarray(cv2.cvtColor(blurred_bgr_np, cv2.COLOR_BGR2RGB))
    # Depth and mask are grayscale numpy, convert directly to PIL
    depth_pil = Image.fromarray(refined_depth_np)
    mask_pil = Image.fromarray(mask_np)

    processing_end_time = time.time()
    processing_duration = processing_end_time - processing_start_time
    print(f"Processing finished in {processing_duration:.2f} seconds.")
    # Move success message display outside this function, near where results are shown
    # st.success(f"Processing finished in {processing_duration:.2f} seconds.")

    return blurred_pil, depth_pil, mask_pil, processing_duration # Return duration


# --- Initialize Session State --- (Do this early)
if 'results' not in st.session_state:
    st.session_state.results = None # Will store tuple (blurred, depth, mask) or None
if 'original_image_pil' not in st.session_state:
    st.session_state.original_image_pil = None
if 'processing_error_occurred' not in st.session_state:
    st.session_state.processing_error_occurred = False
if 'current_filename' not in st.session_state:
    st.session_state.current_filename = None
if 'last_process_duration' not in st.session_state:
    st.session_state.last_process_duration = None


# --- Sidebar for Controls ---
with st.sidebar: # Use 'with' notation for clarity
    st.title("Controls")
    uploaded_file = st.file_uploader(
        "Upload Portrait Image",
        type=["jpg", "png", "jpeg"],
        label_visibility="collapsed"
        )

    # --- Handle New Upload for Instant Display ---
    if uploaded_file is not None:
        # Check if it's a new file by comparing names
        if uploaded_file.name != st.session_state.get('current_filename', None):
            print(f"New file uploaded: {uploaded_file.name}. Loading for display.")
            try:
                # Load the new image immediately
                st.session_state.original_image_pil = Image.open(uploaded_file).convert("RGB")
                # Clear previous results, error state and duration
                st.session_state.results = None
                st.session_state.processing_error_occurred = False
                st.session_state.last_process_duration = None
                # Update the tracked filename
                st.session_state.current_filename = uploaded_file.name
            except Exception as e:
                st.error(f"Error loading image: {e}")
                # Clear states if loading failed
                st.session_state.original_image_pil = None
                st.session_state.results = None
                st.session_state.processing_error_occurred = False
                st.session_state.current_filename = None
                st.session_state.last_process_duration = None

    elif st.session_state.current_filename is not None:
        # If file uploader is cleared by the user (uploaded_file becomes None)
        print("File upload cleared.")
        st.session_state.original_image_pil = None
        st.session_state.results = None
        st.session_state.processing_error_occurred = False
        st.session_state.current_filename = None
        st.session_state.last_process_duration = None
    # --- End Handle New Upload ---


    st.markdown("---") # Separator
    st.markdown("**Adjust Parameters:**")
    slider_max_blur = st.slider("Blur Intensity (Kernel Size)", min_value=3, max_value=101, step=2, value=31)
    slider_depth_thr = st.slider("Subject Depth Threshold (Lower=Far Away)", min_value=1, max_value=254, step=1, value=120)
    slider_feather = st.slider("Feathering (Mask Smoothness)", min_value=1, max_value=51, step=2, value=5) # <-- Default changed to 5
    # REMOVED: slider_sharpen = st.slider("Subject Sharpening Strength", min_value=0.0, max_value=2.5, step=0.1, value=1.0)
    st.markdown("---") # Separator

    # Button to trigger processing - disable if no file *loaded* in session state
    process_button = st.button(
        "Apply Blur",
        type="primary",
        disabled=(st.session_state.original_image_pil is None) # Disable if no original image is loaded
    )


# --- Main Area for Images ---
col1, col2 = st.columns(2) # Create two columns for Original | Result

# --- Handle Processing Trigger ---
if process_button: # Button is only enabled if original_image_pil exists
    if st.session_state.original_image_pil is not None:
        # Reset error flag on new processing attempt
        st.session_state.processing_error_occurred = False
        # Clear previous results and duration before showing spinner
        st.session_state.results = None
        st.session_state.last_process_duration = None

        with col2: # Show spinner in the results column
            with st.spinner('Applying blur... This may take a moment...'):
                results_output = process_image_blur(
                    pipeline_obj=pipe,
                    input_image_pil=st.session_state.original_image_pil, # Use the image from session state
                    max_blur_ksize=slider_max_blur,
                    depth_thresh=slider_depth_thr,
                    feather_ksize=slider_feather,
                    sharpen_val=1.0 # <-- Hardcoded sharpen value
                )

                # Check if processing returned successfully (4 values expected now)
                if results_output is not None and len(results_output) == 4:
                    # Unpack results and store duration separately
                    blurred_pil, depth_pil, mask_pil, duration = results_output
                    st.session_state.results = (blurred_pil, depth_pil, mask_pil) # Store tuple
                    st.session_state.last_process_duration = duration
                else:
                    # Processing failed (returned None or wrong number of items)
                    st.session_state.results = None # Ensure results are None
                    st.session_state.processing_error_occurred = True
                    st.session_state.last_process_duration = None

    else:
         # This case should technically not happen due to button disable logic, but good practice
         st.error("No image loaded to process.")


# --- Display Images based on Session State ---

# Display Original Image in Column 1 if available
if st.session_state.original_image_pil is not None:
    col1.image(st.session_state.original_image_pil, caption="Original Image", use_container_width=True)
else:
    col1.markdown("### Upload an image")
    col1.markdown("Use the sidebar controls to upload your portrait.")

# Display Results/Status in Column 2
if st.session_state.results is not None:
    # Check if the first element (blurred_img) is not None, indicating successful processing within the function
    blurred_img, depth_img, mask_img = st.session_state.results
    if blurred_img is not None:
        # Display success message with duration
        if st.session_state.last_process_duration is not None:
             st.success(f"Processing finished in {st.session_state.last_process_duration:.2f} seconds.")

        col2.image(blurred_img, caption="Blurred Background Result", use_container_width=True)

        # --- ADD DOWNLOAD BUTTON ---
        # 1. Convert PIL Image to Bytes
        buf = BytesIO()
        blurred_img.save(buf, format="PNG") # Save image to buffer in PNG format
        byte_im = buf.getvalue() # Get bytes from buffer

        # 2. Add Download Button
        col2.download_button(
            label="Download Blurred Image",
            data=byte_im,
            file_name=f"blurred_{st.session_state.current_filename or 'result'}.png", # Suggest filename based on original
            mime="image/png"                 # Set the MIME type for PNG
        )
        # --- END DOWNLOAD BUTTON ---

        # Optionally display depth and mask below the main images or in expanders
        with st.expander("Show Details (Depth Map & Mask)"):
            # Use columns inside expander for better layout if needed
            exp_col1, exp_col2 = st.columns(2)
            exp_col1.image(depth_img, caption="Refined Depth Map", use_container_width=True)
            exp_col2.image(mask_img, caption="Subject Mask", use_container_width=True)
    else:
        # This case might occur if results tuple was somehow malformed, treat as error
        st.session_state.processing_error_occurred = True # Mark as error if blurred_img is None but results tuple exists
        col2.error("An unexpected issue occurred during processing. Please check logs or try again.")


# Handle explicit error state OR "Ready to Process" state OR default state
if st.session_state.processing_error_occurred:
     # Display specific error message if processing failed after button press
     # The error might already be shown by st.error inside process_image_blur,
     # but this provides a fallback message in col2.
     col2.warning("Image processing failed. Check messages above or terminal logs.")

elif st.session_state.original_image_pil is not None and st.session_state.results is None:
    # If file is uploaded/loaded but not processed yet (and no error occurred)
    col2.markdown("### Ready to Process")
    col2.markdown("Adjust parameters in the sidebar (if needed) and click **Apply Blur**.")

elif st.session_state.original_image_pil is None:
    # Default state when no file is uploaded/loaded and nothing processed
     col2.markdown("### Results")
     col2.markdown("The processed image and details will appear here after uploading an image and clicking 'Apply Blur'.")