File size: 10,239 Bytes
2552c4b aef98ce 16f4e59 aef98ce 16f4e59 c1d8b1e 18e62f1 c1d8b1e aef98ce c1d8b1e 18e62f1 6c8993e 18e62f1 6c8993e c1d8b1e 16f4e59 6c8993e 18e62f1 6c8993e 18e62f1 16f4e59 18e62f1 c1d8b1e 18e62f1 6c8993e 16f4e59 6c8993e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 6c8993e d494365 18e62f1 6c8993e aef98ce 18e62f1 c1d8b1e 18e62f1 aef98ce 18e62f1 d494365 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 c1d8b1e 18e62f1 aef98ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import torch
import gradio as gr
import imageio
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor, Resize
import spaces
import tempfile
from scipy.ndimage import gaussian_filter
from aura_sr import AuraSR
import cv2
import torch.nn.functional as F
# Load AuraSR-v2 model once at startup
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
# Post-processing functions
def apply_lens_distortion(image, k1=0.2):
"""Apply lens distortion using OpenCV."""
h, w = image.shape[:2]
camera_matrix = np.array([[w, 0, w / 2], [0, w, h / 2], [0, 0, 1]], dtype=np.float32)
dist_coeffs = np.array([k1, 0, 0, 0], dtype=np.float32)
distorted = cv2.undistort(image, camera_matrix, dist_coeffs)
return distorted
def apply_depth_of_field(image_tensor, depth_tensor, focus_depth=0.5, blur_size=5):
"""Apply depth of field blur using PyTorch."""
depth_diff = torch.abs(depth_tensor - focus_depth)
blur_kernel = blur_size * depth_diff.clamp(0, 1)
blur_kernel = blur_kernel.unsqueeze(0).unsqueeze(0)
padded_image = F.pad(image_tensor, (blur_size // 2, blur_size // 2, blur_size // 2, blur_size // 2), mode='reflect')
blurred = F.conv2d(padded_image, torch.ones(1, 1, blur_size, blur_size, device='cuda') / (blur_size ** 2), groups=3)
mask = (depth_diff < 0.1).float()
return image_tensor * mask + blurred * (1 - mask)
def apply_vignette(image):
"""Apply vignette effect."""
h, w = image.shape[:2]
x, y = np.meshgrid(np.arange(w), np.arange(h))
center_x, center_y = w / 2, h / 2
radius = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
max_radius = np.sqrt(center_x ** 2 + center_y ** 2)
vignette = 1 - (radius / max_radius) ** 2
vignette = np.clip(vignette, 0, 1)
return (image * vignette[..., np.newaxis]).astype(np.uint8)
def parse_keyframes(keyframe_text):
"""Parse keyframe text into time-position pairs."""
keyframes = []
try:
for entry in keyframe_text.split():
time, pos = entry.split(':')
x, y = map(float, pos.split(','))
keyframes.append((float(time), x, y))
keyframes.sort() # Sort by time
return keyframes
except:
return [(0, 0, 0), (1, 0, 0)] # Default fallback
def interpolate_keyframes(t, keyframes):
"""Interpolate camera position between keyframes."""
if t <= keyframes[0][0]:
return keyframes[0][1], keyframes[0][2]
if t >= keyframes[-1][0]:
return keyframes[-1][1], keyframes[-1][2]
for i in range(len(keyframes) - 1):
t1, x1, y1 = keyframes[i]
t2, x2, y2 = keyframes[i + 1]
if t1 <= t <= t2:
alpha = (t - t1) / (t2 - t1)
return x1 + alpha * (x2 - x1), y1 + alpha * (y2 - y1)
return 0, 0 # Fallback
@spaces.GPU
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale, apply_lens, apply_dof, apply_vig, keyframe_text):
"""Generate a 3D parallax video with advanced features."""
# Validate input dimensions
if image.size != depth_map.size:
raise ValueError("Image and depth map must have the same dimensions")
# Convert to tensors
image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)
# Smooth depth map
depth_np = depth_tensor.squeeze().cpu().numpy()
depth_np = gaussian_filter(depth_np, sigma=1)
depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)
# Apply SSAA
if ssaa_factor > 1:
upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
image_tensor = upscale(image_tensor)
depth_tensor = upscale(depth_tensor)
H, W = image_tensor.shape[1], image_tensor.shape[2]
x = torch.arange(0, W).float().to('cuda')
y = torch.arange(0, H).float().to('cuda')
xx, yy = torch.meshgrid(x, y, indexing='xy')
pixel_grid = torch.stack((xx, yy), dim=-1)
# Parse keyframes for custom path
keyframes = parse_keyframes(keyframe_text) if animation_style == "custom" else None
# Generate frames
num_frames = int(fps * duration)
frames = []
prev_frame = None
for frame in range(num_frames):
t = frame / num_frames
if animation_style == "zoom":
zoom_factor = 1 + amplitude * np.sin(2 * np.pi * t)
displacement_x = (pixel_grid[:, :, 0] - W / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
displacement_y = (pixel_grid[:, :, 1] - H / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
elif animation_style == "horizontal":
camera_x = amplitude * np.sin(2 * np.pi * t)
displacement_x = k * camera_x * depth_tensor.squeeze()
displacement_y = 0
elif animation_style == "vertical":
camera_y = amplitude * np.sin(2 * np.pi * t)
displacement_x = 0
displacement_y = k * camera_y * depth_tensor.squeeze()
elif animation_style == "circle":
camera_x = amplitude * np.sin(2 * np.pi * t)
camera_y = amplitude * np.cos(2 * np.pi * t)
displacement_x = k * camera_x * depth_tensor.squeeze()
displacement_y = k * camera_y * depth_tensor.squeeze()
elif animation_style == "spiral":
radius = amplitude * (1 - t)
camera_x = radius * np.sin(4 * np.pi * t)
camera_y = radius * np.cos(4 * np.pi * t)
displacement_x = k * camera_x * depth_tensor.squeeze()
displacement_y = k * camera_y * depth_tensor.squeeze()
elif animation_style == "custom":
camera_x, camera_y = interpolate_keyframes(t, keyframes)
displacement_x = k * camera_x * depth_tensor.squeeze()
displacement_y = k * camera_y * depth_tensor.squeeze()
else:
raise ValueError(f"Unsupported animation style: {animation_style}")
source_pixel_x = pixel_grid[:, :, 0] + displacement_x
source_pixel_y = pixel_grid[:, :, 1] + displacement_y
# Normalize to [-1, 1]
grid_x = 2 * source_pixel_x / (W - 1) - 1
grid_y = 2 * source_pixel_y / (H - 1) - 1
grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)
# Warp image
warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)
# Downsample if SSAA
if ssaa_factor > 1:
downscale = Resize((image.height, image.width), antialias=True)
warped = downscale(warped.squeeze(0)).unsqueeze(0)
# Apply depth of field if enabled
if apply_dof:
warped = apply_depth_of_field(warped.squeeze(0), depth_tensor.squeeze(0))
# Convert to numpy
frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
frame_img = (frame_img * 255).astype(np.uint8)
# Apply lens distortion if enabled
if apply_lens:
frame_img = apply_lens_distortion(frame_img)
# Apply vignette if enabled
if apply_vig:
frame_img = apply_vignette(frame_img)
# Apply upscaling if enabled
if use_upscale:
frame_pil = Image.fromarray(frame_img)
frame_pil = aura_sr.upscale_4x_overlapped(frame_pil)
frame_img = np.array(frame_pil)
# Apply TAA if enabled
if use_taa and prev_frame is not None:
frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)
frames.append(frame_img)
prev_frame = frame_img.copy() if use_taa else None
# Save video
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
output_path = tmpfile.name
writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
for frame in frames:
writer.append_data(frame)
writer.close()
return output_path
# Gradio interface
with gr.Blocks(title="Ultimate 3D Parallax Video Generator") as demo:
gr.Markdown("# Ultimate 3D Parallax Video Generator")
gr.Markdown("Generate high-quality 3D parallax videos with advanced features, post-processing, and custom paths.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
depth_input = gr.Image(type="pil", label="Upload Depth Map")
with gr.Row():
animation_style = gr.Dropdown(
["zoom", "horizontal", "vertical", "circle", "spiral", "custom"],
label="Animation Style",
value="horizontal"
)
amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)
with gr.Row():
ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
use_taa = gr.Checkbox(label="Enable TAA", value=False)
use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
apply_lens = gr.Checkbox(label="Apply Lens Distortion", value=False)
apply_dof = gr.Checkbox(label="Apply Depth of Field", value=False)
apply_vig = gr.Checkbox(label="Apply Vignette", value=False)
with gr.Row():
keyframe_text = gr.Textbox(
label="Custom Keyframes (time:x,y)",
value="0:0,0 0.5:5,0 1:0,0",
placeholder="e.g., 0:0,0 0.5:5,0 1:0,0",
visible=True
)
generate_btn = gr.Button("Generate Video")
video_output = gr.Video(label="Parallax Video")
generate_btn.click(
fn=generate_parallax_video,
inputs=[
image_input, depth_input, animation_style, amplitude_slider, k_slider,
fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale,
apply_lens, apply_dof, apply_vig, keyframe_text
],
outputs=video_output
)
demo.launch() |