|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from tqdm import tqdm |
|
|
|
from seva.geometry import get_camera_dist |
|
|
|
|
|
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: |
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
|
dims_to_append = target_dims - x.ndim |
|
if dims_to_append < 0: |
|
raise ValueError( |
|
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" |
|
) |
|
return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
|
def append_zero(x: torch.Tensor) -> torch.Tensor: |
|
return torch.cat([x, x.new_zeros([1])]) |
|
|
|
|
|
def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: |
|
return (x - denoised) / append_dims(sigma, x.ndim) |
|
|
|
|
|
def make_betas( |
|
num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2 |
|
) -> np.ndarray: |
|
betas = ( |
|
torch.linspace( |
|
linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64 |
|
) |
|
** 2 |
|
) |
|
return betas.numpy() |
|
|
|
|
|
def generate_roughly_equally_spaced_steps( |
|
num_substeps: int, max_step: int |
|
) -> np.ndarray: |
|
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] |
|
|
|
|
|
class EpsScaling(object): |
|
def __call__( |
|
self, sigma: torch.Tensor |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
c_skip = torch.ones_like(sigma, device=sigma.device) |
|
c_out = -sigma |
|
c_in = 1 / (sigma**2 + 1.0) ** 0.5 |
|
c_noise = sigma.clone() |
|
return c_skip, c_out, c_in, c_noise |
|
|
|
|
|
class DDPMDiscretization(object): |
|
def __init__( |
|
self, |
|
linear_start: float = 5e-06, |
|
linear_end: float = 0.012, |
|
num_timesteps: int = 1000, |
|
log_snr_shift: float | None = 2.4, |
|
): |
|
self.num_timesteps = num_timesteps |
|
|
|
betas = make_betas( |
|
num_timesteps, |
|
linear_start=linear_start, |
|
linear_end=linear_end, |
|
) |
|
self.log_snr_shift = log_snr_shift |
|
|
|
alphas = 1.0 - betas |
|
self.alphas_cumprod = np.cumprod(alphas, axis=0) |
|
|
|
def get_sigmas(self, n: int, device: str | torch.device = "cpu") -> torch.Tensor: |
|
if n < self.num_timesteps: |
|
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) |
|
alphas_cumprod = self.alphas_cumprod[timesteps] |
|
elif n == self.num_timesteps: |
|
alphas_cumprod = self.alphas_cumprod |
|
else: |
|
raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.") |
|
|
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 |
|
if self.log_snr_shift is not None: |
|
sigmas = sigmas * np.exp(self.log_snr_shift) |
|
return torch.flip( |
|
torch.tensor(sigmas, dtype=torch.float32, device=device), (0,) |
|
) |
|
|
|
def __call__( |
|
self, |
|
n: int, |
|
do_append_zero: bool = True, |
|
flip: bool = False, |
|
device: str | torch.device = "cpu", |
|
) -> torch.Tensor: |
|
sigmas = self.get_sigmas(n, device=device) |
|
sigmas = append_zero(sigmas) if do_append_zero else sigmas |
|
return sigmas if not flip else torch.flip(sigmas, (0,)) |
|
|
|
|
|
class DiscreteDenoiser(object): |
|
sigmas: torch.Tensor |
|
|
|
def __init__( |
|
self, |
|
discretization: DDPMDiscretization, |
|
num_idx: int = 1000, |
|
device: str | torch.device = "cpu", |
|
): |
|
self.scaling = EpsScaling() |
|
self.discretization = discretization |
|
self.num_idx = num_idx |
|
self.device = device |
|
|
|
self.register_sigmas() |
|
|
|
def register_sigmas(self): |
|
self.sigmas = self.discretization( |
|
self.num_idx, do_append_zero=False, flip=True, device=self.device |
|
) |
|
|
|
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: |
|
dists = sigma - self.sigmas[:, None] |
|
return dists.abs().argmin(dim=0).view(sigma.shape) |
|
|
|
def idx_to_sigma(self, idx: torch.Tensor | int) -> torch.Tensor: |
|
return self.sigmas[idx] |
|
|
|
def __call__( |
|
self, |
|
network: nn.Module, |
|
input: torch.Tensor, |
|
sigma: torch.Tensor, |
|
cond: dict, |
|
**additional_model_inputs, |
|
) -> torch.Tensor: |
|
sigma = self.idx_to_sigma(self.sigma_to_idx(sigma)) |
|
sigma_shape = sigma.shape |
|
sigma = append_dims(sigma, input.ndim) |
|
c_skip, c_out, c_in, c_noise = self.scaling(sigma) |
|
c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape)) |
|
if "replace" in cond: |
|
x, mask = cond.pop("replace").split((input.shape[1], 1), dim=1) |
|
input = input * (1 - mask) + x * mask |
|
return ( |
|
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out |
|
+ input * c_skip |
|
) |
|
|
|
|
|
class ConstantScaleRule(object): |
|
def __call__(self, scale: float | torch.Tensor) -> float | torch.Tensor: |
|
return scale |
|
|
|
|
|
class MultiviewScaleRule(object): |
|
def __init__(self, min_scale: float = 1.0): |
|
self.min_scale = min_scale |
|
|
|
def __call__( |
|
self, |
|
scale: float | torch.Tensor, |
|
c2w: torch.Tensor, |
|
K: torch.Tensor, |
|
input_frame_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
c2w_input = c2w[input_frame_mask] |
|
rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values |
|
translation_diff = ( |
|
get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values |
|
) |
|
K_diff = ( |
|
((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1) |
|
) |
|
close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff |
|
if isinstance(scale, torch.Tensor): |
|
scale = scale.clone() |
|
scale[close_frame] = self.min_scale |
|
elif isinstance(scale, float): |
|
scale = torch.where(close_frame, self.min_scale, scale) |
|
else: |
|
raise ValueError(f"Invalid scale type {type(scale)}.") |
|
return scale |
|
|
|
|
|
class ConstantScaleSchedule(object): |
|
def __call__( |
|
self, sigma: float | torch.Tensor, scale: float | torch.Tensor |
|
) -> float | torch.Tensor: |
|
if isinstance(sigma, float): |
|
return scale |
|
elif isinstance(sigma, torch.Tensor): |
|
if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor): |
|
sigma = append_dims(sigma, scale.ndim) |
|
return scale * torch.ones_like(sigma) |
|
else: |
|
raise ValueError(f"Invalid sigma type {type(sigma)}.") |
|
|
|
|
|
class ConstantGuidance(object): |
|
def __call__( |
|
self, |
|
uncond: torch.Tensor, |
|
cond: torch.Tensor, |
|
scale: float | torch.Tensor, |
|
) -> torch.Tensor: |
|
if isinstance(scale, torch.Tensor) and len(scale.shape) == 1: |
|
scale = append_dims(scale, cond.ndim) |
|
return uncond + scale * (cond - uncond) |
|
|
|
|
|
class VanillaCFG(object): |
|
def __init__(self): |
|
self.scale_rule = ConstantScaleRule() |
|
self.scale_schedule = ConstantScaleSchedule() |
|
self.guidance = ConstantGuidance() |
|
|
|
def __call__( |
|
self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor |
|
) -> torch.Tensor: |
|
x_u, x_c = x.chunk(2) |
|
scale = self.scale_rule(scale) |
|
scale_value = self.scale_schedule(sigma, scale) |
|
x_pred = self.guidance(x_u, x_c, scale_value) |
|
return x_pred |
|
|
|
def prepare_inputs( |
|
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict |
|
) -> tuple[torch.Tensor, torch.Tensor, dict]: |
|
c_out = dict() |
|
|
|
for k in c: |
|
if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]: |
|
c_out[k] = torch.cat((uc[k], c[k]), 0) |
|
else: |
|
assert c[k] == uc[k] |
|
c_out[k] = c[k] |
|
return torch.cat([x] * 2), torch.cat([s] * 2), c_out |
|
|
|
|
|
class MultiviewCFG(VanillaCFG): |
|
def __init__(self, cfg_min: float = 1.0): |
|
self.scale_min = cfg_min |
|
self.scale_rule = MultiviewScaleRule(min_scale=cfg_min) |
|
self.scale_schedule = ConstantScaleSchedule() |
|
self.guidance = ConstantGuidance() |
|
|
|
def __call__( |
|
self, |
|
x: torch.Tensor, |
|
sigma: float | torch.Tensor, |
|
scale: float | torch.Tensor, |
|
c2w: torch.Tensor, |
|
K: torch.Tensor, |
|
input_frame_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
x_u, x_c = x.chunk(2) |
|
scale = self.scale_rule(scale, c2w, K, input_frame_mask) |
|
scale_value = self.scale_schedule(sigma, scale) |
|
x_pred = self.guidance(x_u, x_c, scale_value) |
|
return x_pred |
|
|
|
|
|
class MultiviewTemporalCFG(MultiviewCFG): |
|
def __init__(self, num_frames: int, cfg_min: float = 1.0): |
|
super().__init__(cfg_min=cfg_min) |
|
|
|
self.num_frames = num_frames |
|
distance_matrix = ( |
|
torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None] |
|
).abs() |
|
self.distance_matrix = distance_matrix |
|
|
|
def __call__( |
|
self, |
|
x: torch.Tensor, |
|
sigma: float | torch.Tensor, |
|
scale: float | torch.Tensor, |
|
c2w: torch.Tensor, |
|
K: torch.Tensor, |
|
input_frame_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
input_frame_mask = rearrange( |
|
input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames |
|
) |
|
min_distance = ( |
|
self.distance_matrix[None].to(x.device) |
|
+ (~input_frame_mask[:, None]) * self.num_frames |
|
).min(-1)[0] |
|
min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1) |
|
scale = min_distance * (scale - self.scale_min) + self.scale_min |
|
scale = rearrange(scale, "b t ... -> (b t) ...") |
|
scale = append_dims(scale, x.ndim) |
|
return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1)) |
|
|
|
|
|
class EulerEDMSampler(object): |
|
def __init__( |
|
self, |
|
discretization: DDPMDiscretization, |
|
guider: VanillaCFG | MultiviewCFG | MultiviewTemporalCFG, |
|
num_steps: int | None = None, |
|
verbose: bool = False, |
|
device: str | torch.device = "cuda", |
|
s_churn=0.0, |
|
s_tmin=0.0, |
|
s_tmax=float("inf"), |
|
s_noise=1.0, |
|
): |
|
self.num_steps = num_steps |
|
self.discretization = discretization |
|
self.guider = guider |
|
self.verbose = verbose |
|
self.device = device |
|
|
|
self.s_churn = s_churn |
|
self.s_tmin = s_tmin |
|
self.s_tmax = s_tmax |
|
self.s_noise = s_noise |
|
|
|
def prepare_sampling_loop( |
|
self, x: torch.Tensor, cond: dict, uc: dict, num_steps: int | None = None |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]: |
|
num_steps = num_steps or self.num_steps |
|
assert num_steps is not None, "num_steps must be specified" |
|
sigmas = self.discretization(num_steps, device=self.device) |
|
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) |
|
num_sigmas = len(sigmas) |
|
s_in = x.new_ones([x.shape[0]]) |
|
return x, s_in, sigmas, num_sigmas, cond, uc |
|
|
|
def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> range | tqdm: |
|
sigma_generator = range(num_sigmas - 1) |
|
if self.verbose and verbose: |
|
sigma_generator = tqdm( |
|
sigma_generator, |
|
total=num_sigmas - 1, |
|
desc="Sampling", |
|
leave=False, |
|
) |
|
return sigma_generator |
|
|
|
def sampler_step( |
|
self, |
|
sigma: torch.Tensor, |
|
next_sigma: torch.Tensor, |
|
denoiser, |
|
x: torch.Tensor, |
|
scale: float | torch.Tensor, |
|
cond: dict, |
|
uc: dict, |
|
gamma: float = 0.0, |
|
**guider_kwargs, |
|
) -> torch.Tensor: |
|
sigma_hat = sigma * (gamma + 1.0) + 1e-6 |
|
|
|
eps = torch.randn_like(x) * self.s_noise |
|
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 |
|
|
|
denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc)) |
|
denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs) |
|
d = to_d(x, sigma_hat, denoised) |
|
dt = append_dims(next_sigma - sigma_hat, x.ndim) |
|
return x + dt * d |
|
|
|
def __call__( |
|
self, |
|
denoiser, |
|
x: torch.Tensor, |
|
scale: float | torch.Tensor, |
|
cond: dict, |
|
uc: dict | None = None, |
|
num_steps: int | None = None, |
|
verbose: bool = True, |
|
**guider_kwargs, |
|
) -> torch.Tensor: |
|
uc = cond if uc is None else uc |
|
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( |
|
x, |
|
cond, |
|
uc, |
|
num_steps, |
|
) |
|
for i in self.get_sigma_gen(num_sigmas, verbose=verbose): |
|
gamma = ( |
|
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) |
|
if self.s_tmin <= sigmas[i] <= self.s_tmax |
|
else 0.0 |
|
) |
|
x = self.sampler_step( |
|
s_in * sigmas[i], |
|
s_in * sigmas[i + 1], |
|
denoiser, |
|
x, |
|
scale, |
|
cond, |
|
uc, |
|
gamma, |
|
**guider_kwargs, |
|
) |
|
return x |
|
|