File size: 10,239 Bytes
2552c4b
aef98ce
 
 
 
16f4e59
aef98ce
 
16f4e59
c1d8b1e
 
 
18e62f1
 
 
 
c1d8b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aef98ce
c1d8b1e
 
18e62f1
6c8993e
18e62f1
6c8993e
c1d8b1e
16f4e59
 
6c8993e
 
18e62f1
 
 
 
6c8993e
18e62f1
16f4e59
 
 
 
 
 
18e62f1
 
 
 
 
c1d8b1e
 
 
18e62f1
6c8993e
 
16f4e59
6c8993e
18e62f1
 
c1d8b1e
 
 
 
 
18e62f1
c1d8b1e
 
18e62f1
 
c1d8b1e
 
18e62f1
 
 
c1d8b1e
 
18e62f1
 
 
 
c1d8b1e
 
 
 
 
 
18e62f1
 
 
 
 
 
 
 
 
 
 
c1d8b1e
18e62f1
 
c1d8b1e
18e62f1
 
 
 
c1d8b1e
 
 
 
 
18e62f1
 
 
c1d8b1e
 
 
 
 
 
 
 
 
18e62f1
c1d8b1e
 
18e62f1
 
 
 
 
 
6c8993e
d494365
 
18e62f1
 
 
 
6c8993e
 
aef98ce
 
18e62f1
 
 
c1d8b1e
 
 
18e62f1
aef98ce
18e62f1
 
d494365
18e62f1
c1d8b1e
 
 
 
 
18e62f1
 
 
 
c1d8b1e
 
18e62f1
 
 
c1d8b1e
 
 
 
 
 
 
 
 
 
 
18e62f1
 
 
 
 
 
c1d8b1e
 
 
 
 
18e62f1
 
aef98ce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import torch
import gradio as gr
import imageio
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor, Resize
import spaces
import tempfile
from scipy.ndimage import gaussian_filter
from aura_sr import AuraSR
import cv2
import torch.nn.functional as F

# Load AuraSR-v2 model once at startup
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")

# Post-processing functions
def apply_lens_distortion(image, k1=0.2):
    """Apply lens distortion using OpenCV."""
    h, w = image.shape[:2]
    camera_matrix = np.array([[w, 0, w / 2], [0, w, h / 2], [0, 0, 1]], dtype=np.float32)
    dist_coeffs = np.array([k1, 0, 0, 0], dtype=np.float32)
    distorted = cv2.undistort(image, camera_matrix, dist_coeffs)
    return distorted

