CoreloneH commited on
Commit
7cc4b41
·
1 Parent(s): 5207cf4

Add application file

Browse files
Files changed (47) hide show
  1. configs/realcustom_sigdino_highres.json +119 -0
  2. configs/realcustom_sigdino_highres_shallow.json +114 -0
  3. inference/__pycache__/inference_utils.cpython-310.pyc +0 -0
  4. inference/__pycache__/mask_generation.cpython-310.pyc +0 -0
  5. inference/__pycache__/pipeline.cpython-310.pyc +0 -0
  6. inference/app.py +82 -0
  7. inference/inference_single_image.py +317 -0
  8. inference/inference_single_image.sh +55 -0
  9. inference/inference_utils.py +76 -0
  10. inference/mask_generation.py +114 -0
  11. inference/pipeline.py +359 -0
  12. models/__pycache__/attention_custom.cpython-310.pyc +0 -0
  13. models/__pycache__/attention_processor_custom_cross.cpython-310.pyc +0 -0
  14. models/__pycache__/base_vision.cpython-310.pyc +0 -0
  15. models/__pycache__/dino.cpython-310.pyc +0 -0
  16. models/__pycache__/image_encoder_siglipdino_shallowdeep.cpython-310.pyc +0 -0
  17. models/__pycache__/projectors.cpython-310.pyc +0 -0
  18. models/__pycache__/sigclip.cpython-310.pyc +0 -0
  19. models/__pycache__/text.cpython-310.pyc +0 -0
  20. models/__pycache__/transformer_2d_custom.cpython-310.pyc +0 -0
  21. models/__pycache__/unet_2d_blocks_custom.cpython-310.pyc +0 -0
  22. models/__pycache__/unet_2d_condition_custom.cpython-310.pyc +0 -0
  23. models/__pycache__/vae.cpython-310.pyc +0 -0
  24. models/attention_custom.py +425 -0
  25. models/attention_processor_custom_cross.py +1778 -0
  26. models/base_vision.py +227 -0
  27. models/dino.py +203 -0
  28. models/image_encoder_siglipdino_shallowdeep.py +162 -0
  29. models/projectors.py +150 -0
  30. models/sigclip.py +159 -0
  31. models/text.py +113 -0
  32. models/transformer_2d_custom.py +388 -0
  33. models/unet_2d_blocks_custom.py +0 -0
  34. models/unet_2d_condition_custom.py +1059 -0
  35. models/vae.py +36 -0
  36. prompts/validation_negative.txt +1 -0
  37. requirements.txt +34 -0
  38. schedulers/__pycache__/base.cpython-310.pyc +0 -0
  39. schedulers/__pycache__/ddim.cpython-310.pyc +0 -0
  40. schedulers/__pycache__/dpm_s.cpython-310.pyc +0 -0
  41. schedulers/__pycache__/utils.cpython-310.pyc +0 -0
  42. schedulers/base.py +133 -0
  43. schedulers/ddim.py +85 -0
  44. schedulers/dpm_m.py +412 -0
  45. schedulers/dpm_s.py +243 -0
  46. schedulers/utils.py +124 -0
  47. utils.py +55 -0
configs/realcustom_sigdino_highres.json ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "act_fn": "silu",
3
+ "addition_embed_type": "text_time",
4
+ "addition_embed_type_num_heads": 64,
5
+ "addition_time_embed_dim": 256,
6
+ "attention_head_dim": [
7
+ 5,
8
+ 10,
9
+ 20
10
+ ],
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280
15
+ ],
16
+ "center_input_sample": false,
17
+ "class_embed_type": null,
18
+ "class_embeddings_concat": false,
19
+ "conv_in_kernel": 3,
20
+ "conv_out_kernel": 3,
21
+ "cross_attention_dim": 2048,
22
+ "cross_attention_norm": null,
23
+ "down_block_types": [
24
+ "DownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D"
27
+ ],
28
+ "downsample_padding": 1,
29
+ "dual_cross_attention": false,
30
+ "encoder_hid_dim": null,
31
+ "encoder_hid_dim_type": null,
32
+ "flip_sin_to_cos": true,
33
+ "freq_shift": 0,
34
+ "in_channels": 4,
35
+ "layers_per_block": 2,
36
+ "mid_block_only_cross_attention": null,
37
+ "mid_block_scale_factor": 1,
38
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
39
+ "norm_eps": 1e-05,
40
+ "norm_num_groups": 32,
41
+ "num_attention_heads": null,
42
+ "num_class_embeds": null,
43
+ "only_cross_attention": false,
44
+ "out_channels": 4,
45
+ "projection_class_embeddings_input_dim": 2816,
46
+ "resnet_out_scale_factor": 1.0,
47
+ "resnet_skip_time_act": false,
48
+ "resnet_time_scale_shift": "default",
49
+ "sample_size": 128,
50
+ "time_cond_proj_dim": null,
51
+ "time_embedding_act_fn": null,
52
+ "time_embedding_dim": null,
53
+ "time_embedding_type": "positional",
54
+ "timestep_post_act": null,
55
+ "transformer_layers_per_block": [
56
+ 1,
57
+ 2,
58
+ 10
59
+ ],
60
+ "up_block_types": [
61
+ "CrossAttnUpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "UpBlock2D"
64
+ ],
65
+ "upcast_attention": false,
66
+ "use_linear_projection": true,
67
+ "image_ref_processor_config": {
68
+ "target": "utils.image_ref_processor.default.NaiveResizeProcessor",
69
+ "params": {
70
+ "target_image_size": 768,
71
+ "resize_mode": "resize",
72
+ "crop_min_ratio": 1.0,
73
+ "crop_max_ratio": 1.0,
74
+ "image_dropout": 0.1
75
+ }
76
+ },
77
+ "image_ref_processor_input_keys": [
78
+ "image_ref"
79
+ ],
80
+ "vision_model_config": {
81
+ "vision_model_config": {
82
+ "target": "models.image_encoder_siglipdino_shallowdeep.ShallowDeepPatchfySiglipDinoEncoder",
83
+ "params": {
84
+ "siglip_config": {
85
+ "backbone_name_or_path": "vit_so400m_patch14_siglip_384",
86
+ "image_resize_strategy": "resize-naive",
87
+ "default_image_size": 384,
88
+ "feature_index": [
89
+ 25
90
+ ]
91
+ },
92
+ "dino_config": {
93
+ "backbone_name_or_path": "vit_large_patch14_reg4_dinov2.lvd142m",
94
+ "image_resize_strategy": "resize-naive",
95
+ "default_image_size": 384,
96
+ "feature_index": [
97
+ 22
98
+ ]
99
+ },
100
+ "patchfy_scale": 2,
101
+ "default_image_size": 384
102
+ }
103
+ }
104
+ },
105
+ "image_prompt_settings": {
106
+ "vision_projection_type": "custom",
107
+ "vision_projection_config": {
108
+ "target": "models.projectors.ProjectorHighResMinAttn",
109
+ "params": {
110
+ "vision_dim": 2176,
111
+ "out_dim": 2048,
112
+ "dim_head": 64,
113
+ "adaptive_scale": false
114
+ }
115
+ },
116
+ "image_prompt_mode": "naive",
117
+ "cross_attention_id": 70
118
+ }
119
+ }
configs/realcustom_sigdino_highres_shallow.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "act_fn": "silu",
3
+ "addition_embed_type": "text_time",
4
+ "addition_embed_type_num_heads": 64,
5
+ "addition_time_embed_dim": 256,
6
+ "attention_head_dim": [
7
+ 5,
8
+ 10,
9
+ 20
10
+ ],
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280
15
+ ],
16
+ "center_input_sample": false,
17
+ "class_embed_type": null,
18
+ "class_embeddings_concat": false,
19
+ "conv_in_kernel": 3,
20
+ "conv_out_kernel": 3,
21
+ "cross_attention_dim": 2048,
22
+ "cross_attention_norm": null,
23
+ "down_block_types": [
24
+ "DownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D"
27
+ ],
28
+ "downsample_padding": 1,
29
+ "dual_cross_attention": false,
30
+ "encoder_hid_dim": null,
31
+ "encoder_hid_dim_type": null,
32
+ "flip_sin_to_cos": true,
33
+ "freq_shift": 0,
34
+ "in_channels": 4,
35
+ "layers_per_block": 2,
36
+ "mid_block_only_cross_attention": null,
37
+ "mid_block_scale_factor": 1,
38
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
39
+ "norm_eps": 1e-05,
40
+ "norm_num_groups": 32,
41
+ "num_attention_heads": null,
42
+ "num_class_embeds": null,
43
+ "only_cross_attention": false,
44
+ "out_channels": 4,
45
+ "projection_class_embeddings_input_dim": 2816,
46
+ "resnet_out_scale_factor": 1.0,
47
+ "resnet_skip_time_act": false,
48
+ "resnet_time_scale_shift": "default",
49
+ "sample_size": 128,
50
+ "time_cond_proj_dim": null,
51
+ "time_embedding_act_fn": null,
52
+ "time_embedding_dim": null,
53
+ "time_embedding_type": "positional",
54
+ "timestep_post_act": null,
55
+ "transformer_layers_per_block": [
56
+ 1,
57
+ 2,
58
+ 10
59
+ ],
60
+ "up_block_types": [
61
+ "CrossAttnUpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "UpBlock2D"
64
+ ],
65
+ "upcast_attention": false,
66
+ "use_linear_projection": true,
67
+
68
+ "image_ref_processor_input_keys": ["image_ref"],
69
+ "vision_model_config": {
70
+ "vision_model_config": {
71
+ "target": "models.image_encoder_siglipdino_shallowdeep.ShallowDeepPatchfySiglipDinoEncoder_v2",
72
+ "params": {
73
+ "siglip_config": {
74
+ "backbone_name_or_path": "vit_so400m_patch14_siglip_384",
75
+ "image_resize_strategy": "resize-naive",
76
+ "default_image_size": 384,
77
+ "feature_index": [
78
+ 7,
79
+ 13,
80
+ 19,
81
+ 25
82
+ ]
83
+ },
84
+ "dino_config": {
85
+ "backbone_name_or_path": "vit_large_patch14_reg4_dinov2.lvd142m",
86
+ "image_resize_strategy": "resize-naive",
87
+ "default_image_size": 384,
88
+ "feature_index": [
89
+ 4,
90
+ 10,
91
+ 16,
92
+ 22
93
+ ]
94
+ },
95
+ "patchfy_scale": 2,
96
+ "default_image_size": 384
97
+ }
98
+ }
99
+ },
100
+
101
+ "image_prompt_settings": {
102
+ "vision_projection_type": "custom",
103
+ "vision_projection_config": {
104
+ "target": "models.projectors.ProjectorHighResShallowMinAttnV1",
105
+ "params": {
106
+ "vision_dim": 2176,
107
+ "out_dim": 2048,
108
+ "dim_head": 64
109
+ }
110
+ },
111
+
112
+ "image_prompt_mode": "naive"
113
+ }
114
+ }
inference/__pycache__/inference_utils.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
inference/__pycache__/mask_generation.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
inference/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
inference/app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import gradio as gr
16
+
17
+ from inference.pipeline import RealCustomInferencePipeline
18
+
19
+ def create_demo():
20
+ pipeline = RealCustomInferencePipeline(
21
+ unet_config="configs/realcustom_sigdino_highres.json",
22
+ unet_checkpoint="ckpts/sdxl/unet/sdxl-unet.bin",
23
+ realcustom_checkpoint="ckpts/realcustom/RealCustom_highres.pth",
24
+ vae_config="ckpts/sdxl/vae/sdxl.json",
25
+ vae_checkpoint="ckpts/sdxl/vae/sdxl-vae.pth",
26
+ )
27
+
28
+ badges_text = r"""
29
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
30
+ <a href="https://corleone-huang.github.io/RealCustom_plus_plus/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-RealCustom-yellow"></a>
31
+ <a href="https://arxiv.org/pdf/2408.09744?"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-RealCustom-b31b1b.svg"></a>
32
+ </div>
33
+ """.strip()
34
+
35
+ with gr.Blocks() as demo:
36
+ gr.Markdown(f"# RealCustom")
37
+ gr.Markdown(badges_text)
38
+ with gr.Row():
39
+ with gr.Column():
40
+ prompt = gr.Textbox(label="Prompt", value="")
41
+ target_phrase = gr.Textbox(label="Target Phrase", value="")
42
+ with gr.Row():
43
+ image_prompt = gr.Image(label="Ref Img", visible=True, interactive=True, type="pil")
44
+
45
+ with gr.Row():
46
+ with gr.Column():
47
+ width = gr.Slider(512, 2048, 1024, step=16, label="Gneration Width")
48
+ height = gr.Slider(512, 2048, 1024, step=16, label="Gneration Height")
49
+
50
+ with gr.Accordion("Advanced Options", open=False):
51
+ with gr.Row():
52
+ guidance = gr.Slider(1.0, 15, 3.5, step=0.5, label="Guidance Scale", interactive=True)
53
+ mask_scope = gr.Slider(0.05, 1.0, 0.2, step=0.05, label="Mask Scope", interactive=True)
54
+ seed = gr.Number(0, label="Seed (-1 for random)")
55
+ num = gr.Number(4, label="Generation Number")
56
+ new_unet_local_path = gr.Textbox(label="New Unet Local Path", value="")
57
+ new_realcustom_local_path = gr.Textbox(label="New RealCustom Local Path", value="")
58
+
59
+ generate_btn = gr.Button("Generate")
60
+
61
+ with gr.Column():
62
+ output_image = gr.Image(label="Generated Image")
63
+ output_mask = gr.Image(label="Guidance Mask")
64
+
65
+ inputs = [
66
+ prompt, image_prompt, target_phrase,
67
+ height, width, guidance, seed, num,
68
+ mask_scope,
69
+ new_unet_local_path, new_realcustom_local_path,
70
+ ]
71
+ generate_btn.click(
72
+ fn=pipeline.generation,
73
+ inputs=inputs,
74
+ outputs=[output_image, output_mask],
75
+ )
76
+
77
+ return demo
78
+
79
+
80
+ if __name__ == "__main__":
81
+ demo = create_demo()
82
+ demo.launch(server_name='0.0.0.0', server_port=7860)
inference/inference_single_image.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import torch
17
+ import json
18
+ import os
19
+ import torchvision
20
+ from torchvision.utils import make_grid
21
+ from torchvision.transforms.functional import to_pil_image
22
+ from tqdm import tqdm
23
+ from PIL import Image
24
+
25
+ from models.text import TextModel
26
+ from models.vae import AutoencoderKL
27
+
28
+ from models.unet_2d_condition_custom import UNet2DConditionModel as UNet2DConditionModelDiffusers
29
+ from schedulers.ddim import DDIMScheduler
30
+ from schedulers.dpm_s import DPMSolverSingleStepScheduler
31
+ from schedulers.utils import get_betas
32
+
33
+ from inference_utils import find_phrase_positions_in_text, classifier_free_guidance_image_prompt_cascade
34
+ from mask_generation import mask_generation
35
+ from utils import instantiate_from_config
36
+
37
+ # Argument parser
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--width", type=int, default=512)
40
+ parser.add_argument("--height", type=int, default=512)
41
+
42
+ parser.add_argument("--samples_per_prompt", type=int, required=True)
43
+ parser.add_argument("--nrow", type=int, default=4)
44
+ parser.add_argument("--sample_steps", type=int, required=True)
45
+ parser.add_argument("--schedule_type", type=str, default="squared_linear") # default, `squared_linear
46
+ parser.add_argument("--scheduler_type", type=str, default="dpm", choices=["ddim", "dpm"]) # default, "dpm"
47
+ parser.add_argument("--schedule_shift_snr", type=float, default=1) # default, 1
48
+
49
+ parser.add_argument("--text_encoder_variant", type=str, nargs="+")
50
+ parser.add_argument("--vae_config", type=str, default="configs/vae.json") # default
51
+ parser.add_argument("--vae_checkpoint", type=str, required=True)
52
+ parser.add_argument("--unet_config", type=str, required=True)
53
+ parser.add_argument("--unet_checkpoint", type=str, required=True)
54
+ parser.add_argument("--unet_checkpoint_base_model", type=str, default="")
55
+ parser.add_argument("--unet_prediction", type=str, choices=DDIMScheduler.prediction_types, default="epsilon") # default, "epsilon"
56
+
57
+ parser.add_argument("--negative_prompt", type=str, default="prompts/validation_negative.txt") # default
58
+
59
+ parser.add_argument("--compile", action="store_true", default=False)
60
+ parser.add_argument("--output_dir", type=str, required=True)
61
+
62
+ parser.add_argument("--guidance_weight", type=float, default=7.5)
63
+ parser.add_argument("--seed", type=int, default=666)
64
+ parser.add_argument("--device", type=str, default="cuda")
65
+
66
+ parser.add_argument("--text_prompt", type=str, required=True)
67
+ parser.add_argument("--image_prompt_path", type=str, required=True)
68
+ parser.add_argument("--target_phrase", type=str, required=True)
69
+ parser.add_argument("--mask_scope", type=float, default=0.20)
70
+ parser.add_argument("--mask_strategy", type=str, nargs="+", default=["max_norm"])
71
+ parser.add_argument("--mask_reused_step", type=int, default=12)
72
+
73
+ args = parser.parse_args()
74
+
75
+ # Initialize unet model
76
+ with open(args.unet_config) as unet_config_file:
77
+ unet_config = json.load(unet_config_file)
78
+
79
+ # Settings for image encoder
80
+ vision_model_config = unet_config.pop("vision_model_config", None)
81
+ args.vision_model_config = vision_model_config.pop("vision_model_config", None)
82
+
83
+ unet_type = unet_config.pop("type", None)
84
+ unet_model = UNet2DConditionModelDiffusers(**unet_config)
85
+
86
+ unet_model.eval().to(args.device)
87
+ unet_model.load_state_dict(torch.load(args.unet_checkpoint, map_location=args.device), strict=False)
88
+ print("loading unet model finished.")
89
+
90
+ if args.unet_checkpoint_base_model != "":
91
+ if "safetensors" in args.unet_checkpoint_base_model:
92
+ from safetensors import safe_open
93
+ tensors = {}
94
+ with safe_open(args.unet_checkpoint_base_model, framework="pt", device='cpu') as f:
95
+ for k in f.keys():
96
+ new_k = k.replace("model.diffusion_model.", "")
97
+ tensors[k] = f.get_tensor(k)
98
+ unet_model.load_state_dict(tensors, strict=False)
99
+ else:
100
+ unet_model.load_state_dict(torch.load(args.unet_checkpoint_base_model, map_location=args.device), strict=False)
101
+ unet_model = torch.compile(unet_model, disable=not args.compile)
102
+ print("loading unet base model finished.")
103
+
104
+ # Initialize vae model
105
+ with open(args.vae_config) as vae_config_file:
106
+ vae_config = json.load(vae_config_file)
107
+ vae_downsample_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) # 2 ** 3 = 8
108
+ vae_model = AutoencoderKL(**vae_config)
109
+ vae_model.eval().to(args.device)
110
+ vae_model.load_state_dict(torch.load(args.vae_checkpoint, map_location=args.device))
111
+ vae_decoder = torch.compile(lambda x: vae_model.decode(x / vae_model.scaling_factor).sample.clip(-1, 1), disable=not args.compile)
112
+ vae_encoder = torch.compile(lambda x: vae_model.encode(x).latent_dist.mode().mul_(vae_model.scaling_factor), disable=not args.compile)
113
+ print("loading vae finished.")
114
+
115
+ # Initialize ddim scheduler
116
+ ddim_train_steps = 1000
117
+ ddim_betas = get_betas(name=args.schedule_type, num_steps=ddim_train_steps, shift_snr=args.schedule_shift_snr, terminal_pure_noise=False)
118
+ scheduler_class = DPMSolverSingleStepScheduler if args.scheduler_type == 'dpm' else DDIMScheduler
119
+ scheduler = scheduler_class(betas=ddim_betas, num_train_timesteps=ddim_train_steps, num_inference_timesteps=args.sample_steps, device=args.device)
120
+ infer_timesteps = scheduler.timesteps
121
+
122
+ # Initialize text model
123
+ text_model = TextModel(args.text_encoder_variant, ["penultimate_nonorm"])
124
+ text_model.eval().to(args.device)
125
+ print("loading text model finished.")
126
+
127
+ # Initialize image model.
128
+ vision_model = instantiate_from_config(args.vision_model_config)
129
+ vision_model = vision_model.eval().to(args.device)
130
+ print("loading image model finished.")
131
+
132
+ negative_prompt = ""
133
+ if args.negative_prompt:
134
+ with open(args.negative_prompt) as f:
135
+ negative_prompt = f.read().strip()
136
+
137
+ image_metadata_validate = torch.tensor(
138
+ data=[
139
+ args.width, # original_height
140
+ args.height, # original_width
141
+ 0, # coordinate top
142
+ 0, # coordinate left
143
+ args.width, # target_height
144
+ args.height, # target_width
145
+ ],
146
+ device=args.device,
147
+ dtype=torch.float32
148
+ ).view(1, -1).repeat(args.samples_per_prompt, 1)
149
+
150
+ # Create output directory
151
+ os.makedirs(args.output_dir, exist_ok=True)
152
+ args.output_image_grid_dir = os.path.join(args.output_dir, "images_grid")
153
+ args.output_image_dir = os.path.join(args.output_dir, "images")
154
+ args.output_mask_grid_dir = os.path.join(args.output_dir, "masks_grid")
155
+ args.output_mask_dir = os.path.join(args.output_dir, "masks")
156
+ os.makedirs(args.output_image_grid_dir, exist_ok=True)
157
+ os.makedirs(args.output_image_dir, exist_ok=True)
158
+ os.makedirs(args.output_mask_grid_dir, exist_ok=True)
159
+ os.makedirs(args.output_mask_dir, exist_ok=True)
160
+
161
+ with torch.no_grad():
162
+ # Prepare negative prompt.
163
+ if args.guidance_weight != 1:
164
+ text_negative_output = text_model(negative_prompt)
165
+
166
+ positive_prompt = args.text_prompt
167
+ positive_promt_image_path = args.image_prompt_path
168
+ target_phrase = args.target_phrase
169
+
170
+ # Compute target phrases
171
+ target_token = torch.zeros(1, 77).to(args.device)
172
+ positions = find_phrase_positions_in_text(positive_prompt, target_phrase)
173
+ for position in positions:
174
+ prompt_before = positive_prompt[:position] # NOTE We do not need -1 here because the SDXL text encoder does not encode the trailing space.
175
+ prompt_include = positive_prompt[:position+len(target_phrase)]
176
+ print("prompt before: ", prompt_before, ", prompt_include: ", prompt_include)
177
+ prompt_before_length = text_model.get_vaild_token_length(prompt_before) + 1
178
+ prompt_include_length = text_model.get_vaild_token_length(prompt_include) + 1
179
+ print("prompt_before_length: ", prompt_before_length, ", prompt_include_length: ", prompt_include_length)
180
+ target_token[:, prompt_before_length:prompt_include_length] = 1
181
+
182
+ # Text used for progress bar
183
+ pbar_text = positive_prompt[:40]
184
+
185
+ # Compute text embeddings
186
+ text_positive_output = text_model(positive_prompt)
187
+ text_positive_embeddings = text_positive_output.embeddings.repeat_interleave(args.samples_per_prompt, dim=0)
188
+ text_positive_pooled = text_positive_output.pooled[-1].repeat_interleave(args.samples_per_prompt, dim=0)
189
+ if args.guidance_weight != 1:
190
+ text_negative_embeddings = text_negative_output.embeddings.repeat_interleave(args.samples_per_prompt, dim=0)
191
+ text_negative_pooled = text_negative_output.pooled[-1].repeat_interleave(args.samples_per_prompt, dim=0)
192
+
193
+ # Compute image embeddings
194
+ positive_image = Image.open(positive_promt_image_path).convert("RGB")
195
+ positive_image = torchvision.transforms.ToTensor()(positive_image)
196
+
197
+ positive_image = positive_image.unsqueeze(0).repeat_interleave(args.samples_per_prompt, dim=0)
198
+ positive_image = torch.nn.functional.interpolate(
199
+ positive_image,
200
+ size=(768, 768),
201
+ mode="bilinear",
202
+ align_corners=False
203
+ )
204
+ negative_image = torch.zeros_like(positive_image)
205
+ print(positive_image.size(), negative_image.size())
206
+ positive_image = positive_image.to(args.device)
207
+ negative_image = negative_image.to(args.device)
208
+
209
+ positive_image_dict = {"image_ref": positive_image}
210
+ positive_image_output = vision_model(positive_image_dict, device=args.device)
211
+
212
+ negative_image_dict = {"image_ref": negative_image}
213
+ negative_image_output = vision_model(negative_image_dict, device=args.device)
214
+
215
+ # Initialize latent with input latent + noise (i2i) / pure noise (t2i)
216
+ latent = torch.randn(
217
+ size=[
218
+ args.samples_per_prompt,
219
+ vae_config["latent_channels"],
220
+ args.height // vae_downsample_factor,
221
+ args.width // vae_downsample_factor
222
+ ],
223
+ device=args.device,
224
+ generator=torch.Generator(args.device).manual_seed(args.seed))
225
+ target_h = (args.height // vae_downsample_factor) // 2
226
+ target_w = (args.width // vae_downsample_factor) // 2
227
+
228
+ # Real Reverse diffusion process.
229
+ text2image_crossmap_2d_all_timesteps_list = []
230
+ current_step = 0
231
+ for timestep in tqdm(iterable=infer_timesteps, desc=f"[{pbar_text}]", dynamic_ncols=True):
232
+ if current_step < args.mask_reused_step:
233
+ pred_cond, pred_cond_dict = unet_model(
234
+ sample=latent,
235
+ timestep=timestep,
236
+ encoder_hidden_states=text_positive_embeddings,
237
+ encoder_attention_mask=None,
238
+ added_cond_kwargs=dict(
239
+ text_embeds=text_positive_pooled,
240
+ time_ids=image_metadata_validate
241
+ ),
242
+ vision_input_dict=None,
243
+ vision_guided_mask=None,
244
+ return_as_origin=False,
245
+ return_text2image_mask=True,
246
+ )
247
+ crossmap_2d_avg = mask_generation(
248
+ crossmap_2d_list=pred_cond_dict["text2image_crossmap_2d"], selfmap_2d_list=pred_cond_dict.get("self_attention_map", []),
249
+ target_token=target_token, mask_scope=args.mask_scope,
250
+ mask_target_h=target_h, mask_target_w=target_w, mask_mode=args.mask_strategy,
251
+ )
252
+ else:
253
+ # using previous step's mask
254
+ crossmap_2d_avg = text2image_crossmap_2d_all_timesteps_list[-1].squeeze(1)
255
+ if crossmap_2d_avg.dim() == 5: # Means that each layer uses a separate mask weight.
256
+ text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.mean(dim=2).unsqueeze(1))
257
+ else:
258
+ text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.unsqueeze(1))
259
+
260
+ pred_cond, pred_cond_dict = unet_model(
261
+ sample=latent,
262
+ timestep=timestep,
263
+ encoder_hidden_states=text_positive_embeddings,
264
+ encoder_attention_mask=None,
265
+ added_cond_kwargs=dict(
266
+ text_embeds=text_positive_pooled,
267
+ time_ids=image_metadata_validate
268
+ ),
269
+ vision_input_dict=positive_image_output,
270
+ vision_guided_mask=crossmap_2d_avg,
271
+ return_as_origin=False,
272
+ return_text2image_mask=True,
273
+ multiple_reference_image=False
274
+ )
275
+
276
+ crossmap_2d_avg_neg = crossmap_2d_avg.mean(dim=1, keepdim=True)
277
+ pred_negative, pred_negative_dict = unet_model(
278
+ sample=latent,
279
+ timestep=timestep,
280
+ encoder_hidden_states=text_negative_embeddings,
281
+ encoder_attention_mask=None,
282
+ added_cond_kwargs=dict(
283
+ text_embeds=text_negative_pooled,
284
+ time_ids=image_metadata_validate
285
+ ),
286
+ vision_input_dict=negative_image_output,
287
+ vision_guided_mask=crossmap_2d_avg,
288
+ return_as_origin=False,
289
+ return_text2image_mask=True,
290
+ multiple_reference_image=False
291
+ )
292
+
293
+ pred = classifier_free_guidance_image_prompt_cascade(
294
+ pred_t_cond=None, pred_ti_cond=pred_cond, pred_uncond=pred_negative,
295
+ guidance_weight_t=args.guidance_weight, guidance_weight_i=args.guidance_weight,
296
+ guidance_stdev_rescale_factor=0, cfg_rescale_mode="naive_global_direct"
297
+ )
298
+ step = scheduler.step(
299
+ model_output=pred,
300
+ model_output_type=args.unet_prediction,
301
+ timestep=timestep,
302
+ sample=latent)
303
+
304
+ latent = step.prev_sample
305
+
306
+ current_step += 1
307
+
308
+ sample = vae_decoder(step.pred_original_sample)
309
+
310
+ # save each image
311
+ for sample_i in range(sample.size(0)):
312
+ sample_i_image = torch.clamp(sample[sample_i] * 0.5 + 0.5, min=0, max=1).float()
313
+ to_pil_image(sample_i_image).save(args.output_image_dir + "/output_{}.jpg".format(sample_i))
314
+
315
+ # save grid images
316
+ sample = make_grid(sample, normalize=True, value_range=(-1, 1), nrow=args.nrow).float()
317
+ to_pil_image(sample).save(args.output_image_grid_dir + "/grid_image.jpg")
inference/inference_single_image.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/bin/bash
2
+ # ----------------------------------------------------------------------------------------------------
3
+ HEIGHT="1024" # Base height.
4
+ WIDTH="1024" # Base width.
5
+ SAMPLES_PER_PROMPT="4" # Num of samples to generate per prompt.
6
+ NROW="2" # Grid images per row.
7
+
8
+ OUTPUT_DIR="outputs/test"
9
+
10
+ # ----------------------------------------------------------------------------------------------------
11
+ MASK_TYPE=("max_norm")
12
+ # usually:"max_norm" "crossmap_32" "selfmap_min_max_per_channel" "selfmap_64"
13
+ # [
14
+ # "max_norm", "min_max_norm", "binary", "min_max_per_channel", "decoder_map"
15
+ # "selfmap", "selfmap_min_max_per_channel" "selfmap_64"
16
+
17
+ # ]
18
+
19
+ CFG=7.5
20
+ STEPS=25
21
+ mask_reused_step=12
22
+
23
+ UNET_CONFIG="configs/realcustom_sigdino_highres.json"
24
+ UNET_CHECKPOINT="ckpts/realcustom/RealCustom_0025000_ema_highres.pth"
25
+ UNET_CHECKPOINT_BASE_MODEL="ckpts/sdxl/unet/general_v1-3_sdxl_03.pth"
26
+ # ----------------------------------------------------------------------------------------------------
27
+ CLIP1_DIR="ckpts/sdxl/clip-sdxl-1"
28
+ CLIP2_DIR="ckpts/sdxl/clip-sdxl-2"
29
+ VAE_CONFIG_PATH="ckpts/sdxl/vae/sdxl.json"
30
+ VAE_CHECKPOINT_PATH="ckpts/sdxl/vae/sdxl-vae.pth"
31
+
32
+
33
+ echo "Start inference"
34
+ python3 inference/inference_single_image.py \
35
+ --width $WIDTH \
36
+ --height $HEIGHT \
37
+ --samples_per_prompt $SAMPLES_PER_PROMPT \
38
+ --nrow $NROW \
39
+ --sample_steps $STEPS \
40
+ --guidance_weight $CFG \
41
+ --text_encoder_variant \
42
+ $CLIP1_DIR \
43
+ $CLIP2_DIR \
44
+ --unet_config $UNET_CONFIG \
45
+ --unet_checkpoint $UNET_CHECKPOINT \
46
+ --unet_checkpoint_base_model $UNET_CHECKPOINT_BASE_MODEL \
47
+ --vae_config $VAE_CONFIG_PATH \
48
+ --vae_checkpoint $VAE_CHECKPOINT_PATH \
49
+ --output_dir $OUTPUT_DIR \
50
+ --seed 2024 \
51
+ --text_prompt "the figurine is flying in the sky" \
52
+ --image_prompt_path "prompts/figurine.png" \
53
+ --target_phrase "figurine" \
54
+ --mask_scope 0.25 \
55
+ --mask_strategy ${MASK_TYPE[*]}
inference/inference_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ def find_phrase_positions_in_text(text, phrase):
16
+ """
17
+ Return the position of the first character of the phrase in the text.
18
+ """
19
+
20
+ position = -1
21
+ positions = []
22
+ while True:
23
+ position = text.find(phrase, position + 1)
24
+ if position == -1:
25
+ break
26
+ positions.append(position)
27
+ return positions
28
+
29
+ def classifier_free_guidance_image_prompt_cascade(
30
+ pred_t_cond, pred_ti_cond, pred_uncond, guidance_weight_t=7.5, guidance_weight_i=7.5,
31
+ guidance_stdev_rescale_factor=0.7, cfg_rescale_mode="none", super_cross_mask=None
32
+ ):
33
+
34
+ if cfg_rescale_mode == "none":
35
+ pred = pred_uncond + guidance_weight_t * (pred_t_cond - pred_uncond) + guidance_weight_i * (pred_ti_cond - pred_t_cond)
36
+ elif cfg_rescale_mode == "none_direct":
37
+ pred = pred_uncond + guidance_weight_i * (pred_ti_cond - pred_uncond)
38
+ elif cfg_rescale_mode == "naive":
39
+ assert super_cross_mask is not None
40
+ pred_std_t_before = pred_t_cond.std([1, 2, 3], keepdim=True)
41
+ pred_std_ti_before = pred_ti_cond.std([1, 2, 3], keepdim=True)
42
+
43
+ pred = pred_uncond + guidance_weight_t * (pred_t_cond - pred_uncond) + guidance_weight_i * (pred_ti_cond - pred_t_cond)
44
+
45
+ pred_std_after = pred.std([1, 2, 3], keepdim=True)
46
+
47
+ pred_rescale_t_factor = guidance_stdev_rescale_factor * (pred_std_t_before / pred_std_after) + (1 - guidance_stdev_rescale_factor)
48
+ pred_rescale_ti_factor = guidance_stdev_rescale_factor * (pred_std_ti_before / pred_std_after) + (1 - guidance_stdev_rescale_factor)
49
+
50
+ pred_ti = pred * super_cross_mask
51
+ pred_t = pred * (1 - super_cross_mask)
52
+ pred = pred_ti * pred_rescale_ti_factor + pred_t * pred_rescale_t_factor
53
+ elif cfg_rescale_mode == "naive_global":
54
+ pred_std_ti_before = pred_ti_cond.std([1, 2, 3], keepdim=True)
55
+
56
+ pred = pred_uncond + guidance_weight_t * (pred_t_cond - pred_uncond) + guidance_weight_i * (pred_ti_cond - pred_t_cond)
57
+
58
+ pred_std_after = pred.std([1, 2, 3], keepdim=True)
59
+
60
+ pred_rescale_ti_factor = guidance_stdev_rescale_factor * (pred_std_ti_before / pred_std_after) + (1 - guidance_stdev_rescale_factor)
61
+
62
+ pred = pred * pred_rescale_ti_factor
63
+ elif cfg_rescale_mode == "naive_global_direct":
64
+ pred_std_ti_before = pred_ti_cond.std([1, 2, 3], keepdim=True)
65
+
66
+ pred = pred_uncond + guidance_weight_i * (pred_ti_cond - pred_uncond)
67
+
68
+ pred_std_after = pred.std([1, 2, 3], keepdim=True)
69
+
70
+ pred_rescale_ti_factor = guidance_stdev_rescale_factor * (pred_std_ti_before / pred_std_after) + (1 - guidance_stdev_rescale_factor)
71
+
72
+ pred = pred * pred_rescale_ti_factor
73
+ else:
74
+ raise NotImplementedError()
75
+
76
+ return pred
inference/mask_generation.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from einops import rearrange
18
+
19
+ def mask_generation(
20
+ crossmap_2d_list, selfmap_2d_list=None,
21
+ target_token=None, mask_scope=None,
22
+ mask_target_h=64, mask_target_w=64,
23
+ mask_mode=["binary"],
24
+ ):
25
+ if len(selfmap_2d_list) > 0:
26
+ target_hw_selfmap = mask_target_h * mask_target_w
27
+ selfmap_2ds = []
28
+ for i in range(len(selfmap_2d_list)):
29
+ selfmap_ = selfmap_2d_list[i]
30
+ selfmap_ = F.interpolate(selfmap_, size=(target_hw_selfmap, target_hw_selfmap), mode='bilinear')
31
+ selfmap_2ds.append(selfmap_ )
32
+ selfmap_2ds = torch.cat(selfmap_2ds, dim=1)
33
+ if "selfmap_min_max_per_channel" in mask_mode:
34
+ selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)")
35
+ channel_max_self = torch.max(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1)
36
+ channel_min_self = torch.min(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1)
37
+ selfmap_2ds = (selfmap_2ds - channel_min_self) / (channel_max_self - channel_min_self + 1e-6)
38
+ elif "selfmap_max_norm" in mask_mode:
39
+ selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)")
40
+ b = selfmap_1ds.size(0)
41
+ batch_max = torch.max(selfmap_1ds.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1)
42
+ selfmap_2ds = selfmap_2ds / (batch_max + 1e-10)
43
+
44
+ selfmap_2d = selfmap_2ds.mean(dim=1, keepdim=True)
45
+ else:
46
+ selfmap_2d = None
47
+
48
+ crossmap_2ds = []
49
+ for i in range(len(crossmap_2d_list)):
50
+ crossmap = crossmap_2d_list[i]
51
+ crossmap = crossmap.mean(dim=1) # average on head dim
52
+ crossmap = crossmap * target_token.unsqueeze(-1).unsqueeze(-1) # target token valid
53
+ crossmap = crossmap.sum(dim=1, keepdim=True)
54
+
55
+ crossmap = F.interpolate(crossmap, size=(mask_target_h, mask_target_w), mode='bilinear')
56
+ crossmap_2ds.append(crossmap)
57
+ crossmap_2ds = torch.cat(crossmap_2ds, dim=1)
58
+ crossmap_1ds = rearrange(crossmap_2ds, "b c h w -> b c (h w)")
59
+
60
+ if "max_norm" in mask_mode:
61
+ crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
62
+ if selfmap_2d is not None:
63
+ crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
64
+ b, c, n = crossmap_1ds.shape
65
+ batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
66
+ crossmap_1d_avg = crossmap_1d_avg / (batch_max + 1e-6)
67
+ elif "min_max_norm" in mask_mode:
68
+ crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
69
+ if selfmap_2d is not None:
70
+ crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
71
+ b, c, n = crossmap_1ds.shape
72
+ batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze
73
+ batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze
74
+ crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6)
75
+ elif "min_max_per_channel" in mask_mode:
76
+ channel_max = torch.max(crossmap_1ds, dim=-1, keepdim=True)[0]
77
+ channel_min = torch.min(crossmap_1ds, dim=-1, keepdim=True)[0]
78
+ crossmap_1ds = (crossmap_1ds - channel_min) / (channel_max - channel_min + 1e-6)
79
+ crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
80
+ if selfmap_2d is not None:
81
+ crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
82
+
83
+ # renormalize to 0-1
84
+ b, c, n = crossmap_1d_avg.shape
85
+ batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
86
+ batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
87
+ crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6)
88
+ else:
89
+ crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
90
+
91
+
92
+ if "threshold" in mask_mode:
93
+ threshold = 1 - mask_scope
94
+ crossmap_1d_avg[crossmap_1d_avg < threshold] = 0.0
95
+ if "binary" in mask_mode:
96
+ crossmap_1d_avg[crossmap_1d_avg > threshold] = 1.0
97
+ else:
98
+ # topk
99
+ topk_num = int(crossmap_1d_avg.size(-1) * mask_scope)
100
+ sort_score, sort_order = crossmap_1d_avg.sort(descending=True, dim=-1)
101
+ sort_topk = sort_order[:, :, :topk_num]
102
+ sort_topk_remain = sort_order[:, :, topk_num:]
103
+ crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk_remain, 0.)
104
+ if "binary" in mask_mode:
105
+ crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk, 1.0)
106
+
107
+ crossmap_2d_avg = rearrange(crossmap_1d_avg, "b c (h w) -> b c h w", h=mask_target_h, w=mask_target_w)
108
+ crossmap_2d_avg = crossmap_2d_avg
109
+
110
+ output = crossmap_2d_avg.unsqueeze(1) # torch.Size([4, 1, 60, 64, 64]), The second dimension is the dimension of the number of reference images.
111
+ if output.size(2) == 1: # The dimension of the layer.
112
+ output = output.squeeze(2) # If there is only a single dimension, then all layers will share the same mask.
113
+
114
+ return output
inference/pipeline.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import json
17
+ import torch
18
+ import torchvision
19
+ from torchvision.utils import make_grid
20
+ from torchvision.transforms.functional import to_pil_image
21
+ from PIL import Image
22
+
23
+ from models.text import TextModel
24
+ from models.vae import AutoencoderKL
25
+ from models.unet_2d_condition_custom import UNet2DConditionModel as UNet2DConditionModelDiffusers
26
+
27
+ from schedulers.ddim import DDIMScheduler
28
+ from schedulers.dpm_s import DPMSolverSingleStepScheduler
29
+ from schedulers.utils import get_betas
30
+
31
+ from inference_utils import find_phrase_positions_in_text, classifier_free_guidance_image_prompt_cascade
32
+ from mask_generation import mask_generation
33
+ from utils import instantiate_from_config
34
+
35
+ from tqdm import tqdm
36
+ from einops import rearrange
37
+
38
+ class RealCustomInferencePipeline:
39
+ def __init__(
40
+ self,
41
+ unet_config,
42
+ unet_checkpoint,
43
+ realcustom_checkpoint,
44
+ vae_config="ckpts/sdxl/vae/sdxl.json",
45
+ vae_checkpoint="ckpts/sdxl/vae/sdxl-vae.pth",
46
+ model_type="bf16",
47
+ device="cuda",
48
+ ):
49
+ if model_type == "bf16":
50
+ self.torch_dtype = torch.bfloat16
51
+ else:
52
+ self.torch_dtype = torch.float32
53
+
54
+ if not os.path.exists("ckpts/"):
55
+ from huggingface_hub import snapshot_download
56
+ print("Downloading RealCustom ...")
57
+ snapshot_download(
58
+ repo_id="bytedance-research/RealCustom",
59
+ repo_type="model",
60
+ local_dir="ckpts", # 指定本地目录
61
+ allow_patterns="ckpts/**", # 只下载 ckpts 文件夹内容
62
+ local_dir_use_symlinks=False # 直接存储文件而非符号链接
63
+ )
64
+
65
+ self.device = device
66
+ self.unet_checkpoint = unet_checkpoint
67
+ self.realcustom_checkpoint = realcustom_checkpoint
68
+ self._load_unet_checkpoint(unet_config, unet_checkpoint, realcustom_checkpoint)
69
+ self._load_vae_checkpoint(vae_config, vae_checkpoint)
70
+ self._load_encoder_checkpoint()
71
+ self._init_scheduler()
72
+ self._load_negative_prompt()
73
+
74
+
75
+ def _load_unet_checkpoint(self, unet_config, unet_checkpoint, realcustom_checkpoint):
76
+ # Initialize unet model
77
+ with open(unet_config) as unet_config_file:
78
+ unet_config = json.load(unet_config_file)
79
+ self.unet_prediction = "epsilon"
80
+
81
+ # Settings for image encoder
82
+ vision_model_config = unet_config.pop("vision_model_config", None)
83
+ self.vision_model_config = vision_model_config.pop("vision_model_config", None)
84
+
85
+ self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
86
+
87
+ self.unet_model.eval().to(self.device).to(self.torch_dtype)
88
+ self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
89
+ self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
90
+ print("loading unet model finished.")
91
+
92
+ def _reload_unet_checkpoint(self, unet_checkpoint, realcustom_checkpoint):
93
+ self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
94
+ self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
95
+ print("reloading unet model finished.")
96
+
97
+ def _load_vae_checkpoint(self, vae_config, vae_checkpoint):
98
+ # Initialize vae model
99
+ with open(vae_config) as vae_config_file:
100
+ vae_config = json.load(vae_config_file)
101
+ self.latent_channels = vae_config["latent_channels"]
102
+ self.vae_downsample_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) # 2 ** 3 = 8
103
+
104
+ vae_model = AutoencoderKL(**vae_config)
105
+ vae_model.eval().to(self.device).to(self.torch_dtype)
106
+ vae_model.load_state_dict(torch.load(vae_checkpoint, map_location=self.device))
107
+ self.vae_decoder = torch.compile(lambda x: vae_model.decode(x / vae_model.scaling_factor).sample.clip(-1, 1), disable=True)
108
+ self.vae_encoder = torch.compile(lambda x: vae_model.encode(x).latent_dist.mode().mul_(vae_model.scaling_factor), disable=True)
109
+
110
+ print("loading vae finished.")
111
+
112
+ def _load_encoder_checkpoint(self, ):
113
+ # Initialize text encoder
114
+ text_encoder_variant = ["ckpts/sdxl/clip-sdxl-1", "ckpts/sdxl/clip-sdxl-2"]
115
+ text_encoder_mode = ["penultimate_nonorm"]
116
+ self.text_model = TextModel(text_encoder_variant, text_encoder_mode)
117
+ self.text_model.eval().to(self.device).to(self.torch_dtype)
118
+ print("loading text model finished.")
119
+
120
+ # Initialize image encoder
121
+ self.vision_model = instantiate_from_config(self.vision_model_config)
122
+ self.vision_model.eval().to(self.device).to(self.torch_dtype)
123
+ print("loading image model finished.")
124
+
125
+ def _init_scheduler(self, ):
126
+ # Initialize ddim scheduler
127
+ ddim_train_steps = 1000
128
+ schedule_type = "squared_linear"
129
+ scheduler_type = "dpm"
130
+ schedule_shift_snr = 1
131
+ self.sample_steps = 25
132
+ ddim_betas = get_betas(name=schedule_type, num_steps=ddim_train_steps, shift_snr=schedule_shift_snr, terminal_pure_noise=False)
133
+ scheduler_class = DPMSolverSingleStepScheduler if scheduler_type == 'dpm' else DDIMScheduler
134
+
135
+ self.scheduler = scheduler_class(betas=ddim_betas, num_train_timesteps=ddim_train_steps, num_inference_timesteps=self.sample_steps, device=self.device)
136
+ self.infer_timesteps = self.scheduler.timesteps
137
+
138
+ def _load_negative_prompt(self, ):
139
+ with open("prompts/validation_negative.txt") as f:
140
+ self.negative_prompt = f.read().strip()
141
+ self.text_negative_output = self.text_model(self.negative_prompt)
142
+
143
+ def generation(
144
+ self,
145
+ text,
146
+ image_pil,
147
+ target_phrase,
148
+
149
+ height=1024,
150
+ width=1024,
151
+ guidance_scale=3.5,
152
+ seed=1234,
153
+ samples_per_prompt=4,
154
+
155
+ mask_scope=0.25,
156
+
157
+ new_unet_checkpoint="", # in case you want to change
158
+ new_realcustom_checkpoint="", # in case you want to change
159
+ mask_strategy=["min_max_per_channel"],
160
+ mask_reused_step=12,
161
+ return_each_image=False,
162
+ ):
163
+
164
+ if new_unet_checkpoint != "" and new_unet_checkpoint != self.unet_checkpoint:
165
+ self.unet_checkpoint = new_unet_checkpoint
166
+ self.unet_model.load_state_dict(torch.load(new_unet_checkpoint, map_location=self.device), strict=False)
167
+ print("Reloading Unet {} finised.".format(new_unet_checkpoint))
168
+ if new_realcustom_checkpoint != "" and new_realcustom_checkpoint != self.realcustom_checkpoint:
169
+ self.realcustom_checkpoint = new_realcustom_checkpoint
170
+ self.unet_model.load_state_dict(torch.load(new_realcustom_checkpoint, map_location=self.device), strict=False)
171
+ print("Reloading RealCustom {} finised.".format(new_realcustom_checkpoint))
172
+
173
+ samples_per_prompt = int(samples_per_prompt)
174
+ image_metadata_validate = self._get_metadata(height, width, samples_per_prompt)
175
+ if seed == -1:
176
+ seed = torch.randint(0, 1000000, (1,)).item()
177
+ seed = int(seed)
178
+
179
+ with torch.no_grad(), torch.autocast(self.device, self.torch_dtype):
180
+ target_token = self._find_phrase_positions_in_text(text, target_phrase)
181
+
182
+ # Compute text embeddings
183
+ text_positive_output = self.text_model(text)
184
+ text_positive_embeddings = text_positive_output.embeddings.repeat_interleave(samples_per_prompt, dim=0)
185
+ text_positive_pooled = text_positive_output.pooled[-1].repeat_interleave(samples_per_prompt, dim=0)
186
+ if guidance_scale != 1:
187
+ text_negative_embeddings = self.text_negative_output.embeddings.repeat_interleave(samples_per_prompt, dim=0)
188
+ text_negative_pooled = self.text_negative_output.pooled[-1].repeat_interleave(samples_per_prompt, dim=0)
189
+
190
+ # Compute image embeddings
191
+ # positive_image = Image.open(image_path).convert("RGB")
192
+ positive_image = image_pil
193
+ positive_image = torchvision.transforms.ToTensor()(positive_image)
194
+
195
+ positive_image = positive_image.unsqueeze(0).repeat_interleave(samples_per_prompt, dim=0)
196
+ positive_image = torch.nn.functional.interpolate(
197
+ positive_image,
198
+ size=(768, 768),
199
+ mode="bilinear",
200
+ align_corners=False
201
+ )
202
+ negative_image = torch.zeros_like(positive_image)
203
+ positive_image = positive_image.to(self.device).to(self.torch_dtype)
204
+ negative_image = negative_image.to(self.device).to(self.torch_dtype)
205
+
206
+ positive_image_dict = {"image_ref": positive_image}
207
+ positive_image_output = self.vision_model(positive_image_dict, device=self.device)
208
+
209
+ negative_image_dict = {"image_ref": negative_image}
210
+ negative_image_output = self.vision_model(negative_image_dict, device=self.device)
211
+
212
+ # Initialize latent with input latent
213
+ latent = torch.randn(
214
+ size=[
215
+ samples_per_prompt,
216
+ self.latent_channels,
217
+ height // self.vae_downsample_factor,
218
+ width // self.vae_downsample_factor
219
+ ],
220
+ device=self.device,
221
+ generator=torch.Generator(self.device).manual_seed(seed)).to(self.torch_dtype)
222
+ target_h = (height // self.vae_downsample_factor) // 2
223
+ target_w = (width // self.vae_downsample_factor) // 2
224
+
225
+ text2image_crossmap_2d_all_timesteps_list = []
226
+ current_step = 0
227
+ pbar_text = text[:40]
228
+ for timestep in tqdm(iterable=self.infer_timesteps, desc=f"[{pbar_text}]", dynamic_ncols=True):
229
+ if current_step < mask_reused_step:
230
+ pred_cond, pred_cond_dict = self.unet_model(
231
+ sample=latent,
232
+ timestep=timestep,
233
+ encoder_hidden_states=text_positive_embeddings,
234
+ encoder_attention_mask=None,
235
+ added_cond_kwargs=dict(
236
+ text_embeds=text_positive_pooled,
237
+ time_ids=image_metadata_validate
238
+ ),
239
+ vision_input_dict=None,
240
+ vision_guided_mask=None,
241
+ return_as_origin=False,
242
+ return_text2image_mask=True,
243
+ )
244
+ crossmap_2d_avg = mask_generation(
245
+ crossmap_2d_list=pred_cond_dict["text2image_crossmap_2d"], selfmap_2d_list=pred_cond_dict.get("self_attention_map", []),
246
+ target_token=target_token, mask_scope=mask_scope,
247
+ mask_target_h=target_h, mask_target_w=target_w, mask_mode=mask_strategy,
248
+ )
249
+ else:
250
+ # using previous step's mask
251
+ crossmap_2d_avg = text2image_crossmap_2d_all_timesteps_list[-1].squeeze(1)
252
+ if crossmap_2d_avg.dim() == 5: # Means that each layer uses a separate mask weight.
253
+ text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.mean(dim=2).unsqueeze(1))
254
+ else:
255
+ text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.unsqueeze(1))
256
+
257
+ pred_cond, pred_cond_dict = self.unet_model(
258
+ sample=latent,
259
+ timestep=timestep,
260
+ encoder_hidden_states=text_positive_embeddings,
261
+ encoder_attention_mask=None,
262
+ added_cond_kwargs=dict(
263
+ text_embeds=text_positive_pooled,
264
+ time_ids=image_metadata_validate
265
+ ),
266
+ vision_input_dict=positive_image_output,
267
+ vision_guided_mask=crossmap_2d_avg,
268
+ return_as_origin=False,
269
+ return_text2image_mask=True,
270
+ multiple_reference_image=False
271
+ )
272
+
273
+ pred_negative, pred_negative_dict = self.unet_model(
274
+ sample=latent,
275
+ timestep=timestep,
276
+ encoder_hidden_states=text_negative_embeddings,
277
+ encoder_attention_mask=None,
278
+ added_cond_kwargs=dict(
279
+ text_embeds=text_negative_pooled,
280
+ time_ids=image_metadata_validate
281
+ ),
282
+ vision_input_dict=negative_image_output,
283
+ vision_guided_mask=crossmap_2d_avg,
284
+ return_as_origin=False,
285
+ return_text2image_mask=True,
286
+ multiple_reference_image=False
287
+ )
288
+
289
+ pred = classifier_free_guidance_image_prompt_cascade(
290
+ pred_t_cond=None, pred_ti_cond=pred_cond, pred_uncond=pred_negative,
291
+ guidance_weight_t=guidance_scale, guidance_weight_i=guidance_scale,
292
+ guidance_stdev_rescale_factor=0, cfg_rescale_mode="naive_global_direct"
293
+ )
294
+ step = self.scheduler.step(
295
+ model_output=pred,
296
+ model_output_type=self.unet_prediction,
297
+ timestep=timestep,
298
+ sample=latent)
299
+
300
+ latent = step.prev_sample
301
+
302
+ current_step += 1
303
+ sample = self.vae_decoder(step.pred_original_sample)
304
+
305
+ # save each image
306
+ images_pil_list = []
307
+ for sample_i in range(sample.size(0)):
308
+ sample_i_image = torch.clamp(sample[sample_i] * 0.5 + 0.5, min=0, max=1).float()
309
+
310
+ images_pil_list.append(to_pil_image(sample_i_image))
311
+ # to_pil_image(sample_i_image).save("./test_{}.jpg".format(sample_i))
312
+
313
+ # save grid images
314
+ sample = make_grid(sample, normalize=True, value_range=(-1, 1), nrow=int(samples_per_prompt ** 0.5)).float()
315
+ # to_pil_image(sample).save("./output_grid_image.jpg")
316
+
317
+ # save all masks
318
+ text2image_crossmap_2d_all_timesteps = torch.cat(text2image_crossmap_2d_all_timesteps_list, dim=1)
319
+ text2image_crossmap_2d_all_timesteps = rearrange(text2image_crossmap_2d_all_timesteps, "b t c h w -> (b t) c h w")
320
+ c = text2image_crossmap_2d_all_timesteps.size(1)
321
+ text2image_crossmap_2d_all_timesteps = rearrange(text2image_crossmap_2d_all_timesteps, "B (c 1) h w -> (B c) 1 h w")
322
+ sample_mask = make_grid(text2image_crossmap_2d_all_timesteps, normalize=False, value_range=(-1, 1), nrow=int(self.sample_steps * c))
323
+ # to_pil_image(sample_mask).save("./output_grid_mask.jpg")
324
+
325
+ if return_each_image:
326
+ return images_pil_list, to_pil_image(sample), to_pil_image(sample_mask)
327
+ else:
328
+ return to_pil_image(sample), to_pil_image(sample_mask)
329
+
330
+ def _get_metadata(self, height, width, samples_per_prompt):
331
+ image_metadata_validate = torch.tensor(
332
+ data=[
333
+ width, # original_height
334
+ height, # original_width
335
+ 0, # coordinate top
336
+ 0, # coordinate left
337
+ width, # target_height
338
+ height, # target_width
339
+ ],
340
+ device=self.device,
341
+ dtype=self.torch_dtype
342
+ ).view(1, -1).repeat(samples_per_prompt, 1)
343
+
344
+ return image_metadata_validate
345
+
346
+ def _find_phrase_positions_in_text(self, text, target_phrase):
347
+ # Compute target phrases
348
+ target_token = torch.zeros(1, 77).to(self.device)
349
+ positions = find_phrase_positions_in_text(text, target_phrase)
350
+ for position in positions:
351
+ prompt_before = text[:position] # NOTE We do not need -1 here because the SDXL text encoder does not encode the trailing space.
352
+ prompt_include = text[:position+len(target_phrase)]
353
+ print("prompt before: ", prompt_before, ", prompt_include: ", prompt_include)
354
+ prompt_before_length = self.text_model.get_vaild_token_length(prompt_before) + 1
355
+ prompt_include_length = self.text_model.get_vaild_token_length(prompt_include) + 1
356
+ print("prompt_before_length: ", prompt_before_length, ", prompt_include_length: ", prompt_include_length)
357
+ target_token[:, prompt_before_length:prompt_include_length] = 1
358
+
359
+ return target_token
models/__pycache__/attention_custom.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
models/__pycache__/attention_processor_custom_cross.cpython-310.pyc ADDED
Binary file (38.7 kB). View file
 
