Spaces:
Sleeping
Sleeping
File size: 15,405 Bytes
280f1e8 ef06d35 280f1e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
# streamlit_app.py
import streamlit as st
import numpy as np
from PIL import Image
import cv2
import torch
from transformers import pipeline
import time
import os
from io import BytesIO # <-- IMPORT BytesIO
# --- Page Config (MUST BE FIRST st command) ---
# Set page config early
st.set_page_config(
page_title="Depth Blur Studio",
page_icon="๐ธ",
layout="wide"
)
# --- Import Custom Class ---
# Assuming PortraitBlurrer.py is in a subfolder 'Portrait' relative to this script
try:
# If PortraitBlurrer is in ./Portrait/Portrait.py
from Portrait.Portrait import PortraitBlurrer
except ImportError:
# Fallback if PortraitBlurrer is in ./PortraitBlurrer.py
try:
from PortraitBlurrer import PortraitBlurrer # type: ignore
# st.warning("Assuming PortraitBlurrer class is in the root directory.") # Optional warning
except ImportError:
st.error("Fatal Error: Could not find the PortraitBlurrer class. Please check the file structure and import path.")
st.stop() # Stop execution if class can't be found
# --- Model Loading (Cached) ---
@st.cache_resource # Use cache_resource for non-data objects like models/pipelines
def load_depth_pipeline():
"""Loads the depth estimation pipeline and caches it. Returns tuple (pipeline, device_id)."""
t_device = 0 if torch.cuda.is_available() else -1
print(f"Attempting to load model on device: {'GPU (CUDA)' if t_device == 0 else 'CPU'}")
try:
# Use default precision (float32)
t_pipe = pipeline(task="depth-estimation",
model="depth-anything/Depth-Anything-V2-Large-hf",
device=t_device)
print("Depth Anything V2 Large model loaded successfully.")
return t_pipe, t_device # Return pipeline and device used
except Exception as e:
print(f"Error loading model: {e}")
# Error will be displayed in the main app body after this function returns None
return None, t_device # Return None for pipe on error
# Load the model via the cached function
pipe, device_used = load_depth_pipeline()
# --- Title and Model Status ---
# Display title and info AFTER attempting model load
st.title("Depth Blur Studio ๐ธ (Streamlit)")
st.markdown(
"Upload a portrait image. The model will estimate depth and blur the background, keeping the subject sharp."
"\n*Model: `depth-anything/Depth-Anything-V2-Large-hf`*"
)
st.caption(f"_(Using device: {'GPU (CUDA)' if device_used == 0 else 'CPU'})_") # Display device info
# Handle model loading failure AFTER potential UI elements like title
if pipe is None:
st.error("Error loading depth estimation model. Application cannot proceed.")
st.stop() # Stop if model loading failed
# --- Processing Function ---
def process_image_blur(pipeline_obj, input_image_pil, max_blur_ksize, depth_thresh, feather_ksize, sharpen_val):
"""
Processes the image using the pipeline and PortraitBlurrer.
Returns tuple: (blurred_pil, depth_pil, mask_pil) or (None, None, None) on failure.
"""
print("Processing image...")
processing_start_time = time.time()
# 1. Convert PIL Image (RGB) to NumPy array (BGR for OpenCV)
input_image_np_rgb = np.array(input_image_pil)
original_bgr_np = cv2.cvtColor(input_image_np_rgb, cv2.COLOR_RGB2BGR)
# 2. Perform depth estimation
try:
with torch.no_grad(): # Inference only
depth_output = pipeline_obj(input_image_pil)
# Ensure depth map is PIL Image
if isinstance(depth_output, dict) and "depth" in depth_output:
depth_image_pil = depth_output["depth"]
if not isinstance(depth_image_pil, Image.Image):
# Attempt conversion if it's tensor/numpy (specifics might depend on pipeline output)
# This is a basic attempt; might need refinement based on actual output type
try:
depth_data = np.array(depth_image_pil)
# Normalize if needed (example: scale to 0-255)
depth_data = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
depth_image_pil = Image.fromarray(depth_data)
except Exception as conversion_e:
print(f"Could not convert depth output to PIL Image: {conversion_e}")
raise ValueError("Depth estimation did not return a usable PIL Image.")
else:
# Handle cases where output might be directly the image or unexpected format
if isinstance(depth_output, Image.Image):
depth_image_pil = depth_output
else:
raise ValueError(f"Unexpected depth estimation output format: {type(depth_output)}")
print("Depth map generated.")
except Exception as e:
print(f"Error during depth estimation: {e}")
st.error(f"Depth estimation failed: {e}") # Show error in UI
return None, None, None
# 3. Initialize Blurrer and Process
portrait_blurrer = PortraitBlurrer(
max_blur=int(max_blur_ksize),
depth_threshold=int(depth_thresh),
feather_strength=int(feather_ksize),
sharpen_strength=float(sharpen_val) # Use the passed sharpen value
)
try:
# process_image returns blurred_bgr, depth_gray, mask_gray
blurred_bgr_np, refined_depth_np, mask_np = portrait_blurrer.process_image(
original_bgr_np, depth_image_pil
)
except Exception as e:
print(f"Error during blurring/sharpening: {e}")
st.error(f"Image processing (blur/sharpen) failed: {e}") # Show error in UI
return None, None, None
# 4. Convert results back to RGB PIL Images for Streamlit display
blurred_pil = Image.fromarray(cv2.cvtColor(blurred_bgr_np, cv2.COLOR_BGR2RGB))
# Depth and mask are grayscale numpy, convert directly to PIL
depth_pil = Image.fromarray(refined_depth_np)
mask_pil = Image.fromarray(mask_np)
processing_end_time = time.time()
processing_duration = processing_end_time - processing_start_time
print(f"Processing finished in {processing_duration:.2f} seconds.")
# Move success message display outside this function, near where results are shown
# st.success(f"Processing finished in {processing_duration:.2f} seconds.")
return blurred_pil, depth_pil, mask_pil, processing_duration # Return duration
# --- Initialize Session State --- (Do this early)
if 'results' not in st.session_state:
st.session_state.results = None # Will store tuple (blurred, depth, mask) or None
if 'original_image_pil' not in st.session_state:
st.session_state.original_image_pil = None
if 'processing_error_occurred' not in st.session_state:
st.session_state.processing_error_occurred = False
if 'current_filename' not in st.session_state:
st.session_state.current_filename = None
if 'last_process_duration' not in st.session_state:
st.session_state.last_process_duration = None
# --- Sidebar for Controls ---
with st.sidebar: # Use 'with' notation for clarity
st.title("Controls")
uploaded_file = st.file_uploader(
"Upload Portrait Image",
type=["jpg", "png", "jpeg"],
label_visibility="collapsed"
)
# --- Handle New Upload for Instant Display ---
if uploaded_file is not None:
# Check if it's a new file by comparing names
if uploaded_file.name != st.session_state.get('current_filename', None):
print(f"New file uploaded: {uploaded_file.name}. Loading for display.")
try:
# Load the new image immediately
st.session_state.original_image_pil = Image.open(uploaded_file).convert("RGB")
# Clear previous results, error state and duration
st.session_state.results = None
st.session_state.processing_error_occurred = False
st.session_state.last_process_duration = None
# Update the tracked filename
st.session_state.current_filename = uploaded_file.name
except Exception as e:
st.error(f"Error loading image: {e}")
# Clear states if loading failed
st.session_state.original_image_pil = None
st.session_state.results = None
st.session_state.processing_error_occurred = False
st.session_state.current_filename = None
st.session_state.last_process_duration = None
elif st.session_state.current_filename is not None:
# If file uploader is cleared by the user (uploaded_file becomes None)
print("File upload cleared.")
st.session_state.original_image_pil = None
st.session_state.results = None
st.session_state.processing_error_occurred = False
st.session_state.current_filename = None
st.session_state.last_process_duration = None
# --- End Handle New Upload ---
st.markdown("---") # Separator
st.markdown("**Adjust Parameters:**")
slider_max_blur = st.slider("Blur Intensity (Kernel Size)", min_value=3, max_value=101, step=2, value=31)
slider_depth_thr = st.slider("Subject Depth Threshold (Lower=Far Away)", min_value=1, max_value=254, step=1, value=120)
slider_feather = st.slider("Feathering (Mask Smoothness)", min_value=1, max_value=51, step=2, value=5) # <-- Default changed to 5
# REMOVED: slider_sharpen = st.slider("Subject Sharpening Strength", min_value=0.0, max_value=2.5, step=0.1, value=1.0)
st.markdown("---") # Separator
# Button to trigger processing - disable if no file *loaded* in session state
process_button = st.button(
"Apply Blur",
type="primary",
disabled=(st.session_state.original_image_pil is None) # Disable if no original image is loaded
)
# --- Main Area for Images ---
col1, col2 = st.columns(2) # Create two columns for Original | Result
# --- Handle Processing Trigger ---
if process_button: # Button is only enabled if original_image_pil exists
if st.session_state.original_image_pil is not None:
# Reset error flag on new processing attempt
st.session_state.processing_error_occurred = False
# Clear previous results and duration before showing spinner
st.session_state.results = None
st.session_state.last_process_duration = None
with col2: # Show spinner in the results column
with st.spinner('Applying blur... This may take a moment...'):
results_output = process_image_blur(
pipeline_obj=pipe,
input_image_pil=st.session_state.original_image_pil, # Use the image from session state
max_blur_ksize=slider_max_blur,
depth_thresh=slider_depth_thr,
feather_ksize=slider_feather,
sharpen_val=1.0 # <-- Hardcoded sharpen value
)
# Check if processing returned successfully (4 values expected now)
if results_output is not None and len(results_output) == 4:
# Unpack results and store duration separately
blurred_pil, depth_pil, mask_pil, duration = results_output
st.session_state.results = (blurred_pil, depth_pil, mask_pil) # Store tuple
st.session_state.last_process_duration = duration
else:
# Processing failed (returned None or wrong number of items)
st.session_state.results = None # Ensure results are None
st.session_state.processing_error_occurred = True
st.session_state.last_process_duration = None
else:
# This case should technically not happen due to button disable logic, but good practice
st.error("No image loaded to process.")
# --- Display Images based on Session State ---
# Display Original Image in Column 1 if available
if st.session_state.original_image_pil is not None:
col1.image(st.session_state.original_image_pil, caption="Original Image", use_container_width=True)
else:
col1.markdown("### Upload an image")
col1.markdown("Use the sidebar controls to upload your portrait.")
# Display Results/Status in Column 2
if st.session_state.results is not None:
# Check if the first element (blurred_img) is not None, indicating successful processing within the function
blurred_img, depth_img, mask_img = st.session_state.results
if blurred_img is not None:
# Display success message with duration
if st.session_state.last_process_duration is not None:
st.success(f"Processing finished in {st.session_state.last_process_duration:.2f} seconds.")
col2.image(blurred_img, caption="Blurred Background Result", use_container_width=True)
# --- ADD DOWNLOAD BUTTON ---
# 1. Convert PIL Image to Bytes
buf = BytesIO()
blurred_img.save(buf, format="PNG") # Save image to buffer in PNG format
byte_im = buf.getvalue() # Get bytes from buffer
# 2. Add Download Button
col2.download_button(
label="Download Blurred Image",
data=byte_im,
file_name=f"blurred_{st.session_state.current_filename or 'result'}.png", # Suggest filename based on original
mime="image/png" # Set the MIME type for PNG
)
# --- END DOWNLOAD BUTTON ---
# Optionally display depth and mask below the main images or in expanders
with st.expander("Show Details (Depth Map & Mask)"):
# Use columns inside expander for better layout if needed
exp_col1, exp_col2 = st.columns(2)
exp_col1.image(depth_img, caption="Refined Depth Map", use_container_width=True)
exp_col2.image(mask_img, caption="Subject Mask", use_container_width=True)
else:
# This case might occur if results tuple was somehow malformed, treat as error
st.session_state.processing_error_occurred = True # Mark as error if blurred_img is None but results tuple exists
col2.error("An unexpected issue occurred during processing. Please check logs or try again.")
# Handle explicit error state OR "Ready to Process" state OR default state
if st.session_state.processing_error_occurred:
# Display specific error message if processing failed after button press
# The error might already be shown by st.error inside process_image_blur,
# but this provides a fallback message in col2.
col2.warning("Image processing failed. Check messages above or terminal logs.")
elif st.session_state.original_image_pil is not None and st.session_state.results is None:
# If file is uploaded/loaded but not processed yet (and no error occurred)
col2.markdown("### Ready to Process")
col2.markdown("Adjust parameters in the sidebar (if needed) and click **Apply Blur**.")
elif st.session_state.original_image_pil is None:
# Default state when no file is uploaded/loaded and nothing processed
col2.markdown("### Results")
col2.markdown("The processed image and details will appear here after uploading an image and clicking 'Apply Blur'.") |