Spaces:
Sleeping
Sleeping
# streamlit_app.py | |
import streamlit as st | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import torch | |
from transformers import pipeline | |
import time | |
import os | |
from io import BytesIO # <-- IMPORT BytesIO | |
# --- Page Config (MUST BE FIRST st command) --- | |
# Set page config early | |
st.set_page_config( | |
page_title="Depth Blur Studio", | |
page_icon="๐ธ", | |
layout="wide" | |
) | |
# --- Import Custom Class --- | |
# Assuming PortraitBlurrer.py is in a subfolder 'Portrait' relative to this script | |
try: | |
# If PortraitBlurrer is in ./Portrait/Portrait.py | |
from Portrait.Portrait import PortraitBlurrer | |
except ImportError: | |
# Fallback if PortraitBlurrer is in ./PortraitBlurrer.py | |
try: | |
from PortraitBlurrer import PortraitBlurrer # type: ignore | |
# st.warning("Assuming PortraitBlurrer class is in the root directory.") # Optional warning | |
except ImportError: | |
st.error("Fatal Error: Could not find the PortraitBlurrer class. Please check the file structure and import path.") | |
st.stop() # Stop execution if class can't be found | |
# --- Model Loading (Cached) --- | |
# Use cache_resource for non-data objects like models/pipelines | |
def load_depth_pipeline(): | |
"""Loads the depth estimation pipeline and caches it. Returns tuple (pipeline, device_id).""" | |
t_device = 0 if torch.cuda.is_available() else -1 | |
print(f"Attempting to load model on device: {'GPU (CUDA)' if t_device == 0 else 'CPU'}") | |
try: | |
# Use default precision (float32) | |
t_pipe = pipeline(task="depth-estimation", | |
model="depth-anything/Depth-Anything-V2-Large-hf", | |
device=t_device) | |
print("Depth Anything V2 Large model loaded successfully.") | |
return t_pipe, t_device # Return pipeline and device used | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Error will be displayed in the main app body after this function returns None | |
return None, t_device # Return None for pipe on error | |
# Load the model via the cached function | |
pipe, device_used = load_depth_pipeline() | |
# --- Title and Model Status --- | |
# Display title and info AFTER attempting model load | |
st.title("Depth Blur Studio ๐ธ (Streamlit)") | |
st.markdown( | |
"Upload a portrait image. The model will estimate depth and blur the background, keeping the subject sharp." | |
"\n*Model: `depth-anything/Depth-Anything-V2-Large-hf`*" | |
) | |
st.caption(f"_(Using device: {'GPU (CUDA)' if device_used == 0 else 'CPU'})_") # Display device info | |
# Handle model loading failure AFTER potential UI elements like title | |
if pipe is None: | |
st.error("Error loading depth estimation model. Application cannot proceed.") | |
st.stop() # Stop if model loading failed | |
# --- Processing Function --- | |
def process_image_blur(pipeline_obj, input_image_pil, max_blur_ksize, depth_thresh, feather_ksize, sharpen_val): | |
""" | |
Processes the image using the pipeline and PortraitBlurrer. | |
Returns tuple: (blurred_pil, depth_pil, mask_pil) or (None, None, None) on failure. | |
""" | |
print("Processing image...") | |
processing_start_time = time.time() | |
# 1. Convert PIL Image (RGB) to NumPy array (BGR for OpenCV) | |
input_image_np_rgb = np.array(input_image_pil) | |
original_bgr_np = cv2.cvtColor(input_image_np_rgb, cv2.COLOR_RGB2BGR) | |
# 2. Perform depth estimation | |
try: | |
with torch.no_grad(): # Inference only | |
depth_output = pipeline_obj(input_image_pil) | |
# Ensure depth map is PIL Image | |
if isinstance(depth_output, dict) and "depth" in depth_output: | |
depth_image_pil = depth_output["depth"] | |
if not isinstance(depth_image_pil, Image.Image): | |
# Attempt conversion if it's tensor/numpy (specifics might depend on pipeline output) | |
# This is a basic attempt; might need refinement based on actual output type | |
try: | |
depth_data = np.array(depth_image_pil) | |
# Normalize if needed (example: scale to 0-255) | |
depth_data = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) | |
depth_image_pil = Image.fromarray(depth_data) | |
except Exception as conversion_e: | |
print(f"Could not convert depth output to PIL Image: {conversion_e}") | |
raise ValueError("Depth estimation did not return a usable PIL Image.") | |
else: | |
# Handle cases where output might be directly the image or unexpected format | |
if isinstance(depth_output, Image.Image): | |
depth_image_pil = depth_output | |
else: | |
raise ValueError(f"Unexpected depth estimation output format: {type(depth_output)}") | |
print("Depth map generated.") | |
except Exception as e: | |
print(f"Error during depth estimation: {e}") | |
st.error(f"Depth estimation failed: {e}") # Show error in UI | |
return None, None, None | |
# 3. Initialize Blurrer and Process | |
portrait_blurrer = PortraitBlurrer( | |
max_blur=int(max_blur_ksize), | |
depth_threshold=int(depth_thresh), | |
feather_strength=int(feather_ksize), | |
sharpen_strength=float(sharpen_val) # Use the passed sharpen value | |
) | |
try: | |
# process_image returns blurred_bgr, depth_gray, mask_gray | |
blurred_bgr_np, refined_depth_np, mask_np = portrait_blurrer.process_image( | |
original_bgr_np, depth_image_pil | |
) | |
except Exception as e: | |
print(f"Error during blurring/sharpening: {e}") | |
st.error(f"Image processing (blur/sharpen) failed: {e}") # Show error in UI | |
return None, None, None | |
# 4. Convert results back to RGB PIL Images for Streamlit display | |
blurred_pil = Image.fromarray(cv2.cvtColor(blurred_bgr_np, cv2.COLOR_BGR2RGB)) | |
# Depth and mask are grayscale numpy, convert directly to PIL | |
depth_pil = Image.fromarray(refined_depth_np) | |
mask_pil = Image.fromarray(mask_np) | |
processing_end_time = time.time() | |
processing_duration = processing_end_time - processing_start_time | |
print(f"Processing finished in {processing_duration:.2f} seconds.") | |
# Move success message display outside this function, near where results are shown | |
# st.success(f"Processing finished in {processing_duration:.2f} seconds.") | |
return blurred_pil, depth_pil, mask_pil, processing_duration # Return duration | |
# --- Initialize Session State --- (Do this early) | |
if 'results' not in st.session_state: | |
st.session_state.results = None # Will store tuple (blurred, depth, mask) or None | |
if 'original_image_pil' not in st.session_state: | |
st.session_state.original_image_pil = None | |
if 'processing_error_occurred' not in st.session_state: | |
st.session_state.processing_error_occurred = False | |
if 'current_filename' not in st.session_state: | |
st.session_state.current_filename = None | |
if 'last_process_duration' not in st.session_state: | |
st.session_state.last_process_duration = None | |
# --- Sidebar for Controls --- | |
with st.sidebar: # Use 'with' notation for clarity | |
st.title("Controls") | |
uploaded_file = st.file_uploader( | |
"Upload Portrait Image", | |
type=["jpg", "png", "jpeg"], | |
label_visibility="collapsed" | |
) | |
# --- Handle New Upload for Instant Display --- | |
if uploaded_file is not None: | |
# Check if it's a new file by comparing names | |
if uploaded_file.name != st.session_state.get('current_filename', None): | |
print(f"New file uploaded: {uploaded_file.name}. Loading for display.") | |
try: | |
# Load the new image immediately | |
st.session_state.original_image_pil = Image.open(uploaded_file).convert("RGB") | |
# Clear previous results, error state and duration | |
st.session_state.results = None | |
st.session_state.processing_error_occurred = False | |
st.session_state.last_process_duration = None | |
# Update the tracked filename | |
st.session_state.current_filename = uploaded_file.name | |
except Exception as e: | |
st.error(f"Error loading image: {e}") | |
# Clear states if loading failed | |
st.session_state.original_image_pil = None | |
st.session_state.results = None | |
st.session_state.processing_error_occurred = False | |
st.session_state.current_filename = None | |
st.session_state.last_process_duration = None | |
elif st.session_state.current_filename is not None: | |
# If file uploader is cleared by the user (uploaded_file becomes None) | |
print("File upload cleared.") | |
st.session_state.original_image_pil = None | |
st.session_state.results = None | |
st.session_state.processing_error_occurred = False | |
st.session_state.current_filename = None | |
st.session_state.last_process_duration = None | |
# --- End Handle New Upload --- | |
st.markdown("---") # Separator | |
st.markdown("**Adjust Parameters:**") | |
slider_max_blur = st.slider("Blur Intensity (Kernel Size)", min_value=3, max_value=101, step=2, value=31) | |
slider_depth_thr = st.slider("Subject Depth Threshold (Lower=Far Away)", min_value=1, max_value=254, step=1, value=120) | |
slider_feather = st.slider("Feathering (Mask Smoothness)", min_value=1, max_value=51, step=2, value=5) # <-- Default changed to 5 | |
# REMOVED: slider_sharpen = st.slider("Subject Sharpening Strength", min_value=0.0, max_value=2.5, step=0.1, value=1.0) | |
st.markdown("---") # Separator | |
# Button to trigger processing - disable if no file *loaded* in session state | |
process_button = st.button( | |
"Apply Blur", | |
type="primary", | |
disabled=(st.session_state.original_image_pil is None) # Disable if no original image is loaded | |
) | |
# --- Main Area for Images --- | |
col1, col2 = st.columns(2) # Create two columns for Original | Result | |
# --- Handle Processing Trigger --- | |
if process_button: # Button is only enabled if original_image_pil exists | |
if st.session_state.original_image_pil is not None: | |
# Reset error flag on new processing attempt | |
st.session_state.processing_error_occurred = False | |
# Clear previous results and duration before showing spinner | |
st.session_state.results = None | |
st.session_state.last_process_duration = None | |
with col2: # Show spinner in the results column | |
with st.spinner('Applying blur... This may take a moment...'): | |
results_output = process_image_blur( | |
pipeline_obj=pipe, | |
input_image_pil=st.session_state.original_image_pil, # Use the image from session state | |
max_blur_ksize=slider_max_blur, | |
depth_thresh=slider_depth_thr, | |
feather_ksize=slider_feather, | |
sharpen_val=1.0 # <-- Hardcoded sharpen value | |
) | |
# Check if processing returned successfully (4 values expected now) | |
if results_output is not None and len(results_output) == 4: | |
# Unpack results and store duration separately | |
blurred_pil, depth_pil, mask_pil, duration = results_output | |
st.session_state.results = (blurred_pil, depth_pil, mask_pil) # Store tuple | |
st.session_state.last_process_duration = duration | |
else: | |
# Processing failed (returned None or wrong number of items) | |
st.session_state.results = None # Ensure results are None | |
st.session_state.processing_error_occurred = True | |
st.session_state.last_process_duration = None | |
else: | |
# This case should technically not happen due to button disable logic, but good practice | |
st.error("No image loaded to process.") | |
# --- Display Images based on Session State --- | |
# Display Original Image in Column 1 if available | |
if st.session_state.original_image_pil is not None: | |
col1.image(st.session_state.original_image_pil, caption="Original Image", use_container_width=True) | |
else: | |
col1.markdown("### Upload an image") | |
col1.markdown("Use the sidebar controls to upload your portrait.") | |
# Display Results/Status in Column 2 | |
if st.session_state.results is not None: | |
# Check if the first element (blurred_img) is not None, indicating successful processing within the function | |
blurred_img, depth_img, mask_img = st.session_state.results | |
if blurred_img is not None: | |
# Display success message with duration | |
if st.session_state.last_process_duration is not None: | |
st.success(f"Processing finished in {st.session_state.last_process_duration:.2f} seconds.") | |
col2.image(blurred_img, caption="Blurred Background Result", use_container_width=True) | |
# --- ADD DOWNLOAD BUTTON --- | |
# 1. Convert PIL Image to Bytes | |
buf = BytesIO() | |
blurred_img.save(buf, format="PNG") # Save image to buffer in PNG format | |
byte_im = buf.getvalue() # Get bytes from buffer | |
# 2. Add Download Button | |
col2.download_button( | |
label="Download Blurred Image", | |
data=byte_im, | |
file_name=f"blurred_{st.session_state.current_filename or 'result'}.png", # Suggest filename based on original | |
mime="image/png" # Set the MIME type for PNG | |
) | |
# --- END DOWNLOAD BUTTON --- | |
# Optionally display depth and mask below the main images or in expanders | |
with st.expander("Show Details (Depth Map & Mask)"): | |
# Use columns inside expander for better layout if needed | |
exp_col1, exp_col2 = st.columns(2) | |
exp_col1.image(depth_img, caption="Refined Depth Map", use_container_width=True) | |
exp_col2.image(mask_img, caption="Subject Mask", use_container_width=True) | |
else: | |
# This case might occur if results tuple was somehow malformed, treat as error | |
st.session_state.processing_error_occurred = True # Mark as error if blurred_img is None but results tuple exists | |
col2.error("An unexpected issue occurred during processing. Please check logs or try again.") | |
# Handle explicit error state OR "Ready to Process" state OR default state | |
if st.session_state.processing_error_occurred: | |
# Display specific error message if processing failed after button press | |
# The error might already be shown by st.error inside process_image_blur, | |
# but this provides a fallback message in col2. | |
col2.warning("Image processing failed. Check messages above or terminal logs.") | |
elif st.session_state.original_image_pil is not None and st.session_state.results is None: | |
# If file is uploaded/loaded but not processed yet (and no error occurred) | |
col2.markdown("### Ready to Process") | |
col2.markdown("Adjust parameters in the sidebar (if needed) and click **Apply Blur**.") | |
elif st.session_state.original_image_pil is None: | |
# Default state when no file is uploaded/loaded and nothing processed | |
col2.markdown("### Results") | |
col2.markdown("The processed image and details will appear here after uploading an image and clicking 'Apply Blur'.") |