models/__pycache__/base_vision.cpython-310.pyc ADDED
Binary file (8.28 kB). View file
 
models/__pycache__/dino.cpython-310.pyc ADDED
Binary file (7.08 kB). View file
 
models/__pycache__/image_encoder_siglipdino_shallowdeep.cpython-310.pyc ADDED
Binary file (4.23 kB). View file
 
models/__pycache__/projectors.cpython-310.pyc ADDED
Binary file (4.33 kB). View file
 
models/__pycache__/sigclip.cpython-310.pyc ADDED
Binary file (5.81 kB). View file
 
models/__pycache__/text.cpython-310.pyc ADDED
Binary file (2.88 kB). View file
 
models/__pycache__/transformer_2d_custom.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
models/__pycache__/unet_2d_blocks_custom.cpython-310.pyc ADDED
Binary file (52.5 kB). View file
 
models/__pycache__/unet_2d_condition_custom.cpython-310.pyc ADDED
Binary file (31.1 kB). View file
 
models/__pycache__/vae.cpython-310.pyc ADDED
Binary file (1.08 kB). View file
 
models/attention_custom.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ # from diffusers.utils import maybe_allow_in_graph
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import Attention
25
+ from models.attention_processor_custom_cross import Attention as CrossAttention
26
+ from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
27
+
28
+ from utils import update_dict
29
+
30
+ # @maybe_allow_in_graph
31
+ class BasicTransformerBlock(nn.Module):
32
+ r"""
33
+ A basic Transformer block.
34
+
35
+ Parameters:
36
+ dim (`int`): The number of channels in the input and output.
37
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
38
+ attention_head_dim (`int`): The number of channels in each head.
39
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
40
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
41
+ only_cross_attention (`bool`, *optional*):
42
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
43
+ double_self_attention (`bool`, *optional*):
44
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
45
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
46
+ num_embeds_ada_norm (:
47
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
48
+ attention_bias (:
49
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ dim: int,
55
+ num_attention_heads: int,
56
+ attention_head_dim: int,
57
+ dropout=0.0,
58
+ cross_attention_dim: Optional[int] = None,
59
+ activation_fn: str = "geglu",
60
+ num_embeds_ada_norm: Optional[int] = None,
61
+ attention_bias: bool = False,
62
+ only_cross_attention: bool = False,
63
+ double_self_attention: bool = False,
64
+ upcast_attention: bool = False,
65
+ norm_elementwise_affine: bool = True,
66
+ norm_type: str = "layer_norm",
67
+ final_dropout: bool = False,
68
+ image_prompt_settings: dict = {},
69
+ ):
70
+ super().__init__()
71
+ self.only_cross_attention = only_cross_attention
72
+
73
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
74
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
75
+
76
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
77
+ raise ValueError(
78
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
79
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
80
+ )
81
+
82
+ # Define 3 blocks. Each block has its own normalization layer.
83
+ # 1. Self-Attn
84
+ if self.use_ada_layer_norm:
85
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
86
+ elif self.use_ada_layer_norm_zero:
87
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
88
+ else:
89
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
90
+ self.attn1 = Attention(
91
+ query_dim=dim,
92
+ heads=num_attention_heads,
93
+ dim_head=attention_head_dim,
94
+ dropout=dropout,
95
+ bias=attention_bias,
96
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
97
+ upcast_attention=upcast_attention,
98
+ )
99
+
100
+ # 2. Cross-Attn
101
+ if cross_attention_dim is not None or double_self_attention:
102
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
103
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
104
+ # the second cross attention block.
105
+ self.norm2 = (
106
+ AdaLayerNorm(dim, num_embeds_ada_norm)
107
+ if self.use_ada_layer_norm
108
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
109
+ )
110
+ # self.attn2 = CrossAttention(
111
+ # self.attn2 = SelfAttention(
112
+ self.attn2 = CrossAttention(
113
+ query_dim=dim,
114
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
115
+ heads=num_attention_heads,
116
+ dim_head=attention_head_dim,
117
+ dropout=dropout,
118
+ bias=attention_bias,
119
+ upcast_attention=upcast_attention,
120
+ image_prompt_settings=image_prompt_settings,
121
+ )
122
+ else:
123
+ self.norm2 = None
124
+ self.attn2 = None
125
+
126
+ # 3. Feed-forward
127
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
128
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
129
+
130
+ # let chunk size default to None
131
+ self._chunk_size = None
132
+ self._chunk_dim = 0
133
+
134
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
135
+ # Sets chunk feed-forward
136
+ self._chunk_size = chunk_size
137
+ self._chunk_dim = dim
138
+
139
+ def forward(
140
+ self,
141
+ hidden_states: torch.FloatTensor,
142
+ attention_mask: Optional[torch.FloatTensor] = None,
143
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
144
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
145
+ timestep: Optional[torch.LongTensor] = None,
146
+ cross_attention_kwargs: Dict[str, Any] = None,
147
+ class_labels: Optional[torch.LongTensor] = None,
148
+ encoder_hidden_states_vision = None,
149
+ encoder_hidden_states_control = None,
150
+ vision_guided_mask = None,
151
+ extra_dict_inputs = {},
152
+ height = None,
153
+ width = None,
154
+ return_self_attn_map = False,
155
+ ):
156
+ extra_dict_outputs = {}
157
+ # Notice that normalization is always applied before the real computation in the following blocks.
158
+ # 1. Self-Attention
159
+ if self.use_ada_layer_norm:
160
+ norm_hidden_states = self.norm1(hidden_states, timestep)
161
+ elif self.use_ada_layer_norm_zero:
162
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
163
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
164
+ )
165
+ else:
166
+ norm_hidden_states = self.norm1(hidden_states)
167
+
168
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
169
+
170
+ # self attention in XL
171
+ attn_output = self.attn1(
172
+ norm_hidden_states,
173
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
174
+ attention_mask=attention_mask,
175
+ # vision_guided_mask=vision_guided_mask,
176
+ # height=height,
177
+ # width=width,
178
+ # return_self_attn_map=return_self_attn_map,
179
+ # **cross_attention_kwargs,
180
+ )
181
+ # extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_attn)
182
+ if self.use_ada_layer_norm_zero:
183
+ attn_output = gate_msa.unsqueeze(1) * attn_output
184
+ hidden_states = attn_output + hidden_states
185
+
186
+ # 2. Cross-Attention
187
+ if self.attn2 is not None:
188
+ norm_hidden_states = (
189
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
190
+ )
191
+
192
+ attn_output, extra_dict_output_attn = self.attn2(
193
+ norm_hidden_states,
194
+ encoder_hidden_states=encoder_hidden_states,
195
+ attention_mask=encoder_attention_mask,
196
+ encoder_hidden_states_vision=encoder_hidden_states_vision,
197
+ encoder_hidden_states_control=encoder_hidden_states_control,
198
+ vision_guided_mask=vision_guided_mask,
199
+ extra_dict_inputs=extra_dict_inputs,
200
+ height=height,
201
+ width=width,
202
+ **cross_attention_kwargs,
203
+ )
204
+ extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_attn)
205
+ # attn_output = self.attn2(
206
+ # norm_hidden_states,
207
+ # encoder_hidden_states=encoder_hidden_states,
208
+ # attention_mask=encoder_attention_mask,
209
+ # **cross_attention_kwargs,
210
+ # )
211
+ hidden_states = attn_output + hidden_states
212
+
213
+
214
+ # 3. Feed-forward
215
+ norm_hidden_states = self.norm3(hidden_states)
216
+
217
+ if self.use_ada_layer_norm_zero:
218
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
219
+
220
+ if self._chunk_size is not None:
221
+ # "feed_forward_chunk_size" can be used to save memory
222
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
223
+ raise ValueError(
224
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
225
+ )
226
+
227
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
228
+ ff_output = torch.cat(
229
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
230
+ dim=self._chunk_dim,
231
+ )
232
+ else:
233
+ ff_output = self.ff(norm_hidden_states)
234
+
235
+ if self.use_ada_layer_norm_zero:
236
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
237
+
238
+ hidden_states = ff_output + hidden_states
239
+
240
+ return hidden_states, extra_dict_outputs
241
+
242
+
243
+ class FeedForward(nn.Module):
244
+ r"""
245
+ A feed-forward layer.
246
+
247
+ Parameters:
248
+ dim (`int`): The number of channels in the input.
249
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
250
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
251
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
252
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
253
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ dim: int,
259
+ dim_out: Optional[int] = None,
260
+ mult: int = 4,
261
+ dropout: float = 0.0,
262
+ activation_fn: str = "geglu",
263
+ final_dropout: bool = False,
264
+ ):
265
+ super().__init__()
266
+ inner_dim = int(dim * mult)
267
+ dim_out = dim_out if dim_out is not None else dim
268
+
269
+ if activation_fn == "gelu":
270
+ act_fn = GELU(dim, inner_dim)
271
+ if activation_fn == "gelu-approximate":
272
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
273
+ elif activation_fn == "geglu":
274
+ act_fn = GEGLU(dim, inner_dim)
275
+ elif activation_fn == "geglu-approximate":
276
+ act_fn = ApproximateGELU(dim, inner_dim)
277
+
278
+ self.net = nn.ModuleList([])
279
+ # project in
280
+ self.net.append(act_fn)
281
+ # project dropout
282
+ self.net.append(nn.Dropout(dropout))
283
+ # project out
284
+ self.net.append(nn.Linear(inner_dim, dim_out))
285
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
286
+ if final_dropout:
287
+ self.net.append(nn.Dropout(dropout))
288
+
289
+ def forward(self, hidden_states):
290
+ for module in self.net:
291
+ hidden_states = module(hidden_states)
292
+ return hidden_states
293
+
294
+
295
+ class GELU(nn.Module):
296
+ r"""
297
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
298
+ """
299
+
300
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
301
+ super().__init__()
302
+ self.proj = nn.Linear(dim_in, dim_out)
303
+ self.approximate = approximate
304
+
305
+ def gelu(self, gate):
306
+ if gate.device.type != "mps":
307
+ return F.gelu(gate, approximate=self.approximate)
308
+ # mps: gelu is not implemented for float16
309
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
310
+
311
+ def forward(self, hidden_states):
312
+ hidden_states = self.proj(hidden_states)
313
+ hidden_states = self.gelu(hidden_states)
314
+ return hidden_states
315
+
316
+
317
+ class GEGLU(nn.Module):
318
+ r"""
319
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
320
+
321
+ Parameters:
322
+ dim_in (`int`): The number of channels in the input.
323
+ dim_out (`int`): The number of channels in the output.
324
+ """
325
+
326
+ def __init__(self, dim_in: int, dim_out: int):
327
+ super().__init__()
328
+ self.proj = nn.Linear(dim_in, dim_out * 2)
329
+
330
+ def gelu(self, gate):
331
+ if gate.device.type != "mps":
332
+ return F.gelu(gate)
333
+ # mps: gelu is not implemented for float16
334
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
335
+
336
+ def forward(self, hidden_states):
337
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
338
+ return hidden_states * self.gelu(gate)
339
+
340
+
341
+ class ApproximateGELU(nn.Module):
342
+ """
343
+ The approximate form of Gaussian Error Linear Unit (GELU)
344
+
345
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
346
+ """
347
+
348
+ def __init__(self, dim_in: int, dim_out: int):
349
+ super().__init__()
350
+ self.proj = nn.Linear(dim_in, dim_out)
351
+
352
+ def forward(self, x):
353
+ x = self.proj(x)
354
+ return x * torch.sigmoid(1.702 * x)
355
+
356
+
357
+ class AdaLayerNorm(nn.Module):
358
+ """
359
+ Norm layer modified to incorporate timestep embeddings.
360
+ """
361
+
362
+ def __init__(self, embedding_dim, num_embeddings):
363
+ super().__init__()
364
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
365
+ self.silu = nn.SiLU()
366
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
367
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
368
+
369
+ def forward(self, x, timestep):
370
+ emb = self.linear(self.silu(self.emb(timestep)))
371
+ scale, shift = torch.chunk(emb, 2)
372
+ x = self.norm(x) * (1 + scale) + shift
373
+ return x
374
+
375
+
376
+ class AdaLayerNormZero(nn.Module):
377
+ """
378
+ Norm layer adaptive layer norm zero (adaLN-Zero).
379
+ """
380
+
381
+ def __init__(self, embedding_dim, num_embeddings):
382
+ super().__init__()
383
+
384
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
385
+
386
+ self.silu = nn.SiLU()
387
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
388
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
389
+
390
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
391
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
392
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
393
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
394
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
395
+
396
+
397
+ class AdaGroupNorm(nn.Module):
398
+ """
399
+ GroupNorm layer modified to incorporate timestep embeddings.
400
+ """
401
+
402
+ def __init__(
403
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
404
+ ):
405
+ super().__init__()
406
+ self.num_groups = num_groups
407
+ self.eps = eps
408
+
409
+ if act_fn is None:
410
+ self.act = None
411
+ else:
412
+ self.act = get_activation(act_fn)
413
+
414
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
415
+
416
+ def forward(self, x, emb):
417
+ if self.act:
418
+ emb = self.act(emb)
419
+ emb = self.linear(emb)
420
+ emb = emb[:, :, None, None]
421
+ scale, shift = emb.chunk(2, dim=1)
422
+
423
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
424
+ x = x * (1 + scale) + shift
425
+ return x
models/attention_processor_custom_cross.py ADDED
@@ -0,0 +1,1778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Callable, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ # from diffusers.utils import deprecate, logging, maybe_allow_in_graph
23
+ from diffusers.utils import deprecate, logging
24
+ from diffusers.utils.import_utils import is_xformers_available
25
+ from einops import rearrange
26
+ import random
27
+
28
+ from utils import zero_module
29
+ # from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ if is_xformers_available():
34
+ import xformers
35
+ import xformers.ops
36
+ else:
37
+ xformers = None
38
+
39
+ # Cross Attention
40
+ # @maybe_allow_in_graph
41
+ class Attention(nn.Module):
42
+ r"""
43
+ A cross attention layer.
44
+
45
+ Parameters:
46
+ query_dim (`int`): The number of channels in the query.
47
+ cross_attention_dim (`int`, *optional*):
48
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
49
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
50
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
51
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
52
+ bias (`bool`, *optional*, defaults to False):
53
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ query_dim: int,
59
+ cross_attention_dim: Optional[int] = None,
60
+ heads: int = 8,
61
+ dim_head: int = 64,
62
+ dropout: float = 0.0,
63
+ bias=False,
64
+ upcast_attention: bool = False,
65
+ upcast_softmax: bool = False,
66
+ cross_attention_norm: Optional[str] = None,
67
+ cross_attention_norm_num_groups: int = 32,
68
+ added_kv_proj_dim: Optional[int] = None,
69
+ norm_num_groups: Optional[int] = None,
70
+ spatial_norm_dim: Optional[int] = None,
71
+ out_bias: bool = True,
72
+ scale_qk: bool = True,
73
+ only_cross_attention: bool = False,
74
+ eps: float = 1e-5,
75
+ rescale_output_factor: float = 1.0,
76
+ residual_connection: bool = False,
77
+ _from_deprecated_attn_block=False,
78
+ processor: Optional["AttnProcessor"] = None,
79
+ image_prompt_settings: dict = {}
80
+ ):
81
+ super().__init__()
82
+
83
+ inner_dim = dim_head * heads
84
+ self.inner_dim = inner_dim
85
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
86
+ self.upcast_attention = upcast_attention
87
+ self.upcast_softmax = upcast_softmax
88
+ self.rescale_output_factor = rescale_output_factor
89
+ self.residual_connection = residual_connection
90
+ self.dropout = dropout
91
+
92
+ # we make use of this private variable to know whether this class is loaded
93
+ # with an deprecated state dict so that we can convert it on the fly
94
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
95
+
96
+ self.scale_qk = scale_qk
97
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
98
+
99
+ self.heads = heads
100
+ # for slice_size > 0 the attention score computation
101
+ # is split across the batch axis to save memory
102
+ # You can set slice_size with `set_attention_slice`
103
+ self.sliceable_head_dim = heads
104
+
105
+ self.added_kv_proj_dim = added_kv_proj_dim
106
+ self.only_cross_attention = only_cross_attention
107
+
108
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
109
+ raise ValueError(
110
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
111
+ )
112
+
113
+ if norm_num_groups is not None:
114
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
115
+ else:
116
+ self.group_norm = None
117
+
118
+ if spatial_norm_dim is not None:
119
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
120
+ else:
121
+ self.spatial_norm = None
122
+
123
+ if cross_attention_norm is None:
124
+ self.norm_cross = None
125
+ elif cross_attention_norm == "layer_norm":
126
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
127
+ elif cross_attention_norm == "group_norm":
128
+ if self.added_kv_proj_dim is not None:
129
+ # The given `encoder_hidden_states` are initially of shape
130
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
131
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
132
+ # before the projection, so we need to use `added_kv_proj_dim` as
133
+ # the number of channels for the group norm.
134
+ norm_cross_num_channels = added_kv_proj_dim
135
+ else:
136
+ norm_cross_num_channels = cross_attention_dim
137
+
138
+ self.norm_cross = nn.GroupNorm(
139
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
140
+ )
141
+ else:
142
+ raise ValueError(
143
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
144
+ )
145
+
146
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
147
+
148
+ if not self.only_cross_attention:
149
+ # only relevant for the `AddedKVProcessor` classes
150
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
151
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
152
+ else:
153
+ self.to_k = None
154
+ self.to_v = None
155
+
156
+ if self.added_kv_proj_dim is not None:
157
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
158
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
159
+
160
+ # NOTE for custom dual branch settings
161
+ self.cross_attention_id = image_prompt_settings.get("cross_attention_id", 0)
162
+ self.use_cross_attention_id = image_prompt_settings.get("use_cross_attention_id", False)
163
+ self.image_prompt_mode = image_prompt_settings.get("image_prompt_mode", "none")
164
+
165
+ if self.image_prompt_mode == "naive": # only used in cross-attention, NOT self-attention
166
+ self.to_k_vision = nn.Linear(cross_attention_dim, inner_dim, bias=False)
167
+ self.to_v_vision = nn.Linear(cross_attention_dim, inner_dim, bias=False)
168
+ else:
169
+ if self.image_prompt_mode != "none":
170
+ print("Warning .... unknown self.image_prompt_mode")
171
+
172
+ self.to_out = nn.ModuleList([])
173
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
174
+ self.to_out.append(nn.Dropout(dropout))
175
+
176
+ self.to_out = nn.ModuleList([])
177
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
178
+ self.to_out.append(nn.Dropout(dropout))
179
+
180
+ # set attention processor
181
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
182
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
183
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
184
+ # <note> processor default to be None
185
+ if processor is None:
186
+ assert hasattr(F, "scaled_dot_product_attention") and self.scale_qk
187
+ processor = AttnProcessor2_0_image_prompt()
188
+
189
+ self.set_processor(processor)
190
+
191
+ def set_use_memory_efficient_attention_xformers(
192
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
193
+ ):
194
+ is_lora = hasattr(self, "processor") and isinstance(
195
+ self.processor,
196
+ (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
197
+ )
198
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
199
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
200
+ )
201
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
202
+ self.processor,
203
+ (
204
+ AttnAddedKVProcessor,
205
+ AttnAddedKVProcessor2_0,
206
+ SlicedAttnAddedKVProcessor,
207
+ XFormersAttnAddedKVProcessor,
208
+ LoRAAttnAddedKVProcessor,
209
+ ),
210
+ )
211
+
212
+ if use_memory_efficient_attention_xformers:
213
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
214
+ raise NotImplementedError(
215
+ f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
216
+ )
217
+ if not is_xformers_available():
218
+ raise ModuleNotFoundError(
219
+ (
220
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
221
+ " xformers"
222
+ ),
223
+ name="xformers",
224
+ )
225
+ elif not torch.cuda.is_available():
226
+ raise ValueError(
227
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
228
+ " only available for GPU "
229
+ )
230
+ else:
231
+ try:
232
+ # Make sure we can run the memory efficient attention
233
+ _ = xformers.ops.memory_efficient_attention(
234
+ torch.randn((1, 2, 40), device="cuda"),
235
+ torch.randn((1, 2, 40), device="cuda"),
236
+ torch.randn((1, 2, 40), device="cuda"),
237
+ )
238
+ except Exception as e:
239
+ raise e
240
+
241
+ if is_lora:
242
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
243
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
244
+ processor = LoRAXFormersAttnProcessor(
245
+ hidden_size=self.processor.hidden_size,
246
+ cross_attention_dim=self.processor.cross_attention_dim,
247
+ rank=self.processor.rank,
248
+ attention_op=attention_op,
249
+ )
250
+ processor.load_state_dict(self.processor.state_dict())
251
+ processor.to(self.processor.to_q_lora.up.weight.device)
252
+ elif is_custom_diffusion:
253
+ processor = CustomDiffusionXFormersAttnProcessor(
254
+ train_kv=self.processor.train_kv,
255
+ train_q_out=self.processor.train_q_out,
256
+ hidden_size=self.processor.hidden_size,
257
+ cross_attention_dim=self.processor.cross_attention_dim,
258
+ attention_op=attention_op,
259
+ )
260
+ processor.load_state_dict(self.processor.state_dict())
261
+ if hasattr(self.processor, "to_k_custom_diffusion"):
262
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
263
+ elif is_added_kv_processor:
264
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
265
+ # which uses this type of cross attention ONLY because the attention mask of format
266
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
267
+ # throw warning
268
+ logger.info(
269
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
270
+ )
271
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
272
+ else:
273
+ processor = XFormersAttnProcessor(attention_op=attention_op)
274
+ else:
275
+ if is_lora:
276
+ attn_processor_class = (
277
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
278
+ )
279
+ processor = attn_processor_class(
280
+ hidden_size=self.processor.hidden_size,
281
+ cross_attention_dim=self.processor.cross_attention_dim,
282
+ rank=self.processor.rank,
283
+ )
284
+ processor.load_state_dict(self.processor.state_dict())
285
+ processor.to(self.processor.to_q_lora.up.weight.device)
286
+ elif is_custom_diffusion:
287
+ processor = CustomDiffusionAttnProcessor(
288
+ train_kv=self.processor.train_kv,
289
+ train_q_out=self.processor.train_q_out,
290
+ hidden_size=self.processor.hidden_size,
291
+ cross_attention_dim=self.processor.cross_attention_dim,
292
+ )
293
+ processor.load_state_dict(self.processor.state_dict())
294
+ if hasattr(self.processor, "to_k_custom_diffusion"):
295
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
296
+ else:
297
+ # set attention processor
298
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
299
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
300
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
301
+ processor = (
302
+ AttnProcessor2_0()
303
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
304
+ else AttnProcessor()
305
+ )
306
+
307
+ self.set_processor(processor)
308
+
309
+ def set_attention_slice(self, slice_size):
310
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
311
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
312
+
313
+ if slice_size is not None and self.added_kv_proj_dim is not None:
314
+ processor = SlicedAttnAddedKVProcessor(slice_size)
315
+ elif slice_size is not None:
316
+ processor = SlicedAttnProcessor(slice_size)
317
+ elif self.added_kv_proj_dim is not None:
318
+ processor = AttnAddedKVProcessor()
319
+ else:
320
+ # set attention processor
321
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
322
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
323
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
324
+ processor = (
325
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
326
+ )
327
+
328
+ self.set_processor(processor)
329
+
330
+ def set_processor(self, processor: "AttnProcessor"):
331
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
332
+ # pop `processor` from `self._modules`
333
+ if (
334
+ hasattr(self, "processor")
335
+ and isinstance(self.processor, torch.nn.Module)
336
+ and not isinstance(processor, torch.nn.Module)
337
+ ):
338
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
339
+ self._modules.pop("processor")
340
+
341
+ self.processor = processor
342
+
343
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None,
344
+ encoder_hidden_states_vision=None, encoder_hidden_states_control=None,
345
+ vision_guided_mask=None, extra_dict_inputs={}, height=None, width=None,
346
+ **cross_attention_kwargs):
347
+ # The `Attention` class can call different attention processors / attention functions
348
+ # here we simply pass along all tensors to the selected processor class
349
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
350
+ return self.processor(
351
+ self,
352
+ hidden_states,
353
+ encoder_hidden_states=encoder_hidden_states,
354
+ attention_mask=attention_mask,
355
+
356
+ encoder_hidden_states_vision=encoder_hidden_states_vision,
357
+ encoder_hidden_states_control=encoder_hidden_states_control,
358
+ vision_guided_mask=vision_guided_mask,
359
+ extra_dict_inputs=extra_dict_inputs,
360
+ image_prompt_mode=self.image_prompt_mode,
361
+
362
+ height=height, width=width,
363
+
364
+ **cross_attention_kwargs,
365
+ )
366
+
367
+ def batch_to_head_dim(self, tensor):
368
+ head_size = self.heads
369
+ batch_size, seq_len, dim = tensor.shape
370
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
371
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
372
+ return tensor
373
+
374
+ def head_to_batch_dim(self, tensor, out_dim=3):
375
+ head_size = self.heads
376
+ batch_size, seq_len, dim = tensor.shape
377
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
378
+ tensor = tensor.permute(0, 2, 1, 3)
379
+
380
+ if out_dim == 3:
381
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
382
+
383
+ return tensor
384
+
385
+ def get_attention_scores(self, query, key, attention_mask=None):
386
+ dtype = query.dtype
387
+ if self.upcast_attention:
388
+ query = query.float()
389
+ key = key.float()
390
+
391
+ if attention_mask is None:
392
+ baddbmm_input = torch.empty(
393
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
394
+ )
395
+ beta = 0
396
+ else:
397
+ baddbmm_input = attention_mask
398
+ beta = 1
399
+
400
+ attention_scores = torch.baddbmm(
401
+ baddbmm_input,
402
+ query,
403
+ key.transpose(-1, -2),
404
+ beta=beta,
405
+ alpha=self.scale,
406
+ )
407
+ del baddbmm_input
408
+
409
+ if self.upcast_softmax:
410
+ attention_scores = attention_scores.float()
411
+
412
+ attention_probs = attention_scores.softmax(dim=-1)
413
+ del attention_scores
414
+
415
+ attention_probs = attention_probs.to(dtype)
416
+
417
+ return attention_probs
418
+
419
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
420
+ if batch_size is None:
421
+ deprecate(
422
+ "batch_size=None",
423
+ "0.0.15",
424
+ (
425
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
426
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
427
+ " `prepare_attention_mask` when preparing the attention_mask."
428
+ ),
429
+ )
430
+ batch_size = 1
431
+
432
+ head_size = self.heads
433
+ if attention_mask is None:
434
+ return attention_mask
435
+
436
+ current_length: int = attention_mask.shape[-1]
437
+ if current_length != target_length:
438
+ if attention_mask.device.type == "mps":
439
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
440
+ # Instead, we can manually construct the padding tensor.
441
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
442
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
443
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
444
+ else:
445
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
446
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
447
+ # remaining_length: int = target_length - current_length
448
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
449
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
450
+
451
+ if out_dim == 3:
452
+ if attention_mask.shape[0] < batch_size * head_size:
453
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
454
+ elif out_dim == 4:
455
+ attention_mask = attention_mask.unsqueeze(1)
456
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
457
+
458
+ return attention_mask
459
+
460
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
461
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
462
+
463
+ if isinstance(self.norm_cross, nn.LayerNorm):
464
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
465
+ elif isinstance(self.norm_cross, nn.GroupNorm):
466
+ # Group norm norms along the channels dimension and expects
467
+ # input to be in the shape of (N, C, *). In this case, we want
468
+ # to norm along the hidden dimension, so we need to move
469
+ # (batch_size, sequence_length, hidden_size) ->
470
+ # (batch_size, hidden_size, sequence_length)
471
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
472
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
473
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
474
+ else:
475
+ assert False
476
+
477
+ return encoder_hidden_states
478
+
479
+
480
+ class AttnProcessor2_0_image_prompt:
481
+ r"""
482
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
483
+ """
484
+
485
+ def __init__(self):
486
+ if not hasattr(F, "scaled_dot_product_attention"):
487
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
488
+
489
+ def __call__(
490
+ self,
491
+ attn: Attention,
492
+ hidden_states,
493
+ encoder_hidden_states=None,
494
+ attention_mask=None,
495
+ temb=None,
496
+
497
+ encoder_hidden_states_vision=None,
498
+ encoder_hidden_states_control=None,
499
+ vision_guided_mask=None,
500
+ extra_dict_inputs={},
501
+ image_prompt_mode="none",
502
+
503
+ height=None, width=None,
504
+ ):
505
+ if "multiple_reference_image" in extra_dict_inputs.keys():
506
+ multiple_reference_image = extra_dict_inputs["multiple_reference_image"]
507
+ else:
508
+ multiple_reference_image = False
509
+
510
+ resampled_token = None
511
+ if encoder_hidden_states_vision is not None:
512
+ if attn.use_cross_attention_id:
513
+ encoder_hidden_states_vision = encoder_hidden_states_vision[:, attn.cross_attention_id, :, :]
514
+
515
+ extra_dict_outputs = {}
516
+ residual = hidden_states
517
+
518
+ if attn.spatial_norm is not None:
519
+ hidden_states = attn.spatial_norm(hidden_states, temb)
520
+
521
+ input_ndim = hidden_states.ndim
522
+
523
+ if input_ndim == 4:
524
+ batch_size, channel, height, width = hidden_states.shape
525
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
526
+
527
+ batch_size, sequence_length, _ = (
528
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
529
+ )
530
+ inner_dim = hidden_states.shape[-1]
531
+
532
+ if attention_mask is not None:
533
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
534
+ # scaled_dot_product_attention expects attention_mask shape to be
535
+ # (batch, heads, source_length, target_length)
536
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
537
+ if attn.group_norm is not None:
538
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
539
+
540
+ query = attn.to_q(hidden_states)
541
+
542
+ if encoder_hidden_states is None:
543
+ encoder_hidden_states = hidden_states
544
+ elif attn.norm_cross: # attn.norm_cross: None
545
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
546
+
547
+ # text-imgae cross attention
548
+ key = attn.to_k(encoder_hidden_states)
549
+ value = attn.to_v(encoder_hidden_states)
550
+
551
+ head_dim = inner_dim // attn.heads
552
+
553
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
554
+
555
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
556
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
557
+
558
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
559
+ # TODO: add support for attn.scale when we move to Torch 2.1
560
+ if attn.training:
561
+ # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
562
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
563
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
564
+ # hidden_states = flash_attn_func(query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False)
565
+ else: # use vanilla attention during inference
566
+ with torch.autocast(enabled=True, device_type = 'cuda'):
567
+ q, k, v = query.float(), key.float(), value.float()
568
+ sim = (q @ k.transpose(-2, -1) * attn.scale)
569
+ if attention_mask is not None: # no mask in SDXL?
570
+ attention_mask = 1 + (attention_mask / -10000.0)
571
+ attention_mask = attention_mask.bool()
572
+ max_neg_value = -torch.finfo(sim.dtype).max
573
+ sim.masked_fill_(~attention_mask, max_neg_value)
574
+ sim = sim.softmax(dim=-1)
575
+ hidden_states = torch.einsum('b h i j, b h j d -> b h i d', sim, v)
576
+ extra_dict_outputs["text2image_crossmap_2d"] = rearrange(sim, "b head (h w) n -> b head n h w", h=height, w=width)
577
+
578
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
579
+ hidden_states = hidden_states.to(query.dtype)
580
+
581
+ if encoder_hidden_states_vision is not None and not multiple_reference_image: # single image
582
+ if image_prompt_mode == "naive":
583
+ key_vision = attn.to_k_vision(encoder_hidden_states_vision)
584
+ value_vision = attn.to_v_vision(encoder_hidden_states_vision)
585
+ key_vision = key_vision.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
586
+ value_vision = value_vision.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
587
+ hidden_states_vision = F.scaled_dot_product_attention(query, key_vision, value_vision, dropout_p=0.0, is_causal=False)
588
+ hidden_states_vision = hidden_states_vision.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
589
+ else:
590
+ hidden_states_vision = torch.zeros_like(hidden_states).to(hidden_states.device)
591
+
592
+ if vision_guided_mask is not None:
593
+ if vision_guided_mask.dim() == 4: # 所有层共用相同的mask
594
+ target_h, target_w = vision_guided_mask.size(-2), vision_guided_mask.size(-1)
595
+ vision_guided_mask = F.interpolate(vision_guided_mask.float(), scale_factor=height/target_h, mode='bilinear')
596
+ vision_guided_mask_1d = rearrange(vision_guided_mask, "b c h w -> b (h w) c")
597
+ hidden_states_vision = hidden_states_vision * vision_guided_mask_1d
598
+ else: # according to different self.cross_attention_id, 每一层用单独的mask
599
+ vision_guided_mask = vision_guided_mask[:, :, attn.cross_attention_id, :, :]
600
+ target_h, target_w = vision_guided_mask.size(-2), vision_guided_mask.size(-1)
601
+ vision_guided_mask = F.interpolate(vision_guided_mask, scale_factor=height/target_h, mode='bilinear')
602
+ vision_guided_mask_1d = rearrange(vision_guided_mask, "b c h w -> b (h w) c")
603
+ hidden_states_vision = hidden_states_vision * vision_guided_mask_1d
604
+
605
+ elif encoder_hidden_states_vision is not None and multiple_reference_image: # multiple image
606
+ if image_prompt_mode == "naive":
607
+ image_num = encoder_hidden_states_vision.size(1)
608
+ encoder_hidden_states_vision_list = encoder_hidden_states_vision.chunk(image_num, dim=1)
609
+ if vision_guided_mask is not None:
610
+ vision_guided_mask_list = vision_guided_mask.chunk(image_num, dim=1)
611
+ else:
612
+ vision_guided_mask_list = [None] * image_num
613
+ hidden_states_vision_results = []
614
+ for encoder_hidden_states_vision_i, vision_guided_mask_i in zip(encoder_hidden_states_vision_list, vision_guided_mask_list):
615
+ encoder_hidden_states_vision_i = encoder_hidden_states_vision_i.squeeze(1)
616
+
617
+ key_vision = attn.to_k_vision(encoder_hidden_states_vision_i)
618
+ value_vision = attn.to_v_vision(encoder_hidden_states_vision_i)
619
+ key_vision = key_vision.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
620
+ value_vision = value_vision.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
621
+ hidden_states_vision = F.scaled_dot_product_attention(query, key_vision, value_vision, dropout_p=0.0, is_causal=False)
622
+ hidden_states_vision = hidden_states_vision.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
623
+
624
+ if vision_guided_mask_i is not None:
625
+ target_h, target_w = vision_guided_mask_i.size(-2), vision_guided_mask_i.size(-1)
626
+ vision_guided_mask_i = F.interpolate(vision_guided_mask_i, scale_factor=height/target_h, mode='bilinear')
627
+ vision_guided_mask_1d_i = rearrange(vision_guided_mask_i, "b c h w -> b (h w) c")
628
+ hidden_states_vision = hidden_states_vision * vision_guided_mask_1d_i
629
+
630
+ hidden_states_vision_results.append(hidden_states_vision.unsqueeze(1))
631
+
632
+ hidden_states_vision = torch.cat(hidden_states_vision_results, dim=1).sum(dim=1)
633
+
634
+ else:
635
+ hidden_states_vision = torch.zeros_like(hidden_states).to(hidden_states.device)
636
+ else:
637
+ hidden_states_vision = torch.zeros_like(hidden_states).to(hidden_states.device)
638
+
639
+ hidden_states = hidden_states + hidden_states_vision
640
+
641
+ # linear proj
642
+ hidden_states = attn.to_out[0](hidden_states)
643
+ # dropout
644
+ hidden_states = attn.to_out[1](hidden_states)
645
+
646
+ if input_ndim == 4:
647
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
648
+
649
+ if attn.residual_connection:
650
+ hidden_states = hidden_states + residual
651
+
652
+ hidden_states = hidden_states / attn.rescale_output_factor
653
+
654
+ return hidden_states, extra_dict_outputs
655
+
656
+
657
+ class AttnProcessor:
658
+ r"""
659
+ Default processor for performing attention-related computations.
660
+ """
661
+
662
+ def __call__(
663
+ self,
664
+ attn: Attention,
665
+ hidden_states,
666
+ encoder_hidden_states=None,
667
+ attention_mask=None,
668
+ temb=None,
669
+ ):
670
+ residual = hidden_states
671
+
672
+ if attn.spatial_norm is not None:
673
+ hidden_states = attn.spatial_norm(hidden_states, temb)
674
+
675
+ input_ndim = hidden_states.ndim
676
+
677
+ if input_ndim == 4:
678
+ batch_size, channel, height, width = hidden_states.shape
679
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
680
+
681
+ batch_size, sequence_length, _ = (
682
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
683
+ )
684
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
685
+
686
+ if attn.group_norm is not None:
687
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
688
+
689
+ query = attn.to_q(hidden_states)
690
+
691
+ if encoder_hidden_states is None:
692
+ encoder_hidden_states = hidden_states
693
+ elif attn.norm_cross:
694
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
695
+
696
+ key = attn.to_k(encoder_hidden_states)
697
+ value = attn.to_v(encoder_hidden_states)
698
+
699
+ query = attn.head_to_batch_dim(query)
700
+ key = attn.head_to_batch_dim(key)
701
+ value = attn.head_to_batch_dim(value)
702
+
703
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
704
+ hidden_states = torch.bmm(attention_probs, value)
705
+ hidden_states = attn.batch_to_head_dim(hidden_states)
706
+
707
+ # linear proj
708
+ hidden_states = attn.to_out[0](hidden_states)
709
+ # dropout
710
+ hidden_states = attn.to_out[1](hidden_states)
711
+
712
+ if input_ndim == 4:
713
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
714
+
715
+ if attn.residual_connection:
716
+ hidden_states = hidden_states + residual
717
+
718
+ hidden_states = hidden_states / attn.rescale_output_factor
719
+
720
+ return hidden_states
721
+
722
+
723
+ class LoRALinearLayer(nn.Module):
724
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None):
725
+ super().__init__()
726
+
727
+ if rank > min(in_features, out_features):
728
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
729
+
730
+ self.down = nn.Linear(in_features, rank, bias=False)
731
+ self.up = nn.Linear(rank, out_features, bias=False)
732
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
733
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
734
+ self.network_alpha = network_alpha
735
+ self.rank = rank
736
+
737
+ nn.init.normal_(self.down.weight, std=1 / rank)
738
+ nn.init.zeros_(self.up.weight)
739
+
740
+ def forward(self, hidden_states):
741
+ orig_dtype = hidden_states.dtype
742
+ dtype = self.down.weight.dtype
743
+
744
+ down_hidden_states = self.down(hidden_states.to(dtype))
745
+ up_hidden_states = self.up(down_hidden_states)
746
+
747
+ if self.network_alpha is not None:
748
+ up_hidden_states *= self.network_alpha / self.rank
749
+
750
+ return up_hidden_states.to(orig_dtype)
751
+
752
+
753
+ class LoRAAttnProcessor(nn.Module):
754
+ r"""
755
+ Processor for implementing the LoRA attention mechanism.
756
+
757
+ Args:
758
+ hidden_size (`int`, *optional*):
759
+ The hidden size of the attention layer.
760
+ cross_attention_dim (`int`, *optional*):
761
+ The number of channels in the `encoder_hidden_states`.
762
+ rank (`int`, defaults to 4):
763
+ The dimension of the LoRA update matrices.
764
+ network_alpha (`int`, *optional*):
765
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
766
+ """
767
+
768
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
769
+ super().__init__()
770
+
771
+ self.hidden_size = hidden_size
772
+ self.cross_attention_dim = cross_attention_dim
773
+ self.rank = rank
774
+
775
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
776
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
777
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
778
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
779
+
780
+ def __call__(
781
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
782
+ ):
783
+ residual = hidden_states
784
+
785
+ if attn.spatial_norm is not None:
786
+ hidden_states = attn.spatial_norm(hidden_states, temb)
787
+
788
+ input_ndim = hidden_states.ndim
789
+
790
+ if input_ndim == 4:
791
+ batch_size, channel, height, width = hidden_states.shape
792
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
793
+
794
+ batch_size, sequence_length, _ = (
795
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
796
+ )
797
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
798
+
799
+ if attn.group_norm is not None:
800
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
801
+
802
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
803
+ query = attn.head_to_batch_dim(query)
804
+
805
+ if encoder_hidden_states is None:
806
+ encoder_hidden_states = hidden_states
807
+ elif attn.norm_cross:
808
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
809
+
810
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
811
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
812
+
813
+ key = attn.head_to_batch_dim(key)
814
+ value = attn.head_to_batch_dim(value)
815
+
816
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
817
+ hidden_states = torch.bmm(attention_probs, value)
818
+ hidden_states = attn.batch_to_head_dim(hidden_states)
819
+
820
+ # linear proj
821
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
822
+ # dropout
823
+ hidden_states = attn.to_out[1](hidden_states)
824
+
825
+ if input_ndim == 4:
826
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
827
+
828
+ if attn.residual_connection:
829
+ hidden_states = hidden_states + residual
830
+
831
+ hidden_states = hidden_states / attn.rescale_output_factor
832
+
833
+ return hidden_states
834
+
835
+
836
+ class CustomDiffusionAttnProcessor(nn.Module):
837
+ r"""
838
+ Processor for implementing attention for the Custom Diffusion method.
839
+
840
+ Args:
841
+ train_kv (`bool`, defaults to `True`):
842
+ Whether to newly train the key and value matrices corresponding to the text features.
843
+ train_q_out (`bool`, defaults to `True`):
844
+ Whether to newly train query matrices corresponding to the latent image features.
845
+ hidden_size (`int`, *optional*, defaults to `None`):
846
+ The hidden size of the attention layer.
847
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
848
+ The number of channels in the `encoder_hidden_states`.
849
+ out_bias (`bool`, defaults to `True`):
850
+ Whether to include the bias parameter in `train_q_out`.
851
+ dropout (`float`, *optional*, defaults to 0.0):
852
+ The dropout probability to use.
853
+ """
854
+
855
+ def __init__(
856
+ self,
857
+ train_kv=True,
858
+ train_q_out=True,
859
+ hidden_size=None,
860
+ cross_attention_dim=None,
861
+ out_bias=True,
862
+ dropout=0.0,
863
+ ):
864
+ super().__init__()
865
+ self.train_kv = train_kv
866
+ self.train_q_out = train_q_out
867
+
868
+ self.hidden_size = hidden_size
869
+ self.cross_attention_dim = cross_attention_dim
870
+
871
+ # `_custom_diffusion` id for easy serialization and loading.
872
+ if self.train_kv:
873
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
874
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
875
+ if self.train_q_out:
876
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
877
+ self.to_out_custom_diffusion = nn.ModuleList([])
878
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
879
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
880
+
881
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
882
+ batch_size, sequence_length, _ = hidden_states.shape
883
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
884
+ if self.train_q_out:
885
+ query = self.to_q_custom_diffusion(hidden_states)
886
+ else:
887
+ query = attn.to_q(hidden_states)
888
+
889
+ if encoder_hidden_states is None:
890
+ crossattn = False
891
+ encoder_hidden_states = hidden_states
892
+ else:
893
+ crossattn = True
894
+ if attn.norm_cross:
895
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
896
+
897
+ if self.train_kv:
898
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
899
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
900
+ else:
901
+ key = attn.to_k(encoder_hidden_states)
902
+ value = attn.to_v(encoder_hidden_states)
903
+
904
+ if crossattn:
905
+ detach = torch.ones_like(key)
906
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
907
+ key = detach * key + (1 - detach) * key.detach()
908
+ value = detach * value + (1 - detach) * value.detach()
909
+
910
+ query = attn.head_to_batch_dim(query)
911
+ key = attn.head_to_batch_dim(key)
912
+ value = attn.head_to_batch_dim(value)
913
+
914
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
915
+ hidden_states = torch.bmm(attention_probs, value)
916
+ hidden_states = attn.batch_to_head_dim(hidden_states)
917
+
918
+ if self.train_q_out:
919
+ # linear proj
920
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
921
+ # dropout
922
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
923
+ else:
924
+ # linear proj
925
+ hidden_states = attn.to_out[0](hidden_states)
926
+ # dropout
927
+ hidden_states = attn.to_out[1](hidden_states)
928
+
929
+ return hidden_states
930
+
931
+
932
+ class AttnAddedKVProcessor:
933
+ r"""
934
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
935
+ encoder.
936
+ """
937
+
938
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
939
+ residual = hidden_states
940
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
941
+ batch_size, sequence_length, _ = hidden_states.shape
942
+
943
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
944
+
945
+ if encoder_hidden_states is None:
946
+ encoder_hidden_states = hidden_states
947
+ elif attn.norm_cross:
948
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
949
+
950
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
951
+
952
+ query = attn.to_q(hidden_states)
953
+ query = attn.head_to_batch_dim(query)
954
+
955
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
956
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
957
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
958
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
959
+
960
+ if not attn.only_cross_attention:
961
+ key = attn.to_k(hidden_states)
962
+ value = attn.to_v(hidden_states)
963
+ key = attn.head_to_batch_dim(key)
964
+ value = attn.head_to_batch_dim(value)
965
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
966
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
967
+ else:
968
+ key = encoder_hidden_states_key_proj
969
+ value = encoder_hidden_states_value_proj
970
+
971
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
972
+ hidden_states = torch.bmm(attention_probs, value)
973
+ hidden_states = attn.batch_to_head_dim(hidden_states)
974
+
975
+ # linear proj
976
+ hidden_states = attn.to_out[0](hidden_states)
977
+ # dropout
978
+ hidden_states = attn.to_out[1](hidden_states)
979
+
980
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
981
+ hidden_states = hidden_states + residual
982
+
983
+ return hidden_states
984
+
985
+
986
+ class AttnAddedKVProcessor2_0:
987
+ r"""
988
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
989
+ learnable key and value matrices for the text encoder.
990
+ """
991
+
992
+ def __init__(self):
993
+ if not hasattr(F, "scaled_dot_product_attention"):
994
+ raise ImportError(
995
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
996
+ )
997
+
998
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
999
+ residual = hidden_states
1000
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1001
+ batch_size, sequence_length, _ = hidden_states.shape
1002
+
1003
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
1004
+
1005
+ if encoder_hidden_states is None:
1006
+ encoder_hidden_states = hidden_states
1007
+ elif attn.norm_cross:
1008
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1009
+
1010
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1011
+
1012
+ query = attn.to_q(hidden_states)
1013
+ query = attn.head_to_batch_dim(query, out_dim=4)
1014
+
1015
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1016
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1017
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
1018
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
1019
+
1020
+ if not attn.only_cross_attention:
1021
+ key = attn.to_k(hidden_states)
1022
+ value = attn.to_v(hidden_states)
1023
+ key = attn.head_to_batch_dim(key, out_dim=4)
1024
+ value = attn.head_to_batch_dim(value, out_dim=4)
1025
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1026
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1027
+ else:
1028
+ key = encoder_hidden_states_key_proj
1029
+ value = encoder_hidden_states_value_proj
1030
+
1031
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1032
+ # TODO: add support for attn.scale when we move to Torch 2.1
1033
+ hidden_states = F.scaled_dot_product_attention(
1034
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1035
+ )
1036
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
1037
+
1038
+ # linear proj
1039
+ hidden_states = attn.to_out[0](hidden_states)
1040
+ # dropout
1041
+ hidden_states = attn.to_out[1](hidden_states)
1042
+
1043
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1044
+ hidden_states = hidden_states + residual
1045
+
1046
+ return hidden_states
1047
+
1048
+
1049
+ class LoRAAttnAddedKVProcessor(nn.Module):
1050
+ r"""
1051
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
1052
+ encoder.
1053
+
1054
+ Args:
1055
+ hidden_size (`int`, *optional*):
1056
+ The hidden size of the attention layer.
1057
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1058
+ The number of channels in the `encoder_hidden_states`.
1059
+ rank (`int`, defaults to 4):
1060
+ The dimension of the LoRA update matrices.
1061
+
1062
+ """
1063
+
1064
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1065
+ super().__init__()
1066
+
1067
+ self.hidden_size = hidden_size
1068
+ self.cross_attention_dim = cross_attention_dim
1069
+ self.rank = rank
1070
+
1071
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1072
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1073
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1074
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1075
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1076
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1077
+
1078
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1079
+ residual = hidden_states
1080
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1081
+ batch_size, sequence_length, _ = hidden_states.shape
1082
+
1083
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1084
+
1085
+ if encoder_hidden_states is None:
1086
+ encoder_hidden_states = hidden_states
1087
+ elif attn.norm_cross:
1088
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1089
+
1090
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1091
+
1092
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1093
+ query = attn.head_to_batch_dim(query)
1094
+
1095
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
1096
+ encoder_hidden_states
1097
+ )
1098
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
1099
+ encoder_hidden_states
1100
+ )
1101
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1102
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1103
+
1104
+ if not attn.only_cross_attention:
1105
+ key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
1106
+ value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
1107
+ key = attn.head_to_batch_dim(key)
1108
+ value = attn.head_to_batch_dim(value)
1109
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1110
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1111
+ else:
1112
+ key = encoder_hidden_states_key_proj
1113
+ value = encoder_hidden_states_value_proj
1114
+
1115
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1116
+ hidden_states = torch.bmm(attention_probs, value)
1117
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1118
+
1119
+ # linear proj
1120
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1121
+ # dropout
1122
+ hidden_states = attn.to_out[1](hidden_states)
1123
+
1124
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1125
+ hidden_states = hidden_states + residual
1126
+
1127
+ return hidden_states
1128
+
1129
+
1130
+ class XFormersAttnAddedKVProcessor:
1131
+ r"""
1132
+ Processor for implementing memory efficient attention using xFormers.
1133
+
1134
+ Args:
1135
+ attention_op (`Callable`, *optional*, defaults to `None`):
1136
+ The base
1137
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1138
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1139
+ operator.
1140
+ """
1141
+
1142
+ def __init__(self, attention_op: Optional[Callable] = None):
1143
+ self.attention_op = attention_op
1144
+
1145
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1146
+ residual = hidden_states
1147
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1148
+ batch_size, sequence_length, _ = hidden_states.shape
1149
+
1150
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1151
+
1152
+ if encoder_hidden_states is None:
1153
+ encoder_hidden_states = hidden_states
1154
+ elif attn.norm_cross:
1155
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1156
+
1157
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1158
+
1159
+ query = attn.to_q(hidden_states)
1160
+ query = attn.head_to_batch_dim(query)
1161
+
1162
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1163
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1164
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1165
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1166
+
1167
+ if not attn.only_cross_attention:
1168
+ key = attn.to_k(hidden_states)
1169
+ value = attn.to_v(hidden_states)
1170
+ key = attn.head_to_batch_dim(key)
1171
+ value = attn.head_to_batch_dim(value)
1172
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1173
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1174
+ else:
1175
+ key = encoder_hidden_states_key_proj
1176
+ value = encoder_hidden_states_value_proj
1177
+
1178
+ hidden_states = xformers.ops.memory_efficient_attention(
1179
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1180
+ )
1181
+ hidden_states = hidden_states.to(query.dtype)
1182
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1183
+
1184
+ # linear proj
1185
+ hidden_states = attn.to_out[0](hidden_states)
1186
+ # dropout
1187
+ hidden_states = attn.to_out[1](hidden_states)
1188
+
1189
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1190
+ hidden_states = hidden_states + residual
1191
+
1192
+ return hidden_states
1193
+
1194
+
1195
+ class XFormersAttnProcessor:
1196
+ r"""
1197
+ Processor for implementing memory efficient attention using xFormers.
1198
+
1199
+ Args:
1200
+ attention_op (`Callable`, *optional*, defaults to `None`):
1201
+ The base
1202
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1203
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1204
+ operator.
1205
+ """
1206
+
1207
+ def __init__(self, attention_op: Optional[Callable] = None):
1208
+ self.attention_op = attention_op
1209
+
1210
+ def __call__(
1211
+ self,
1212
+ attn: Attention,
1213
+ hidden_states: torch.FloatTensor,
1214
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1215
+ attention_mask: Optional[torch.FloatTensor] = None,
1216
+ temb: Optional[torch.FloatTensor] = None,
1217
+ ):
1218
+ residual = hidden_states
1219
+
1220
+ if attn.spatial_norm is not None:
1221
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1222
+
1223
+ input_ndim = hidden_states.ndim
1224
+
1225
+ if input_ndim == 4:
1226
+ batch_size, channel, height, width = hidden_states.shape
1227
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1228
+
1229
+ batch_size, key_tokens, _ = (
1230
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1231
+ )
1232
+
1233
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1234
+ if attention_mask is not None:
1235
+ # expand our mask's singleton query_tokens dimension:
1236
+ # [batch*heads, 1, key_tokens] ->
1237
+ # [batch*heads, query_tokens, key_tokens]
1238
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1239
+ # [batch*heads, query_tokens, key_tokens]
1240
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1241
+ _, query_tokens, _ = hidden_states.shape
1242
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1243
+
1244
+ if attn.group_norm is not None:
1245
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1246
+
1247
+ query = attn.to_q(hidden_states)
1248
+
1249
+ if encoder_hidden_states is None:
1250
+ encoder_hidden_states = hidden_states
1251
+ elif attn.norm_cross:
1252
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1253
+
1254
+ key = attn.to_k(encoder_hidden_states)
1255
+ value = attn.to_v(encoder_hidden_states)
1256
+
1257
+ query = attn.head_to_batch_dim(query).contiguous()
1258
+ key = attn.head_to_batch_dim(key).contiguous()
1259
+ value = attn.head_to_batch_dim(value).contiguous()
1260
+
1261
+ hidden_states = xformers.ops.memory_efficient_attention(
1262
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1263
+ )
1264
+ hidden_states = hidden_states.to(query.dtype)
1265
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1266
+
1267
+ # linear proj
1268
+ hidden_states = attn.to_out[0](hidden_states)
1269
+ # dropout
1270
+ hidden_states = attn.to_out[1](hidden_states)
1271
+
1272
+ if input_ndim == 4:
1273
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1274
+
1275
+ if attn.residual_connection:
1276
+ hidden_states = hidden_states + residual
1277
+
1278
+ hidden_states = hidden_states / attn.rescale_output_factor
1279
+
1280
+ return hidden_states
1281
+
1282
+
1283
+ class LoRAXFormersAttnProcessor(nn.Module):
1284
+ r"""
1285
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1286
+
1287
+ Args:
1288
+ hidden_size (`int`, *optional*):
1289
+ The hidden size of the attention layer.
1290
+ cross_attention_dim (`int`, *optional*):
1291
+ The number of channels in the `encoder_hidden_states`.
1292
+ rank (`int`, defaults to 4):
1293
+ The dimension of the LoRA update matrices.
1294
+ attention_op (`Callable`, *optional*, defaults to `None`):
1295
+ The base
1296
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1297
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1298
+ operator.
1299
+ network_alpha (`int`, *optional*):
1300
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1301
+
1302
+ """
1303
+
1304
+ def __init__(
1305
+ self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
1306
+ ):
1307
+ super().__init__()
1308
+
1309
+ self.hidden_size = hidden_size
1310
+ self.cross_attention_dim = cross_attention_dim
1311
+ self.rank = rank
1312
+ self.attention_op = attention_op
1313
+
1314
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1315
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1316
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1317
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1318
+
1319
+ def __call__(
1320
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1321
+ ):
1322
+ residual = hidden_states
1323
+
1324
+ if attn.spatial_norm is not None:
1325
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1326
+
1327
+ input_ndim = hidden_states.ndim
1328
+
1329
+ if input_ndim == 4:
1330
+ batch_size, channel, height, width = hidden_states.shape
1331
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1332
+
1333
+ batch_size, sequence_length, _ = (
1334
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1335
+ )
1336
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1337
+
1338
+ if attn.group_norm is not None:
1339
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1340
+
1341
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1342
+ query = attn.head_to_batch_dim(query).contiguous()
1343
+
1344
+ if encoder_hidden_states is None:
1345
+ encoder_hidden_states = hidden_states
1346
+ elif attn.norm_cross:
1347
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1348
+
1349
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1350
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1351
+
1352
+ key = attn.head_to_batch_dim(key).contiguous()
1353
+ value = attn.head_to_batch_dim(value).contiguous()
1354
+
1355
+ hidden_states = xformers.ops.memory_efficient_attention(
1356
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1357
+ )
1358
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1359
+
1360
+ # linear proj
1361
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1362
+ # dropout
1363
+ hidden_states = attn.to_out[1](hidden_states)
1364
+
1365
+ if input_ndim == 4:
1366
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1367
+
1368
+ if attn.residual_connection:
1369
+ hidden_states = hidden_states + residual
1370
+
1371
+ hidden_states = hidden_states / attn.rescale_output_factor
1372
+
1373
+ return hidden_states
1374
+
1375
+
1376
+ class LoRAAttnProcessor2_0(nn.Module):
1377
+ r"""
1378
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1379
+ attention.
1380
+
1381
+ Args:
1382
+ hidden_size (`int`):
1383
+ The hidden size of the attention layer.
1384
+ cross_attention_dim (`int`, *optional*):
1385
+ The number of channels in the `encoder_hidden_states`.
1386
+ rank (`int`, defaults to 4):
1387
+ The dimension of the LoRA update matrices.
1388
+ network_alpha (`int`, *optional*):
1389
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1390
+ """
1391
+
1392
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1393
+ super().__init__()
1394
+ if not hasattr(F, "scaled_dot_product_attention"):
1395
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1396
+
1397
+ self.hidden_size = hidden_size
1398
+ self.cross_attention_dim = cross_attention_dim
1399
+ self.rank = rank
1400
+
1401
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1402
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1403
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1404
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1405
+
1406
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1407
+ residual = hidden_states
1408
+
1409
+ input_ndim = hidden_states.ndim
1410
+
1411
+ if input_ndim == 4:
1412
+ batch_size, channel, height, width = hidden_states.shape
1413
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1414
+
1415
+ batch_size, sequence_length, _ = (
1416
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1417
+ )
1418
+ inner_dim = hidden_states.shape[-1]
1419
+
1420
+ if attention_mask is not None:
1421
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1422
+ # scaled_dot_product_attention expects attention_mask shape to be
1423
+ # (batch, heads, source_length, target_length)
1424
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1425
+
1426
+ if attn.group_norm is not None:
1427
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1428
+
1429
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1430
+
1431
+ if encoder_hidden_states is None:
1432
+ encoder_hidden_states = hidden_states
1433
+ elif attn.norm_cross:
1434
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1435
+
1436
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1437
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1438
+
1439
+ head_dim = inner_dim // attn.heads
1440
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1441
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1442
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1443
+
1444
+ # TODO: add support for attn.scale when we move to Torch 2.1
1445
+ hidden_states = F.scaled_dot_product_attention(
1446
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1447
+ )
1448
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1449
+ hidden_states = hidden_states.to(query.dtype)
1450
+
1451
+ # linear proj
1452
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1453
+ # dropout
1454
+ hidden_states = attn.to_out[1](hidden_states)
1455
+
1456
+ if input_ndim == 4:
1457
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1458
+
1459
+ if attn.residual_connection:
1460
+ hidden_states = hidden_states + residual
1461
+
1462
+ hidden_states = hidden_states / attn.rescale_output_factor
1463
+
1464
+ return hidden_states
1465
+
1466
+
1467
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1468
+ r"""
1469
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1470
+
1471
+ Args:
1472
+ train_kv (`bool`, defaults to `True`):
1473
+ Whether to newly train the key and value matrices corresponding to the text features.
1474
+ train_q_out (`bool`, defaults to `True`):
1475
+ Whether to newly train query matrices corresponding to the latent image features.
1476
+ hidden_size (`int`, *optional*, defaults to `None`):
1477
+ The hidden size of the attention layer.
1478
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1479
+ The number of channels in the `encoder_hidden_states`.
1480
+ out_bias (`bool`, defaults to `True`):
1481
+ Whether to include the bias parameter in `train_q_out`.
1482
+ dropout (`float`, *optional*, defaults to 0.0):
1483
+ The dropout probability to use.
1484
+ attention_op (`Callable`, *optional*, defaults to `None`):
1485
+ The base
1486
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1487
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1488
+ """
1489
+
1490
+ def __init__(
1491
+ self,
1492
+ train_kv=True,
1493
+ train_q_out=False,
1494
+ hidden_size=None,
1495
+ cross_attention_dim=None,
1496
+ out_bias=True,
1497
+ dropout=0.0,
1498
+ attention_op: Optional[Callable] = None,
1499
+ ):
1500
+ super().__init__()
1501
+ self.train_kv = train_kv
1502
+ self.train_q_out = train_q_out
1503
+
1504
+ self.hidden_size = hidden_size
1505
+ self.cross_attention_dim = cross_attention_dim
1506
+ self.attention_op = attention_op
1507
+
1508
+ # `_custom_diffusion` id for easy serialization and loading.
1509
+ if self.train_kv:
1510
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1511
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1512
+ if self.train_q_out:
1513
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1514
+ self.to_out_custom_diffusion = nn.ModuleList([])
1515
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1516
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1517
+
1518
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1519
+ batch_size, sequence_length, _ = (
1520
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1521
+ )
1522
+
1523
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1524
+
1525
+ if self.train_q_out:
1526
+ query = self.to_q_custom_diffusion(hidden_states)
1527
+ else:
1528
+ query = attn.to_q(hidden_states)
1529
+
1530
+ if encoder_hidden_states is None:
1531
+ crossattn = False
1532
+ encoder_hidden_states = hidden_states
1533
+ else:
1534
+ crossattn = True
1535
+ if attn.norm_cross:
1536
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1537
+
1538
+ if self.train_kv:
1539
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
1540
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
1541
+ else:
1542
+ key = attn.to_k(encoder_hidden_states)
1543
+ value = attn.to_v(encoder_hidden_states)
1544
+
1545
+ if crossattn:
1546
+ detach = torch.ones_like(key)
1547
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1548
+ key = detach * key + (1 - detach) * key.detach()
1549
+ value = detach * value + (1 - detach) * value.detach()
1550
+
1551
+ query = attn.head_to_batch_dim(query).contiguous()
1552
+ key = attn.head_to_batch_dim(key).contiguous()
1553
+ value = attn.head_to_batch_dim(value).contiguous()
1554
+
1555
+ hidden_states = xformers.ops.memory_efficient_attention(
1556
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1557
+ )
1558
+ hidden_states = hidden_states.to(query.dtype)
1559
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1560
+
1561
+ if self.train_q_out:
1562
+ # linear proj
1563
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1564
+ # dropout
1565
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1566
+ else:
1567
+ # linear proj
1568
+ hidden_states = attn.to_out[0](hidden_states)
1569
+ # dropout
1570
+ hidden_states = attn.to_out[1](hidden_states)
1571
+ return hidden_states
1572
+
1573
+
1574
+ class SlicedAttnProcessor:
1575
+ r"""
1576
+ Processor for implementing sliced attention.
1577
+
1578
+ Args:
1579
+ slice_size (`int`, *optional*):
1580
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1581
+ `attention_head_dim` must be a multiple of the `slice_size`.
1582
+ """
1583
+
1584
+ def __init__(self, slice_size):
1585
+ self.slice_size = slice_size
1586
+
1587
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1588
+ residual = hidden_states
1589
+
1590
+ input_ndim = hidden_states.ndim
1591
+
1592
+ if input_ndim == 4:
1593
+ batch_size, channel, height, width = hidden_states.shape
1594
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1595
+
1596
+ batch_size, sequence_length, _ = (
1597
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1598
+ )
1599
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1600
+
1601
+ if attn.group_norm is not None:
1602
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1603
+
1604
+ query = attn.to_q(hidden_states)
1605
+ dim = query.shape[-1]
1606
+ query = attn.head_to_batch_dim(query)
1607
+
1608
+ if encoder_hidden_states is None:
1609
+ encoder_hidden_states = hidden_states
1610
+ elif attn.norm_cross:
1611
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1612
+
1613
+ key = attn.to_k(encoder_hidden_states)
1614
+ value = attn.to_v(encoder_hidden_states)
1615
+ key = attn.head_to_batch_dim(key)
1616
+ value = attn.head_to_batch_dim(value)
1617
+
1618
+ batch_size_attention, query_tokens, _ = query.shape
1619
+ hidden_states = torch.zeros(
1620
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1621
+ )
1622
+
1623
+ for i in range(batch_size_attention // self.slice_size):
1624
+ start_idx = i * self.slice_size
1625
+ end_idx = (i + 1) * self.slice_size
1626
+
1627
+ query_slice = query[start_idx:end_idx]
1628
+ key_slice = key[start_idx:end_idx]
1629
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1630
+
1631
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1632
+
1633
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1634
+
1635
+ hidden_states[start_idx:end_idx] = attn_slice
1636
+
1637
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1638
+
1639
+ # linear proj
1640
+ hidden_states = attn.to_out[0](hidden_states)
1641
+ # dropout
1642
+ hidden_states = attn.to_out[1](hidden_states)
1643
+
1644
+ if input_ndim == 4:
1645
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1646
+
1647
+ if attn.residual_connection:
1648
+ hidden_states = hidden_states + residual
1649
+
1650
+ hidden_states = hidden_states / attn.rescale_output_factor
1651
+
1652
+ return hidden_states
1653
+
1654
+
1655
+ class SlicedAttnAddedKVProcessor:
1656
+ r"""
1657
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1658
+
1659
+ Args:
1660
+ slice_size (`int`, *optional*):
1661
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1662
+ `attention_head_dim` must be a multiple of the `slice_size`.
1663
+ """
1664
+
1665
+ def __init__(self, slice_size):
1666
+ self.slice_size = slice_size
1667
+
1668
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1669
+ residual = hidden_states
1670
+
1671
+ if attn.spatial_norm is not None:
1672
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1673
+
1674
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1675
+
1676
+ batch_size, sequence_length, _ = hidden_states.shape
1677
+
1678
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1679
+
1680
+ if encoder_hidden_states is None:
1681
+ encoder_hidden_states = hidden_states
1682
+ elif attn.norm_cross:
1683
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1684
+
1685
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1686
+
1687
+ query = attn.to_q(hidden_states)
1688
+ dim = query.shape[-1]
1689
+ query = attn.head_to_batch_dim(query)
1690
+
1691
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1692
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1693
+
1694
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1695
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1696
+
1697
+ if not attn.only_cross_attention:
1698
+ key = attn.to_k(hidden_states)
1699
+ value = attn.to_v(hidden_states)
1700
+ key = attn.head_to_batch_dim(key)
1701
+ value = attn.head_to_batch_dim(value)
1702
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1703
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1704
+ else:
1705
+ key = encoder_hidden_states_key_proj
1706
+ value = encoder_hidden_states_value_proj
1707
+
1708
+ batch_size_attention, query_tokens, _ = query.shape
1709
+ hidden_states = torch.zeros(
1710
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1711
+ )
1712
+
1713
+ for i in range(batch_size_attention // self.slice_size):
1714
+ start_idx = i * self.slice_size
1715
+ end_idx = (i + 1) * self.slice_size
1716
+
1717
+ query_slice = query[start_idx:end_idx]
1718
+ key_slice = key[start_idx:end_idx]
1719
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1720
+
1721
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1722
+
1723
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1724
+
1725
+ hidden_states[start_idx:end_idx] = attn_slice
1726
+
1727
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1728
+
1729
+ # linear proj
1730
+ hidden_states = attn.to_out[0](hidden_states)
1731
+ # dropout
1732
+ hidden_states = attn.to_out[1](hidden_states)
1733
+
1734
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1735
+ hidden_states = hidden_states + residual
1736
+
1737
+ return hidden_states
1738
+
1739
+
1740
+ AttentionProcessor = Union[
1741
+ AttnProcessor,
1742
+ AttnProcessor2_0_image_prompt,
1743
+ XFormersAttnProcessor,
1744
+ SlicedAttnProcessor,
1745
+ AttnAddedKVProcessor,
1746
+ SlicedAttnAddedKVProcessor,
1747
+ AttnAddedKVProcessor2_0,
1748
+ XFormersAttnAddedKVProcessor,
1749
+ LoRAAttnProcessor,
1750
+ LoRAXFormersAttnProcessor,
1751
+ LoRAAttnProcessor2_0,
1752
+ LoRAAttnAddedKVProcessor,
1753
+ CustomDiffusionAttnProcessor,
1754
+ CustomDiffusionXFormersAttnProcessor,
1755
+ ]
1756
+
1757
+
1758
+ class SpatialNorm(nn.Module):
1759
+ """
1760
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1761
+ """
1762
+
1763
+ def __init__(
1764
+ self,
1765
+ f_channels,
1766
+ zq_channels,
1767
+ ):
1768
+ super().__init__()
1769
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1770
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1771
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1772
+
1773
+ def forward(self, f, zq):
1774
+ f_size = f.shape[-2:]
1775
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1776
+ norm_f = self.norm_layer(f)
1777
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1778
+ return new_f
models/base_vision.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ base_vision.py
17
+
18
+ Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
19
+ functions, and initialization logic.
20
+
21
+ We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
22
+ Transformer model for feature extraction.
23
+ """
24
+ from abc import ABC, abstractmethod
25
+ from dataclasses import dataclass
26
+ from functools import partial
27
+ from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
28
+
29
+ import timm
30
+ import torch
31
+ import torch.nn as nn
32
+ import torchvision.transforms.functional as TVF
33
+ from PIL.Image import Image
34
+ from timm.models.vision_transformer import Block, VisionTransformer
35
+ from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
36
+ from torchvision.transforms import Compose, Resize
37
+
38
+
39
+ # === Utility Functions for Monkey-Patching ===
40
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
41
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
42
+ result = fn(*args, **kwargs)
43
+ return result[0] if (isinstance(result, tuple) or isinstance(result, list)) else result
44
+
45
+ return wrapper
46
+
47
+ def return_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
48
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
49
+ result = fn(*args, **kwargs)
50
+ return result
51
+
52
+ return wrapper
53
+
54
+
55
+ # === Interface for an Image Transform ===
56
+ class ImageTransform(Protocol):
57
+ def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ...
58
+
59
+
60
+ # === Custom Torchvision Image Transforms ===
61
+ @dataclass
62
+ class LetterboxPad:
63
+ padding_fill_value: Tuple[int, int, int]
64
+
65
+ def __call__(self, image: Image) -> Image:
66
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
67
+ (w, h), max_wh = image.size, max(image.size)
68
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
69
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
70
+ return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant")
71
+
72
+
73
+ # === Abstract Base Class for arbitrary Vision Backbones ===
74
+ class VisionBackbone(nn.Module, ABC):
75
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
76
+ super().__init__()
77
+ self.identifier: str = vision_backbone_id
78
+ self.image_resize_strategy: str = image_resize_strategy
79
+ self.default_image_size: int = default_image_size
80
+
81
+ # Instance attributes for a Vision Backbone
82
+ self.featurizer: nn.Module = None
83
+ self.image_transform: ImageTransform = None
84
+
85
+ def get_image_transform(self) -> ImageTransform:
86
+ return self.image_transform
87
+
88
+ @abstractmethod
89
+ def get_fsdp_wrapping_policy(self) -> Callable: ...
90
+
91
+ @abstractmethod
92
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
93
+ """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
94
+ raise NotImplementedError
95
+
96
+ @property
97
+ @abstractmethod
98
+ def default_image_resolution(self) -> Tuple[int, int, int]: ...
99
+
100
+ @property
101
+ @abstractmethod
102
+ def embed_dim(self) -> int: ...
103
+
104
+ @property
105
+ @abstractmethod
106
+ def num_patches(self) -> int: ...
107
+
108
+ @property
109
+ @abstractmethod
110
+ def half_precision_dtype(self) -> torch.dtype: ...
111
+
112
+
113
+ # === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
114
+ class TimmViTBackbone(VisionBackbone, ABC):
115
+ def __init__(
116
+ self,
117
+ vision_backbone_id: str,
118
+ timm_path_or_url: str,
119
+ image_resize_strategy: str,
120
+ default_image_size: int = 224,
121
+ override_act_layer: Optional[str] = None,
122
+ ) -> None:
123
+ super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
124
+ self.timm_path_or_url = timm_path_or_url
125
+ self.override_act_layer = override_act_layer
126
+ self.dtype = torch.bfloat16
127
+
128
+ # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
129
+ if self.override_act_layer is None:
130
+ self.featurizer: VisionTransformer = timm.create_model(
131
+ self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size,
132
+ )
133
+ else:
134
+ self.featurizer: VisionTransformer = timm.create_model(
135
+ self.timm_path_or_url,
136
+ pretrained=True,
137
+ num_classes=0,
138
+ img_size=self.default_image_size,
139
+ act_layer=self.override_act_layer,
140
+ )
141
+ self.featurizer.eval()
142
+
143
+ # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
144
+ # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
145
+ # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
146
+ self.featurizer.forward = unpack_tuple(
147
+ partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
148
+ )
149
+
150
+ # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
151
+ assert isinstance(self.featurizer, VisionTransformer), (
152
+ "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
153
+ "file an issue or implement the requisite logic (see `cobra/models/backbones/vision/base_vision.py`)!"
154
+ )
155
+
156
+ # Get Config =>> Note :: Override default image size to ensure correct image transform
157
+ self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
158
+ self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
159
+
160
+ # Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
161
+ default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False)
162
+
163
+ # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
164
+ if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
165
+ assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
166
+ assert isinstance(resize_transform := default_image_transform.transforms[0], Resize)
167
+ default_image_transform = Compose(
168
+ [
169
+ Resize(self.default_image_size, interpolation=resize_transform.interpolation),
170
+ *default_image_transform.transforms[1:],
171
+ ]
172
+ )
173
+
174
+ # Switch on `image_resize_strategy`
175
+ if self.image_resize_strategy == "resize-naive":
176
+ assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
177
+ assert isinstance(resize_transform := default_image_transform.transforms[0], Resize)
178
+
179
+ target_size = (self.default_image_size, self.default_image_size)
180
+ self.image_transform = Compose(
181
+ [
182
+ Resize(target_size, interpolation=resize_transform.interpolation),
183
+ *default_image_transform.transforms[1:],
184
+ ]
185
+ )
186
+
187
+ elif self.image_resize_strategy == "resize-crop":
188
+ self.image_transform = default_image_transform
189
+
190
+ elif self.image_resize_strategy == "letterbox":
191
+ assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
192
+ assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
193
+
194
+ # Compute Padding Fill Value (rescaled normalization mean if applicable)
195
+ fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
196
+
197
+ # Build New Transform
198
+ self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms])
199
+
200
+ else:
201
+ raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
202
+
203
+ def get_fsdp_wrapping_policy(self) -> Callable:
204
+ """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
205
+ vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
206
+ transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
207
+ return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
208
+
209
+ def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
210
+ """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
211
+ return self.featurizer(pixel_values)
212
+
213
+ @property
214
+ def default_image_resolution(self) -> Tuple[int, int, int]:
215
+ return self.data_cfg["input_size"]
216
+
217
+ @property
218
+ def embed_dim(self) -> int:
219
+ return self.featurizer.embed_dim
220
+
221
+ @property
222
+ def num_patches(self) -> int:
223
+ return self.featurizer.patch_embed.num_patches
224
+
225
+ @property
226
+ def half_precision_dtype(self) -> torch.dtype:
227
+ return self.dtype
models/dino.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ dinosiglip_vit.py
17
+
18
+ Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
19
+ """
20
+ from dataclasses import dataclass
21
+ from functools import partial
22
+ from typing import Callable, Dict, Tuple
23
+ import os
24
+ import timm
25
+ import torch
26
+ from PIL import Image
27
+ from einops import rearrange
28
+ from timm.models.vision_transformer import Block, VisionTransformer
29
+ from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
30
+ from torchvision.transforms import Compose, Resize
31
+
32
+ from models.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple, return_tuple
33
+
34
+ import torch.nn as nn
35
+ import torchvision
36
+
37
+ @dataclass
38
+ class DinoSigLIPImageTransform:
39
+ dino_image_transform: ImageTransform
40
+ siglip_image_transform: ImageTransform
41
+ is_cobra: bool = True
42
+
43
+ def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
44
+ return {"dino": self.dino_image_transform(img, **kwargs).unsqueeze(0), "siglip": self.siglip_image_transform(img, **kwargs).unsqueeze(0)}
45
+
46
+
47
+ class DinoViTBackbone(VisionBackbone):
48
+ def __init__(self, backbone_name_or_path, image_resize_strategy: str, default_image_size: int = 224, last_n = 2, feature_index = 22) -> None:
49
+ super().__init__(backbone_name_or_path, image_resize_strategy, default_image_size=default_image_size)
50
+ # load from local paths
51
+ dino_pretrained_cfg = timm.models.create_model(backbone_name_or_path).default_cfg
52
+ dino_pretrained_cfg['file'] = 'ckpts/vit_large_patch14_reg4_dinov2.lvd142m/pytorch_model.bin'
53
+
54
+ # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
55
+ self.dino_featurizer: VisionTransformer = timm.create_model(
56
+ backbone_name_or_path, pretrained=True, num_classes=0, img_size=self.default_image_size,
57
+ pretrained_cfg=dino_pretrained_cfg
58
+ )
59
+ self.dino_featurizer.eval()
60
+
61
+ # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
62
+ # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
63
+ # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
64
+ # return the output tokens from the `n` last blocks
65
+ print("dino has {} layer intermediate features. ".format(len(self.dino_featurizer.blocks))) # 24
66
+ # self.dino_featurizer.forward = unpack_tuple(
67
+ # partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - last_n})
68
+ # )
69
+ if isinstance(feature_index, tuple) or isinstance(feature_index, list):
70
+ feature_index = set(feature_index)
71
+ else:
72
+ feature_index = {feature_index}
73
+ self.dino_featurizer.forward = return_tuple(
74
+ partial(self.dino_featurizer.get_intermediate_layers, n=feature_index)
75
+ )
76
+
77
+ # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
78
+ self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer)
79
+ self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
80
+
81
+ # Initialize *both* Transforms
82
+ default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False)
83
+
84
+ if self.image_resize_strategy == "resize-naive":
85
+ assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!"
86
+ assert isinstance(dino_resize_transform := default_dino_transform.transforms[0], Resize)
87
+
88
+ target_size = (self.default_image_size, self.default_image_size)
89
+ dino_transform = Compose(
90
+ [
91
+ Resize(target_size, interpolation=dino_resize_transform.interpolation),
92
+ *default_dino_transform.transforms[1:],
93
+ ]
94
+ )
95
+
96
+ self.dino_transform = dino_transform
97
+ else:
98
+ raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
99
+
100
+ def get_fsdp_wrapping_policy(self) -> Callable:
101
+ """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
102
+ vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
103
+ transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
104
+ return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
105
+
106
+ def forward(self, pixel_values, device="cpu", input_dtype_new=None) -> torch.Tensor:
107
+ """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
108
+ # b, c , h , w : 0-1
109
+ t_tensors = []
110
+ for pixel_value in pixel_values:
111
+ t_tensors.append(self.dino_transform(pixel_value).unsqueeze(0))
112
+ t_tensors = torch.cat(t_tensors, dim=0).to(device)
113
+ if input_dtype_new is not None:
114
+ t_tensors = t_tensors.to(input_dtype_new)
115
+
116
+ t_tensors_list = self.dino_featurizer(t_tensors)
117
+ return t_tensors_list
118
+
119
+ @property
120
+ def default_image_resolution(self) -> Tuple[int, int, int]:
121
+ return self.dino_data_cfg["input_size"]
122
+
123
+ @property
124
+ def embed_dim(self) -> int:
125
+ return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
126
+
127
+ @property
128
+ def num_patches(self) -> int:
129
+ assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
130
+ return self.dino_featurizer.patch_embed.num_patches
131
+
132
+ @property
133
+ def half_precision_dtype(self) -> torch.dtype:
134
+ return torch.bfloat16
135
+
136
+
137
+ class DinoEncoder(nn.Module):
138
+ def __init__(self, backbone_name_or_path, image_resize_strategy: str, default_image_size: int = 224, feature_index = 22) -> None:
139
+ super().__init__()
140
+
141
+ self.image_encoder = DinoViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
142
+ self.to_pil = torchvision.transforms.ToPILImage()
143
+
144
+ def forward(self, image_tensor, device="cpu", input_dtype_new=torch.float32): # input image size = 768
145
+ pixel_values = []
146
+
147
+ for image_tensor_i in image_tensor:
148
+ pixel_values.append(self.to_pil(image_tensor_i))
149
+
150
+ embeddings_dino_list = self.image_encoder(pixel_values, device, input_dtype_new)
151
+ if len(embeddings_dino_list) == 1:
152
+ embeddings_dino_list = embeddings_dino_list[0]
153
+ return embeddings_dino_list
154
+
155
+ class DinoEncoderV2(nn.Module):
156
+ def __init__(self, backbone_name_or_path, image_resize_strategy: str, default_image_size: int = 224, feature_index = 22) -> None:
157
+ super().__init__()
158
+
159
+ self.image_encoder = DinoViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
160
+ self.to_pil = torchvision.transforms.ToPILImage()
161
+
162
+ def get_fsdp_wrapping_policy(self):
163
+ return self.image_encoder.get_fsdp_wrapping_policy()
164
+
165
+ def forward(self, image_tensor_dict, device="cpu", input_dtype_new=torch.float32):
166
+ image_tensor = image_tensor_dict["images_ref"]
167
+
168
+ output_dict = {}
169
+ pixel_values = []
170
+
171
+ for image_tensor_i in image_tensor:
172
+ pixel_values.append(self.to_pil(image_tensor_i))
173
+
174
+ embeddings_dino_list = self.image_encoder(pixel_values, device, input_dtype_new)
175
+ if len(embeddings_dino_list) == 1:
176
+ embeddings_dino_list = embeddings_dino_list[0]
177
+ output_dict["img_patch_features"] = embeddings_dino_list
178
+ return output_dict
179
+
180
+ class DinoEncoderV2_Canny(nn.Module):
181
+ def __init__(self, backbone_name_or_path, image_resize_strategy: str, default_image_size: int = 224, feature_index = 22) -> None:
182
+ super().__init__()
183
+
184
+ self.image_encoder = DinoViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
185
+ self.to_pil = torchvision.transforms.ToPILImage()
186
+
187
+ def get_fsdp_wrapping_policy(self):
188
+ return self.image_encoder.get_fsdp_wrapping_policy()
189
+
190
+ def forward(self, image_tensor_dict, device="cpu", input_dtype_new=torch.float32):
191
+ image_canny = image_tensor_dict["images_canny"]
192
+
193
+ output_dict = {}
194
+ pixel_values = []
195
+
196
+ for image_tensor_i in image_canny:
197
+ pixel_values.append(self.to_pil(image_tensor_i))
198
+
199
+ embeddings_dino_list = self.image_encoder(pixel_values, device, input_dtype_new)
200
+ if len(embeddings_dino_list) == 1:
201
+ embeddings_dino_list = embeddings_dino_list[0]
202
+ output_dict["img_patch_features"] = embeddings_dino_list
203
+ return output_dict
models/image_encoder_siglipdino_shallowdeep.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torchvision
17
+ import torch.nn as nn
18
+ from einops import rearrange
19
+
20
+ from models.sigclip import SigLIPViTBackbone
21
+ from models.dino import DinoViTBackbone
22
+
23
+ class ShallowDeepSiglipDinoEncoder(nn.Module):
24
+ def __init__(self, siglip_config={}, dino_config={}):
25
+ super().__init__()
26
+ self.to_pil = torchvision.transforms.ToPILImage()
27
+ self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config)
28
+ self.image_encoder_dino = DinoViTBackbone(**dino_config)
29
+
30
+ def forward(self, image_tensor, device="cpu"):
31
+ bs = image_tensor.size(0)
32
+ # tensor 转 PIL
33
+ pixel_values = []
34
+ for image_tensor_i in image_tensor:
35
+ pixel_values.append(self.to_pil(image_tensor_i))
36
+
37
+ embeddings = []
38
+ embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device)
39
+ embeddings_dino_list = self.image_encoder_dino(pixel_values, device)
40
+ for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
41
+ embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
42
+ embeddings.append(embeddings_i)
43
+
44
+ return embeddings
45
+
46
+ # The default is to use double the image size, i.e., 768x768.
47
+ class ShallowDeepPatchfySiglipDinoEncoder(nn.Module):
48
+ def __init__(self, siglip_config={}, dino_config={}, patchfy_scale=2, default_image_size=384):
49
+ super().__init__()
50
+ self.to_pil = torchvision.transforms.ToPILImage()
51
+ self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config)
52
+ self.image_encoder_dino = DinoViTBackbone(**dino_config)
53
+
54
+ self.patchfy = (patchfy_scale > 1)
55
+ self.patchfy_scale = patchfy_scale
56
+ self.default_image_size = default_image_size
57
+
58
+ def forward(self, image_tensor, device="cpu", **kwargs): # input image size = 768
59
+ image_tensor = image_tensor["image_ref"] # this is a dict
60
+ bs = image_tensor.size(0)
61
+
62
+ if self.patchfy:
63
+ image_local = rearrange(image_tensor, "b c (h hp) (w wp) -> (b hp wp) c h w", hp=self.patchfy_scale, wp=self.patchfy_scale)
64
+ image_global = torch.nn.functional.interpolate(image_tensor, size=(self.default_image_size, self.default_image_size), mode='bilinear', align_corners=True)
65
+
66
+ # tensor 转 PIL
67
+ pixel_values_local, pixel_values_global = [], []
68
+ for image_tensor_i in image_local:
69
+ pixel_values_local.append(self.to_pil(image_tensor_i.to(torch.float)))
70
+ for image_tensor_i in image_global:
71
+ pixel_values_global.append(self.to_pil(image_tensor_i.to(torch.float)))
72
+
73
+ embeddings = []
74
+ embeddings_siglip_list = self.image_encoder_siglip(pixel_values_global, device)
75
+ embeddings_dino_list = self.image_encoder_dino(pixel_values_global, device)
76
+ for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
77
+ embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
78
+ embeddings.append(embeddings_i)
79
+
80
+ embeddings_local_siglip_deep = self.image_encoder_siglip(pixel_values_local, device)[-1]
81
+ embeddings_local_dino_deep = self.image_encoder_dino(pixel_values_local, device)[-1]
82
+ embeddings_local_deep = torch.cat([embeddings_local_siglip_deep, embeddings_local_dino_deep], dim=-1)
83
+
84
+ embeddings_local_deep = rearrange(embeddings_local_deep, "(b hp wp) l c -> b (l hp wp) c", hp=self.patchfy_scale, wp=self.patchfy_scale)
85
+
86
+ embeddings.append(embeddings_local_deep)
87
+
88
+ else:
89
+ # tensor 转 PIL
90
+ pixel_values = []
91
+ for image_tensor_i in image_tensor:
92
+ pixel_values.append(self.to_pil(image_tensor_i))
93
+
94
+ embeddings = []
95
+ embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device)
96
+ embeddings_dino_list = self.image_encoder_dino(pixel_values, device)
97
+ for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
98
+ # 逐层concat的方式
99
+ embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
100
+ embeddings.append(embeddings_i)
101
+
102
+ if len(embeddings) == 1:
103
+ embeddings = embeddings[0]
104
+ return embeddings
105
+
106
+
107
+ class ShallowDeepPatchfySiglipDinoEncoder_v2(nn.Module):
108
+ def __init__(self, siglip_config={}, dino_config={}, patchfy_scale=2, default_image_size=384):
109
+ super().__init__()
110
+ self.to_pil = torchvision.transforms.ToPILImage()
111
+ self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config)
112
+ self.image_encoder_dino = DinoViTBackbone(**dino_config)
113
+
114
+ self.patchfy = (patchfy_scale > 1)
115
+ self.patchfy_scale = patchfy_scale
116
+ self.default_image_size = default_image_size
117
+
118
+ def forward(self, image_tensor_dict, device="cpu", **kwargs): # input image size = 768
119
+ image_tensor = image_tensor_dict["image_ref"]
120
+ bs = image_tensor.size(0)
121
+
122
+ if self.patchfy:
123
+ image_local = rearrange(image_tensor, "b c (h hp) (w wp) -> (b hp wp) c h w", hp=self.patchfy_scale, wp=self.patchfy_scale)
124
+ image_global = torch.nn.functional.interpolate(image_tensor, size=(self.default_image_size, self.default_image_size), mode='bilinear', align_corners=True)
125
+
126
+ pixel_values_local, pixel_values_global = [], []
127
+ for image_tensor_i in image_local:
128
+ pixel_values_local.append(self.to_pil(image_tensor_i.to(torch.float32)))
129
+ for image_tensor_i in image_global:
130
+ pixel_values_global.append(self.to_pil(image_tensor_i.to(torch.float32)))
131
+
132
+ embeddings = []
133
+ embeddings_siglip_list = self.image_encoder_siglip(pixel_values_global, device)
134
+ embeddings_dino_list = self.image_encoder_dino(pixel_values_global, device)
135
+ for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
136
+ embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
137
+ embeddings.append(embeddings_i)
138
+
139
+ embeddings_local_siglip_deep = self.image_encoder_siglip(pixel_values_local, device)[-1]
140
+ embeddings_local_dino_deep = self.image_encoder_dino(pixel_values_local, device)[-1]
141
+ embeddings_local_deep = torch.cat([embeddings_local_siglip_deep, embeddings_local_dino_deep], dim=-1)
142
+
143
+ embeddings_local_deep = rearrange(embeddings_local_deep, "(b hp wp) l c -> b (l hp wp) c", hp=self.patchfy_scale, wp=self.patchfy_scale)
144
+
145
+ embeddings.append(embeddings_local_deep)
146
+
147
+ else:
148
+ # tensor 转 PIL
149
+ pixel_values = []
150
+ for image_tensor_i in image_tensor:
151
+ pixel_values.append(self.to_pil(image_tensor_i))
152
+
153
+ embeddings = []
154
+ embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device)
155
+ embeddings_dino_list = self.image_encoder_dino(pixel_values, device)
156
+ for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
157
+ embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
158
+ embeddings.append(embeddings_i)
159
+
160
+ if len(embeddings) == 1:
161
+ embeddings = embeddings[0]
162
+ return embeddings
models/projectors.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ class MinAttention(nn.Module):
20
+ def __init__(self, q_dim: int, kv_dim: int, dim_head=64, heads=8):
21
+ super().__init__()
22
+ self.dim_head = dim_head
23
+ self.heads = heads
24
+ inner_dim = dim_head * heads
25
+
26
+ self.norm1 = nn.LayerNorm(q_dim)
27
+ self.norm2 = nn.LayerNorm(kv_dim)
28
+
29
+ self.to_q = nn.Linear(q_dim, inner_dim, bias=False)
30
+ self.to_k = nn.Linear(kv_dim, inner_dim, bias=False)
31
+ self.to_v = nn.Linear(kv_dim, inner_dim, bias=False)
32
+
33
+ def forward(self, local_fea, global_fea):
34
+ global_fea = self.norm1(global_fea)
35
+ local_fea = self.norm2(local_fea)
36
+ b, l, _ = global_fea.shape
37
+
38
+ q = self.to_q(global_fea)
39
+ k = self.to_k(local_fea)
40
+ v = self.to_v(local_fea)
41
+
42
+ q = q.view(b, -1, self.heads, self.dim_head).transpose(1, 2)
43
+ k = k.view(b, -1, self.heads, self.dim_head).transpose(1, 2)
44
+ v = v.view(b, -1, self.heads, self.dim_head).transpose(1, 2)
45
+ hidden_states = F.scaled_dot_product_attention(
46
+ q,k,v, dropout_p=0.0, is_causal=False
47
+ )
48
+ hidden_states = hidden_states.transpose(1, 2).reshape(b, -1, self.heads*self.dim_head)
49
+ hidden_states = hidden_states.to(q.dtype)
50
+ return hidden_states
51
+
52
+ class CustomParameter(nn.Module):
53
+ def __init__(self, init_value):
54
+ super().__init__()
55
+ self.init_value = init_value
56
+ self.value = nn.Parameter(torch.tensor(init_value))
57
+
58
+ def forward(self):
59
+ return self.value
60
+
61
+
62
+ class ProjectorHighResMinAttn(nn.Module):
63
+ def __init__(self, vision_dim, out_dim, dim_head=64, adaptive_scale=False, scale_value=1.0, **kwargs):
64
+ super().__init__()
65
+ self.initial_projection_dim = vision_dim * 4
66
+ heads = vision_dim // dim_head
67
+
68
+ self.min_attention = MinAttention(q_dim=vision_dim, kv_dim=vision_dim, dim_head=dim_head, heads=heads)
69
+ self.projector = nn.Sequential(
70
+ nn.Linear(vision_dim, self.initial_projection_dim, bias=True),
71
+ nn.GELU(),
72
+ nn.Linear(self.initial_projection_dim, out_dim, bias=True),
73
+ nn.GELU(),
74
+ nn.Linear(out_dim, out_dim, bias=True),
75
+ nn.LayerNorm(out_dim)
76
+ )
77
+ self.projector_base = nn.Linear(vision_dim, out_dim, bias=True)
78
+
79
+ self.adaptive_scale = adaptive_scale
80
+ if self.adaptive_scale:
81
+ self.scale_value = CustomParameter(scale_value)
82
+
83
+ def forward(self, vision_input_dict, time_emb=None, **kwargs):
84
+ """
85
+ vision_input_dict: here, this is not a dict, just for the unity of naming
86
+ """
87
+ img_patch_features = vision_input_dict
88
+ deep_features, deep_features_local = img_patch_features
89
+
90
+ fused_img_features = self.min_attention(deep_features_local, deep_features)
91
+ fused_img_features = self.projector(fused_img_features)
92
+
93
+ deep_img_features = self.projector_base(deep_features)
94
+
95
+ if self.adaptive_scale:
96
+ output = deep_img_features + fused_img_features * self.scale_value()
97
+ else:
98
+ output = deep_img_features + fused_img_features
99
+ return output
100
+
101
+
102
+ class ProjectorHighResShallowMinAttnV1(nn.Module):
103
+ def __init__(self, vision_dim, out_dim, dim_head=64, **kwargs):
104
+ super().__init__()
105
+ self.initial_projection_dim = vision_dim * 4
106
+ heads = vision_dim // dim_head
107
+
108
+ self.min_attention = MinAttention(q_dim=vision_dim, kv_dim=vision_dim, dim_head=dim_head, heads=heads)
109
+ self.projector = nn.Sequential(
110
+ nn.Linear(vision_dim, self.initial_projection_dim, bias=True),
111
+ nn.GELU(),
112
+ nn.Linear(self.initial_projection_dim, out_dim, bias=True),
113
+ nn.GELU(),
114
+ nn.Linear(out_dim, out_dim, bias=True),
115
+ nn.LayerNorm(out_dim)
116
+ )
117
+ self.projector_base = nn.Linear(vision_dim, out_dim, bias=True)
118
+
119
+ self.min_attention2 = MinAttention(q_dim=vision_dim, kv_dim=vision_dim, dim_head=dim_head, heads=heads)
120
+ self.projector2 = nn.Sequential(
121
+ nn.Linear(vision_dim, self.initial_projection_dim, bias=True),
122
+ nn.GELU(),
123
+ nn.Linear(self.initial_projection_dim, out_dim, bias=True),
124
+ nn.GELU(),
125
+ nn.Linear(out_dim, out_dim, bias=True),
126
+ nn.LayerNorm(out_dim)
127
+ )
128
+
129
+ def forward(self, vision_input_dict, time_emb=None, **kwargs):
130
+ """
131
+ vision_input_dict: here, this is not a dict, just for the unity of naming
132
+ """
133
+ img_patch_features = vision_input_dict
134
+ shallow_features1, shallow_features2, shallow_features3, deep_features, deep_features_local = img_patch_features
135
+ shallow_features = torch.cat([shallow_features1, shallow_features2, shallow_features3], dim=1) # token concat
136
+
137
+ # original code
138
+ fused_img_features = self.min_attention(deep_features_local, deep_features)
139
+ fused_img_features = self.projector(fused_img_features)
140
+
141
+ deep_img_features = self.projector_base(deep_features)
142
+
143
+ output = deep_img_features + fused_img_features
144
+
145
+ # new code part
146
+ fused_img_features2 = self.min_attention2(shallow_features, deep_features)
147
+ fused_img_features2 = self.projector2(fused_img_features2)
148
+
149
+ output = torch.cat([deep_img_features, fused_img_features2], dim=1)
150
+ return output
models/sigclip.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ dinosiglip_vit.py
17
+
18
+ Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
19
+ """
20
+ from dataclasses import dataclass
21
+ from functools import partial
22
+ from typing import Callable, Dict, Tuple
23
+ import os
24
+ import timm
25
+ import torch
26
+ from PIL import Image
27
+ from timm.models.vision_transformer import Block, VisionTransformer
28
+ from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
29
+ from torchvision.transforms import Compose, Resize
30
+
31
+ from models.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple, return_tuple
32
+
33
+ import torchvision
34
+ import torch.nn as nn
35
+
36
+ @dataclass
37
+ class DinoSigLIPImageTransform:
38
+ dino_image_transform: ImageTransform
39
+ siglip_image_transform: ImageTransform
40
+ is_cobra: bool = True
41
+
42
+ def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
43
+ return {"dino": self.dino_image_transform(img, **kwargs).unsqueeze(0), "siglip": self.siglip_image_transform(img, **kwargs).unsqueeze(0)}
44
+
45
+
46
+ class SigLIPViTBackbone(VisionBackbone):
47
+ def __init__(self, backbone_name_or_path: str, image_resize_strategy: str, default_image_size: int = 224, last_n = 2, feature_index = 25) -> None:
48
+ super().__init__(backbone_name_or_path, image_resize_strategy, default_image_size=default_image_size)
49
+ # load from local paths
50
+ sigclip_pretrained_cfg = timm.models.create_model(backbone_name_or_path).default_cfg
51
+ sigclip_pretrained_cfg['file'] = 'ckpts/vit_so400m_patch14_siglip_384/open_clip_pytorch_model.bin'
52
+
53
+ # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
54
+ self.siglip_featurizer: VisionTransformer = timm.create_model(
55
+ backbone_name_or_path, pretrained=True, num_classes=0, img_size=self.default_image_size,
56
+ pretrained_cfg=sigclip_pretrained_cfg
57
+ )
58
+ self.siglip_featurizer.eval()
59
+
60
+ # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
61
+ # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
62
+ # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
63
+ # return the output tokens from the `n` last blocks
64
+ print("siglip has {} layer intermediate features. ".format(len(self.siglip_featurizer.blocks))) # 27
65
+ # self.siglip_featurizer.forward = unpack_tuple(
66
+ # partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - last_n})
67
+ # )
68
+ if isinstance(feature_index, tuple) or isinstance(feature_index, list):
69
+ feature_index = set(feature_index)
70
+ else:
71
+ feature_index = {feature_index}
72
+ self.siglip_featurizer.forward = return_tuple(
73
+ partial(self.siglip_featurizer.get_intermediate_layers, n=feature_index)
74
+ )
75
+
76
+ # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
77
+
78
+ self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer)
79
+ self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
80
+
81
+ # Initialize *both* Transforms
82
+ default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False)
83
+
84
+ # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
85
+ assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!"
86
+ assert isinstance(sl_resize_transform := default_siglip_transform.transforms[0], Resize)
87
+ default_siglip_transform = Compose(
88
+ [
89
+ Resize(self.default_image_size, interpolation=sl_resize_transform.interpolation),
90
+ *default_siglip_transform.transforms[1:],
91
+ ]
92
+ )
93
+
94
+ if self.image_resize_strategy == "resize-naive":
95
+ assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!"
96
+ assert isinstance(siglip_resize_transform := default_siglip_transform.transforms[0], Resize)
97
+
98
+ target_size = (self.default_image_size, self.default_image_size)
99
+ siglip_transform = Compose(
100
+ [
101
+ Resize(target_size, interpolation=siglip_resize_transform.interpolation),
102
+ *default_siglip_transform.transforms[1:],
103
+ ]
104
+ )
105
+
106
+ self.siglip_transform = siglip_transform
107
+ else:
108
+ raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
109
+
110
+ def get_fsdp_wrapping_policy(self) -> Callable:
111
+ """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
112
+ vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
113
+ transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
114
+ return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
115
+
116
+ def forward(self, pixel_values, device="cpu") -> torch.Tensor:
117
+ """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
118
+ # b, c , h , w : 0-1
119
+ t_tensors = []
120
+ for pixel_value in pixel_values:
121
+ t_tensors.append(self.siglip_transform(pixel_value).unsqueeze(0))
122
+ t_tensors = torch.cat(t_tensors, dim=0).to(device)
123
+
124
+ t_tensors_list = self.siglip_featurizer(t_tensors)
125
+ return t_tensors_list
126
+
127
+ @property
128
+ def default_image_resolution(self) -> Tuple[int, int, int]:
129
+ return self.dino_data_cfg["input_size"]
130
+
131
+ @property
132
+ def embed_dim(self) -> int:
133
+ return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
134
+
135
+ @property
136
+ def num_patches(self) -> int:
137
+ assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
138
+ return self.dino_featurizer.patch_embed.num_patches
139
+
140
+ @property
141
+ def half_precision_dtype(self) -> torch.dtype:
142
+ return torch.bfloat16
143
+
144
+
145
+ class SigLIPEncoder(nn.Module):
146
+ def __init__(self, backbone_name_or_path: str, image_resize_strategy: str, default_image_size: int = 224, feature_index = 25):
147
+ super().__init__()
148
+ self.image_encoder = SigLIPViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
149
+ self.to_pil = torchvision.transforms.ToPILImage()
150
+
151
+ def forward(self, image_tensor, device="cpu"): # input image size = 768
152
+ pixel_values = []
153
+ for image_tensor_i in image_tensor:
154
+ pixel_values.append(self.to_pil(image_tensor_i))
155
+
156
+ embeddings_dino_list = self.image_encoder(pixel_values, device)
157
+ if len(embeddings_dino_list) == 1:
158
+ embeddings_dino_list = embeddings_dino_list[0]
159
+ return embeddings_dino_list
models/text.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from dataclasses import dataclass
17
+ from torch import nn
18
+ from transformers import AutoTokenizer, CLIPTokenizerFast, CLIPTextModel, T5EncoderModel
19
+ from typing import List
20
+
21
+ @dataclass
22
+ class TextModelOutput:
23
+ embeddings: torch.Tensor
24
+ masks: torch.Tensor
25
+ pooled: List
26
+
27
+
28
+ class TextModel(nn.Module):
29
+ available_modes = [
30
+ "last", # If present, use last layer.
31
+ "penultimate", # If present, use penultimate layer.
32
+ "penultimate_nonorm", # If present, use penultimate layer without final norm.
33
+ "token_cat", # If present, concat in token dimension, default concat in channel dimension.
34
+ "pad0", # If present, use 0 padding, default use EOT padding.
35
+ "masked", # If present, pass attention mask to encoder.
36
+ ]
37
+
38
+ def __init__(self, variant: List[str], mode: List[str]):
39
+ super().__init__()
40
+ self.mode = set(mode)
41
+ self.tokenizers = []
42
+ self.models = nn.ModuleList([])
43
+
44
+ for v in variant:
45
+ if "clip" in v.lower():
46
+ self.tokenizers.append(CLIPTokenizerFast.from_pretrained(v, model_max_length=77))
47
+ self.models.append(CLIPTextModel.from_pretrained(v))
48
+ elif "t5" in v.lower() or "ul2" in v.lower():
49
+ self.tokenizers.append(AutoTokenizer.from_pretrained(v, model_max_length=77))
50
+ self.models.append(T5EncoderModel.from_pretrained(v, torch_dtype=torch.bfloat16))
51
+ else:
52
+ raise NotImplementedError
53
+
54
+ def get_vaild_token_length(self, text): # Return the length of the BPE encoding of the text, excluding `<sos>` and `<eos>`.
55
+ lengths = []
56
+ for tokenizer, model in zip(self.tokenizers, self.models):
57
+
58
+ tokens = tokenizer(
59
+ text=text,
60
+ truncation=True,
61
+ padding="max_length",
62
+ return_tensors="pt"
63
+ ).to(model.device)
64
+ token_length = tokens["attention_mask"].sum() - 2 # In the attention mask, both the SOS and EOS (first PAD) have a value of 1.
65
+ lengths.append(token_length.item())
66
+ length = int(sum(lengths) / len(lengths))
67
+ return length
68
+
69
+ def forward(self, text: List[str]) -> TextModelOutput:
70
+ embeddings = []
71
+ masks = []
72
+ pooled = []
73
+
74
+ for tokenizer, model in zip(self.tokenizers, self.models):
75
+
76
+ tokens = tokenizer(
77
+ text=text,
78
+ truncation=True,
79
+ padding="max_length",
80
+ return_tensors="pt"
81
+ ).to(model.device)
82
+
83
+ if "pad0" in self.mode:
84
+ tokens.input_ids *= tokens.attention_mask
85
+
86
+ output = model(
87
+ input_ids=tokens.input_ids,
88
+ attention_mask=tokens.attention_mask if "masked" in self.mode else None,
89
+ output_hidden_states=True
90
+ )
91
+
92
+ if "last" in self.mode:
93
+ embeddings.append(output.last_hidden_state)
94
+ if "penultimate" in self.mode:
95
+ embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
96
+ if "penultimate_nonorm" in self.mode:
97
+ embeddings.append(output.hidden_states[-2])
98
+ masks.append(tokens.attention_mask)
99
+ if hasattr(output, "pooler_output"):
100
+ pooled.append(output.pooler_output)
101
+
102
+ if "token_cat" in self.mode:
103
+ return TextModelOutput(
104
+ embeddings=torch.cat(embeddings, dim=1),
105
+ masks=torch.cat(masks, dim=1),
106
+ pooled=pooled
107
+ )
108
+ else:
109
+ return TextModelOutput(
110
+ embeddings=torch.cat(embeddings, dim=2),
111
+ masks=torch.stack(masks, dim=2).sum(2).clamp_max(1),
112
+ pooled=pooled
113
+ )
models/transformer_2d_custom.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
25
+ from diffusers.utils import BaseOutput, deprecate
26
+ # from diffusers.models.attention import BasicTransformerBlock
27
+ from models.attention_custom import BasicTransformerBlock
28
+ from diffusers.models.embeddings import PatchEmbed
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+
31
+ from utils import update_dict
32
+
33
+ @dataclass
34
+ class Transformer2DModelOutput(BaseOutput):
35
+ """
36
+ The output of [`Transformer2DModel`].
37
+
38
+ Args:
39
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
40
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
41
+ distributions for the unnoised latent pixels.
42
+ """
43
+
44
+ sample: torch.FloatTensor
45
+
46
+ # Transformer2DModel
47
+ class Transformer2DModel(ModelMixin, ConfigMixin):
48
+ """
49
+ A 2D Transformer model for image-like data.
50
+
51
+ Parameters:
52
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
53
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
54
+ in_channels (`int`, *optional*):
55
+ The number of channels in the input and output (specify if the input is **continuous**).
56
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
57
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
58
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
59
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
60
+ This is fixed during training since it is used to learn a number of position embeddings.
61
+ num_vector_embeds (`int`, *optional*):
62
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
63
+ Includes the class for the masked latent pixel.
64
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
65
+ num_embeds_ada_norm ( `int`, *optional*):
66
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
67
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
68
+ added to the hidden states.
69
+
70
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
71
+ attention_bias (`bool`, *optional*):
72
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
73
+ """
74
+
75
+ @register_to_config
76
+ def __init__(
77
+ self,
78
+ num_attention_heads: int = 16,
79
+ attention_head_dim: int = 88,
80
+ in_channels: Optional[int] = None,
81
+ out_channels: Optional[int] = None,
82
+ num_layers: int = 1,
83
+ dropout: float = 0.0,
84
+ norm_num_groups: int = 32,
85
+ cross_attention_dim: Optional[int] = None,
86
+ attention_bias: bool = False,
87
+ sample_size: Optional[int] = None,
88
+ num_vector_embeds: Optional[int] = None,
89
+ patch_size: Optional[int] = None,
90
+ activation_fn: str = "geglu",
91
+ num_embeds_ada_norm: Optional[int] = None,
92
+ use_linear_projection: bool = False,
93
+ only_cross_attention: bool = False,
94
+ upcast_attention: bool = False,
95
+ norm_type: str = "layer_norm",
96
+ norm_elementwise_affine: bool = True,
97
+ image_prompt_settings = {},
98
+ ):
99
+ super().__init__()
100
+ self.use_linear_projection = use_linear_projection
101
+ self.num_attention_heads = num_attention_heads
102
+ self.attention_head_dim = attention_head_dim
103
+ inner_dim = num_attention_heads * attention_head_dim
104
+
105
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
106
+ # Define whether input is continuous or discrete depending on configuration
107
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
108
+ self.is_input_vectorized = num_vector_embeds is not None
109
+ self.is_input_patches = in_channels is not None and patch_size is not None
110
+
111
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
112
+ deprecation_message = (
113
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
114
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
115
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
116
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
117
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
118
+ )
119
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
120
+ norm_type = "ada_norm"
121
+
122
+ if self.is_input_continuous and self.is_input_vectorized:
123
+ raise ValueError(
124
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
125
+ " sure that either `in_channels` or `num_vector_embeds` is None."
126
+ )
127
+ elif self.is_input_vectorized and self.is_input_patches:
128
+ raise ValueError(
129
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
130
+ " sure that either `num_vector_embeds` or `num_patches` is None."
131
+ )
132
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
133
+ raise ValueError(
134
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
135
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
136
+ )
137
+
138
+ # 2. Define input layers
139
+ if self.is_input_continuous:
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
143
+ if use_linear_projection:
144
+ self.proj_in = nn.Linear(in_channels, inner_dim)
145
+ else:
146
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
147
+ elif self.is_input_vectorized:
148
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
149
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
150
+
151
+ self.height = sample_size
152
+ self.width = sample_size
153
+ self.num_vector_embeds = num_vector_embeds
154
+ self.num_latent_pixels = self.height * self.width
155
+
156
+ self.latent_image_embedding = ImagePositionalEmbeddings(
157
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
158
+ )
159
+ elif self.is_input_patches:
160
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
161
+
162
+ self.height = sample_size
163
+ self.width = sample_size
164
+
165
+ self.patch_size = patch_size
166
+ self.pos_embed = PatchEmbed(
167
+ height=sample_size,
168
+ width=sample_size,
169
+ patch_size=patch_size,
170
+ in_channels=in_channels,
171
+ embed_dim=inner_dim,
172
+ )
173
+
174
+ # 3. Define transformers blocks, NOTE: we change the format
175
+ self.transformer_blocks = []
176
+ for d in range(num_layers):
177
+ self.transformer_blocks.append(
178
+ BasicTransformerBlock(
179
+ inner_dim,
180
+ num_attention_heads,
181
+ attention_head_dim,
182
+ dropout=dropout,
183
+ cross_attention_dim=cross_attention_dim,
184
+ activation_fn=activation_fn,
185
+ num_embeds_ada_norm=num_embeds_ada_norm,
186
+ attention_bias=attention_bias,
187
+ only_cross_attention=only_cross_attention,
188
+ upcast_attention=upcast_attention,
189
+ norm_type=norm_type,
190
+ norm_elementwise_affine=norm_elementwise_affine,
191
+ image_prompt_settings=image_prompt_settings,
192
+ )
193
+ )
194
+ image_prompt_settings["cross_attention_id"] += 1
195
+ self.transformer_blocks = nn.ModuleList(self.transformer_blocks)
196
+
197
+ # self.transformer_blocks = nn.ModuleList(
198
+ # [
199
+ # BasicTransformerBlock(
200
+ # inner_dim,
201
+ # num_attention_heads,
202
+ # attention_head_dim,
203
+ # dropout=dropout,
204
+ # cross_attention_dim=cross_attention_dim,
205
+ # activation_fn=activation_fn,
206
+ # num_embeds_ada_norm=num_embeds_ada_norm,
207
+ # attention_bias=attention_bias,
208
+ # only_cross_attention=only_cross_attention,
209
+ # upcast_attention=upcast_attention,
210
+ # norm_type=norm_type,
211
+ # norm_elementwise_affine=norm_elementwise_affine,
212
+ # image_prompt_settings=image_prompt_settings,
213
+ # )
214
+ # for d in range(num_layers)
215
+ # ]
216
+ # )
217
+
218
+ # 4. Define output layers
219
+ self.out_channels = in_channels if out_channels is None else out_channels
220
+ if self.is_input_continuous:
221
+ # TODO: should use out_channels for continuous projections
222
+ if use_linear_projection:
223
+ self.proj_out = nn.Linear(inner_dim, in_channels)
224
+ else:
225
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
226
+ elif self.is_input_vectorized:
227
+ self.norm_out = nn.LayerNorm(inner_dim)
228
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
229
+ elif self.is_input_patches:
230
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
231
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
232
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ encoder_hidden_states: Optional[torch.Tensor] = None,
238
+ timestep: Optional[torch.LongTensor] = None,
239
+ class_labels: Optional[torch.LongTensor] = None,
240
+ cross_attention_kwargs: Dict[str, Any] = None,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ encoder_attention_mask: Optional[torch.Tensor] = None,
243
+ return_dict: bool = True,
244
+ encoder_hidden_states_vision = None,
245
+ encoder_hidden_states_control = None,
246
+ vision_guided_mask = None,
247
+ extra_dict_inputs = {},
248
+ return_self_attn_map = False,
249
+ ):
250
+ """
251
+ The [`Transformer2DModel`] forward method.
252
+
253
+ Args:
254
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
255
+ Input `hidden_states`.
256
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
257
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
258
+ self-attention.
259
+ timestep ( `torch.LongTensor`, *optional*):
260
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
261
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
262
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
263
+ `AdaLayerZeroNorm`.
264
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
265
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
266
+
267
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
268
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
269
+
270
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
271
+ above. This bias will be added to the cross-attention scores.
272
+ return_dict (`bool`, *optional*, defaults to `True`):
273
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
274
+ tuple.
275
+
276
+ Returns:
277
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
278
+ `tuple` where the first element is the sample tensor.
279
+ """
280
+ # <notice>
281
+ extra_dict_outputs = {}
282
+ height, width = hidden_states.size(-2), hidden_states.size(-1)
283
+
284
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
285
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
286
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
287
+ # expects mask of shape:
288
+ # [batch, key_tokens]
289
+ # adds singleton query_tokens dimension:
290
+ # [batch, 1, key_tokens]
291
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
292
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
293
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
294
+ if attention_mask is not None and attention_mask.ndim == 2:
295
+ # assume that mask is expressed as:
296
+ # (1 = keep, 0 = discard)
297
+ # convert mask into a bias that can be added to attention scores:
298
+ # (keep = +0, discard = -10000.0)
299
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
300
+ attention_mask = attention_mask.unsqueeze(1)
301
+
302
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
303
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
304
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
305
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
306
+
307
+ # 1. Input
308
+ if self.is_input_continuous:
309
+ batch, _, height, width = hidden_states.shape
310
+ residual = hidden_states
311
+
312
+ hidden_states = self.norm(hidden_states)
313
+ if not self.use_linear_projection:
314
+ hidden_states = self.proj_in(hidden_states)
315
+ inner_dim = hidden_states.shape[1]
316
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
317
+ else:
318
+ inner_dim = hidden_states.shape[1]
319
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
320
+ hidden_states = self.proj_in(hidden_states)
321
+ elif self.is_input_vectorized:
322
+ hidden_states = self.latent_image_embedding(hidden_states)
323
+ elif self.is_input_patches:
324
+ hidden_states = self.pos_embed(hidden_states)
325
+
326
+ # 2. Blocks
327
+ for block in self.transformer_blocks:
328
+ hidden_states, extra_dict_output_transformer = block(
329
+ hidden_states,
330
+ attention_mask=attention_mask,
331
+ encoder_hidden_states=encoder_hidden_states,
332
+ encoder_attention_mask=encoder_attention_mask,
333
+ timestep=timestep,
334
+ cross_attention_kwargs=cross_attention_kwargs,
335
+ class_labels=class_labels,
336
+ encoder_hidden_states_vision=encoder_hidden_states_vision,
337
+ encoder_hidden_states_control=encoder_hidden_states_control,
338
+ vision_guided_mask=vision_guided_mask,
339
+ extra_dict_inputs=extra_dict_inputs,
340
+ height=height,
341
+ width=width,
342
+ return_self_attn_map=return_self_attn_map
343
+ )
344
+ extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_transformer)
345
+
346
+ # 3. Output
347
+ if self.is_input_continuous:
348
+ if not self.use_linear_projection:
349
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
350
+ hidden_states = self.proj_out(hidden_states)
351
+ else:
352
+ hidden_states = self.proj_out(hidden_states)
353
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
354
+
355
+ output = hidden_states + residual
356
+ elif self.is_input_vectorized:
357
+ hidden_states = self.norm_out(hidden_states)
358
+ logits = self.out(hidden_states)
359
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
360
+ logits = logits.permute(0, 2, 1)
361
+
362
+ # log(p(x_0))
363
+ output = F.log_softmax(logits.double(), dim=1).float()
364
+ elif self.is_input_patches:
365
+ # TODO: cleanup!
366
+ conditioning = self.transformer_blocks[0].norm1.emb(
367
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
368
+ )
369
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
370
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
371
+ hidden_states = self.proj_out_2(hidden_states)
372
+
373
+ # unpatchify
374
+ height = width = int(hidden_states.shape[1] ** 0.5)
375
+ hidden_states = hidden_states.reshape(
376
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
377
+ )
378
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
379
+ output = hidden_states.reshape(
380
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
381
+ )
382
+
383
+ if not return_dict: # return_dict=False
384
+ return output, extra_dict_outputs
385
+ # return (output,)
386
+
387
+ # return Transformer2DModelOutput(sample=output)
388
+ return Transformer2DModelOutput(sample=output)[0], extra_dict_outputs
models/unet_2d_blocks_custom.py ADDED
The diff for this file is too large to render. See raw diff
 
models/unet_2d_condition_custom.py ADDED
@@ -0,0 +1,1059 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from einops import rearrange
24
+
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ # from diffusers.models.unet_2d_blocks import (
27
+ from models.unet_2d_blocks_custom import (
28
+ CrossAttnDownBlock2D,
29
+ CrossAttnUpBlock2D,
30
+ DownBlock2D,
31
+ UNetMidBlock2DCrossAttn,
32
+ UNetMidBlock2DSimpleCrossAttn,
33
+ UpBlock2D,
34
+ get_down_block,
35
+ get_up_block,
36
+ )
37
+
38
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
39
+ from diffusers.loaders import UNet2DConditionLoadersMixin
40
+ from diffusers.utils import BaseOutput, logging
41
+ from diffusers.models.activations import get_activation
42
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
43
+ from diffusers.models.embeddings import (
44
+ GaussianFourierProjection,
45
+ ImageHintTimeEmbedding,
46
+ ImageProjection,
47
+ ImageTimeEmbedding,
48
+ TextImageProjection,
49
+ TextImageTimeEmbedding,
50
+ TextTimeEmbedding,
51
+ TimestepEmbedding,
52
+ Timesteps,
53
+ )
54
+
55
+ from utils import update_dict
56
+
57
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
58
+
59
+
60
+ @dataclass
61
+ class UNet2DConditionOutput(BaseOutput):
62
+ """
63
+ The output of [`UNet2DConditionModel`].
64
+
65
+ Args:
66
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
67
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
68
+ """
69
+
70
+ sample: torch.FloatTensor = None
71
+
72
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
73
+ r"""
74
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
75
+ shaped output.
76
+
77
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
78
+ for all models (such as downloading or saving).
79
+
80
+ Parameters:her
81
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
82
+ Height and width of input/output sample.
83
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
84
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
85
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
86
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
87
+ Whether to flip the sin to cos in the time embedding.
88
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
89
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
90
+ The tuple of downsample blocks to use.
91
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
92
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
93
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
94
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
95
+ The tuple of upsample blocks to use.
96
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
97
+ Whether to include self-attention in the basic transformer blocks, see
98
+ [`~models.attention.BasicTransformerBlock`].
99
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
100
+ The tuple of output channels for each block.
101
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
102
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
103
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
104
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
105
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
106
+ If `None`, normalization and activation layers is skipped in post-processing.
107
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
108
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
109
+ The dimension of the cross attention features.
110
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
111
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
112
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
113
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
114
+ encoder_hid_dim (`int`, *optional*, defaults to None):
115
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
116
+ dimension to `cross_attention_dim`.
117
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
118
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
119
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
120
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
121
+ num_attention_heads (`int`, *optional*):
122
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
123
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
124
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
125
+ class_embed_type (`str`, *optional*, defaults to `None`):
126
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
127
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
128
+ addition_embed_type (`str`, *optional*, defaults to `None`):
129
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
130
+ "text". "text" will use the `TextTimeEmbedding` layer.
131
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
132
+ Dimension for the timestep embeddings.
133
+ num_class_embeds (`int`, *optional*, defaults to `None`):
134
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
135
+ class conditioning with `class_embed_type` equal to `None`.
136
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
137
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
138
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
139
+ An optional override for the dimension of the projected time embedding.
140
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
141
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
142
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
143
+ timestep_post_act (`str`, *optional*, defaults to `None`):
144
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
145
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
146
+ The dimension of `cond_proj` layer in the timestep embedding.
147
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
148
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
149
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
150
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
151
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
152
+ embeddings with the class embeddings.
153
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
154
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
155
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
156
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
157
+ otherwise.
158
+ """
159
+
160
+ _supports_gradient_checkpointing = True
161
+
162
+ @register_to_config
163
+ def __init__(
164
+ self,
165
+ sample_size: Optional[int] = None,
166
+ in_channels: int = 4,
167
+ out_channels: int = 4,
168
+ center_input_sample: bool = False,
169
+ flip_sin_to_cos: bool = True,
170
+ freq_shift: int = 0,
171
+ down_block_types: Tuple[str] = (
172
+ "CrossAttnDownBlock2D",
173
+ "CrossAttnDownBlock2D",
174
+ "CrossAttnDownBlock2D",
175
+ "DownBlock2D",
176
+ ),
177
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
178
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
179
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
180
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
181
+ layers_per_block: Union[int, Tuple[int]] = 2,
182
+ downsample_padding: int = 1,
183
+ mid_block_scale_factor: float = 1,
184
+ act_fn: str = "silu",
185
+ norm_num_groups: Optional[int] = 32,
186
+ norm_eps: float = 1e-5,
187
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
188
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
189
+ encoder_hid_dim: Optional[int] = None,
190
+ encoder_hid_dim_type: Optional[str] = None,
191
+ attention_head_dim: Union[int, Tuple[int]] = 8,
192
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
193
+ dual_cross_attention: bool = False,
194
+ use_linear_projection: bool = False,
195
+ class_embed_type: Optional[str] = None,
196
+ addition_embed_type: Optional[str] = None,
197
+ addition_time_embed_dim: Optional[int] = None,
198
+ num_class_embeds: Optional[int] = None,
199
+ upcast_attention: bool = False,
200
+ resnet_time_scale_shift: str = "default",
201
+ resnet_skip_time_act: bool = False,
202
+ resnet_out_scale_factor: int = 1.0,
203
+ time_embedding_type: str = "positional",
204
+ time_embedding_dim: Optional[int] = None,
205
+ time_embedding_act_fn: Optional[str] = None,
206
+ timestep_post_act: Optional[str] = None,
207
+ time_cond_proj_dim: Optional[int] = None,
208
+ conv_in_kernel: int = 3,
209
+ conv_out_kernel: int = 3,
210
+ projection_class_embeddings_input_dim: Optional[int] = None,
211
+ class_embeddings_concat: bool = False,
212
+ mid_block_only_cross_attention: Optional[bool] = None,
213
+ cross_attention_norm: Optional[str] = None,
214
+ addition_embed_type_num_heads=64,
215
+
216
+ image_prompt_settings = {"dualbranch_mode": "none"},
217
+ **ignore_kwargs,
218
+ ):
219
+ super().__init__()
220
+
221
+ self.sample_size = sample_size
222
+
223
+ if num_attention_heads is not None:
224
+ raise ValueError(
225
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
226
+ )
227
+
228
+ # If `num_attention_heads` is not defined (which is the case for most models)
229
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
230
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
231
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
232
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
233
+ # which is why we correct for the naming here.
234
+ num_attention_heads = num_attention_heads or attention_head_dim
235
+
236
+ # Check inputs
237
+ if len(down_block_types) != len(up_block_types):
238
+ raise ValueError(
239
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
240
+ )
241
+
242
+ if len(block_out_channels) != len(down_block_types):
243
+ raise ValueError(
244
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
245
+ )
246
+
247
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
248
+ raise ValueError(
249
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
250
+ )
251
+
252
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
253
+ raise ValueError(
254
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
255
+ )
256
+
257
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
258
+ raise ValueError(
259
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
260
+ )
261
+
262
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
263
+ raise ValueError(
264
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
265
+ )
266
+
267
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
268
+ raise ValueError(
269
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
270
+ )
271
+
272
+ # input
273
+ conv_in_padding = (conv_in_kernel - 1) // 2
274
+ self.conv_in = nn.Conv2d(
275
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
276
+ )
277
+
278
+ # time
279
+ if time_embedding_type == "fourier":
280
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
281
+ if time_embed_dim % 2 != 0:
282
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
283
+ self.time_proj = GaussianFourierProjection(
284
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
285
+ )
286
+ timestep_input_dim = time_embed_dim
287
+ elif time_embedding_type == "positional":
288
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
289
+
290
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
291
+ timestep_input_dim = block_out_channels[0]
292
+ else:
293
+ raise ValueError(
294
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
295
+ )
296
+
297
+ self.time_embedding = TimestepEmbedding(
298
+ timestep_input_dim,
299
+ time_embed_dim,
300
+ act_fn=act_fn,
301
+ post_act_fn=timestep_post_act,
302
+ cond_proj_dim=time_cond_proj_dim,
303
+ )
304
+
305
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
306
+ encoder_hid_dim_type = "text_proj"
307
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
308
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
309
+
310
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
311
+ raise ValueError(
312
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
313
+ )
314
+
315
+ if encoder_hid_dim_type == "text_proj":
316
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
317
+ elif encoder_hid_dim_type == "text_image_proj":
318
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
319
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
320
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
321
+ self.encoder_hid_proj = TextImageProjection(
322
+ text_embed_dim=encoder_hid_dim,
323
+ image_embed_dim=cross_attention_dim,
324
+ cross_attention_dim=cross_attention_dim,
325
+ )
326
+ elif encoder_hid_dim_type == "image_proj":
327
+ # Kandinsky 2.2
328
+ self.encoder_hid_proj = ImageProjection(
329
+ image_embed_dim=encoder_hid_dim,
330
+ cross_attention_dim=cross_attention_dim,
331
+ )
332
+ elif encoder_hid_dim_type is not None:
333
+ raise ValueError(
334
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
335
+ )
336
+ else:
337
+ self.encoder_hid_proj = None
338
+
339
+ # class embedding
340
+ if class_embed_type is None and num_class_embeds is not None:
341
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
342
+ elif class_embed_type == "timestep":
343
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
344
+ elif class_embed_type == "identity":
345
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
346
+ elif class_embed_type == "projection":
347
+ if projection_class_embeddings_input_dim is None:
348
+ raise ValueError(
349
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
350
+ )
351
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
352
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
353
+ # 2. it projects from an arbitrary input dimension.
354
+ #
355
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
356
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
357
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
358
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
359
+ elif class_embed_type == "simple_projection":
360
+ if projection_class_embeddings_input_dim is None:
361
+ raise ValueError(
362
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
363
+ )
364
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
365
+ else:
366
+ self.class_embedding = None
367
+
368
+ if addition_embed_type == "text":
369
+ if encoder_hid_dim is not None:
370
+ text_time_embedding_from_dim = encoder_hid_dim
371
+ else:
372
+ text_time_embedding_from_dim = cross_attention_dim
373
+
374
+ self.add_embedding = TextTimeEmbedding(
375
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
376
+ )
377
+ elif addition_embed_type == "text_image":
378
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
379
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
380
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
381
+ self.add_embedding = TextImageTimeEmbedding(
382
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
383
+ )
384
+ elif addition_embed_type == "text_time":
385
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
386
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
387
+ elif addition_embed_type == "image":
388
+ # Kandinsky 2.2
389
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
390
+ elif addition_embed_type == "image_hint":
391
+ # Kandinsky 2.2 ControlNet
392
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
393
+ elif addition_embed_type is not None:
394
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
395
+
396
+ if time_embedding_act_fn is None:
397
+ self.time_embed_act = None
398
+ else:
399
+ self.time_embed_act = get_activation(time_embedding_act_fn)
400
+
401
+ self.down_blocks = nn.ModuleList([])
402
+ self.up_blocks = nn.ModuleList([])
403
+
404
+ if isinstance(only_cross_attention, bool):
405
+ if mid_block_only_cross_attention is None:
406
+ mid_block_only_cross_attention = only_cross_attention
407
+
408
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
409
+
410
+ if mid_block_only_cross_attention is None:
411
+ mid_block_only_cross_attention = False
412
+
413
+ if isinstance(num_attention_heads, int):
414
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
415
+
416
+ if isinstance(attention_head_dim, int):
417
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
418
+
419
+ if isinstance(cross_attention_dim, int):
420
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
421
+
422
+ if isinstance(layers_per_block, int):
423
+ layers_per_block = [layers_per_block] * len(down_block_types)
424
+
425
+ if isinstance(transformer_layers_per_block, int):
426
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
427
+
428
+ if class_embeddings_concat:
429
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
430
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
431
+ # regular time embeddings
432
+ blocks_time_embed_dim = time_embed_dim * 2
433
+ else:
434
+ blocks_time_embed_dim = time_embed_dim
435
+
436
+ # NOTE: we need to mark each cross attention id
437
+ image_prompt_settings["cross_attention_id"] = 0
438
+ # down
439
+ output_channel = block_out_channels[0]
440
+ # XL: ['DownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D']
441
+ for i, down_block_type in enumerate(down_block_types):
442
+ input_channel = output_channel
443
+ output_channel = block_out_channels[i]
444
+ is_final_block = i == len(block_out_channels) - 1
445
+
446
+ down_block = get_down_block(
447
+ down_block_type,
448
+ num_layers=layers_per_block[i],
449
+ transformer_layers_per_block=transformer_layers_per_block[i],
450
+ in_channels=input_channel,
451
+ out_channels=output_channel,
452
+ temb_channels=blocks_time_embed_dim,
453
+ add_downsample=not is_final_block,
454
+ resnet_eps=norm_eps,
455
+ resnet_act_fn=act_fn,
456
+ resnet_groups=norm_num_groups,
457
+ cross_attention_dim=cross_attention_dim[i],
458
+ num_attention_heads=num_attention_heads[i],
459
+ downsample_padding=downsample_padding,
460
+ dual_cross_attention=dual_cross_attention,
461
+ use_linear_projection=use_linear_projection,
462
+ only_cross_attention=only_cross_attention[i],
463
+ upcast_attention=upcast_attention,
464
+ resnet_time_scale_shift=resnet_time_scale_shift,
465
+ resnet_skip_time_act=resnet_skip_time_act,
466
+ resnet_out_scale_factor=resnet_out_scale_factor,
467
+ cross_attention_norm=cross_attention_norm,
468
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
469
+ image_prompt_settings=image_prompt_settings,
470
+ )
471
+ self.down_blocks.append(down_block)
472
+
473
+ # mid, XL: UNetMidBlock2DCrossAttn
474
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
475
+ self.mid_block = UNetMidBlock2DCrossAttn(
476
+ transformer_layers_per_block=transformer_layers_per_block[-1],
477
+ in_channels=block_out_channels[-1],
478
+ temb_channels=blocks_time_embed_dim,
479
+ resnet_eps=norm_eps,
480
+ resnet_act_fn=act_fn,
481
+ output_scale_factor=mid_block_scale_factor,
482
+ resnet_time_scale_shift=resnet_time_scale_shift,
483
+ cross_attention_dim=cross_attention_dim[-1],
484
+ num_attention_heads=num_attention_heads[-1],
485
+ resnet_groups=norm_num_groups,
486
+ dual_cross_attention=dual_cross_attention,
487
+ use_linear_projection=use_linear_projection,
488
+ upcast_attention=upcast_attention,
489
+ image_prompt_settings=image_prompt_settings,
490
+ )
491
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
492
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
493
+ in_channels=block_out_channels[-1],
494
+ temb_channels=blocks_time_embed_dim,
495
+ resnet_eps=norm_eps,
496
+ resnet_act_fn=act_fn,
497
+ output_scale_factor=mid_block_scale_factor,
498
+ cross_attention_dim=cross_attention_dim[-1],
499
+ attention_head_dim=attention_head_dim[-1],
500
+ resnet_groups=norm_num_groups,
501
+ resnet_time_scale_shift=resnet_time_scale_shift,
502
+ skip_time_act=resnet_skip_time_act,
503
+ only_cross_attention=mid_block_only_cross_attention,
504
+ cross_attention_norm=cross_attention_norm,
505
+ )
506
+ elif mid_block_type is None:
507
+ self.mid_block = None
508
+ else:
509
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
510
+
511
+ # count how many layers upsample the images
512
+ self.num_upsamplers = 0
513
+
514
+ # up, XL: ['CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'UpBlock2D']
515
+ reversed_block_out_channels = list(reversed(block_out_channels))
516
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
517
+ reversed_layers_per_block = list(reversed(layers_per_block))
518
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
519
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
520
+ only_cross_attention = list(reversed(only_cross_attention))
521
+
522
+ output_channel = reversed_block_out_channels[0]
523
+ for i, up_block_type in enumerate(up_block_types):
524
+ is_final_block = i == len(block_out_channels) - 1
525
+
526
+ prev_output_channel = output_channel
527
+ output_channel = reversed_block_out_channels[i]
528
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
529
+
530
+ # add upsample block for all BUT final layer
531
+ if not is_final_block:
532
+ add_upsample = True
533
+ self.num_upsamplers += 1
534
+ else:
535
+ add_upsample = False
536
+
537
+ up_block = get_up_block(
538
+ up_block_type,
539
+ num_layers=reversed_layers_per_block[i] + 1,
540
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
541
+ in_channels=input_channel,
542
+ out_channels=output_channel,
543
+ prev_output_channel=prev_output_channel,
544
+ temb_channels=blocks_time_embed_dim,
545
+ add_upsample=add_upsample,
546
+ resnet_eps=norm_eps,
547
+ resnet_act_fn=act_fn,
548
+ resnet_groups=norm_num_groups,
549
+ cross_attention_dim=reversed_cross_attention_dim[i],
550
+ num_attention_heads=reversed_num_attention_heads[i],
551
+ dual_cross_attention=dual_cross_attention,
552
+ use_linear_projection=use_linear_projection,
553
+ only_cross_attention=only_cross_attention[i],
554
+ upcast_attention=upcast_attention,
555
+ resnet_time_scale_shift=resnet_time_scale_shift,
556
+ resnet_skip_time_act=resnet_skip_time_act,
557
+ resnet_out_scale_factor=resnet_out_scale_factor,
558
+ cross_attention_norm=cross_attention_norm,
559
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
560
+ image_prompt_settings=image_prompt_settings
561
+ )
562
+ self.up_blocks.append(up_block)
563
+ prev_output_channel = output_channel
564
+
565
+ # out
566
+ if norm_num_groups is not None:
567
+ self.conv_norm_out = nn.GroupNorm(
568
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
569
+ )
570
+
571
+ self.conv_act = get_activation(act_fn)
572
+
573
+ else:
574
+ self.conv_norm_out = None
575
+ self.conv_act = None
576
+
577
+ conv_out_padding = (conv_out_kernel - 1) // 2
578
+ self.conv_out = nn.Conv2d(
579
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
580
+ )
581
+
582
+ # NOTE: settings for IP-consistent generation
583
+ from utils import instantiate_from_config
584
+ self.vision_projection_type = image_prompt_settings.get("vision_projection_type", "none")
585
+ self.cross_attention_dim = cross_attention_dim[0]
586
+ if self.vision_projection_type != "none":
587
+ self.encoder_hidden_states_vision_projection = instantiate_from_config(image_prompt_settings["vision_projection_config"])
588
+
589
+ @property
590
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
591
+ r"""
592
+ Returns:
593
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
594
+ indexed by its weight name.
595
+ """
596
+ # set recursively
597
+ processors = {}
598
+
599
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
600
+ if hasattr(module, "set_processor"):
601
+ processors[f"{name}.processor"] = module.processor
602
+
603
+ for sub_name, child in module.named_children():
604
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
605
+
606
+ return processors
607
+
608
+ for name, module in self.named_children():
609
+ fn_recursive_add_processors(name, module, processors)
610
+
611
+ return processors
612
+
613
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
614
+ r"""
615
+ Sets the attention processor to use to compute attention.
616
+
617
+ Parameters:
618
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
619
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
620
+ for **all** `Attention` layers.
621
+
622
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
623
+ processor. This is strongly recommended when setting trainable attention processors.
624
+
625
+ """
626
+ count = len(self.attn_processors.keys())
627
+
628
+ if isinstance(processor, dict) and len(processor) != count:
629
+ raise ValueError(
630
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
631
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
632
+ )
633
+
634
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
635
+ if hasattr(module, "set_processor"):
636
+ if not isinstance(processor, dict):
637
+ module.set_processor(processor)
638
+ else:
639
+ module.set_processor(processor.pop(f"{name}.processor"))
640
+
641
+ for sub_name, child in module.named_children():
642
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
643
+
644
+ for name, module in self.named_children():
645
+ fn_recursive_attn_processor(name, module, processor)
646
+
647
+ def set_default_attn_processor(self):
648
+ """
649
+ Disables custom attention processors and sets the default attention implementation.
650
+ """
651
+ self.set_attn_processor(AttnProcessor())
652
+
653
+ def set_attention_slice(self, slice_size):
654
+ r"""
655
+ Enable sliced attention computation.
656
+
657
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
658
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
659
+
660
+ Args:
661
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
662
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
663
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
664
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
665
+ must be a multiple of `slice_size`.
666
+ """
667
+ sliceable_head_dims = []
668
+
669
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
670
+ if hasattr(module, "set_attention_slice"):
671
+ sliceable_head_dims.append(module.sliceable_head_dim)
672
+
673
+ for child in module.children():
674
+ fn_recursive_retrieve_sliceable_dims(child)
675
+
676
+ # retrieve number of attention layers
677
+ for module in self.children():
678
+ fn_recursive_retrieve_sliceable_dims(module)
679
+
680
+ num_sliceable_layers = len(sliceable_head_dims)
681
+
682
+ if slice_size == "auto":
683
+ # half the attention head size is usually a good trade-off between
684
+ # speed and memory
685
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
686
+ elif slice_size == "max":
687
+ # make smallest slice possible
688
+ slice_size = num_sliceable_layers * [1]
689
+
690
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
691
+
692
+ if len(slice_size) != len(sliceable_head_dims):
693
+ raise ValueError(
694
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
695
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
696
+ )
697
+
698
+ for i in range(len(slice_size)):
699
+ size = slice_size[i]
700
+ dim = sliceable_head_dims[i]
701
+ if size is not None and size > dim:
702
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
703
+
704
+ # Recursively walk through all the children.
705
+ # Any children which exposes the set_attention_slice method
706
+ # gets the message
707
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
708
+ if hasattr(module, "set_attention_slice"):
709
+ module.set_attention_slice(slice_size.pop())
710
+
711
+ for child in module.children():
712
+ fn_recursive_set_attention_slice(child, slice_size)
713
+
714
+ reversed_slice_size = list(reversed(slice_size))
715
+ for module in self.children():
716
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
717
+
718
+ def _set_gradient_checkpointing(self, module, value=False):
719
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
720
+ module.gradient_checkpointing = value
721
+
722
+ def forward(
723
+ self,
724
+ sample: torch.FloatTensor,
725
+ timestep: Union[torch.Tensor, float, int],
726
+ encoder_hidden_states: torch.Tensor,
727
+ class_labels: Optional[torch.Tensor] = None,
728
+ timestep_cond: Optional[torch.Tensor] = None,
729
+ attention_mask: Optional[torch.Tensor] = None,
730
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
731
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
732
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
733
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
734
+ encoder_attention_mask: Optional[torch.Tensor] = None,
735
+ return_dict: bool = True,
736
+
737
+ vision_input_dict = None,
738
+ vision_guided_mask: Optional[torch.Tensor] = None,
739
+ return_text2image_mask: bool = False,
740
+ return_as_origin: bool = True,
741
+ return_self_attn_map = False,
742
+ multiple_reference_image = False,
743
+ ) -> Union[UNet2DConditionOutput, Tuple]:
744
+ r"""
745
+ The [`UNet2DConditionModel`] forward method.
746
+
747
+ Args:
748
+ sample (`torch.FloatTensor`):
749
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
750
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
751
+ encoder_hidden_states (`torch.FloatTensor`):
752
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
753
+ encoder_attention_mask (`torch.Tensor`):
754
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
755
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
756
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
757
+ return_dict (`bool`, *optional*, defaults to `True`):
758
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
759
+ tuple.
760
+ cross_attention_kwargs (`dict`, *optional*):
761
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
762
+ added_cond_kwargs: (`dict`, *optional*):
763
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
764
+ are passed along to the UNet blocks.
765
+
766
+ Returns:
767
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
768
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
769
+ a `tuple` is returned where the first element is the sample tensor.
770
+ """
771
+ extra_dict_outputs = {}
772
+ extra_dict_inputs = {}
773
+ extra_dict_inputs["multiple_reference_image"] = multiple_reference_image
774
+
775
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
776
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
777
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
778
+ # on the fly if necessary.
779
+ default_overall_up_factor = 2**self.num_upsamplers
780
+
781
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
782
+ forward_upsample_size = False
783
+ upsample_size = None
784
+
785
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
786
+ logger.info("Forward upsample size to force interpolation output size.")
787
+ forward_upsample_size = True
788
+
789
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
790
+ # expects mask of shape:
791
+ # [batch, key_tokens]
792
+ # adds singleton query_tokens dimension:
793
+ # [batch, 1, key_tokens]
794
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
795
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
796
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
797
+ if attention_mask is not None:
798
+ # assume that mask is expressed as:
799
+ # (1 = keep, 0 = discard)
800
+ # convert mask into a bias that can be added to attention scores:
801
+ # (keep = +0, discard = -10000.0)
802
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
803
+ attention_mask = attention_mask.unsqueeze(1)
804
+
805
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
806
+ if encoder_attention_mask is not None:
807
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
808
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
809
+
810
+ # 0. center input if necessary
811
+ if self.config.center_input_sample:
812
+ sample = 2 * sample - 1.0
813
+
814
+ # 1. time
815
+ timesteps = timestep
816
+ if not torch.is_tensor(timesteps):
817
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
818
+ # This would be a good case for the `match` statement (Python 3.10+)
819
+ is_mps = sample.device.type == "mps"
820
+ if isinstance(timestep, float):
821
+ dtype = torch.float32 if is_mps else torch.float64
822
+ else:
823
+ dtype = torch.int32 if is_mps else torch.int64
824
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
825
+ elif len(timesteps.shape) == 0:
826
+ timesteps = timesteps[None].to(sample.device)
827
+
828
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
829
+ timesteps = timesteps.expand(sample.shape[0])
830
+
831
+ t_emb = self.time_proj(timesteps)
832
+
833
+ # `Timesteps` does not contain any weights and will always return f32 tensors
834
+ # but time_embedding might actually be running in fp16. so we need to cast here.
835
+ # there might be better ways to encapsulate this.
836
+ t_emb = t_emb.to(dtype=sample.dtype)
837
+
838
+ emb = self.time_embedding(t_emb, timestep_cond)
839
+ aug_emb = None
840
+
841
+ if self.class_embedding is not None:
842
+ if class_labels is None:
843
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
844
+
845
+ if self.config.class_embed_type == "timestep":
846
+ class_labels = self.time_proj(class_labels)
847
+
848
+ # `Timesteps` does not contain any weights and will always return f32 tensors
849
+ # there might be better ways to encapsulate this.
850
+ class_labels = class_labels.to(dtype=sample.dtype)
851
+
852
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
853
+
854
+ if self.config.class_embeddings_concat:
855
+ emb = torch.cat([emb, class_emb], dim=-1)
856
+ else:
857
+ emb = emb + class_emb
858
+
859
+ if self.config.addition_embed_type == "text":
860
+ aug_emb = self.add_embedding(encoder_hidden_states)
861
+ elif self.config.addition_embed_type == "text_image":
862
+ # Kandinsky 2.1 - style
863
+ if "image_embeds" not in added_cond_kwargs:
864
+ raise ValueError(
865
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
866
+ )
867
+
868
+ image_embs = added_cond_kwargs.get("image_embeds")
869
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
870
+ aug_emb = self.add_embedding(text_embs, image_embs)
871
+ elif self.config.addition_embed_type == "text_time":
872
+ if "text_embeds" not in added_cond_kwargs:
873
+ raise ValueError(
874
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
875
+ )
876
+ text_embeds = added_cond_kwargs.get("text_embeds")
877
+ if "time_ids" not in added_cond_kwargs:
878
+ raise ValueError(
879
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
880
+ )
881
+ time_ids = added_cond_kwargs.get("time_ids")
882
+ time_embeds = self.add_time_proj(time_ids.flatten())
883
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
884
+
885
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
886
+ add_embeds = add_embeds.to(emb.dtype)
887
+ aug_emb = self.add_embedding(add_embeds)
888
+ elif self.config.addition_embed_type == "image":
889
+ # Kandinsky 2.2 - style
890
+ if "image_embeds" not in added_cond_kwargs:
891
+ raise ValueError(
892
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
893
+ )
894
+ image_embs = added_cond_kwargs.get("image_embeds")
895
+ aug_emb = self.add_embedding(image_embs)
896
+ elif self.config.addition_embed_type == "image_hint":
897
+ # Kandinsky 2.2 - style
898
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
899
+ raise ValueError(
900
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
901
+ )
902
+ image_embs = added_cond_kwargs.get("image_embeds")
903
+ hint = added_cond_kwargs.get("hint")
904
+ aug_emb, hint = self.add_embedding(image_embs, hint)
905
+ sample = torch.cat([sample, hint], dim=1)
906
+
907
+ emb = emb + aug_emb if aug_emb is not None else emb
908
+
909
+ if self.time_embed_act is not None:
910
+ emb = self.time_embed_act(emb)
911
+
912
+ # 2. pre-process
913
+ sample = self.conv_in(sample)
914
+
915
+ # NOTE: image condition
916
+ encoder_hidden_states_vision = None
917
+ if vision_input_dict is not None and self.vision_projection_type != "none":
918
+ if multiple_reference_image:
919
+ encoder_hidden_states_vision = []
920
+ for encoder_hidden_states_vision_i in encoder_hidden_states_vision:
921
+ encoder_hidden_states_vision_i = self.encoder_hidden_states_vision_projection(vision_input_dict=vision_input_dict, time_emb=emb, image_latent=sample)
922
+ encoder_hidden_states_vision.append(encoder_hidden_states_vision_i)
923
+ else:
924
+ encoder_hidden_states_vision = self.encoder_hidden_states_vision_projection(vision_input_dict=vision_input_dict, time_emb=emb, image_latent=sample)
925
+
926
+ if type(encoder_hidden_states_vision) == dict:
927
+ if "l_disen" in encoder_hidden_states_vision.keys():
928
+ extra_dict_outputs["l_disen"] = encoder_hidden_states_vision["l_disen"]
929
+
930
+
931
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
932
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
933
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
934
+ # Kadinsky 2.1 - style
935
+ if "image_embeds" not in added_cond_kwargs:
936
+ raise ValueError(
937
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
938
+ )
939
+
940
+ image_embeds = added_cond_kwargs.get("image_embeds")
941
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
942
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
943
+ # Kandinsky 2.2 - style
944
+ if "image_embeds" not in added_cond_kwargs:
945
+ raise ValueError(
946
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
947
+ )
948
+ image_embeds = added_cond_kwargs.get("image_embeds")
949
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
950
+
951
+ # 3. down
952
+ down_block_res_samples = (sample,)
953
+ additional_residuals = {}
954
+ for downsample_block in self.down_blocks:
955
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
956
+ sample, res_samples, extra_dict_output_down = downsample_block(
957
+ hidden_states=sample,
958
+ temb=emb,
959
+ encoder_hidden_states=encoder_hidden_states,
960
+ attention_mask=attention_mask,
961
+ cross_attention_kwargs=cross_attention_kwargs,
962
+ encoder_attention_mask=encoder_attention_mask,
963
+ encoder_hidden_states_vision=encoder_hidden_states_vision,
964
+ encoder_hidden_states_control=None,
965
+ vision_guided_mask=vision_guided_mask,
966
+ extra_dict_inputs=extra_dict_inputs,
967
+ return_self_attn_map=return_self_attn_map,
968
+ **additional_residuals,
969
+ )
970
+ else:
971
+ sample, res_samples, extra_dict_output_down = downsample_block(hidden_states=sample, temb=emb)
972
+ extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_down)
973
+
974
+ down_block_res_samples += res_samples
975
+
976
+ if down_block_additional_residuals is not None:
977
+ new_down_block_res_samples = ()
978
+
979
+ for down_block_res_sample, down_block_additional_residual in zip(
980
+ down_block_res_samples, down_block_additional_residuals
981
+ ):
982
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
983
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
984
+
985
+ down_block_res_samples = new_down_block_res_samples
986
+
987
+ # 4. mid
988
+ if self.mid_block is not None:
989
+ sample, extra_dict_output_middel = self.mid_block(
990
+ sample,
991
+ emb,
992
+ encoder_hidden_states=encoder_hidden_states,
993
+ attention_mask=attention_mask,
994
+ cross_attention_kwargs=cross_attention_kwargs,
995
+ encoder_attention_mask=encoder_attention_mask,
996
+ encoder_hidden_states_vision=encoder_hidden_states_vision,
997
+ encoder_hidden_states_control=None,
998
+ vision_guided_mask=vision_guided_mask,
999
+ extra_dict_inputs=extra_dict_inputs,
1000
+ return_self_attn_map=return_self_attn_map,
1001
+ )
1002
+ extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_middel)
1003
+
1004
+
1005
+ if mid_block_additional_residual is not None:
1006
+ sample = sample + mid_block_additional_residual
1007
+
1008
+ # 5. up
1009
+ for i, upsample_block in enumerate(self.up_blocks):
1010
+ is_final_block = i == len(self.up_blocks) - 1
1011
+
1012
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1013
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1014
+
1015
+ # if we have not reached the final block and need to forward the
1016
+ # upsample size, we do it here
1017
+ if not is_final_block and forward_upsample_size:
1018
+ upsample_size = down_block_res_samples[-1].shape[2:]
1019
+
1020
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1021
+ sample, extra_dict_output_up = upsample_block(
1022
+ hidden_states=sample,
1023
+ temb=emb,
1024
+ res_hidden_states_tuple=res_samples,
1025
+ encoder_hidden_states=encoder_hidden_states,
1026
+ cross_attention_kwargs=cross_attention_kwargs,
1027
+ upsample_size=upsample_size,
1028
+ attention_mask=attention_mask,
1029
+ encoder_attention_mask=encoder_attention_mask,
1030
+ encoder_hidden_states_vision=encoder_hidden_states_vision,
1031
+ encoder_hidden_states_control=None,
1032
+ vision_guided_mask=vision_guided_mask,
1033
+ extra_dict_inputs=extra_dict_inputs,
1034
+ return_self_attn_map=return_self_attn_map
1035
+ )
1036
+ else:
1037
+ sample, extra_dict_output_up = upsample_block(
1038
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1039
+ )
1040
+ extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_up)
1041
+
1042
+ # 6. post-process
1043
+ if self.conv_norm_out:
1044
+ sample = self.conv_norm_out(sample)
1045
+ sample = self.conv_act(sample)
1046
+ sample = self.conv_out(sample)
1047
+
1048
+ if not return_dict:
1049
+ return (sample,)
1050
+
1051
+ if return_as_origin:
1052
+ return UNet2DConditionOutput(sample=sample)
1053
+ else:
1054
+ if return_text2image_mask:
1055
+ return UNet2DConditionOutput(sample=sample).sample, extra_dict_outputs
1056
+ else:
1057
+ return UNet2DConditionOutput(sample=sample).sample
1058
+
1059
+
models/vae.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import diffusers
17
+ import torch
18
+
19
+ class AutoencoderKL(diffusers.AutoencoderKL):
20
+ """
21
+ We simply inherit the model code from diffusers
22
+ """
23
+ def __init__(self, attention=True, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ # A hacky way to remove attention.
26
+ if not attention:
27
+ self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
28
+ self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
29
+
30
+ def load_state_dict(self, state_dict, strict=True):
31
+ # Newer version of diffusers changed the model keys, causing incompatibility with old checkpoints.
32
+ # They provided a method for conversion. We call conversion before loading state_dict.
33
+ convert_deprecated_attention_blocks = getattr(self, "_convert_deprecated_attention_blocks", None)
34
+ if callable(convert_deprecated_attention_blocks):
35
+ convert_deprecated_attention_blocks(state_dict)
36
+ return super().load_state_dict(state_dict, strict)
prompts/validation_negative.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ (((naked))), deformed, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, malformed hands, blurry, ((((mutated hands and fingers)))), watermark, watermarked, oversaturated, censored, distorted hands, amputation, missing hands, obese, doubled face, double hands
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
2
+
3
+ numpy
4
+ ftfy
5
+
6
+ # Training
7
+ bs4==0.0.1 # Needed for text cleaning
8
+ bson==0.5.10
9
+ diffusers==0.19.3 # diffusers[torch]==0.19.3 in control
10
+ einops==0.6.0
11
+ ftfy==6.1.1 # Needed for text cleaning
12
+ kornia==0.6.12
13
+ lpips==0.1.4
14
+ sentencepiece==0.1.99 # Needed for T5 tokenizer
15
+ transformers==4.36.2
16
+ tqdm==4.64.1
17
+ torchgeometry # Needed for ssim loss
18
+ expecttest # Needed for compile
19
+ accelerate==0.24.1 # model saving bugs when accelerate==0.25.0
20
+
21
+ # Inference
22
+ av==10.0.0
23
+ pims==0.6.1
24
+ opencv-python-headless==4.6.0.66
25
+
26
+ gradio==3.42.0
27
+ httpx==0.23.3
28
+ opencv-python
29
+ open_clip_torch
30
+ protobuf==3.20.0
31
+ huggingface_hub==0.25.0
32
+
33
+ open_clip_torch
34
+ git+https://github.com/openai/CLIP.git
schedulers/__pycache__/base.cpython-310.pyc ADDED
Binary file (3.63 kB). View file
 
schedulers/__pycache__/ddim.cpython-310.pyc ADDED
Binary file (1.98 kB). View file
 
schedulers/__pycache__/dpm_s.cpython-310.pyc ADDED
Binary file (6.09 kB). View file
 
schedulers/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.8 kB). View file
 
schedulers/base.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from dataclasses import dataclass
17
+ from abc import ABC
18
+ from typing import Optional, Union, List
19
+
20
+
21
+ @dataclass
22
+ class SchedulerConversionOutput:
23
+ pred_epsilon: torch.Tensor
24
+ pred_original_sample: torch.Tensor
25
+ pred_velocity: torch.Tensor
26
+
27
+
28
+ @dataclass
29
+ class SchedulerStepOutput:
30
+ prev_sample: torch.Tensor
31
+ pred_original_sample: Optional[torch.Tensor] = None
32
+
33
+
34
+ class Scheduler(ABC):
35
+ prediction_types = ["epsilon", "sample", "v_prediction"]
36
+ timesteps_types = ["leading", "linspace", "trailing"]
37
+
38
+ def __init__(
39
+ self,
40
+ num_train_timesteps: int,
41
+ num_inference_timesteps: int,
42
+ betas: torch.Tensor,
43
+ inference_timesteps: Union[str, List[int]] = "trailing",
44
+ set_alpha_to_one: bool = True,
45
+ device: Optional[Union[str, torch.device]] = None,
46
+ dtype: torch.dtype = torch.float32
47
+ ):
48
+ assert num_train_timesteps > 0
49
+ assert num_train_timesteps >= num_inference_timesteps
50
+ assert num_train_timesteps == betas.size(0)
51
+ assert betas.ndim == 1
52
+
53
+ self.device = device or betas.device
54
+ self.dtype = dtype
55
+
56
+ self.num_train_timesteps = num_train_timesteps
57
+ self.num_inference_timesteps = num_inference_timesteps
58
+
59
+ self.betas = betas.to(device=device, dtype=dtype)
60
+ self.alphas = 1.0 - self.betas
61
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
62
+ self.final_alpha_cumprod = torch.tensor(1.0, device=self.device, dtype=self.dtype) if set_alpha_to_one else self.alphas_cumprod[0]
63
+
64
+ if isinstance(inference_timesteps, list):
65
+ # If user defines a custom inference timestep, directly assign it.
66
+ assert len(inference_timesteps) == num_inference_timesteps
67
+ self.timesteps = torch.tensor(inference_timesteps, device=self.device, dtype=torch.int)
68
+ elif inference_timesteps == "trailing":
69
+ # Example 20 steps: [999, 949, 899, 849, 799, 749, 699, 649, 599, 549, 499, 449, 399, 349, 299, 249, 199, 149, 99, 49]
70
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / num_inference_timesteps, device=self.device).round().int()
71
+ elif inference_timesteps == "linspace":
72
+ # Example 20 steps: [999, 946, 894, 841, 789, 736, 684, 631, 578, 526, 473, 421, 368, 315, 263, 210, 158, 105, 53, 0]
73
+ self.timesteps = torch.linspace(0, num_train_timesteps - 1, num_inference_timesteps, device=self.device).round().int().flip(0)
74
+ elif inference_timesteps == "leading":
75
+ # Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
76
+ # The inference timestep does not start from 999.
77
+ # Example 20 steps: [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0]
78
+ self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
79
+ else:
80
+ raise NotImplementedError
81
+
82
+ def reset(self):
83
+ pass
84
+
85
+ def add_noise(
86
+ self,
87
+ original_samples: torch.Tensor,
88
+ noise: torch.Tensor,
89
+ timesteps: Union[torch.Tensor, int],
90
+ ) -> torch.Tensor:
91
+ alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (original_samples.ndim - 1)))
92
+ return alpha_prod_t ** (0.5) * original_samples + (1 - alpha_prod_t) ** (0.5) * noise
93
+
94
+ def convert_output(
95
+ self,
96
+ model_output: torch.Tensor,
97
+ model_output_type: str,
98
+ sample: torch.Tensor,
99
+ timesteps: Union[torch.Tensor, int]
100
+ ) -> SchedulerConversionOutput:
101
+ assert model_output_type in self.prediction_types
102
+
103
+ alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
104
+ beta_prod_t = 1 - alpha_prod_t
105
+
106
+ if model_output_type == "epsilon":
107
+ pred_epsilon = model_output
108
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5)
109
+ pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
110
+ elif model_output_type == "sample":
111
+ pred_original_sample = model_output
112
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
113
+ pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
114
+ elif model_output_type == "v_prediction":
115
+ pred_velocity = model_output
116
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
117
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
118
+ else:
119
+ raise ValueError("Unknown prediction type")
120
+
121
+ return SchedulerConversionOutput(
122
+ pred_epsilon=pred_epsilon,
123
+ pred_original_sample=pred_original_sample,
124
+ pred_velocity=pred_velocity)
125
+
126
+ def get_velocity(
127
+ self,
128
+ sample: torch.Tensor,
129
+ noise: torch.Tensor,
130
+ timesteps: torch.Tensor
131
+ ) -> torch.FloatTensor:
132
+ alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
133
+ return alpha_prod_t ** (0.5) * noise - (1 - alpha_prod_t) ** (0.5) * sample
schedulers/ddim.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import *
16
+
17
+
18
+ class DDIMScheduler(Scheduler):
19
+ def step(
20
+ self,
21
+ model_output: torch.Tensor,
22
+ model_output_type: str,
23
+ timestep: Union[torch.Tensor, int],
24
+ sample: torch.Tensor,
25
+ eta: float = 0.0,
26
+ clip_sample: bool = False,
27
+ dynamic_threshold: Optional[float] = None,
28
+ variance_noise: Optional[torch.Tensor] = None,
29
+ ) -> SchedulerStepOutput:
30
+ # 1. get previous step value (t-1)
31
+ if not isinstance(timestep, torch.Tensor):
32
+ timestep = torch.tensor(timestep, device=self.device, dtype=torch.int)
33
+
34
+ idx = timestep.reshape(-1, 1).eq(self.timesteps.reshape(1, -1)).nonzero()[:, 1]
35
+ prev_timestep = self.timesteps[idx.add(1).clamp_max(self.num_inference_timesteps - 1)]
36
+
37
+ # 2. compute alphas, betas
38
+ alpha_prod_t = self.alphas_cumprod[timestep].reshape(-1, *([1] * (sample.ndim - 1)))
39
+ alpha_prod_t_prev = torch.where(idx < self.num_inference_timesteps - 1, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod).reshape(-1, *([1] * (sample.ndim - 1)))
40
+ beta_prod_t = 1 - alpha_prod_t
41
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
42
+
43
+ # 3. compute predicted original sample from predicted noise also called
44
+ model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep)
45
+ pred_original_sample = model_output_conversion.pred_original_sample
46
+ pred_epsilon = model_output_conversion.pred_epsilon
47
+
48
+ # 4. Clip or threshold "predicted x_0"
49
+ if clip_sample:
50
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
51
+ pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
52
+
53
+ if dynamic_threshold is not None:
54
+ # Dynamic thresholding in https://arxiv.org/abs/2205.11487
55
+ dynamic_max_val = pred_original_sample \
56
+ .flatten(1) \
57
+ .abs() \
58
+ .float() \
59
+ .quantile(dynamic_threshold, dim=1) \
60
+ .type_as(pred_original_sample) \
61
+ .clamp_min(1) \
62
+ .view(-1, *([1] * (pred_original_sample.ndim - 1)))
63
+ pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
64
+ pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
65
+
66
+ # 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf
67
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
68
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
69
+ std_dev_t = eta * variance ** (0.5)
70
+
71
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
72
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
73
+
74
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
75
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
76
+
77
+ # 8. add "random noise" if needed.
78
+ if eta > 0:
79
+ if variance_noise is None:
80
+ variance_noise = torch.randn_like(model_output)
81
+ prev_sample = prev_sample + std_dev_t * variance_noise
82
+
83
+ return SchedulerStepOutput(
84
+ prev_sample=prev_sample,
85
+ pred_original_sample=pred_original_sample)
schedulers/dpm_m.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from .base import *
24
+
25
+
26
+ class DPMSolverMultistepScheduler(Scheduler):
27
+ def __init__(
28
+ self,
29
+ # Generic scheduler settings
30
+ num_inference_timesteps: int,
31
+ betas: torch.Tensor,
32
+ num_train_timesteps: int = 1000,
33
+ inference_timesteps: Union[str, List[str]] = "trailing",
34
+ set_alpha_to_one: bool = True,
35
+ device: Optional[torch.device] = None,
36
+ dtype: torch.dtype = torch.float32,
37
+ # DPM scheduler settings
38
+ solver_order: int = 2,
39
+ algorithm_type: str = "dpmsolver++",
40
+ solver_type: str = "midpoint",
41
+ lower_order_final: bool = True,
42
+ use_karras_sigmas: bool = False,
43
+ ):
44
+ super().__init__(
45
+ num_train_timesteps=num_train_timesteps,
46
+ num_inference_timesteps=num_inference_timesteps,
47
+ betas=betas,
48
+ inference_timesteps=inference_timesteps,
49
+ set_alpha_to_one=set_alpha_to_one,
50
+ device=device,
51
+ dtype=dtype,
52
+ )
53
+
54
+ self.solver_order = solver_order
55
+ self.solver_type = solver_type
56
+ self.lower_order_final = lower_order_final
57
+ self.algorithm_type = algorithm_type
58
+
59
+ # Currently we only support VP-type noise schedule
60
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
61
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
62
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
63
+
64
+ sigmas = torch.sqrt((1 - self.alphas_cumprod) / self.alphas_cumprod)
65
+ if use_karras_sigmas:
66
+ log_sigmas = torch.log(sigmas)
67
+ sigmas = self._convert_to_karras(
68
+ in_sigmas=sigmas, num_inference_timesteps=num_inference_timesteps)
69
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas)
70
+ for sigma in sigmas]).round()
71
+ timesteps = np.flip(timesteps).copy().astype(np.int64)
72
+ self.timesteps = torch.from_numpy(timesteps).to(device)
73
+ sigmas = torch.from_numpy(sigmas).to(device)
74
+ self.sigmas = sigmas
75
+
76
+ # standard deviation of the initial noise distribution
77
+ self.init_noise_sigma = 1.0
78
+
79
+ # settings for DPM-Solver
80
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++", "deis"]:
81
+ raise NotImplementedError(
82
+ f"{algorithm_type} does is not implemented for {self.__class__}")
83
+
84
+ if solver_type not in ["midpoint", "heun", "logrho", "bh1", "bh2"]:
85
+ raise NotImplementedError(
86
+ f"{solver_type} does is not implemented for {self.__class__}")
87
+
88
+ # setable values
89
+ self.reset()
90
+
91
+ def reset(self):
92
+ self.model_outputs = [None] * self.solver_order
93
+ self.lower_order_nums = 0
94
+
95
+ def _sigma_to_t(self, sigma, log_sigmas):
96
+ # get log sigma
97
+ log_sigma = np.log(sigma)
98
+
99
+ # get distribution
100
+ dists = log_sigma - log_sigmas[:, np.newaxis]
101
+
102
+ # get sigmas range
103
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(
104
+ axis=0).clip(max=log_sigmas.shape[0] - 2)
105
+ high_idx = low_idx + 1
106
+
107
+ low = log_sigmas[low_idx]
108
+ high = log_sigmas[high_idx]
109
+
110
+ # interpolate sigmas
111
+ w = (low - log_sigma) / (low - high)
112
+ w = np.clip(w, 0, 1)
113
+
114
+ # transform interpolation to time range
115
+ t = (1 - w) * low_idx + w * high_idx
116
+ t = t.reshape(sigma.shape)
117
+ return t
118
+
119
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
120
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_timesteps) -> torch.FloatTensor:
121
+ """Constructs the noise schedule of Karras et al. (2022)."""
122
+
123
+ sigma_min: float = in_sigmas[-1].item()
124
+ sigma_max: float = in_sigmas[0].item()
125
+
126
+ rho = 7.0 # 7.0 is the value used in the paper
127
+ ramp = np.linspace(0, 1, num_inference_timesteps)
128
+ min_inv_rho = sigma_min ** (1 / rho)
129
+ max_inv_rho = sigma_max ** (1 / rho)
130
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
131
+ return sigmas
132
+
133
+ def dpm_solver_first_order_update(
134
+ self,
135
+ model_output: torch.FloatTensor,
136
+ timestep: int,
137
+ prev_timestep: int,
138
+ sample: torch.FloatTensor,
139
+ noise: Optional[torch.FloatTensor] = None,
140
+ ) -> torch.FloatTensor:
141
+ """
142
+ One step for the first-order DPM-Solver (equivalent to DDIM).
143
+
144
+ See https://arxiv.org/abs/2206.00927 for the detailed derivation.
145
+
146
+ Args:
147
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
148
+ timestep (`int`): current discrete timestep in the diffusion chain.
149
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
150
+ sample (`torch.FloatTensor`):
151
+ current instance of sample being created by diffusion process.
152
+
153
+ Returns:
154
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
155
+ """
156
+ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
157
+ alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
158
+ sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
159
+ h = lambda_t - lambda_s
160
+ if self.algorithm_type == "dpmsolver++":
161
+ x_t = (sigma_t / sigma_s) * sample - \
162
+ (alpha_t * (torch.exp(-h) - 1.0)) * model_output
163
+ elif self.algorithm_type == "dpmsolver":
164
+ x_t = (alpha_t / alpha_s) * sample - \
165
+ (sigma_t * (torch.exp(h) - 1.0)) * model_output
166
+ elif self.algorithm_type == "sde-dpmsolver++":
167
+ assert noise is not None
168
+ x_t = (
169
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
170
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
171
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
172
+ )
173
+ elif self.algorithm_type == "sde-dpmsolver":
174
+ assert noise is not None
175
+ x_t = (
176
+ (alpha_t / alpha_s) * sample
177
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
178
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
179
+ )
180
+ return x_t
181
+
182
+ def multistep_dpm_solver_second_order_update(
183
+ self,
184
+ model_output_list: List[torch.FloatTensor],
185
+ timestep_list: List[int],
186
+ prev_timestep: int,
187
+ sample: torch.FloatTensor,
188
+ noise: Optional[torch.FloatTensor] = None,
189
+ ) -> torch.FloatTensor:
190
+ """
191
+ One step for the second-order multistep DPM-Solver.
192
+
193
+ Args:
194
+ model_output_list (`List[torch.FloatTensor]`):
195
+ direct outputs from learned diffusion model at current and latter timesteps.
196
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
197
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
198
+ sample (`torch.FloatTensor`):
199
+ current instance of sample being created by diffusion process.
200
+
201
+ Returns:
202
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
203
+ """
204
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
205
+ m0, m1 = model_output_list[-1], model_output_list[-2]
206
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
207
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
208
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
209
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
210
+ r0 = h_0 / h
211
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
212
+ if self.algorithm_type == "dpmsolver++":
213
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
214
+ if self.solver_type == "midpoint":
215
+ x_t = (
216
+ (sigma_t / sigma_s0) * sample
217
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
218
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
219
+ )
220
+ elif self.solver_type == "heun":
221
+ x_t = (
222
+ (sigma_t / sigma_s0) * sample
223
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
224
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
225
+ )
226
+ elif self.algorithm_type == "dpmsolver":
227
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
228
+ if self.solver_type == "midpoint":
229
+ x_t = (
230
+ (alpha_t / alpha_s0) * sample
231
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
232
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
233
+ )
234
+ elif self.solver_type == "heun":
235
+ x_t = (
236
+ (alpha_t / alpha_s0) * sample
237
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
238
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
239
+ )
240
+ elif self.algorithm_type == "sde-dpmsolver++":
241
+ assert noise is not None
242
+ if self.solver_type == "midpoint":
243
+ x_t = (
244
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
245
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
246
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
247
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
248
+ )
249
+ elif self.solver_type == "heun":
250
+ x_t = (
251
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
252
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
253
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
254
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
255
+ )
256
+ elif self.algorithm_type == "sde-dpmsolver":
257
+ assert noise is not None
258
+ if self.solver_type == "midpoint":
259
+ x_t = (
260
+ (alpha_t / alpha_s0) * sample
261
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
262
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
263
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
264
+ )
265
+ elif self.solver_type == "heun":
266
+ x_t = (
267
+ (alpha_t / alpha_s0) * sample
268
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
269
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
270
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
271
+ )
272
+ return x_t
273
+
274
+ def multistep_dpm_solver_third_order_update(
275
+ self,
276
+ model_output_list: List[torch.FloatTensor],
277
+ timestep_list: List[int],
278
+ prev_timestep: int,
279
+ sample: torch.FloatTensor,
280
+ ) -> torch.FloatTensor:
281
+ """
282
+ One step for the third-order multistep DPM-Solver.
283
+
284
+ Args:
285
+ model_output_list (`List[torch.FloatTensor]`):
286
+ direct outputs from learned diffusion model at current and latter timesteps.
287
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
288
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
289
+ sample (`torch.FloatTensor`):
290
+ current instance of sample being created by diffusion process.
291
+
292
+ Returns:
293
+ `torch.FloatTensor`: the sample tensor at the previous timestep.
294
+ """
295
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
296
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
297
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
298
+ self.lambda_t[t],
299
+ self.lambda_t[s0],
300
+ self.lambda_t[s1],
301
+ self.lambda_t[s2],
302
+ )
303
+ alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
304
+ sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
305
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
306
+ r0, r1 = h_0 / h, h_1 / h
307
+ D0 = m0
308
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
309
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
310
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
311
+ if self.algorithm_type == "dpmsolver++":
312
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
313
+ x_t = (
314
+ (sigma_t / sigma_s0) * sample
315
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
316
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
317
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
318
+ )
319
+ elif self.algorithm_type == "dpmsolver":
320
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
321
+ x_t = (
322
+ (alpha_t / alpha_s0) * sample
323
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
324
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
325
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
326
+ )
327
+ return x_t
328
+
329
+ def step(
330
+ self,
331
+ model_output: torch.FloatTensor,
332
+ model_output_type: str,
333
+ timestep: int,
334
+ sample: torch.FloatTensor,
335
+ ) -> SchedulerStepOutput:
336
+ """
337
+ Step function propagating the sample with the multistep DPM-Solver.
338
+
339
+ Args:
340
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
341
+ timestep (`int`): current discrete timestep in the diffusion chain.
342
+ sample (`torch.FloatTensor`):
343
+ current instance of sample being created by diffusion process.
344
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
345
+
346
+ Returns:
347
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
348
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
349
+
350
+ """
351
+ if self.num_inference_timesteps is None:
352
+ raise ValueError(
353
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
354
+ )
355
+
356
+ if isinstance(timestep, torch.Tensor):
357
+ timestep = timestep.to(self.timesteps.device)
358
+ step_index = (self.timesteps == timestep).nonzero()
359
+ if len(step_index) == 0:
360
+ step_index = len(self.timesteps) - 1
361
+ else:
362
+ step_index = step_index.item()
363
+ prev_timestep = 0 if step_index == len(
364
+ self.timesteps) - 1 else self.timesteps[step_index + 1]
365
+ lower_order_final = (
366
+ (step_index == len(self.timesteps) -
367
+ 1) and self.lower_order_final and len(self.timesteps) < 15
368
+ )
369
+ lower_order_second = (
370
+ (step_index == len(self.timesteps) -
371
+ 2) and self.lower_order_final and len(self.timesteps) < 15
372
+ )
373
+
374
+ model_output_convert = self.convert_output(
375
+ model_output, model_output_type=model_output_type, sample=sample, timesteps=timestep)
376
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
377
+ if self.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
378
+ model_output = model_output_convert.pred_original_sample
379
+ # DPM-Solver needs to solve an integral of the noise prediction model.
380
+ elif self.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
381
+ model_output = model_output_convert.pred_epsilon
382
+
383
+ for i in range(self.solver_order - 1):
384
+ self.model_outputs[i] = self.model_outputs[i + 1]
385
+ self.model_outputs[-1] = model_output
386
+
387
+ if self.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
388
+ noise = torch.randn_like(
389
+ model_output, device=model_output.device, dtype=model_output.dtype)
390
+ else:
391
+ noise = None
392
+
393
+ if self.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
394
+ prev_sample = self.dpm_solver_first_order_update(
395
+ model_output, timestep, prev_timestep, sample, noise=noise
396
+ )
397
+ elif self.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
398
+ timestep_list = [self.timesteps[step_index - 1], timestep]
399
+ prev_sample = self.multistep_dpm_solver_second_order_update(
400
+ self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
401
+ )
402
+ else:
403
+ timestep_list = [self.timesteps[step_index - 2],
404
+ self.timesteps[step_index - 1], timestep]
405
+ prev_sample = self.multistep_dpm_solver_third_order_update(
406
+ self.model_outputs, timestep_list, prev_timestep, sample
407
+ )
408
+
409
+ if self.lower_order_nums < self.solver_order:
410
+ self.lower_order_nums += 1
411
+
412
+ return SchedulerStepOutput(prev_sample=prev_sample, pred_original_sample=model_output_convert.pred_original_sample)
schedulers/dpm_s.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import *
16
+
17
+
18
+ class DPMSolverSingleStepScheduler(Scheduler):
19
+ def __init__(
20
+ self,
21
+ # Generic scheduler settings
22
+ num_train_timesteps: int,
23
+ num_inference_timesteps: int,
24
+ betas: torch.Tensor,
25
+ inference_timesteps: Union[str, List[int]] = "trailing",
26
+ set_alpha_to_one: bool = True,
27
+ device: Optional[Union[str, torch.device]] = None,
28
+ dtype: torch.dtype = torch.float32,
29
+ # DPM scheduler settings
30
+ algorithm_type: str = "dpmsolver++",
31
+ solver_type: str = "midpoint",
32
+ solver_order: int = 2,
33
+ lower_order_final: bool = True,
34
+ ):
35
+ super().__init__(
36
+ num_train_timesteps=num_train_timesteps,
37
+ num_inference_timesteps=num_inference_timesteps,
38
+ betas=betas,
39
+ inference_timesteps=inference_timesteps,
40
+ set_alpha_to_one=set_alpha_to_one,
41
+ device=device,
42
+ dtype=dtype,
43
+ )
44
+
45
+ self.solver_order = solver_order
46
+ self.solver_type = solver_type
47
+ self.lower_order_final = lower_order_final
48
+ self.algorithm_type = algorithm_type
49
+
50
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
51
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
52
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
53
+
54
+ self.reset()
55
+
56
+ def reset(self):
57
+ self.model_outputs = [None] * self.solver_order
58
+ self.sample = None
59
+ self.order_list = self.get_order_list()
60
+ self.last_step_index = None
61
+
62
+ def get_order_list(self) -> List[int]:
63
+ steps = self.num_inference_timesteps
64
+ order = self.solver_order
65
+ # First step must be order 1
66
+ # Second step must be order 1 in case of terminal zero SNR
67
+ orders = [1] + [(i % order) + 1 for i in range(steps - 1)] + [1]
68
+ # Last step should be order 1 for better quality.
69
+ if self.lower_order_final:
70
+ orders[-1] = 1
71
+ return orders
72
+
73
+ def dpm_solver_first_order_update(
74
+ self,
75
+ model_output: torch.FloatTensor,
76
+ timestep: int,
77
+ prev_timestep: int,
78
+ sample: torch.FloatTensor,
79
+ ) -> torch.FloatTensor:
80
+ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
81
+ alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
82
+ sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
83
+ h = lambda_t - lambda_s
84
+ if self.algorithm_type == "dpmsolver++":
85
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
86
+ elif self.algorithm_type == "dpmsolver":
87
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
88
+ return x_t
89
+
90
+ def singlestep_dpm_solver_second_order_update(
91
+ self,
92
+ model_output_list: List[torch.FloatTensor],
93
+ timestep_list: List[int],
94
+ prev_timestep: int,
95
+ sample: torch.FloatTensor,
96
+ ) -> torch.FloatTensor:
97
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
98
+ m0, m1 = model_output_list[-1], model_output_list[-2]
99
+ lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
100
+ alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1]
101
+ sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1]
102
+ h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
103
+ r0 = h_0 / h
104
+ D0, D1 = m1, (1.0 / r0) * (m0 - m1)
105
+ if self.algorithm_type == "dpmsolver++":
106
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
107
+ if self.solver_type == "midpoint":
108
+ x_t = (
109
+ (sigma_t / sigma_s1) * sample
110
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
111
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
112
+ )
113
+ elif self.solver_type == "heun":
114
+ x_t = (
115
+ (sigma_t / sigma_s1) * sample
116
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
117
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
118
+ )
119
+ elif self.algorithm_type == "dpmsolver":
120
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
121
+ if self.solver_type == "midpoint":
122
+ x_t = (
123
+ (alpha_t / alpha_s1) * sample
124
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
125
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
126
+ )
127
+ elif self.solver_type == "heun":
128
+ x_t = (
129
+ (alpha_t / alpha_s1) * sample
130
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
131
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
132
+ )
133
+ return x_t
134
+
135
+ def singlestep_dpm_solver_third_order_update(
136
+ self,
137
+ model_output_list: List[torch.FloatTensor],
138
+ timestep_list: List[int],
139
+ prev_timestep: int,
140
+ sample: torch.FloatTensor,
141
+ ) -> torch.FloatTensor:
142
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
143
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
144
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
145
+ self.lambda_t[t],
146
+ self.lambda_t[s0],
147
+ self.lambda_t[s1],
148
+ self.lambda_t[s2],
149
+ )
150
+ alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2]
151
+ sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2]
152
+ h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
153
+ r0, r1 = h_0 / h, h_1 / h
154
+ D0 = m2
155
+ D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2)
156
+ D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1)
157
+ D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1)
158
+ if self.algorithm_type == "dpmsolver++":
159
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
160
+ if self.solver_type == "midpoint":
161
+ x_t = (
162
+ (sigma_t / sigma_s2) * sample
163
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
164
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1
165
+ )
166
+ elif self.solver_type == "heun":
167
+ x_t = (
168
+ (sigma_t / sigma_s2) * sample
169
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
170
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
171
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
172
+ )
173
+ elif self.algorithm_type == "dpmsolver":
174
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
175
+ if self.solver_type == "midpoint":
176
+ x_t = (
177
+ (alpha_t / alpha_s2) * sample
178
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
179
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1
180
+ )
181
+ elif self.solver_type == "heun":
182
+ x_t = (
183
+ (alpha_t / alpha_s2) * sample
184
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
185
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
186
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
187
+ )
188
+ return x_t
189
+
190
+ def step(
191
+ self,
192
+ model_output: torch.FloatTensor,
193
+ model_output_type: str,
194
+ timestep: int,
195
+ sample: torch.FloatTensor,
196
+ ) -> SchedulerStepOutput:
197
+
198
+ step_index = (self.timesteps == timestep).nonzero().item()
199
+
200
+ # Check if this step is the follow-up of the previous step.
201
+ # If not, then we reset and treat it as a new run.
202
+ if self.last_step_index and self.last_step_index != step_index - 1:
203
+ self.reset()
204
+ self.last_step_index = step_index
205
+
206
+ prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
207
+ model_output_convert = self.convert_output(model_output, model_output_type, sample, timestep)
208
+
209
+ if self.algorithm_type == "dpmsolver++":
210
+ model_output = model_output_convert.pred_original_sample
211
+ else:
212
+ model_output = model_output_convert.pred_epsilon
213
+
214
+ for i in range(self.solver_order - 1):
215
+ self.model_outputs[i] = self.model_outputs[i + 1]
216
+ self.model_outputs[-1] = model_output
217
+
218
+ order = self.order_list[step_index]
219
+
220
+ # For img2img denoising might start with order>1 which is not possible
221
+ # In this case make sure that the first two steps are both order=1
222
+ while self.model_outputs[-order] is None:
223
+ order -= 1
224
+
225
+ # For single-step solvers, we use the initial value at each time with order = 1.
226
+ if order == 1:
227
+ self.sample = sample
228
+
229
+ timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep]
230
+
231
+ if order == 1:
232
+ prev_sample = self.dpm_solver_first_order_update(self.model_outputs[-1], timestep_list[-1], prev_timestep, self.sample)
233
+ elif order == 2:
234
+ prev_sample = self.singlestep_dpm_solver_second_order_update(self.model_outputs, timestep_list, prev_timestep, self.sample)
235
+ elif order == 3:
236
+ prev_sample = self.singlestep_dpm_solver_third_order_update(self.model_outputs, timestep_list, prev_timestep, self.sample)
237
+ else:
238
+ raise NotImplementedError
239
+
240
+ return SchedulerStepOutput(
241
+ prev_sample=prev_sample,
242
+ pred_original_sample=model_output_convert.pred_original_sample
243
+ )
schedulers/utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import List
17
+
18
+ import torch
19
+
20
+
21
+ def get_betas(name: str, num_steps: int = 1000, shift_snr: float = 1, terminal_pure_noise: bool = False):
22
+ # Get betas
23
+ max_beta = 1 if terminal_pure_noise else 0.999
24
+ if name == "squared_linear":
25
+ betas = torch.linspace(0.00085**0.5, 0.012**0.5, num_steps) ** 2
26
+ elif name == "cosine":
27
+ betas = get_cosine_betas(num_steps, max_beta=max_beta)
28
+ elif name == "alphas_cumprod_linear":
29
+ betas = get_alphas_cumprod_linear_betas(num_steps, max_beta=max_beta)
30
+ elif name == "sigmoid":
31
+ betas = get_sigmoid_betas(num_steps, max_beta=max_beta, square=True, slop=0.7)
32
+ else:
33
+ raise NotImplementedError
34
+
35
+ # Shift snr
36
+ betas = shift_betas_by_snr_factor(betas, shift_snr)
37
+
38
+ # Ensure terminal pure noise
39
+ # Only non-cosine schedule uses rescale, cosine schedule can directly set max_beta=1 to ensure temrinal pure noise.
40
+ if name == "squared_linear" and terminal_pure_noise:
41
+ betas = rescale_betas_to_ensure_terminal_pure_noise(betas)
42
+
43
+ return betas
44
+
45
+
46
+ def validate_betas(betas: List[float]) -> bool:
47
+ """
48
+ Validate betas is monotonic and within 0 to 1 range, i.e. 0 < beta_{t-1} < beta_{t} <= 1
49
+
50
+ Args:
51
+ betas (List[float]): betas
52
+
53
+ Returns:
54
+ bool: True if betas is correct
55
+ """
56
+ return all(b1 < b2 for b1, b2 in zip(betas, betas[1:])) and betas[0] > 0 and betas[-1] <= 1
57
+
58
+
59
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta=0.999):
60
+ betas = []
61
+ for i in range(num_diffusion_timesteps):
62
+ t1 = i / num_diffusion_timesteps
63
+ t2 = (i + 1) / num_diffusion_timesteps
64
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
65
+ if not validate_betas(betas):
66
+ import logging
67
+ logging.warning("No feasible betas for given alpha bar")
68
+ return torch.tensor(betas, dtype=torch.float32)
69
+
70
+
71
+ def get_cosine_betas(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
72
+ def alpha_bar_fn(time_step):
73
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
74
+ return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta)
75
+
76
+
77
+ def get_sigmoid_betas(num_diffusion_timesteps, max_beta, square=False, slop=1):
78
+ def alpha_bar_fn(t):
79
+ def sigmoid(x):
80
+ return 1 / (1 + math.exp(-x * slop))
81
+ s = 6 # (-6, 6) from geodiff
82
+ vb = sigmoid(-s)
83
+ ve = sigmoid(s)
84
+ return ((sigmoid(s - t * 2 * s) - vb) / (ve - vb))**(1 if not square else 2)
85
+ return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta)
86
+
87
+
88
+ def get_alphas_cumprod_linear_betas(num_diffusion_timesteps, max_beta):
89
+ def alpha_bar_fn(t):
90
+ return 1 - t
91
+ return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta=max_beta)
92
+
93
+
94
+ def shift_betas_by_snr_factor(betas: torch.Tensor, factor: float) -> torch.Tensor:
95
+ if factor == 1.0:
96
+ return betas
97
+ # Convert betas to snr
98
+ alphas = 1 - betas
99
+ alphas_cumprod = alphas.cumprod(dim=0)
100
+ snr = alphas_cumprod / (1 - alphas_cumprod)
101
+ # Shift snr
102
+ snr *= factor
103
+ # Convert snr to betas
104
+ alphas_cumprod = snr / (1 + snr)
105
+ alphas = torch.cat(
106
+ [alphas_cumprod[0:1], alphas_cumprod[1:] / alphas_cumprod[:-1]])
107
+ betas = 1 - alphas
108
+ return betas
109
+
110
+
111
+ def rescale_betas_to_ensure_terminal_pure_noise(betas: torch.Tensor) -> torch.Tensor:
112
+ # Convert betas to alphas_cumprod_sqrt
113
+ alphas = 1 - betas
114
+ alphas_cumprod = alphas.cumprod(0)
115
+ alphas_cumprod_sqrt = alphas_cumprod.sqrt()
116
+ # Rescale alphas_cumprod_sqrt such that alphas_cumprod_sqrt[0] remains unchanged but alphas_cumprod_sqrt[-1] = 0
117
+ alphas_cumprod_sqrt = (alphas_cumprod_sqrt - alphas_cumprod_sqrt[-1]) / (
118
+ alphas_cumprod_sqrt[0] - alphas_cumprod_sqrt[-1]) * alphas_cumprod_sqrt[0]
119
+ # Convert alphas_cumprod_sqrt to betas
120
+ alphas_cumprod = alphas_cumprod_sqrt ** 2
121
+ alphas = torch.cat(
122
+ [alphas_cumprod[0:1], alphas_cumprod[1:] / alphas_cumprod[:-1]])
123
+ betas = 1 - alphas
124
+ return betas
utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch.nn as nn
16
+ import importlib
17
+
18
+ def zero_module(module):
19
+ if isinstance(module, nn.Linear):
20
+ module.weight.data.zero_()
21
+ if module.bias is not None:
22
+ module.bias.data.zero_()
23
+ return module
24
+
25
+ def get_obj_from_str(string, reload=False):
26
+ module, cls = string.rsplit(".", 1)
27
+ if reload:
28
+ module_imp = importlib.import_module(module)
29
+ importlib.reload(module_imp)
30
+ return getattr(importlib.import_module(module, package=None), cls)
31
+
32
+ def instantiate_from_config(config):
33
+ if not "target" in config:
34
+ if config == '__is_first_stage__':
35
+ return None
36
+ elif config == "__is_unconditional__":
37
+ return None
38
+ raise KeyError("Expected key `target` to instantiate.")
39
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
40
+
41
+ def update_dict(old_dict, new_dict):
42
+ old_keys = old_dict.keys()
43
+ for new_key in new_dict.keys():
44
+ if new_key in old_keys:
45
+ if type(old_dict[new_key]) == list:
46
+ if type(new_dict[new_key]) == list:
47
+ old_dict[new_key].extend(new_dict[new_key])
48
+ else:
49
+ old_dict[new_key].append(new_dict[new_key])
50
+ else:
51
+ old_dict[new_key] = [old_dict[new_key]]
52
+ old_dict[new_key].append(new_dict[new_key])
53
+ else:
54
+ old_dict[new_key] = new_dict[new_key]
55
+ return old_dict