Spaces:
Runtime error
Runtime error
File size: 4,138 Bytes
7cc4b41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import *
class DDIMScheduler(Scheduler):
def step(
self,
model_output: torch.Tensor,
model_output_type: str,
timestep: Union[torch.Tensor, int],
sample: torch.Tensor,
eta: float = 0.0,
clip_sample: bool = False,
dynamic_threshold: Optional[float] = None,
variance_noise: Optional[torch.Tensor] = None,
) -> SchedulerStepOutput:
# 1. get previous step value (t-1)
if not isinstance(timestep, torch.Tensor):
timestep = torch.tensor(timestep, device=self.device, dtype=torch.int)
idx = timestep.reshape(-1, 1).eq(self.timesteps.reshape(1, -1)).nonzero()[:, 1]
prev_timestep = self.timesteps[idx.add(1).clamp_max(self.num_inference_timesteps - 1)]
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep].reshape(-1, *([1] * (sample.ndim - 1)))
alpha_prod_t_prev = torch.where(idx < self.num_inference_timesteps - 1, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod).reshape(-1, *([1] * (sample.ndim - 1)))
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. compute predicted original sample from predicted noise also called
model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep)
pred_original_sample = model_output_conversion.pred_original_sample
pred_epsilon = model_output_conversion.pred_epsilon
# 4. Clip or threshold "predicted x_0"
if clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
if dynamic_threshold is not None:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = pred_original_sample \
.flatten(1) \
.abs() \
.float() \
.quantile(dynamic_threshold, dim=1) \
.type_as(pred_original_sample) \
.clamp_min(1) \
.view(-1, *([1] * (pred_original_sample.ndim - 1)))
pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
# 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
std_dev_t = eta * variance ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
# 8. add "random noise" if needed.
if eta > 0:
if variance_noise is None:
variance_noise = torch.randn_like(model_output)
prev_sample = prev_sample + std_dev_t * variance_noise
return SchedulerStepOutput(
prev_sample=prev_sample,
pred_original_sample=pred_original_sample)
|