Pseudo3D / app.py
yuyutsu07's picture
Update app.py
c1d8b1e verified
import torch
import gradio as gr
import imageio
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor, Resize
import spaces
import tempfile
from scipy.ndimage import gaussian_filter
from aura_sr import AuraSR
import cv2
import torch.nn.functional as F
# Load AuraSR-v2 model once at startup
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
# Post-processing functions
def apply_lens_distortion(image, k1=0.2):
"""Apply lens distortion using OpenCV."""
h, w = image.shape[:2]
camera_matrix = np.array([[w, 0, w / 2], [0, w, h / 2], [0, 0, 1]], dtype=np.float32)
dist_coeffs = np.array([k1, 0, 0, 0], dtype=np.float32)
distorted = cv2.undistort(image, camera_matrix, dist_coeffs)
return distorted
def apply_depth_of_field(image_tensor, depth_tensor, focus_depth=0.5, blur_size=5):
"""Apply depth of field blur using PyTorch."""
depth_diff = torch.abs(depth_tensor - focus_depth)
blur_kernel = blur_size * depth_diff.clamp(0, 1)
blur_kernel = blur_kernel.unsqueeze(0).unsqueeze(0)
padded_image = F.pad(image_tensor, (blur_size // 2, blur_size // 2, blur_size // 2, blur_size // 2), mode='reflect')
blurred = F.conv2d(padded_image, torch.ones(1, 1, blur_size, blur_size, device='cuda') / (blur_size ** 2), groups=3)
mask = (depth_diff < 0.1).float()
return image_tensor * mask + blurred * (1 - mask)
def apply_vignette(image):
"""Apply vignette effect."""
h, w = image.shape[:2]
x, y = np.meshgrid(np.arange(w), np.arange(h))
center_x, center_y = w / 2, h / 2
radius = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
max_radius = np.sqrt(center_x ** 2 + center_y ** 2)
vignette = 1 - (radius / max_radius) ** 2
vignette = np.clip(vignette, 0, 1)
return (image * vignette[..., np.newaxis]).astype(np.uint8)
def parse_keyframes(keyframe_text):
"""Parse keyframe text into time-position pairs."""
keyframes = []
try:
for entry in keyframe_text.split():
time, pos = entry.split(':')
x, y = map(float, pos.split(','))
keyframes.append((float(time), x, y))
keyframes.sort() # Sort by time
return keyframes
except:
return [(0, 0, 0), (1, 0, 0)] # Default fallback
def interpolate_keyframes(t, keyframes):
"""Interpolate camera position between keyframes."""
if t <= keyframes[0][0]:
return keyframes[0][1], keyframes[0][2]
if t >= keyframes[-1][0]:
return keyframes[-1][1], keyframes[-1][2]
for i in range(len(keyframes) - 1):
t1, x1, y1 = keyframes[i]
t2, x2, y2 = keyframes[i + 1]
if t1 <= t <= t2:
alpha = (t - t1) / (t2 - t1)
return x1 + alpha * (x2 - x1), y1 + alpha * (y2 - y1)
return 0, 0 # Fallback
@spaces.GPU
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale, apply_lens, apply_dof, apply_vig, keyframe_text):
"""Generate a 3D parallax video with advanced features."""
# Validate input dimensions
if image.size != depth_map.size:
raise ValueError("Image and depth map must have the same dimensions")
# Convert to tensors
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
# Smooth depth map
depth_np = depth_tensor.squeeze().cpu().numpy()
depth_np = gaussian_filter(depth_np, sigma=1)
depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)
# Apply SSAA
if ssaa_factor > 1:
upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
image_tensor = upscale(image_tensor)
depth_tensor = upscale(depth_tensor)
H, W = image_tensor.shape[1], image_tensor.shape[2]
x = torch.arange(0, W).float().to('cuda')
y = torch.arange(0, H).float().to('cuda')
xx, yy = torch.meshgrid(x, y, indexing='xy')
pixel_grid = torch.stack((xx, yy), dim=-1)
# Parse keyframes for custom path
keyframes = parse_keyframes(keyframe_text) if animation_style == "custom" else None
# Generate frames
num_frames = int(fps * duration)
frames = []
prev_frame = None
for frame in range(num_frames):
t = frame / num_frames
if animation_style == "zoom":
zoom_factor = 1 + amplitude * np.sin(2 * np.pi * t)
displacement_x = (pixel_grid[:, :, 0] - W / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
displacement_y = (pixel_grid[:, :, 1] - H / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
elif animation_style == "horizontal":
camera_x = amplitude * np.sin(2 * np.pi * t)
displacement_x = k * camera_x * depth_tensor.squeeze()
displacement_y = 0
elif animation_style == "vertical":
camera_y = amplitude * np.sin(2 * np.pi * t)
displacement_x = 0
displacement_y = k * camera_y * depth_tensor.squeeze()
elif animation_style == "circle":
camera_x = amplitude * np.sin(2 * np.pi * t)
camera_y = amplitude * np.cos(2 * np.pi * t)
displacement_x = k * camera_x * depth_tensor.squeeze()
displacement_y = k * camera_y * depth_tensor.squeeze()
elif animation_style == "spiral":
radius = amplitude * (1 - t)
camera_x = radius * np.sin(4 * np.pi * t)
camera_y = radius * np.cos(4 * np.pi * t)
displacement_x = k * camera_x * depth_tensor.squeeze()
displacement_y = k * camera_y * depth_tensor.squeeze()
elif animation_style == "custom":
camera_x, camera_y = interpolate_keyframes(t, keyframes)
displacement_x = k * camera_x * depth_tensor.squeeze()
displacement_y = k * camera_y * depth_tensor.squeeze()
else:
raise ValueError(f"Unsupported animation style: {animation_style}")
source_pixel_x = pixel_grid[:, :, 0] + displacement_x
source_pixel_y = pixel_grid[:, :, 1] + displacement_y
# Normalize to [-1, 1]
grid_x = 2 * source_pixel_x / (W - 1) - 1
grid_y = 2 * source_pixel_y / (H - 1) - 1
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
# Warp image
warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
# Downsample if SSAA
if ssaa_factor > 1:
downscale = Resize((image.height, image.width), antialias=True)
warped = downscale(warped.squeeze(0)).unsqueeze(0)
# Apply depth of field if enabled
if apply_dof:
warped = apply_depth_of_field(warped.squeeze(0), depth_tensor.squeeze(0))
# Convert to numpy
frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
frame_img = (frame_img * 255).astype(np.uint8)
# Apply lens distortion if enabled
if apply_lens:
frame_img = apply_lens_distortion(frame_img)
# Apply vignette if enabled
if apply_vig:
frame_img = apply_vignette(frame_img)
# Apply upscaling if enabled
if use_upscale:
frame_pil = Image.fromarray(frame_img)
frame_pil = aura_sr.upscale_4x_overlapped(frame_pil)
frame_img = np.array(frame_pil)
# Apply TAA if enabled
if use_taa and prev_frame is not None:
frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)
frames.append(frame_img)
prev_frame = frame_img.copy() if use_taa else None
# Save video
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
output_path = tmpfile.name
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
for frame in frames:
writer.append_data(frame)
writer.close()
return output_path
# Gradio interface
with gr.Blocks(title="Ultimate 3D Parallax Video Generator") as demo:
gr.Markdown("# Ultimate 3D Parallax Video Generator")
gr.Markdown("Generate high-quality 3D parallax videos with advanced features, post-processing, and custom paths.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
depth_input = gr.Image(type="pil", label="Upload Depth Map")
with gr.Row():
animation_style = gr.Dropdown(
["zoom", "horizontal", "vertical", "circle", "spiral", "custom"],
label="Animation Style",
value="horizontal"
)
amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
with gr.Row():
ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
use_taa = gr.Checkbox(label="Enable TAA", value=False)
use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
apply_lens = gr.Checkbox(label="Apply Lens Distortion", value=False)
apply_dof = gr.Checkbox(label="Apply Depth of Field", value=False)
apply_vig = gr.Checkbox(label="Apply Vignette", value=False)
with gr.Row():
keyframe_text = gr.Textbox(
label="Custom Keyframes (time:x,y)",
value="0:0,0 0.5:5,0 1:0,0",
placeholder="e.g., 0:0,0 0.5:5,0 1:0,0",
visible=True
)
generate_btn = gr.Button("Generate Video")
video_output = gr.Video(label="Parallax Video")
generate_btn.click(
fn=generate_parallax_video,
inputs=[
image_input, depth_input, animation_style, amplitude_slider, k_slider,
fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale,
apply_lens, apply_dof, apply_vig, keyframe_text
],
outputs=video_output
)
demo.launch()