# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .base import * class DDIMScheduler(Scheduler): def step( self, model_output: torch.Tensor, model_output_type: str, timestep: Union[torch.Tensor, int], sample: torch.Tensor, eta: float = 0.0, clip_sample: bool = False, dynamic_threshold: Optional[float] = None, variance_noise: Optional[torch.Tensor] = None, ) -> SchedulerStepOutput: # 1. get previous step value (t-1) if not isinstance(timestep, torch.Tensor): timestep = torch.tensor(timestep, device=self.device, dtype=torch.int) idx = timestep.reshape(-1, 1).eq(self.timesteps.reshape(1, -1)).nonzero()[:, 1] prev_timestep = self.timesteps[idx.add(1).clamp_max(self.num_inference_timesteps - 1)] # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep].reshape(-1, *([1] * (sample.ndim - 1))) alpha_prod_t_prev = torch.where(idx < self.num_inference_timesteps - 1, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod).reshape(-1, *([1] * (sample.ndim - 1))) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 3. compute predicted original sample from predicted noise also called model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep) pred_original_sample = model_output_conversion.pred_original_sample pred_epsilon = model_output_conversion.pred_epsilon # 4. Clip or threshold "predicted x_0" if clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon if dynamic_threshold is not None: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = pred_original_sample \ .flatten(1) \ .abs() \ .float() \ .quantile(dynamic_threshold, dim=1) \ .type_as(pred_original_sample) \ .clamp_min(1) \ .view(-1, *([1] * (pred_original_sample.ndim - 1))) pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon # 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction # 8. add "random noise" if needed. if eta > 0: if variance_noise is None: variance_noise = torch.randn_like(model_output) prev_sample = prev_sample + std_dev_t * variance_noise return SchedulerStepOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample)