|
import torch |
|
import gradio as gr |
|
import imageio |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision.transforms import ToTensor, Resize |
|
import spaces |
|
import tempfile |
|
from scipy.ndimage import gaussian_filter |
|
from aura_sr import AuraSR |
|
import cv2 |
|
import torch.nn.functional as F |
|
|
|
|
|
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2") |
|
|
|
|
|
def apply_lens_distortion(image, k1=0.2): |
|
"""Apply lens distortion using OpenCV.""" |
|
h, w = image.shape[:2] |
|
camera_matrix = np.array([[w, 0, w / 2], [0, w, h / 2], [0, 0, 1]], dtype=np.float32) |
|
dist_coeffs = np.array([k1, 0, 0, 0], dtype=np.float32) |
|
distorted = cv2.undistort(image, camera_matrix, dist_coeffs) |
|
return distorted |
|
|
|
def apply_depth_of_field(image_tensor, depth_tensor, focus_depth=0.5, blur_size=5): |
|
"""Apply depth of field blur using PyTorch.""" |
|
depth_diff = torch.abs(depth_tensor - focus_depth) |
|
blur_kernel = blur_size * depth_diff.clamp(0, 1) |
|
blur_kernel = blur_kernel.unsqueeze(0).unsqueeze(0) |
|
padded_image = F.pad(image_tensor, (blur_size // 2, blur_size // 2, blur_size // 2, blur_size // 2), mode='reflect') |
|
blurred = F.conv2d(padded_image, torch.ones(1, 1, blur_size, blur_size, device='cuda') / (blur_size ** 2), groups=3) |
|
mask = (depth_diff < 0.1).float() |
|
return image_tensor * mask + blurred * (1 - mask) |
|
|
|
def apply_vignette(image): |
|
"""Apply vignette effect.""" |
|
h, w = image.shape[:2] |
|
x, y = np.meshgrid(np.arange(w), np.arange(h)) |
|
center_x, center_y = w / 2, h / 2 |
|
radius = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2) |
|
max_radius = np.sqrt(center_x ** 2 + center_y ** 2) |
|
vignette = 1 - (radius / max_radius) ** 2 |
|
vignette = np.clip(vignette, 0, 1) |
|
return (image * vignette[..., np.newaxis]).astype(np.uint8) |
|
|
|
def parse_keyframes(keyframe_text): |
|
"""Parse keyframe text into time-position pairs.""" |
|
keyframes = [] |
|
try: |
|
for entry in keyframe_text.split(): |
|
time, pos = entry.split(':') |
|
x, y = map(float, pos.split(',')) |
|
keyframes.append((float(time), x, y)) |
|
keyframes.sort() |
|
return keyframes |
|
except: |
|
return [(0, 0, 0), (1, 0, 0)] |
|
|
|
def interpolate_keyframes(t, keyframes): |
|
"""Interpolate camera position between keyframes.""" |
|
if t <= keyframes[0][0]: |
|
return keyframes[0][1], keyframes[0][2] |
|
if t >= keyframes[-1][0]: |
|
return keyframes[-1][1], keyframes[-1][2] |
|
for i in range(len(keyframes) - 1): |
|
t1, x1, y1 = keyframes[i] |
|
t2, x2, y2 = keyframes[i + 1] |
|
if t1 <= t <= t2: |
|
alpha = (t - t1) / (t2 - t1) |
|
return x1 + alpha * (x2 - x1), y1 + alpha * (y2 - y1) |
|
return 0, 0 |
|
|
|
@spaces.GPU |
|
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale, apply_lens, apply_dof, apply_vig, keyframe_text): |
|
"""Generate a 3D parallax video with advanced features.""" |
|
|
|
if image.size != depth_map.size: |
|
raise ValueError("Image and depth map must have the same dimensions") |
|
|
|
|
|
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32) |
|
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32) |
|
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6) |
|
|
|
|
|
depth_np = depth_tensor.squeeze().cpu().numpy() |
|
depth_np = gaussian_filter(depth_np, sigma=1) |
|
depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0) |
|
|
|
|
|
if ssaa_factor > 1: |
|
upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True) |
|
image_tensor = upscale(image_tensor) |
|
depth_tensor = upscale(depth_tensor) |
|
|
|
H, W = image_tensor.shape[1], image_tensor.shape[2] |
|
x = torch.arange(0, W).float().to('cuda') |
|
y = torch.arange(0, H).float().to('cuda') |
|
xx, yy = torch.meshgrid(x, y, indexing='xy') |
|
pixel_grid = torch.stack((xx, yy), dim=-1) |
|
|
|
|
|
keyframes = parse_keyframes(keyframe_text) if animation_style == "custom" else None |
|
|
|
|
|
num_frames = int(fps * duration) |
|
frames = [] |
|
prev_frame = None |
|
|
|
for frame in range(num_frames): |
|
t = frame / num_frames |
|
if animation_style == "zoom": |
|
zoom_factor = 1 + amplitude * np.sin(2 * np.pi * t) |
|
displacement_x = (pixel_grid[:, :, 0] - W / 2) * (1 - zoom_factor) * depth_tensor.squeeze() |
|
displacement_y = (pixel_grid[:, :, 1] - H / 2) * (1 - zoom_factor) * depth_tensor.squeeze() |
|
elif animation_style == "horizontal": |
|
camera_x = amplitude * np.sin(2 * np.pi * t) |
|
displacement_x = k * camera_x * depth_tensor.squeeze() |
|
displacement_y = 0 |
|
elif animation_style == "vertical": |
|
camera_y = amplitude * np.sin(2 * np.pi * t) |
|
displacement_x = 0 |
|
displacement_y = k * camera_y * depth_tensor.squeeze() |
|
elif animation_style == "circle": |
|
camera_x = amplitude * np.sin(2 * np.pi * t) |
|
camera_y = amplitude * np.cos(2 * np.pi * t) |
|
displacement_x = k * camera_x * depth_tensor.squeeze() |
|
displacement_y = k * camera_y * depth_tensor.squeeze() |
|
elif animation_style == "spiral": |
|
radius = amplitude * (1 - t) |
|
camera_x = radius * np.sin(4 * np.pi * t) |
|
camera_y = radius * np.cos(4 * np.pi * t) |
|
displacement_x = k * camera_x * depth_tensor.squeeze() |
|
displacement_y = k * camera_y * depth_tensor.squeeze() |
|
elif animation_style == "custom": |
|
camera_x, camera_y = interpolate_keyframes(t, keyframes) |
|
displacement_x = k * camera_x * depth_tensor.squeeze() |
|
displacement_y = k * camera_y * depth_tensor.squeeze() |
|
else: |
|
raise ValueError(f"Unsupported animation style: {animation_style}") |
|
|
|
source_pixel_x = pixel_grid[:, :, 0] + displacement_x |
|
source_pixel_y = pixel_grid[:, :, 1] + displacement_y |
|
|
|
|
|
grid_x = 2 * source_pixel_x / (W - 1) - 1 |
|
grid_y = 2 * source_pixel_y / (H - 1) - 1 |
|
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0) |
|
|
|
|
|
warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True) |
|
|
|
|
|
if ssaa_factor > 1: |
|
downscale = Resize((image.height, image.width), antialias=True) |
|
warped = downscale(warped.squeeze(0)).unsqueeze(0) |
|
|
|
|
|
if apply_dof: |
|
warped = apply_depth_of_field(warped.squeeze(0), depth_tensor.squeeze(0)) |
|
|
|
|
|
frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy() |
|
frame_img = (frame_img * 255).astype(np.uint8) |
|
|
|
|
|
if apply_lens: |
|
frame_img = apply_lens_distortion(frame_img) |
|
|
|
|
|
if apply_vig: |
|
frame_img = apply_vignette(frame_img) |
|
|
|
|
|
if use_upscale: |
|
frame_pil = Image.fromarray(frame_img) |
|
frame_pil = aura_sr.upscale_4x_overlapped(frame_pil) |
|
frame_img = np.array(frame_pil) |
|
|
|
|
|
if use_taa and prev_frame is not None: |
|
frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8) |
|
|
|
frames.append(frame_img) |
|
prev_frame = frame_img.copy() if use_taa else None |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: |
|
output_path = tmpfile.name |
|
writer = imageio.get_writer(output_path, fps=fps, codec='libx264') |
|
for frame in frames: |
|
writer.append_data(frame) |
|
writer.close() |
|
|
|
return output_path |
|
|
|
|
|
with gr.Blocks(title="Ultimate 3D Parallax Video Generator") as demo: |
|
gr.Markdown("# Ultimate 3D Parallax Video Generator") |
|
gr.Markdown("Generate high-quality 3D parallax videos with advanced features, post-processing, and custom paths.") |
|
|
|
with gr.Row(): |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
depth_input = gr.Image(type="pil", label="Upload Depth Map") |
|
|
|
with gr.Row(): |
|
animation_style = gr.Dropdown( |
|
["zoom", "horizontal", "vertical", "circle", "spiral", "custom"], |
|
label="Animation Style", |
|
value="horizontal" |
|
) |
|
amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1) |
|
k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1) |
|
fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1) |
|
duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1) |
|
|
|
with gr.Row(): |
|
ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1) |
|
use_taa = gr.Checkbox(label="Enable TAA", value=False) |
|
use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False) |
|
apply_lens = gr.Checkbox(label="Apply Lens Distortion", value=False) |
|
apply_dof = gr.Checkbox(label="Apply Depth of Field", value=False) |
|
apply_vig = gr.Checkbox(label="Apply Vignette", value=False) |
|
|
|
with gr.Row(): |
|
keyframe_text = gr.Textbox( |
|
label="Custom Keyframes (time:x,y)", |
|
value="0:0,0 0.5:5,0 1:0,0", |
|
placeholder="e.g., 0:0,0 0.5:5,0 1:0,0", |
|
visible=True |
|
) |
|
|
|
generate_btn = gr.Button("Generate Video") |
|
video_output = gr.Video(label="Parallax Video") |
|
|
|
generate_btn.click( |
|
fn=generate_parallax_video, |
|
inputs=[ |
|
image_input, depth_input, animation_style, amplitude_slider, k_slider, |
|
fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale, |
|
apply_lens, apply_dof, apply_vig, keyframe_text |
|
], |
|
outputs=video_output |
|
) |
|
|
|
demo.launch() |