def apply_depth_of_field(image_tensor, depth_tensor, focus_depth=0.5, blur_size=5):
    """Apply depth of field blur using PyTorch."""
    depth_diff = torch.abs(depth_tensor - focus_depth)
    blur_kernel = blur_size * depth_diff.clamp(0, 1)
    blur_kernel = blur_kernel.unsqueeze(0).unsqueeze(0)
    padded_image = F.pad(image_tensor, (blur_size // 2, blur_size // 2, blur_size // 2, blur_size // 2), mode='reflect')
    blurred = F.conv2d(padded_image, torch.ones(1, 1, blur_size, blur_size, device='cuda') / (blur_size ** 2), groups=3)
    mask = (depth_diff < 0.1).float()
    return image_tensor * mask + blurred * (1 - mask)

def apply_vignette(image):
    """Apply vignette effect."""
    h, w = image.shape[:2]
    x, y = np.meshgrid(np.arange(w), np.arange(h))
    center_x, center_y = w / 2, h / 2
    radius = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
    max_radius = np.sqrt(center_x ** 2 + center_y ** 2)
    vignette = 1 - (radius / max_radius) ** 2
    vignette = np.clip(vignette, 0, 1)
    return (image * vignette[..., np.newaxis]).astype(np.uint8)

def parse_keyframes(keyframe_text):
    """Parse keyframe text into time-position pairs."""
    keyframes = []
    try:
        for entry in keyframe_text.split():
            time, pos = entry.split(':')
            x, y = map(float, pos.split(','))
            keyframes.append((float(time), x, y))
        keyframes.sort()  # Sort by time
        return keyframes
    except:
        return [(0, 0, 0), (1, 0, 0)]  # Default fallback

def interpolate_keyframes(t, keyframes):
    """Interpolate camera position between keyframes."""
    if t <= keyframes[0][0]:
        return keyframes[0][1], keyframes[0][2]
    if t >= keyframes[-1][0]:
        return keyframes[-1][1], keyframes[-1][2]
    for i in range(len(keyframes) - 1):
        t1, x1, y1 = keyframes[i]
        t2, x2, y2 = keyframes[i + 1]
        if t1 <= t <= t2:
            alpha = (t - t1) / (t2 - t1)
            return x1 + alpha * (x2 - x1), y1 + alpha * (y2 - y1)
    return 0, 0  # Fallback

@spaces.GPU
def generate_parallax_video(image, depth_map, animation_style, amplitude, k, fps, duration, ssaa_factor, use_taa, use_upscale, apply_lens, apply_dof, apply_vig, keyframe_text):
    """Generate a 3D parallax video with advanced features."""
    # Validate input dimensions
    if image.size != depth_map.size:
        raise ValueError("Image and depth map must have the same dimensions")

    # Convert to tensors
    image_tensor = ToTensor()(image).to('cuda', dtype=torch.float32)
    depth_tensor = ToTensor()(depth_map.convert('L')).to('cuda', dtype=torch.float32)
    depth_tensor = (depth_tensor - depth_tensor.min()) / (depth_tensor.max() - depth_tensor.min() + 1e-6)

    # Smooth depth map
    depth_np = depth_tensor.squeeze().cpu().numpy()
    depth_np = gaussian_filter(depth_np, sigma=1)
    depth_tensor = torch.tensor(depth_np, device='cuda', dtype=torch.float32).unsqueeze(0)

    # Apply SSAA
    if ssaa_factor > 1:
        upscale = Resize((int(image.height * ssaa_factor), int(image.width * ssaa_factor)), antialias=True)
        image_tensor = upscale(image_tensor)
        depth_tensor = upscale(depth_tensor)

    H, W = image_tensor.shape[1], image_tensor.shape[2]
    x = torch.arange(0, W).float().to('cuda')
    y = torch.arange(0, H).float().to('cuda')
    xx, yy = torch.meshgrid(x, y, indexing='xy')
    pixel_grid = torch.stack((xx, yy), dim=-1)

    # Parse keyframes for custom path
    keyframes = parse_keyframes(keyframe_text) if animation_style == "custom" else None

    # Generate frames
    num_frames = int(fps * duration)
    frames = []
    prev_frame = None

    for frame in range(num_frames):
        t = frame / num_frames
        if animation_style == "zoom":
            zoom_factor = 1 + amplitude * np.sin(2 * np.pi * t)
            displacement_x = (pixel_grid[:, :, 0] - W / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
            displacement_y = (pixel_grid[:, :, 1] - H / 2) * (1 - zoom_factor) * depth_tensor.squeeze()
        elif animation_style == "horizontal":
            camera_x = amplitude * np.sin(2 * np.pi * t)
            displacement_x = k * camera_x * depth_tensor.squeeze()
            displacement_y = 0
        elif animation_style == "vertical":
            camera_y = amplitude * np.sin(2 * np.pi * t)
            displacement_x = 0
            displacement_y = k * camera_y * depth_tensor.squeeze()
        elif animation_style == "circle":
            camera_x = amplitude * np.sin(2 * np.pi * t)
            camera_y = amplitude * np.cos(2 * np.pi * t)
            displacement_x = k * camera_x * depth_tensor.squeeze()
            displacement_y = k * camera_y * depth_tensor.squeeze()
        elif animation_style == "spiral":
            radius = amplitude * (1 - t)
            camera_x = radius * np.sin(4 * np.pi * t)
            camera_y = radius * np.cos(4 * np.pi * t)
            displacement_x = k * camera_x * depth_tensor.squeeze()
            displacement_y = k * camera_y * depth_tensor.squeeze()
        elif animation_style == "custom":
            camera_x, camera_y = interpolate_keyframes(t, keyframes)
            displacement_x = k * camera_x * depth_tensor.squeeze()
            displacement_y = k * camera_y * depth_tensor.squeeze()
        else:
            raise ValueError(f"Unsupported animation style: {animation_style}")

        source_pixel_x = pixel_grid[:, :, 0] + displacement_x
        source_pixel_y = pixel_grid[:, :, 1] + displacement_y

        # Normalize to [-1, 1]
        grid_x = 2 * source_pixel_x / (W - 1) - 1
        grid_y = 2 * source_pixel_y / (H - 1) - 1
        grid = torch.stack((grid_x, grid_y), dim=-1).unsqueeze(0)

        # Warp image
        warped = torch.nn.functional.grid_sample(image_tensor.unsqueeze(0), grid, mode='bicubic', align_corners=True)

        # Downsample if SSAA
        if ssaa_factor > 1:
            downscale = Resize((image.height, image.width), antialias=True)
            warped = downscale(warped.squeeze(0)).unsqueeze(0)

        # Apply depth of field if enabled
        if apply_dof:
            warped = apply_depth_of_field(warped.squeeze(0), depth_tensor.squeeze(0))

        # Convert to numpy
        frame_img = warped.squeeze(0).permute(1, 2, 0).cpu().numpy()
        frame_img = (frame_img * 255).astype(np.uint8)

        # Apply lens distortion if enabled
        if apply_lens:
            frame_img = apply_lens_distortion(frame_img)

        # Apply vignette if enabled
        if apply_vig:
            frame_img = apply_vignette(frame_img)

        # Apply upscaling if enabled
        if use_upscale:
            frame_pil = Image.fromarray(frame_img)
            frame_pil = aura_sr.upscale_4x_overlapped(frame_pil)
            frame_img = np.array(frame_pil)

        # Apply TAA if enabled
        if use_taa and prev_frame is not None:
            frame_img = (frame_img * 0.8 + prev_frame * 0.2).astype(np.uint8)

        frames.append(frame_img)
        prev_frame = frame_img.copy() if use_taa else None

    # Save video
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
        output_path = tmpfile.name
        writer = imageio.get_writer(output_path, fps=fps, codec='libx264')
        for frame in frames:
            writer.append_data(frame)
        writer.close()

    return output_path

# Gradio interface
with gr.Blocks(title="Ultimate 3D Parallax Video Generator") as demo:
    gr.Markdown("# Ultimate 3D Parallax Video Generator")
    gr.Markdown("Generate high-quality 3D parallax videos with advanced features, post-processing, and custom paths.")

    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Image")
        depth_input = gr.Image(type="pil", label="Upload Depth Map")

    with gr.Row():
        animation_style = gr.Dropdown(
            ["zoom", "horizontal", "vertical", "circle", "spiral", "custom"],
            label="Animation Style",
            value="horizontal"
        )
        amplitude_slider = gr.Slider(0, 10, value=2, label="Amplitude", step=0.1)
        k_slider = gr.Slider(1, 20, value=5, label="Depth Scale (k)", step=0.1)
        fps_slider = gr.Slider(10, 60, value=30, label="FPS", step=1)
        duration_slider = gr.Slider(1, 10, value=5, label="Duration (s)", step=0.1)

    with gr.Row():
        ssaa_factor = gr.Dropdown([1, 2, 4], label="SSAA Factor", value=1)
        use_taa = gr.Checkbox(label="Enable TAA", value=False)
        use_upscale = gr.Checkbox(label="Enable AuraSR-v2 Upscaling", value=False)
        apply_lens = gr.Checkbox(label="Apply Lens Distortion", value=False)
        apply_dof = gr.Checkbox(label="Apply Depth of Field", value=False)
        apply_vig = gr.Checkbox(label="Apply Vignette", value=False)

    with gr.Row():
        keyframe_text = gr.Textbox(
            label="Custom Keyframes (time:x,y)",
            value="0:0,0 0.5:5,0 1:0,0",
            placeholder="e.g., 0:0,0 0.5:5,0 1:0,0",
            visible=True
        )

    generate_btn = gr.Button("Generate Video")
    video_output = gr.Video(label="Parallax Video")

    generate_btn.click(
        fn=generate_parallax_video,
        inputs=[
            image_input, depth_input, animation_style, amplitude_slider, k_slider,
            fps_slider, duration_slider, ssaa_factor, use_taa, use_upscale,
            apply_lens, apply_dof, apply_vig, keyframe_text
        ],
        outputs=video_output
    )

demo.launch()