File size: 4,138 Bytes
7cc4b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import *


class DDIMScheduler(Scheduler):
    def step(
        self,
        model_output: torch.Tensor,
        model_output_type: str,
        timestep: Union[torch.Tensor, int],
        sample: torch.Tensor,
        eta: float = 0.0,
        clip_sample: bool = False,
        dynamic_threshold: Optional[float] = None,
        variance_noise: Optional[torch.Tensor] = None,
    ) -> SchedulerStepOutput:
        # 1. get previous step value (t-1)
        if not isinstance(timestep, torch.Tensor):
            timestep = torch.tensor(timestep, device=self.device, dtype=torch.int)

        idx = timestep.reshape(-1, 1).eq(self.timesteps.reshape(1, -1)).nonzero()[:, 1]
        prev_timestep = self.timesteps[idx.add(1).clamp_max(self.num_inference_timesteps - 1)]

        # 2. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[timestep].reshape(-1, *([1] * (sample.ndim - 1)))
        alpha_prod_t_prev = torch.where(idx < self.num_inference_timesteps - 1, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod).reshape(-1, *([1] * (sample.ndim - 1)))
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        # 3. compute predicted original sample from predicted noise also called
        model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep)
        pred_original_sample = model_output_conversion.pred_original_sample
        pred_epsilon = model_output_conversion.pred_epsilon

        # 4. Clip or threshold "predicted x_0"
        if clip_sample:
            pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
            pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon

        if dynamic_threshold is not None:
            # Dynamic thresholding in https://arxiv.org/abs/2205.11487
            dynamic_max_val = pred_original_sample \
                .flatten(1) \
                .abs() \
                .float() \
                .quantile(dynamic_threshold, dim=1) \
                .type_as(pred_original_sample) \
                .clamp_min(1) \
                .view(-1, *([1] * (pred_original_sample.ndim - 1)))
            pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
            pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon

        # 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
        std_dev_t = eta * variance ** (0.5)

        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        # 8. add "random noise" if needed.
        if eta > 0:
            if variance_noise is None:
                variance_noise = torch.randn_like(model_output)
            prev_sample = prev_sample + std_dev_t * variance_noise

        return SchedulerStepOutput(
            prev_sample=prev_sample,
            pred_original_sample=pred_original_sample)