Spaces:
Runtime error
Runtime error
Add application file
Browse files- configs/realcustom_sigdino_highres.json +119 -0
- configs/realcustom_sigdino_highres_shallow.json +114 -0
- inference/__pycache__/inference_utils.cpython-310.pyc +0 -0
- inference/__pycache__/mask_generation.cpython-310.pyc +0 -0
- inference/__pycache__/pipeline.cpython-310.pyc +0 -0
- inference/app.py +82 -0
- inference/inference_single_image.py +317 -0
- inference/inference_single_image.sh +55 -0
- inference/inference_utils.py +76 -0
- inference/mask_generation.py +114 -0
- inference/pipeline.py +359 -0
- models/__pycache__/attention_custom.cpython-310.pyc +0 -0
- models/__pycache__/attention_processor_custom_cross.cpython-310.pyc +0 -0
- models/__pycache__/base_vision.cpython-310.pyc +0 -0
- models/__pycache__/dino.cpython-310.pyc +0 -0
- models/__pycache__/image_encoder_siglipdino_shallowdeep.cpython-310.pyc +0 -0
- models/__pycache__/projectors.cpython-310.pyc +0 -0
- models/__pycache__/sigclip.cpython-310.pyc +0 -0
- models/__pycache__/text.cpython-310.pyc +0 -0
- models/__pycache__/transformer_2d_custom.cpython-310.pyc +0 -0
- models/__pycache__/unet_2d_blocks_custom.cpython-310.pyc +0 -0
- models/__pycache__/unet_2d_condition_custom.cpython-310.pyc +0 -0
- models/__pycache__/vae.cpython-310.pyc +0 -0
- models/attention_custom.py +425 -0
- models/attention_processor_custom_cross.py +1778 -0
- models/base_vision.py +227 -0
- models/dino.py +203 -0
- models/image_encoder_siglipdino_shallowdeep.py +162 -0
- models/projectors.py +150 -0
- models/sigclip.py +159 -0
- models/text.py +113 -0
- models/transformer_2d_custom.py +388 -0
- models/unet_2d_blocks_custom.py +0 -0
- models/unet_2d_condition_custom.py +1059 -0
- models/vae.py +36 -0
- prompts/validation_negative.txt +1 -0
- requirements.txt +34 -0
- schedulers/__pycache__/base.cpython-310.pyc +0 -0
- schedulers/__pycache__/ddim.cpython-310.pyc +0 -0
- schedulers/__pycache__/dpm_s.cpython-310.pyc +0 -0
- schedulers/__pycache__/utils.cpython-310.pyc +0 -0
- schedulers/base.py +133 -0
- schedulers/ddim.py +85 -0
- schedulers/dpm_m.py +412 -0
- schedulers/dpm_s.py +243 -0
- schedulers/utils.py +124 -0
- utils.py +55 -0
configs/realcustom_sigdino_highres.json
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"act_fn": "silu",
|
3 |
+
"addition_embed_type": "text_time",
|
4 |
+
"addition_embed_type_num_heads": 64,
|
5 |
+
"addition_time_embed_dim": 256,
|
6 |
+
"attention_head_dim": [
|
7 |
+
5,
|
8 |
+
10,
|
9 |
+
20
|
10 |
+
],
|
11 |
+
"block_out_channels": [
|
12 |
+
320,
|
13 |
+
640,
|
14 |
+
1280
|
15 |
+
],
|
16 |
+
"center_input_sample": false,
|
17 |
+
"class_embed_type": null,
|
18 |
+
"class_embeddings_concat": false,
|
19 |
+
"conv_in_kernel": 3,
|
20 |
+
"conv_out_kernel": 3,
|
21 |
+
"cross_attention_dim": 2048,
|
22 |
+
"cross_attention_norm": null,
|
23 |
+
"down_block_types": [
|
24 |
+
"DownBlock2D",
|
25 |
+
"CrossAttnDownBlock2D",
|
26 |
+
"CrossAttnDownBlock2D"
|
27 |
+
],
|
28 |
+
"downsample_padding": 1,
|
29 |
+
"dual_cross_attention": false,
|
30 |
+
"encoder_hid_dim": null,
|
31 |
+
"encoder_hid_dim_type": null,
|
32 |
+
"flip_sin_to_cos": true,
|
33 |
+
"freq_shift": 0,
|
34 |
+
"in_channels": 4,
|
35 |
+
"layers_per_block": 2,
|
36 |
+
"mid_block_only_cross_attention": null,
|
37 |
+
"mid_block_scale_factor": 1,
|
38 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
39 |
+
"norm_eps": 1e-05,
|
40 |
+
"norm_num_groups": 32,
|
41 |
+
"num_attention_heads": null,
|
42 |
+
"num_class_embeds": null,
|
43 |
+
"only_cross_attention": false,
|
44 |
+
"out_channels": 4,
|
45 |
+
"projection_class_embeddings_input_dim": 2816,
|
46 |
+
"resnet_out_scale_factor": 1.0,
|
47 |
+
"resnet_skip_time_act": false,
|
48 |
+
"resnet_time_scale_shift": "default",
|
49 |
+
"sample_size": 128,
|
50 |
+
"time_cond_proj_dim": null,
|
51 |
+
"time_embedding_act_fn": null,
|
52 |
+
"time_embedding_dim": null,
|
53 |
+
"time_embedding_type": "positional",
|
54 |
+
"timestep_post_act": null,
|
55 |
+
"transformer_layers_per_block": [
|
56 |
+
1,
|
57 |
+
2,
|
58 |
+
10
|
59 |
+
],
|
60 |
+
"up_block_types": [
|
61 |
+
"CrossAttnUpBlock2D",
|
62 |
+
"CrossAttnUpBlock2D",
|
63 |
+
"UpBlock2D"
|
64 |
+
],
|
65 |
+
"upcast_attention": false,
|
66 |
+
"use_linear_projection": true,
|
67 |
+
"image_ref_processor_config": {
|
68 |
+
"target": "utils.image_ref_processor.default.NaiveResizeProcessor",
|
69 |
+
"params": {
|
70 |
+
"target_image_size": 768,
|
71 |
+
"resize_mode": "resize",
|
72 |
+
"crop_min_ratio": 1.0,
|
73 |
+
"crop_max_ratio": 1.0,
|
74 |
+
"image_dropout": 0.1
|
75 |
+
}
|
76 |
+
},
|
77 |
+
"image_ref_processor_input_keys": [
|
78 |
+
"image_ref"
|
79 |
+
],
|
80 |
+
"vision_model_config": {
|
81 |
+
"vision_model_config": {
|
82 |
+
"target": "models.image_encoder_siglipdino_shallowdeep.ShallowDeepPatchfySiglipDinoEncoder",
|
83 |
+
"params": {
|
84 |
+
"siglip_config": {
|
85 |
+
"backbone_name_or_path": "vit_so400m_patch14_siglip_384",
|
86 |
+
"image_resize_strategy": "resize-naive",
|
87 |
+
"default_image_size": 384,
|
88 |
+
"feature_index": [
|
89 |
+
25
|
90 |
+
]
|
91 |
+
},
|
92 |
+
"dino_config": {
|
93 |
+
"backbone_name_or_path": "vit_large_patch14_reg4_dinov2.lvd142m",
|
94 |
+
"image_resize_strategy": "resize-naive",
|
95 |
+
"default_image_size": 384,
|
96 |
+
"feature_index": [
|
97 |
+
22
|
98 |
+
]
|
99 |
+
},
|
100 |
+
"patchfy_scale": 2,
|
101 |
+
"default_image_size": 384
|
102 |
+
}
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"image_prompt_settings": {
|
106 |
+
"vision_projection_type": "custom",
|
107 |
+
"vision_projection_config": {
|
108 |
+
"target": "models.projectors.ProjectorHighResMinAttn",
|
109 |
+
"params": {
|
110 |
+
"vision_dim": 2176,
|
111 |
+
"out_dim": 2048,
|
112 |
+
"dim_head": 64,
|
113 |
+
"adaptive_scale": false
|
114 |
+
}
|
115 |
+
},
|
116 |
+
"image_prompt_mode": "naive",
|
117 |
+
"cross_attention_id": 70
|
118 |
+
}
|
119 |
+
}
|
configs/realcustom_sigdino_highres_shallow.json
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"act_fn": "silu",
|
3 |
+
"addition_embed_type": "text_time",
|
4 |
+
"addition_embed_type_num_heads": 64,
|
5 |
+
"addition_time_embed_dim": 256,
|
6 |
+
"attention_head_dim": [
|
7 |
+
5,
|
8 |
+
10,
|
9 |
+
20
|
10 |
+
],
|
11 |
+
"block_out_channels": [
|
12 |
+
320,
|
13 |
+
640,
|
14 |
+
1280
|
15 |
+
],
|
16 |
+
"center_input_sample": false,
|
17 |
+
"class_embed_type": null,
|
18 |
+
"class_embeddings_concat": false,
|
19 |
+
"conv_in_kernel": 3,
|
20 |
+
"conv_out_kernel": 3,
|
21 |
+
"cross_attention_dim": 2048,
|
22 |
+
"cross_attention_norm": null,
|
23 |
+
"down_block_types": [
|
24 |
+
"DownBlock2D",
|
25 |
+
"CrossAttnDownBlock2D",
|
26 |
+
"CrossAttnDownBlock2D"
|
27 |
+
],
|
28 |
+
"downsample_padding": 1,
|
29 |
+
"dual_cross_attention": false,
|
30 |
+
"encoder_hid_dim": null,
|
31 |
+
"encoder_hid_dim_type": null,
|
32 |
+
"flip_sin_to_cos": true,
|
33 |
+
"freq_shift": 0,
|
34 |
+
"in_channels": 4,
|
35 |
+
"layers_per_block": 2,
|
36 |
+
"mid_block_only_cross_attention": null,
|
37 |
+
"mid_block_scale_factor": 1,
|
38 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
39 |
+
"norm_eps": 1e-05,
|
40 |
+
"norm_num_groups": 32,
|
41 |
+
"num_attention_heads": null,
|
42 |
+
"num_class_embeds": null,
|
43 |
+
"only_cross_attention": false,
|
44 |
+
"out_channels": 4,
|
45 |
+
"projection_class_embeddings_input_dim": 2816,
|
46 |
+
"resnet_out_scale_factor": 1.0,
|
47 |
+
"resnet_skip_time_act": false,
|
48 |
+
"resnet_time_scale_shift": "default",
|
49 |
+
"sample_size": 128,
|
50 |
+
"time_cond_proj_dim": null,
|
51 |
+
"time_embedding_act_fn": null,
|
52 |
+
"time_embedding_dim": null,
|
53 |
+
"time_embedding_type": "positional",
|
54 |
+
"timestep_post_act": null,
|
55 |
+
"transformer_layers_per_block": [
|
56 |
+
1,
|
57 |
+
2,
|
58 |
+
10
|
59 |
+
],
|
60 |
+
"up_block_types": [
|
61 |
+
"CrossAttnUpBlock2D",
|
62 |
+
"CrossAttnUpBlock2D",
|
63 |
+
"UpBlock2D"
|
64 |
+
],
|
65 |
+
"upcast_attention": false,
|
66 |
+
"use_linear_projection": true,
|
67 |
+
|
68 |
+
"image_ref_processor_input_keys": ["image_ref"],
|
69 |
+
"vision_model_config": {
|
70 |
+
"vision_model_config": {
|
71 |
+
"target": "models.image_encoder_siglipdino_shallowdeep.ShallowDeepPatchfySiglipDinoEncoder_v2",
|
72 |
+
"params": {
|
73 |
+
"siglip_config": {
|
74 |
+
"backbone_name_or_path": "vit_so400m_patch14_siglip_384",
|
75 |
+
"image_resize_strategy": "resize-naive",
|
76 |
+
"default_image_size": 384,
|
77 |
+
"feature_index": [
|
78 |
+
7,
|
79 |
+
13,
|
80 |
+
19,
|
81 |
+
25
|
82 |
+
]
|
83 |
+
},
|
84 |
+
"dino_config": {
|
85 |
+
"backbone_name_or_path": "vit_large_patch14_reg4_dinov2.lvd142m",
|
86 |
+
"image_resize_strategy": "resize-naive",
|
87 |
+
"default_image_size": 384,
|
88 |
+
"feature_index": [
|
89 |
+
4,
|
90 |
+
10,
|
91 |
+
16,
|
92 |
+
22
|
93 |
+
]
|
94 |
+
},
|
95 |
+
"patchfy_scale": 2,
|
96 |
+
"default_image_size": 384
|
97 |
+
}
|
98 |
+
}
|
99 |
+
},
|
100 |
+
|
101 |
+
"image_prompt_settings": {
|
102 |
+
"vision_projection_type": "custom",
|
103 |
+
"vision_projection_config": {
|
104 |
+
"target": "models.projectors.ProjectorHighResShallowMinAttnV1",
|
105 |
+
"params": {
|
106 |
+
"vision_dim": 2176,
|
107 |
+
"out_dim": 2048,
|
108 |
+
"dim_head": 64
|
109 |
+
}
|
110 |
+
},
|
111 |
+
|
112 |
+
"image_prompt_mode": "naive"
|
113 |
+
}
|
114 |
+
}
|
inference/__pycache__/inference_utils.cpython-310.pyc
ADDED
Binary file (1.58 kB). View file
|
|
inference/__pycache__/mask_generation.cpython-310.pyc
ADDED
Binary file (2.58 kB). View file
|
|
inference/__pycache__/pipeline.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
inference/app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
|
17 |
+
from inference.pipeline import RealCustomInferencePipeline
|
18 |
+
|
19 |
+
def create_demo():
|
20 |
+
pipeline = RealCustomInferencePipeline(
|
21 |
+
unet_config="configs/realcustom_sigdino_highres.json",
|
22 |
+
unet_checkpoint="ckpts/sdxl/unet/sdxl-unet.bin",
|
23 |
+
realcustom_checkpoint="ckpts/realcustom/RealCustom_highres.pth",
|
24 |
+
vae_config="ckpts/sdxl/vae/sdxl.json",
|
25 |
+
vae_checkpoint="ckpts/sdxl/vae/sdxl-vae.pth",
|
26 |
+
)
|
27 |
+
|
28 |
+
badges_text = r"""
|
29 |
+
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
30 |
+
<a href="https://corleone-huang.github.io/RealCustom_plus_plus/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-RealCustom-yellow"></a>
|
31 |
+
<a href="https://arxiv.org/pdf/2408.09744?"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-RealCustom-b31b1b.svg"></a>
|
32 |
+
</div>
|
33 |
+
""".strip()
|
34 |
+
|
35 |
+
with gr.Blocks() as demo:
|
36 |
+
gr.Markdown(f"# RealCustom")
|
37 |
+
gr.Markdown(badges_text)
|
38 |
+
with gr.Row():
|
39 |
+
with gr.Column():
|
40 |
+
prompt = gr.Textbox(label="Prompt", value="")
|
41 |
+
target_phrase = gr.Textbox(label="Target Phrase", value="")
|
42 |
+
with gr.Row():
|
43 |
+
image_prompt = gr.Image(label="Ref Img", visible=True, interactive=True, type="pil")
|
44 |
+
|
45 |
+
with gr.Row():
|
46 |
+
with gr.Column():
|
47 |
+
width = gr.Slider(512, 2048, 1024, step=16, label="Gneration Width")
|
48 |
+
height = gr.Slider(512, 2048, 1024, step=16, label="Gneration Height")
|
49 |
+
|
50 |
+
with gr.Accordion("Advanced Options", open=False):
|
51 |
+
with gr.Row():
|
52 |
+
guidance = gr.Slider(1.0, 15, 3.5, step=0.5, label="Guidance Scale", interactive=True)
|
53 |
+
mask_scope = gr.Slider(0.05, 1.0, 0.2, step=0.05, label="Mask Scope", interactive=True)
|
54 |
+
seed = gr.Number(0, label="Seed (-1 for random)")
|
55 |
+
num = gr.Number(4, label="Generation Number")
|
56 |
+
new_unet_local_path = gr.Textbox(label="New Unet Local Path", value="")
|
57 |
+
new_realcustom_local_path = gr.Textbox(label="New RealCustom Local Path", value="")
|
58 |
+
|
59 |
+
generate_btn = gr.Button("Generate")
|
60 |
+
|
61 |
+
with gr.Column():
|
62 |
+
output_image = gr.Image(label="Generated Image")
|
63 |
+
output_mask = gr.Image(label="Guidance Mask")
|
64 |
+
|
65 |
+
inputs = [
|
66 |
+
prompt, image_prompt, target_phrase,
|
67 |
+
height, width, guidance, seed, num,
|
68 |
+
mask_scope,
|
69 |
+
new_unet_local_path, new_realcustom_local_path,
|
70 |
+
]
|
71 |
+
generate_btn.click(
|
72 |
+
fn=pipeline.generation,
|
73 |
+
inputs=inputs,
|
74 |
+
outputs=[output_image, output_mask],
|
75 |
+
)
|
76 |
+
|
77 |
+
return demo
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
demo = create_demo()
|
82 |
+
demo.launch(server_name='0.0.0.0', server_port=7860)
|
inference/inference_single_image.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import torch
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import torchvision
|
20 |
+
from torchvision.utils import make_grid
|
21 |
+
from torchvision.transforms.functional import to_pil_image
|
22 |
+
from tqdm import tqdm
|
23 |
+
from PIL import Image
|
24 |
+
|
25 |
+
from models.text import TextModel
|
26 |
+
from models.vae import AutoencoderKL
|
27 |
+
|
28 |
+
from models.unet_2d_condition_custom import UNet2DConditionModel as UNet2DConditionModelDiffusers
|
29 |
+
from schedulers.ddim import DDIMScheduler
|
30 |
+
from schedulers.dpm_s import DPMSolverSingleStepScheduler
|
31 |
+
from schedulers.utils import get_betas
|
32 |
+
|
33 |
+
from inference_utils import find_phrase_positions_in_text, classifier_free_guidance_image_prompt_cascade
|
34 |
+
from mask_generation import mask_generation
|
35 |
+
from utils import instantiate_from_config
|
36 |
+
|
37 |
+
# Argument parser
|
38 |
+
parser = argparse.ArgumentParser()
|
39 |
+
parser.add_argument("--width", type=int, default=512)
|
40 |
+
parser.add_argument("--height", type=int, default=512)
|
41 |
+
|
42 |
+
parser.add_argument("--samples_per_prompt", type=int, required=True)
|
43 |
+
parser.add_argument("--nrow", type=int, default=4)
|
44 |
+
parser.add_argument("--sample_steps", type=int, required=True)
|
45 |
+
parser.add_argument("--schedule_type", type=str, default="squared_linear") # default, `squared_linear
|
46 |
+
parser.add_argument("--scheduler_type", type=str, default="dpm", choices=["ddim", "dpm"]) # default, "dpm"
|
47 |
+
parser.add_argument("--schedule_shift_snr", type=float, default=1) # default, 1
|
48 |
+
|
49 |
+
parser.add_argument("--text_encoder_variant", type=str, nargs="+")
|
50 |
+
parser.add_argument("--vae_config", type=str, default="configs/vae.json") # default
|
51 |
+
parser.add_argument("--vae_checkpoint", type=str, required=True)
|
52 |
+
parser.add_argument("--unet_config", type=str, required=True)
|
53 |
+
parser.add_argument("--unet_checkpoint", type=str, required=True)
|
54 |
+
parser.add_argument("--unet_checkpoint_base_model", type=str, default="")
|
55 |
+
parser.add_argument("--unet_prediction", type=str, choices=DDIMScheduler.prediction_types, default="epsilon") # default, "epsilon"
|
56 |
+
|
57 |
+
parser.add_argument("--negative_prompt", type=str, default="prompts/validation_negative.txt") # default
|
58 |
+
|
59 |
+
parser.add_argument("--compile", action="store_true", default=False)
|
60 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
61 |
+
|
62 |
+
parser.add_argument("--guidance_weight", type=float, default=7.5)
|
63 |
+
parser.add_argument("--seed", type=int, default=666)
|
64 |
+
parser.add_argument("--device", type=str, default="cuda")
|
65 |
+
|
66 |
+
parser.add_argument("--text_prompt", type=str, required=True)
|
67 |
+
parser.add_argument("--image_prompt_path", type=str, required=True)
|
68 |
+
parser.add_argument("--target_phrase", type=str, required=True)
|
69 |
+
parser.add_argument("--mask_scope", type=float, default=0.20)
|
70 |
+
parser.add_argument("--mask_strategy", type=str, nargs="+", default=["max_norm"])
|
71 |
+
parser.add_argument("--mask_reused_step", type=int, default=12)
|
72 |
+
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
# Initialize unet model
|
76 |
+
with open(args.unet_config) as unet_config_file:
|
77 |
+
unet_config = json.load(unet_config_file)
|
78 |
+
|
79 |
+
# Settings for image encoder
|
80 |
+
vision_model_config = unet_config.pop("vision_model_config", None)
|
81 |
+
args.vision_model_config = vision_model_config.pop("vision_model_config", None)
|
82 |
+
|
83 |
+
unet_type = unet_config.pop("type", None)
|
84 |
+
unet_model = UNet2DConditionModelDiffusers(**unet_config)
|
85 |
+
|
86 |
+
unet_model.eval().to(args.device)
|
87 |
+
unet_model.load_state_dict(torch.load(args.unet_checkpoint, map_location=args.device), strict=False)
|
88 |
+
print("loading unet model finished.")
|
89 |
+
|
90 |
+
if args.unet_checkpoint_base_model != "":
|
91 |
+
if "safetensors" in args.unet_checkpoint_base_model:
|
92 |
+
from safetensors import safe_open
|
93 |
+
tensors = {}
|
94 |
+
with safe_open(args.unet_checkpoint_base_model, framework="pt", device='cpu') as f:
|
95 |
+
for k in f.keys():
|
96 |
+
new_k = k.replace("model.diffusion_model.", "")
|
97 |
+
tensors[k] = f.get_tensor(k)
|
98 |
+
unet_model.load_state_dict(tensors, strict=False)
|
99 |
+
else:
|
100 |
+
unet_model.load_state_dict(torch.load(args.unet_checkpoint_base_model, map_location=args.device), strict=False)
|
101 |
+
unet_model = torch.compile(unet_model, disable=not args.compile)
|
102 |
+
print("loading unet base model finished.")
|
103 |
+
|
104 |
+
# Initialize vae model
|
105 |
+
with open(args.vae_config) as vae_config_file:
|
106 |
+
vae_config = json.load(vae_config_file)
|
107 |
+
vae_downsample_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) # 2 ** 3 = 8
|
108 |
+
vae_model = AutoencoderKL(**vae_config)
|
109 |
+
vae_model.eval().to(args.device)
|
110 |
+
vae_model.load_state_dict(torch.load(args.vae_checkpoint, map_location=args.device))
|
111 |
+
vae_decoder = torch.compile(lambda x: vae_model.decode(x / vae_model.scaling_factor).sample.clip(-1, 1), disable=not args.compile)
|
112 |
+
vae_encoder = torch.compile(lambda x: vae_model.encode(x).latent_dist.mode().mul_(vae_model.scaling_factor), disable=not args.compile)
|
113 |
+
print("loading vae finished.")
|
114 |
+
|
115 |
+
# Initialize ddim scheduler
|
116 |
+
ddim_train_steps = 1000
|
117 |
+
ddim_betas = get_betas(name=args.schedule_type, num_steps=ddim_train_steps, shift_snr=args.schedule_shift_snr, terminal_pure_noise=False)
|
118 |
+
scheduler_class = DPMSolverSingleStepScheduler if args.scheduler_type == 'dpm' else DDIMScheduler
|
119 |
+
scheduler = scheduler_class(betas=ddim_betas, num_train_timesteps=ddim_train_steps, num_inference_timesteps=args.sample_steps, device=args.device)
|
120 |
+
infer_timesteps = scheduler.timesteps
|
121 |
+
|
122 |
+
# Initialize text model
|
123 |
+
text_model = TextModel(args.text_encoder_variant, ["penultimate_nonorm"])
|
124 |
+
text_model.eval().to(args.device)
|
125 |
+
print("loading text model finished.")
|
126 |
+
|
127 |
+
# Initialize image model.
|
128 |
+
vision_model = instantiate_from_config(args.vision_model_config)
|
129 |
+
vision_model = vision_model.eval().to(args.device)
|
130 |
+
print("loading image model finished.")
|
131 |
+
|
132 |
+
negative_prompt = ""
|
133 |
+
if args.negative_prompt:
|
134 |
+
with open(args.negative_prompt) as f:
|
135 |
+
negative_prompt = f.read().strip()
|
136 |
+
|
137 |
+
image_metadata_validate = torch.tensor(
|
138 |
+
data=[
|
139 |
+
args.width, # original_height
|
140 |
+
args.height, # original_width
|
141 |
+
0, # coordinate top
|
142 |
+
0, # coordinate left
|
143 |
+
args.width, # target_height
|
144 |
+
args.height, # target_width
|
145 |
+
],
|
146 |
+
device=args.device,
|
147 |
+
dtype=torch.float32
|
148 |
+
).view(1, -1).repeat(args.samples_per_prompt, 1)
|
149 |
+
|
150 |
+
# Create output directory
|
151 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
152 |
+
args.output_image_grid_dir = os.path.join(args.output_dir, "images_grid")
|
153 |
+
args.output_image_dir = os.path.join(args.output_dir, "images")
|
154 |
+
args.output_mask_grid_dir = os.path.join(args.output_dir, "masks_grid")
|
155 |
+
args.output_mask_dir = os.path.join(args.output_dir, "masks")
|
156 |
+
os.makedirs(args.output_image_grid_dir, exist_ok=True)
|
157 |
+
os.makedirs(args.output_image_dir, exist_ok=True)
|
158 |
+
os.makedirs(args.output_mask_grid_dir, exist_ok=True)
|
159 |
+
os.makedirs(args.output_mask_dir, exist_ok=True)
|
160 |
+
|
161 |
+
with torch.no_grad():
|
162 |
+
# Prepare negative prompt.
|
163 |
+
if args.guidance_weight != 1:
|
164 |
+
text_negative_output = text_model(negative_prompt)
|
165 |
+
|
166 |
+
positive_prompt = args.text_prompt
|
167 |
+
positive_promt_image_path = args.image_prompt_path
|
168 |
+
target_phrase = args.target_phrase
|
169 |
+
|
170 |
+
# Compute target phrases
|
171 |
+
target_token = torch.zeros(1, 77).to(args.device)
|
172 |
+
positions = find_phrase_positions_in_text(positive_prompt, target_phrase)
|
173 |
+
for position in positions:
|
174 |
+
prompt_before = positive_prompt[:position] # NOTE We do not need -1 here because the SDXL text encoder does not encode the trailing space.
|
175 |
+
prompt_include = positive_prompt[:position+len(target_phrase)]
|
176 |
+
print("prompt before: ", prompt_before, ", prompt_include: ", prompt_include)
|
177 |
+
prompt_before_length = text_model.get_vaild_token_length(prompt_before) + 1
|
178 |
+
prompt_include_length = text_model.get_vaild_token_length(prompt_include) + 1
|
179 |
+
print("prompt_before_length: ", prompt_before_length, ", prompt_include_length: ", prompt_include_length)
|
180 |
+
target_token[:, prompt_before_length:prompt_include_length] = 1
|
181 |
+
|
182 |
+
# Text used for progress bar
|
183 |
+
pbar_text = positive_prompt[:40]
|
184 |
+
|
185 |
+
# Compute text embeddings
|
186 |
+
text_positive_output = text_model(positive_prompt)
|
187 |
+
text_positive_embeddings = text_positive_output.embeddings.repeat_interleave(args.samples_per_prompt, dim=0)
|
188 |
+
text_positive_pooled = text_positive_output.pooled[-1].repeat_interleave(args.samples_per_prompt, dim=0)
|
189 |
+
if args.guidance_weight != 1:
|
190 |
+
text_negative_embeddings = text_negative_output.embeddings.repeat_interleave(args.samples_per_prompt, dim=0)
|
191 |
+
text_negative_pooled = text_negative_output.pooled[-1].repeat_interleave(args.samples_per_prompt, dim=0)
|
192 |
+
|
193 |
+
# Compute image embeddings
|
194 |
+
positive_image = Image.open(positive_promt_image_path).convert("RGB")
|
195 |
+
positive_image = torchvision.transforms.ToTensor()(positive_image)
|
196 |
+
|
197 |
+
positive_image = positive_image.unsqueeze(0).repeat_interleave(args.samples_per_prompt, dim=0)
|
198 |
+
positive_image = torch.nn.functional.interpolate(
|
199 |
+
positive_image,
|
200 |
+
size=(768, 768),
|
201 |
+
mode="bilinear",
|
202 |
+
align_corners=False
|
203 |
+
)
|
204 |
+
negative_image = torch.zeros_like(positive_image)
|
205 |
+
print(positive_image.size(), negative_image.size())
|
206 |
+
positive_image = positive_image.to(args.device)
|
207 |
+
negative_image = negative_image.to(args.device)
|
208 |
+
|
209 |
+
positive_image_dict = {"image_ref": positive_image}
|
210 |
+
positive_image_output = vision_model(positive_image_dict, device=args.device)
|
211 |
+
|
212 |
+
negative_image_dict = {"image_ref": negative_image}
|
213 |
+
negative_image_output = vision_model(negative_image_dict, device=args.device)
|
214 |
+
|
215 |
+
# Initialize latent with input latent + noise (i2i) / pure noise (t2i)
|
216 |
+
latent = torch.randn(
|
217 |
+
size=[
|
218 |
+
args.samples_per_prompt,
|
219 |
+
vae_config["latent_channels"],
|
220 |
+
args.height // vae_downsample_factor,
|
221 |
+
args.width // vae_downsample_factor
|
222 |
+
],
|
223 |
+
device=args.device,
|
224 |
+
generator=torch.Generator(args.device).manual_seed(args.seed))
|
225 |
+
target_h = (args.height // vae_downsample_factor) // 2
|
226 |
+
target_w = (args.width // vae_downsample_factor) // 2
|
227 |
+
|
228 |
+
# Real Reverse diffusion process.
|
229 |
+
text2image_crossmap_2d_all_timesteps_list = []
|
230 |
+
current_step = 0
|
231 |
+
for timestep in tqdm(iterable=infer_timesteps, desc=f"[{pbar_text}]", dynamic_ncols=True):
|
232 |
+
if current_step < args.mask_reused_step:
|
233 |
+
pred_cond, pred_cond_dict = unet_model(
|
234 |
+
sample=latent,
|
235 |
+
timestep=timestep,
|
236 |
+
encoder_hidden_states=text_positive_embeddings,
|
237 |
+
encoder_attention_mask=None,
|
238 |
+
added_cond_kwargs=dict(
|
239 |
+
text_embeds=text_positive_pooled,
|
240 |
+
time_ids=image_metadata_validate
|
241 |
+
),
|
242 |
+
vision_input_dict=None,
|
243 |
+
vision_guided_mask=None,
|
244 |
+
return_as_origin=False,
|
245 |
+
return_text2image_mask=True,
|
246 |
+
)
|
247 |
+
crossmap_2d_avg = mask_generation(
|
248 |
+
crossmap_2d_list=pred_cond_dict["text2image_crossmap_2d"], selfmap_2d_list=pred_cond_dict.get("self_attention_map", []),
|
249 |
+
target_token=target_token, mask_scope=args.mask_scope,
|
250 |
+
mask_target_h=target_h, mask_target_w=target_w, mask_mode=args.mask_strategy,
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
# using previous step's mask
|
254 |
+
crossmap_2d_avg = text2image_crossmap_2d_all_timesteps_list[-1].squeeze(1)
|
255 |
+
if crossmap_2d_avg.dim() == 5: # Means that each layer uses a separate mask weight.
|
256 |
+
text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.mean(dim=2).unsqueeze(1))
|
257 |
+
else:
|
258 |
+
text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.unsqueeze(1))
|
259 |
+
|
260 |
+
pred_cond, pred_cond_dict = unet_model(
|
261 |
+
sample=latent,
|
262 |
+
timestep=timestep,
|
263 |
+
encoder_hidden_states=text_positive_embeddings,
|
264 |
+
encoder_attention_mask=None,
|
265 |
+
added_cond_kwargs=dict(
|
266 |
+
text_embeds=text_positive_pooled,
|
267 |
+
time_ids=image_metadata_validate
|
268 |
+
),
|
269 |
+
vision_input_dict=positive_image_output,
|
270 |
+
vision_guided_mask=crossmap_2d_avg,
|
271 |
+
return_as_origin=False,
|
272 |
+
return_text2image_mask=True,
|
273 |
+
multiple_reference_image=False
|
274 |
+
)
|
275 |
+
|
276 |
+
crossmap_2d_avg_neg = crossmap_2d_avg.mean(dim=1, keepdim=True)
|
277 |
+
pred_negative, pred_negative_dict = unet_model(
|
278 |
+
sample=latent,
|
279 |
+
timestep=timestep,
|
280 |
+
encoder_hidden_states=text_negative_embeddings,
|
281 |
+
encoder_attention_mask=None,
|
282 |
+
added_cond_kwargs=dict(
|
283 |
+
text_embeds=text_negative_pooled,
|
284 |
+
time_ids=image_metadata_validate
|
285 |
+
),
|
286 |
+
vision_input_dict=negative_image_output,
|
287 |
+
vision_guided_mask=crossmap_2d_avg,
|
288 |
+
return_as_origin=False,
|
289 |
+
return_text2image_mask=True,
|
290 |
+
multiple_reference_image=False
|
291 |
+
)
|
292 |
+
|
293 |
+
pred = classifier_free_guidance_image_prompt_cascade(
|
294 |
+
pred_t_cond=None, pred_ti_cond=pred_cond, pred_uncond=pred_negative,
|
295 |
+
guidance_weight_t=args.guidance_weight, guidance_weight_i=args.guidance_weight,
|
296 |
+
guidance_stdev_rescale_factor=0, cfg_rescale_mode="naive_global_direct"
|
297 |
+
)
|
298 |
+
step = scheduler.step(
|
299 |
+
model_output=pred,
|
300 |
+
model_output_type=args.unet_prediction,
|
301 |
+
timestep=timestep,
|
302 |
+
sample=latent)
|
303 |
+
|
304 |
+
latent = step.prev_sample
|
305 |
+
|
306 |
+
current_step += 1
|
307 |
+
|
308 |
+
sample = vae_decoder(step.pred_original_sample)
|
309 |
+
|
310 |
+
# save each image
|
311 |
+
for sample_i in range(sample.size(0)):
|
312 |
+
sample_i_image = torch.clamp(sample[sample_i] * 0.5 + 0.5, min=0, max=1).float()
|
313 |
+
to_pil_image(sample_i_image).save(args.output_image_dir + "/output_{}.jpg".format(sample_i))
|
314 |
+
|
315 |
+
# save grid images
|
316 |
+
sample = make_grid(sample, normalize=True, value_range=(-1, 1), nrow=args.nrow).float()
|
317 |
+
to_pil_image(sample).save(args.output_image_grid_dir + "/grid_image.jpg")
|
inference/inference_single_image.sh
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/bin/bash
|
2 |
+
# ----------------------------------------------------------------------------------------------------
|
3 |
+
HEIGHT="1024" # Base height.
|
4 |
+
WIDTH="1024" # Base width.
|
5 |
+
SAMPLES_PER_PROMPT="4" # Num of samples to generate per prompt.
|
6 |
+
NROW="2" # Grid images per row.
|
7 |
+
|
8 |
+
OUTPUT_DIR="outputs/test"
|
9 |
+
|
10 |
+
# ----------------------------------------------------------------------------------------------------
|
11 |
+
MASK_TYPE=("max_norm")
|
12 |
+
# usually:"max_norm" "crossmap_32" "selfmap_min_max_per_channel" "selfmap_64"
|
13 |
+
# [
|
14 |
+
# "max_norm", "min_max_norm", "binary", "min_max_per_channel", "decoder_map"
|
15 |
+
# "selfmap", "selfmap_min_max_per_channel" "selfmap_64"
|
16 |
+
|
17 |
+
# ]
|
18 |
+
|
19 |
+
CFG=7.5
|
20 |
+
STEPS=25
|
21 |
+
mask_reused_step=12
|
22 |
+
|
23 |
+
UNET_CONFIG="configs/realcustom_sigdino_highres.json"
|
24 |
+
UNET_CHECKPOINT="ckpts/realcustom/RealCustom_0025000_ema_highres.pth"
|
25 |
+
UNET_CHECKPOINT_BASE_MODEL="ckpts/sdxl/unet/general_v1-3_sdxl_03.pth"
|
26 |
+
# ----------------------------------------------------------------------------------------------------
|
27 |
+
CLIP1_DIR="ckpts/sdxl/clip-sdxl-1"
|
28 |
+
CLIP2_DIR="ckpts/sdxl/clip-sdxl-2"
|
29 |
+
VAE_CONFIG_PATH="ckpts/sdxl/vae/sdxl.json"
|
30 |
+
VAE_CHECKPOINT_PATH="ckpts/sdxl/vae/sdxl-vae.pth"
|
31 |
+
|
32 |
+
|
33 |
+
echo "Start inference"
|
34 |
+
python3 inference/inference_single_image.py \
|
35 |
+
--width $WIDTH \
|
36 |
+
--height $HEIGHT \
|
37 |
+
--samples_per_prompt $SAMPLES_PER_PROMPT \
|
38 |
+
--nrow $NROW \
|
39 |
+
--sample_steps $STEPS \
|
40 |
+
--guidance_weight $CFG \
|
41 |
+
--text_encoder_variant \
|
42 |
+
$CLIP1_DIR \
|
43 |
+
$CLIP2_DIR \
|
44 |
+
--unet_config $UNET_CONFIG \
|
45 |
+
--unet_checkpoint $UNET_CHECKPOINT \
|
46 |
+
--unet_checkpoint_base_model $UNET_CHECKPOINT_BASE_MODEL \
|
47 |
+
--vae_config $VAE_CONFIG_PATH \
|
48 |
+
--vae_checkpoint $VAE_CHECKPOINT_PATH \
|
49 |
+
--output_dir $OUTPUT_DIR \
|
50 |
+
--seed 2024 \
|
51 |
+
--text_prompt "the figurine is flying in the sky" \
|
52 |
+
--image_prompt_path "prompts/figurine.png" \
|
53 |
+
--target_phrase "figurine" \
|
54 |
+
--mask_scope 0.25 \
|
55 |
+
--mask_strategy ${MASK_TYPE[*]}
|
inference/inference_utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
def find_phrase_positions_in_text(text, phrase):
|
16 |
+
"""
|
17 |
+
Return the position of the first character of the phrase in the text.
|
18 |
+
"""
|
19 |
+
|
20 |
+
position = -1
|
21 |
+
positions = []
|
22 |
+
while True:
|
23 |
+
position = text.find(phrase, position + 1)
|
24 |
+
if position == -1:
|
25 |
+
break
|
26 |
+
positions.append(position)
|
27 |
+
return positions
|
28 |
+
|
29 |
+
def classifier_free_guidance_image_prompt_cascade(
|
30 |
+
pred_t_cond, pred_ti_cond, pred_uncond, guidance_weight_t=7.5, guidance_weight_i=7.5,
|
31 |
+
guidance_stdev_rescale_factor=0.7, cfg_rescale_mode="none", super_cross_mask=None
|
32 |
+
):
|
33 |
+
|
34 |
+
if cfg_rescale_mode == "none":
|
35 |
+
pred = pred_uncond + guidance_weight_t * (pred_t_cond - pred_uncond) + guidance_weight_i * (pred_ti_cond - pred_t_cond)
|
36 |
+
elif cfg_rescale_mode == "none_direct":
|
37 |
+
pred = pred_uncond + guidance_weight_i * (pred_ti_cond - pred_uncond)
|
38 |
+
elif cfg_rescale_mode == "naive":
|
39 |
+
assert super_cross_mask is not None
|
40 |
+
pred_std_t_before = pred_t_cond.std([1, 2, 3], keepdim=True)
|
41 |
+
pred_std_ti_before = pred_ti_cond.std([1, 2, 3], keepdim=True)
|
42 |
+
|
43 |
+
pred = pred_uncond + guidance_weight_t * (pred_t_cond - pred_uncond) + guidance_weight_i * (pred_ti_cond - pred_t_cond)
|
44 |
+
|
45 |
+
pred_std_after = pred.std([1, 2, 3], keepdim=True)
|
46 |
+
|
47 |
+
pred_rescale_t_factor = guidance_stdev_rescale_factor * (pred_std_t_before / pred_std_after) + (1 - guidance_stdev_rescale_factor)
|
48 |
+
pred_rescale_ti_factor = guidance_stdev_rescale_factor * (pred_std_ti_before / pred_std_after) + (1 - guidance_stdev_rescale_factor)
|
49 |
+
|
50 |
+
pred_ti = pred * super_cross_mask
|
51 |
+
pred_t = pred * (1 - super_cross_mask)
|
52 |
+
pred = pred_ti * pred_rescale_ti_factor + pred_t * pred_rescale_t_factor
|
53 |
+
elif cfg_rescale_mode == "naive_global":
|
54 |
+
pred_std_ti_before = pred_ti_cond.std([1, 2, 3], keepdim=True)
|
55 |
+
|
56 |
+
pred = pred_uncond + guidance_weight_t * (pred_t_cond - pred_uncond) + guidance_weight_i * (pred_ti_cond - pred_t_cond)
|
57 |
+
|
58 |
+
pred_std_after = pred.std([1, 2, 3], keepdim=True)
|
59 |
+
|
60 |
+
pred_rescale_ti_factor = guidance_stdev_rescale_factor * (pred_std_ti_before / pred_std_after) + (1 - guidance_stdev_rescale_factor)
|
61 |
+
|
62 |
+
pred = pred * pred_rescale_ti_factor
|
63 |
+
elif cfg_rescale_mode == "naive_global_direct":
|
64 |
+
pred_std_ti_before = pred_ti_cond.std([1, 2, 3], keepdim=True)
|
65 |
+
|
66 |
+
pred = pred_uncond + guidance_weight_i * (pred_ti_cond - pred_uncond)
|
67 |
+
|
68 |
+
pred_std_after = pred.std([1, 2, 3], keepdim=True)
|
69 |
+
|
70 |
+
pred_rescale_ti_factor = guidance_stdev_rescale_factor * (pred_std_ti_before / pred_std_after) + (1 - guidance_stdev_rescale_factor)
|
71 |
+
|
72 |
+
pred = pred * pred_rescale_ti_factor
|
73 |
+
else:
|
74 |
+
raise NotImplementedError()
|
75 |
+
|
76 |
+
return pred
|
inference/mask_generation.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from einops import rearrange
|
18 |
+
|
19 |
+
def mask_generation(
|
20 |
+
crossmap_2d_list, selfmap_2d_list=None,
|
21 |
+
target_token=None, mask_scope=None,
|
22 |
+
mask_target_h=64, mask_target_w=64,
|
23 |
+
mask_mode=["binary"],
|
24 |
+
):
|
25 |
+
if len(selfmap_2d_list) > 0:
|
26 |
+
target_hw_selfmap = mask_target_h * mask_target_w
|
27 |
+
selfmap_2ds = []
|
28 |
+
for i in range(len(selfmap_2d_list)):
|
29 |
+
selfmap_ = selfmap_2d_list[i]
|
30 |
+
selfmap_ = F.interpolate(selfmap_, size=(target_hw_selfmap, target_hw_selfmap), mode='bilinear')
|
31 |
+
selfmap_2ds.append(selfmap_ )
|
32 |
+
selfmap_2ds = torch.cat(selfmap_2ds, dim=1)
|
33 |
+
if "selfmap_min_max_per_channel" in mask_mode:
|
34 |
+
selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)")
|
35 |
+
channel_max_self = torch.max(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1)
|
36 |
+
channel_min_self = torch.min(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1)
|
37 |
+
selfmap_2ds = (selfmap_2ds - channel_min_self) / (channel_max_self - channel_min_self + 1e-6)
|
38 |
+
elif "selfmap_max_norm" in mask_mode:
|
39 |
+
selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)")
|
40 |
+
b = selfmap_1ds.size(0)
|
41 |
+
batch_max = torch.max(selfmap_1ds.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1)
|
42 |
+
selfmap_2ds = selfmap_2ds / (batch_max + 1e-10)
|
43 |
+
|
44 |
+
selfmap_2d = selfmap_2ds.mean(dim=1, keepdim=True)
|
45 |
+
else:
|
46 |
+
selfmap_2d = None
|
47 |
+
|
48 |
+
crossmap_2ds = []
|
49 |
+
for i in range(len(crossmap_2d_list)):
|
50 |
+
crossmap = crossmap_2d_list[i]
|
51 |
+
crossmap = crossmap.mean(dim=1) # average on head dim
|
52 |
+
crossmap = crossmap * target_token.unsqueeze(-1).unsqueeze(-1) # target token valid
|
53 |
+
crossmap = crossmap.sum(dim=1, keepdim=True)
|
54 |
+
|
55 |
+
crossmap = F.interpolate(crossmap, size=(mask_target_h, mask_target_w), mode='bilinear')
|
56 |
+
crossmap_2ds.append(crossmap)
|
57 |
+
crossmap_2ds = torch.cat(crossmap_2ds, dim=1)
|
58 |
+
crossmap_1ds = rearrange(crossmap_2ds, "b c h w -> b c (h w)")
|
59 |
+
|
60 |
+
if "max_norm" in mask_mode:
|
61 |
+
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
|
62 |
+
if selfmap_2d is not None:
|
63 |
+
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
|
64 |
+
b, c, n = crossmap_1ds.shape
|
65 |
+
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
|
66 |
+
crossmap_1d_avg = crossmap_1d_avg / (batch_max + 1e-6)
|
67 |
+
elif "min_max_norm" in mask_mode:
|
68 |
+
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
|
69 |
+
if selfmap_2d is not None:
|
70 |
+
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
|
71 |
+
b, c, n = crossmap_1ds.shape
|
72 |
+
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze
|
73 |
+
batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze
|
74 |
+
crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6)
|
75 |
+
elif "min_max_per_channel" in mask_mode:
|
76 |
+
channel_max = torch.max(crossmap_1ds, dim=-1, keepdim=True)[0]
|
77 |
+
channel_min = torch.min(crossmap_1ds, dim=-1, keepdim=True)[0]
|
78 |
+
crossmap_1ds = (crossmap_1ds - channel_min) / (channel_max - channel_min + 1e-6)
|
79 |
+
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
|
80 |
+
if selfmap_2d is not None:
|
81 |
+
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
|
82 |
+
|
83 |
+
# renormalize to 0-1
|
84 |
+
b, c, n = crossmap_1d_avg.shape
|
85 |
+
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
|
86 |
+
batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
|
87 |
+
crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6)
|
88 |
+
else:
|
89 |
+
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
|
90 |
+
|
91 |
+
|
92 |
+
if "threshold" in mask_mode:
|
93 |
+
threshold = 1 - mask_scope
|
94 |
+
crossmap_1d_avg[crossmap_1d_avg < threshold] = 0.0
|
95 |
+
if "binary" in mask_mode:
|
96 |
+
crossmap_1d_avg[crossmap_1d_avg > threshold] = 1.0
|
97 |
+
else:
|
98 |
+
# topk
|
99 |
+
topk_num = int(crossmap_1d_avg.size(-1) * mask_scope)
|
100 |
+
sort_score, sort_order = crossmap_1d_avg.sort(descending=True, dim=-1)
|
101 |
+
sort_topk = sort_order[:, :, :topk_num]
|
102 |
+
sort_topk_remain = sort_order[:, :, topk_num:]
|
103 |
+
crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk_remain, 0.)
|
104 |
+
if "binary" in mask_mode:
|
105 |
+
crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk, 1.0)
|
106 |
+
|
107 |
+
crossmap_2d_avg = rearrange(crossmap_1d_avg, "b c (h w) -> b c h w", h=mask_target_h, w=mask_target_w)
|
108 |
+
crossmap_2d_avg = crossmap_2d_avg
|
109 |
+
|
110 |
+
output = crossmap_2d_avg.unsqueeze(1) # torch.Size([4, 1, 60, 64, 64]), The second dimension is the dimension of the number of reference images.
|
111 |
+
if output.size(2) == 1: # The dimension of the layer.
|
112 |
+
output = output.squeeze(2) # If there is only a single dimension, then all layers will share the same mask.
|
113 |
+
|
114 |
+
return output
|
inference/pipeline.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import json
|
17 |
+
import torch
|
18 |
+
import torchvision
|
19 |
+
from torchvision.utils import make_grid
|
20 |
+
from torchvision.transforms.functional import to_pil_image
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from models.text import TextModel
|
24 |
+
from models.vae import AutoencoderKL
|
25 |
+
from models.unet_2d_condition_custom import UNet2DConditionModel as UNet2DConditionModelDiffusers
|
26 |
+
|
27 |
+
from schedulers.ddim import DDIMScheduler
|
28 |
+
from schedulers.dpm_s import DPMSolverSingleStepScheduler
|
29 |
+
from schedulers.utils import get_betas
|
30 |
+
|
31 |
+
from inference_utils import find_phrase_positions_in_text, classifier_free_guidance_image_prompt_cascade
|
32 |
+
from mask_generation import mask_generation
|
33 |
+
from utils import instantiate_from_config
|
34 |
+
|
35 |
+
from tqdm import tqdm
|
36 |
+
from einops import rearrange
|
37 |
+
|
38 |
+
class RealCustomInferencePipeline:
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
unet_config,
|
42 |
+
unet_checkpoint,
|
43 |
+
realcustom_checkpoint,
|
44 |
+
vae_config="ckpts/sdxl/vae/sdxl.json",
|
45 |
+
vae_checkpoint="ckpts/sdxl/vae/sdxl-vae.pth",
|
46 |
+
model_type="bf16",
|
47 |
+
device="cuda",
|
48 |
+
):
|
49 |
+
if model_type == "bf16":
|
50 |
+
self.torch_dtype = torch.bfloat16
|
51 |
+
else:
|
52 |
+
self.torch_dtype = torch.float32
|
53 |
+
|
54 |
+
if not os.path.exists("ckpts/"):
|
55 |
+
from huggingface_hub import snapshot_download
|
56 |
+
print("Downloading RealCustom ...")
|
57 |
+
snapshot_download(
|
58 |
+
repo_id="bytedance-research/RealCustom",
|
59 |
+
repo_type="model",
|
60 |
+
local_dir="ckpts", # 指定本地目录
|
61 |
+
allow_patterns="ckpts/**", # 只下载 ckpts 文件夹内容
|
62 |
+
local_dir_use_symlinks=False # 直接存储文件而非符号链接
|
63 |
+
)
|
64 |
+
|
65 |
+
self.device = device
|
66 |
+
self.unet_checkpoint = unet_checkpoint
|
67 |
+
self.realcustom_checkpoint = realcustom_checkpoint
|
68 |
+
self._load_unet_checkpoint(unet_config, unet_checkpoint, realcustom_checkpoint)
|
69 |
+
self._load_vae_checkpoint(vae_config, vae_checkpoint)
|
70 |
+
self._load_encoder_checkpoint()
|
71 |
+
self._init_scheduler()
|
72 |
+
self._load_negative_prompt()
|
73 |
+
|
74 |
+
|
75 |
+
def _load_unet_checkpoint(self, unet_config, unet_checkpoint, realcustom_checkpoint):
|
76 |
+
# Initialize unet model
|
77 |
+
with open(unet_config) as unet_config_file:
|
78 |
+
unet_config = json.load(unet_config_file)
|
79 |
+
self.unet_prediction = "epsilon"
|
80 |
+
|
81 |
+
# Settings for image encoder
|
82 |
+
vision_model_config = unet_config.pop("vision_model_config", None)
|
83 |
+
self.vision_model_config = vision_model_config.pop("vision_model_config", None)
|
84 |
+
|
85 |
+
self.unet_model = UNet2DConditionModelDiffusers(**unet_config)
|
86 |
+
|
87 |
+
self.unet_model.eval().to(self.device).to(self.torch_dtype)
|
88 |
+
self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
|
89 |
+
self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
|
90 |
+
print("loading unet model finished.")
|
91 |
+
|
92 |
+
def _reload_unet_checkpoint(self, unet_checkpoint, realcustom_checkpoint):
|
93 |
+
self.unet_model.load_state_dict(torch.load(unet_checkpoint, map_location=self.device), strict=False)
|
94 |
+
self.unet_model.load_state_dict(torch.load(realcustom_checkpoint, map_location=self.device), strict=False)
|
95 |
+
print("reloading unet model finished.")
|
96 |
+
|
97 |
+
def _load_vae_checkpoint(self, vae_config, vae_checkpoint):
|
98 |
+
# Initialize vae model
|
99 |
+
with open(vae_config) as vae_config_file:
|
100 |
+
vae_config = json.load(vae_config_file)
|
101 |
+
self.latent_channels = vae_config["latent_channels"]
|
102 |
+
self.vae_downsample_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) # 2 ** 3 = 8
|
103 |
+
|
104 |
+
vae_model = AutoencoderKL(**vae_config)
|
105 |
+
vae_model.eval().to(self.device).to(self.torch_dtype)
|
106 |
+
vae_model.load_state_dict(torch.load(vae_checkpoint, map_location=self.device))
|
107 |
+
self.vae_decoder = torch.compile(lambda x: vae_model.decode(x / vae_model.scaling_factor).sample.clip(-1, 1), disable=True)
|
108 |
+
self.vae_encoder = torch.compile(lambda x: vae_model.encode(x).latent_dist.mode().mul_(vae_model.scaling_factor), disable=True)
|
109 |
+
|
110 |
+
print("loading vae finished.")
|
111 |
+
|
112 |
+
def _load_encoder_checkpoint(self, ):
|
113 |
+
# Initialize text encoder
|
114 |
+
text_encoder_variant = ["ckpts/sdxl/clip-sdxl-1", "ckpts/sdxl/clip-sdxl-2"]
|
115 |
+
text_encoder_mode = ["penultimate_nonorm"]
|
116 |
+
self.text_model = TextModel(text_encoder_variant, text_encoder_mode)
|
117 |
+
self.text_model.eval().to(self.device).to(self.torch_dtype)
|
118 |
+
print("loading text model finished.")
|
119 |
+
|
120 |
+
# Initialize image encoder
|
121 |
+
self.vision_model = instantiate_from_config(self.vision_model_config)
|
122 |
+
self.vision_model.eval().to(self.device).to(self.torch_dtype)
|
123 |
+
print("loading image model finished.")
|
124 |
+
|
125 |
+
def _init_scheduler(self, ):
|
126 |
+
# Initialize ddim scheduler
|
127 |
+
ddim_train_steps = 1000
|
128 |
+
schedule_type = "squared_linear"
|
129 |
+
scheduler_type = "dpm"
|
130 |
+
schedule_shift_snr = 1
|
131 |
+
self.sample_steps = 25
|
132 |
+
ddim_betas = get_betas(name=schedule_type, num_steps=ddim_train_steps, shift_snr=schedule_shift_snr, terminal_pure_noise=False)
|
133 |
+
scheduler_class = DPMSolverSingleStepScheduler if scheduler_type == 'dpm' else DDIMScheduler
|
134 |
+
|
135 |
+
self.scheduler = scheduler_class(betas=ddim_betas, num_train_timesteps=ddim_train_steps, num_inference_timesteps=self.sample_steps, device=self.device)
|
136 |
+
self.infer_timesteps = self.scheduler.timesteps
|
137 |
+
|
138 |
+
def _load_negative_prompt(self, ):
|
139 |
+
with open("prompts/validation_negative.txt") as f:
|
140 |
+
self.negative_prompt = f.read().strip()
|
141 |
+
self.text_negative_output = self.text_model(self.negative_prompt)
|
142 |
+
|
143 |
+
def generation(
|
144 |
+
self,
|
145 |
+
text,
|
146 |
+
image_pil,
|
147 |
+
target_phrase,
|
148 |
+
|
149 |
+
height=1024,
|
150 |
+
width=1024,
|
151 |
+
guidance_scale=3.5,
|
152 |
+
seed=1234,
|
153 |
+
samples_per_prompt=4,
|
154 |
+
|
155 |
+
mask_scope=0.25,
|
156 |
+
|
157 |
+
new_unet_checkpoint="", # in case you want to change
|
158 |
+
new_realcustom_checkpoint="", # in case you want to change
|
159 |
+
mask_strategy=["min_max_per_channel"],
|
160 |
+
mask_reused_step=12,
|
161 |
+
return_each_image=False,
|
162 |
+
):
|
163 |
+
|
164 |
+
if new_unet_checkpoint != "" and new_unet_checkpoint != self.unet_checkpoint:
|
165 |
+
self.unet_checkpoint = new_unet_checkpoint
|
166 |
+
self.unet_model.load_state_dict(torch.load(new_unet_checkpoint, map_location=self.device), strict=False)
|
167 |
+
print("Reloading Unet {} finised.".format(new_unet_checkpoint))
|
168 |
+
if new_realcustom_checkpoint != "" and new_realcustom_checkpoint != self.realcustom_checkpoint:
|
169 |
+
self.realcustom_checkpoint = new_realcustom_checkpoint
|
170 |
+
self.unet_model.load_state_dict(torch.load(new_realcustom_checkpoint, map_location=self.device), strict=False)
|
171 |
+
print("Reloading RealCustom {} finised.".format(new_realcustom_checkpoint))
|
172 |
+
|
173 |
+
samples_per_prompt = int(samples_per_prompt)
|
174 |
+
image_metadata_validate = self._get_metadata(height, width, samples_per_prompt)
|
175 |
+
if seed == -1:
|
176 |
+
seed = torch.randint(0, 1000000, (1,)).item()
|
177 |
+
seed = int(seed)
|
178 |
+
|
179 |
+
with torch.no_grad(), torch.autocast(self.device, self.torch_dtype):
|
180 |
+
target_token = self._find_phrase_positions_in_text(text, target_phrase)
|
181 |
+
|
182 |
+
# Compute text embeddings
|
183 |
+
text_positive_output = self.text_model(text)
|
184 |
+
text_positive_embeddings = text_positive_output.embeddings.repeat_interleave(samples_per_prompt, dim=0)
|
185 |
+
text_positive_pooled = text_positive_output.pooled[-1].repeat_interleave(samples_per_prompt, dim=0)
|
186 |
+
if guidance_scale != 1:
|
187 |
+
text_negative_embeddings = self.text_negative_output.embeddings.repeat_interleave(samples_per_prompt, dim=0)
|
188 |
+
text_negative_pooled = self.text_negative_output.pooled[-1].repeat_interleave(samples_per_prompt, dim=0)
|
189 |
+
|
190 |
+
# Compute image embeddings
|
191 |
+
# positive_image = Image.open(image_path).convert("RGB")
|
192 |
+
positive_image = image_pil
|
193 |
+
positive_image = torchvision.transforms.ToTensor()(positive_image)
|
194 |
+
|
195 |
+
positive_image = positive_image.unsqueeze(0).repeat_interleave(samples_per_prompt, dim=0)
|
196 |
+
positive_image = torch.nn.functional.interpolate(
|
197 |
+
positive_image,
|
198 |
+
size=(768, 768),
|
199 |
+
mode="bilinear",
|
200 |
+
align_corners=False
|
201 |
+
)
|
202 |
+
negative_image = torch.zeros_like(positive_image)
|
203 |
+
positive_image = positive_image.to(self.device).to(self.torch_dtype)
|
204 |
+
negative_image = negative_image.to(self.device).to(self.torch_dtype)
|
205 |
+
|
206 |
+
positive_image_dict = {"image_ref": positive_image}
|
207 |
+
positive_image_output = self.vision_model(positive_image_dict, device=self.device)
|
208 |
+
|
209 |
+
negative_image_dict = {"image_ref": negative_image}
|
210 |
+
negative_image_output = self.vision_model(negative_image_dict, device=self.device)
|
211 |
+
|
212 |
+
# Initialize latent with input latent
|
213 |
+
latent = torch.randn(
|
214 |
+
size=[
|
215 |
+
samples_per_prompt,
|
216 |
+
self.latent_channels,
|
217 |
+
height // self.vae_downsample_factor,
|
218 |
+
width // self.vae_downsample_factor
|
219 |
+
],
|
220 |
+
device=self.device,
|
221 |
+
generator=torch.Generator(self.device).manual_seed(seed)).to(self.torch_dtype)
|
222 |
+
target_h = (height // self.vae_downsample_factor) // 2
|
223 |
+
target_w = (width // self.vae_downsample_factor) // 2
|
224 |
+
|
225 |
+
text2image_crossmap_2d_all_timesteps_list = []
|
226 |
+
current_step = 0
|
227 |
+
pbar_text = text[:40]
|
228 |
+
for timestep in tqdm(iterable=self.infer_timesteps, desc=f"[{pbar_text}]", dynamic_ncols=True):
|
229 |
+
if current_step < mask_reused_step:
|
230 |
+
pred_cond, pred_cond_dict = self.unet_model(
|
231 |
+
sample=latent,
|
232 |
+
timestep=timestep,
|
233 |
+
encoder_hidden_states=text_positive_embeddings,
|
234 |
+
encoder_attention_mask=None,
|
235 |
+
added_cond_kwargs=dict(
|
236 |
+
text_embeds=text_positive_pooled,
|
237 |
+
time_ids=image_metadata_validate
|
238 |
+
),
|
239 |
+
vision_input_dict=None,
|
240 |
+
vision_guided_mask=None,
|
241 |
+
return_as_origin=False,
|
242 |
+
return_text2image_mask=True,
|
243 |
+
)
|
244 |
+
crossmap_2d_avg = mask_generation(
|
245 |
+
crossmap_2d_list=pred_cond_dict["text2image_crossmap_2d"], selfmap_2d_list=pred_cond_dict.get("self_attention_map", []),
|
246 |
+
target_token=target_token, mask_scope=mask_scope,
|
247 |
+
mask_target_h=target_h, mask_target_w=target_w, mask_mode=mask_strategy,
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
# using previous step's mask
|
251 |
+
crossmap_2d_avg = text2image_crossmap_2d_all_timesteps_list[-1].squeeze(1)
|
252 |
+
if crossmap_2d_avg.dim() == 5: # Means that each layer uses a separate mask weight.
|
253 |
+
text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.mean(dim=2).unsqueeze(1))
|
254 |
+
else:
|
255 |
+
text2image_crossmap_2d_all_timesteps_list.append(crossmap_2d_avg.unsqueeze(1))
|
256 |
+
|
257 |
+
pred_cond, pred_cond_dict = self.unet_model(
|
258 |
+
sample=latent,
|
259 |
+
timestep=timestep,
|
260 |
+
encoder_hidden_states=text_positive_embeddings,
|
261 |
+
encoder_attention_mask=None,
|
262 |
+
added_cond_kwargs=dict(
|
263 |
+
text_embeds=text_positive_pooled,
|
264 |
+
time_ids=image_metadata_validate
|
265 |
+
),
|
266 |
+
vision_input_dict=positive_image_output,
|
267 |
+
vision_guided_mask=crossmap_2d_avg,
|
268 |
+
return_as_origin=False,
|
269 |
+
return_text2image_mask=True,
|
270 |
+
multiple_reference_image=False
|
271 |
+
)
|
272 |
+
|
273 |
+
pred_negative, pred_negative_dict = self.unet_model(
|
274 |
+
sample=latent,
|
275 |
+
timestep=timestep,
|
276 |
+
encoder_hidden_states=text_negative_embeddings,
|
277 |
+
encoder_attention_mask=None,
|
278 |
+
added_cond_kwargs=dict(
|
279 |
+
text_embeds=text_negative_pooled,
|
280 |
+
time_ids=image_metadata_validate
|
281 |
+
),
|
282 |
+
vision_input_dict=negative_image_output,
|
283 |
+
vision_guided_mask=crossmap_2d_avg,
|
284 |
+
return_as_origin=False,
|
285 |
+
return_text2image_mask=True,
|
286 |
+
multiple_reference_image=False
|
287 |
+
)
|
288 |
+
|
289 |
+
pred = classifier_free_guidance_image_prompt_cascade(
|
290 |
+
pred_t_cond=None, pred_ti_cond=pred_cond, pred_uncond=pred_negative,
|
291 |
+
guidance_weight_t=guidance_scale, guidance_weight_i=guidance_scale,
|
292 |
+
guidance_stdev_rescale_factor=0, cfg_rescale_mode="naive_global_direct"
|
293 |
+
)
|
294 |
+
step = self.scheduler.step(
|
295 |
+
model_output=pred,
|
296 |
+
model_output_type=self.unet_prediction,
|
297 |
+
timestep=timestep,
|
298 |
+
sample=latent)
|
299 |
+
|
300 |
+
latent = step.prev_sample
|
301 |
+
|
302 |
+
current_step += 1
|
303 |
+
sample = self.vae_decoder(step.pred_original_sample)
|
304 |
+
|
305 |
+
# save each image
|
306 |
+
images_pil_list = []
|
307 |
+
for sample_i in range(sample.size(0)):
|
308 |
+
sample_i_image = torch.clamp(sample[sample_i] * 0.5 + 0.5, min=0, max=1).float()
|
309 |
+
|
310 |
+
images_pil_list.append(to_pil_image(sample_i_image))
|
311 |
+
# to_pil_image(sample_i_image).save("./test_{}.jpg".format(sample_i))
|
312 |
+
|
313 |
+
# save grid images
|
314 |
+
sample = make_grid(sample, normalize=True, value_range=(-1, 1), nrow=int(samples_per_prompt ** 0.5)).float()
|
315 |
+
# to_pil_image(sample).save("./output_grid_image.jpg")
|
316 |
+
|
317 |
+
# save all masks
|
318 |
+
text2image_crossmap_2d_all_timesteps = torch.cat(text2image_crossmap_2d_all_timesteps_list, dim=1)
|
319 |
+
text2image_crossmap_2d_all_timesteps = rearrange(text2image_crossmap_2d_all_timesteps, "b t c h w -> (b t) c h w")
|
320 |
+
c = text2image_crossmap_2d_all_timesteps.size(1)
|
321 |
+
text2image_crossmap_2d_all_timesteps = rearrange(text2image_crossmap_2d_all_timesteps, "B (c 1) h w -> (B c) 1 h w")
|
322 |
+
sample_mask = make_grid(text2image_crossmap_2d_all_timesteps, normalize=False, value_range=(-1, 1), nrow=int(self.sample_steps * c))
|
323 |
+
# to_pil_image(sample_mask).save("./output_grid_mask.jpg")
|
324 |
+
|
325 |
+
if return_each_image:
|
326 |
+
return images_pil_list, to_pil_image(sample), to_pil_image(sample_mask)
|
327 |
+
else:
|
328 |
+
return to_pil_image(sample), to_pil_image(sample_mask)
|
329 |
+
|
330 |
+
def _get_metadata(self, height, width, samples_per_prompt):
|
331 |
+
image_metadata_validate = torch.tensor(
|
332 |
+
data=[
|
333 |
+
width, # original_height
|
334 |
+
height, # original_width
|
335 |
+
0, # coordinate top
|
336 |
+
0, # coordinate left
|
337 |
+
width, # target_height
|
338 |
+
height, # target_width
|
339 |
+
],
|
340 |
+
device=self.device,
|
341 |
+
dtype=self.torch_dtype
|
342 |
+
).view(1, -1).repeat(samples_per_prompt, 1)
|
343 |
+
|
344 |
+
return image_metadata_validate
|
345 |
+
|
346 |
+
def _find_phrase_positions_in_text(self, text, target_phrase):
|
347 |
+
# Compute target phrases
|
348 |
+
target_token = torch.zeros(1, 77).to(self.device)
|
349 |
+
positions = find_phrase_positions_in_text(text, target_phrase)
|
350 |
+
for position in positions:
|
351 |
+
prompt_before = text[:position] # NOTE We do not need -1 here because the SDXL text encoder does not encode the trailing space.
|
352 |
+
prompt_include = text[:position+len(target_phrase)]
|
353 |
+
print("prompt before: ", prompt_before, ", prompt_include: ", prompt_include)
|
354 |
+
prompt_before_length = self.text_model.get_vaild_token_length(prompt_before) + 1
|
355 |
+
prompt_include_length = self.text_model.get_vaild_token_length(prompt_include) + 1
|
356 |
+
print("prompt_before_length: ", prompt_before_length, ", prompt_include_length: ", prompt_include_length)
|
357 |
+
target_token[:, prompt_before_length:prompt_include_length] = 1
|
358 |
+
|
359 |
+
return target_token
|
models/__pycache__/attention_custom.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
models/__pycache__/attention_processor_custom_cross.cpython-310.pyc
ADDED
Binary file (38.7 kB). View file
|
|
models/__pycache__/base_vision.cpython-310.pyc
ADDED
Binary file (8.28 kB). View file
|
|
models/__pycache__/dino.cpython-310.pyc
ADDED
Binary file (7.08 kB). View file
|
|
models/__pycache__/image_encoder_siglipdino_shallowdeep.cpython-310.pyc
ADDED
Binary file (4.23 kB). View file
|
|
models/__pycache__/projectors.cpython-310.pyc
ADDED
Binary file (4.33 kB). View file
|
|
models/__pycache__/sigclip.cpython-310.pyc
ADDED
Binary file (5.81 kB). View file
|
|
models/__pycache__/text.cpython-310.pyc
ADDED
Binary file (2.88 kB). View file
|
|
models/__pycache__/transformer_2d_custom.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
models/__pycache__/unet_2d_blocks_custom.cpython-310.pyc
ADDED
Binary file (52.5 kB). View file
|
|
models/__pycache__/unet_2d_condition_custom.cpython-310.pyc
ADDED
Binary file (31.1 kB). View file
|
|
models/__pycache__/vae.cpython-310.pyc
ADDED
Binary file (1.08 kB). View file
|
|
models/attention_custom.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, Dict, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
# from diffusers.utils import maybe_allow_in_graph
|
23 |
+
from diffusers.models.activations import get_activation
|
24 |
+
from diffusers.models.attention_processor import Attention
|
25 |
+
from models.attention_processor_custom_cross import Attention as CrossAttention
|
26 |
+
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
|
27 |
+
|
28 |
+
from utils import update_dict
|
29 |
+
|
30 |
+
# @maybe_allow_in_graph
|
31 |
+
class BasicTransformerBlock(nn.Module):
|
32 |
+
r"""
|
33 |
+
A basic Transformer block.
|
34 |
+
|
35 |
+
Parameters:
|
36 |
+
dim (`int`): The number of channels in the input and output.
|
37 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
38 |
+
attention_head_dim (`int`): The number of channels in each head.
|
39 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
40 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
41 |
+
only_cross_attention (`bool`, *optional*):
|
42 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
43 |
+
double_self_attention (`bool`, *optional*):
|
44 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
45 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
46 |
+
num_embeds_ada_norm (:
|
47 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
48 |
+
attention_bias (:
|
49 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
dim: int,
|
55 |
+
num_attention_heads: int,
|
56 |
+
attention_head_dim: int,
|
57 |
+
dropout=0.0,
|
58 |
+
cross_attention_dim: Optional[int] = None,
|
59 |
+
activation_fn: str = "geglu",
|
60 |
+
num_embeds_ada_norm: Optional[int] = None,
|
61 |
+
attention_bias: bool = False,
|
62 |
+
only_cross_attention: bool = False,
|
63 |
+
double_self_attention: bool = False,
|
64 |
+
upcast_attention: bool = False,
|
65 |
+
norm_elementwise_affine: bool = True,
|
66 |
+
norm_type: str = "layer_norm",
|
67 |
+
final_dropout: bool = False,
|
68 |
+
image_prompt_settings: dict = {},
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
self.only_cross_attention = only_cross_attention
|
72 |
+
|
73 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
74 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
75 |
+
|
76 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
77 |
+
raise ValueError(
|
78 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
79 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
80 |
+
)
|
81 |
+
|
82 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
83 |
+
# 1. Self-Attn
|
84 |
+
if self.use_ada_layer_norm:
|
85 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
86 |
+
elif self.use_ada_layer_norm_zero:
|
87 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
88 |
+
else:
|
89 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
90 |
+
self.attn1 = Attention(
|
91 |
+
query_dim=dim,
|
92 |
+
heads=num_attention_heads,
|
93 |
+
dim_head=attention_head_dim,
|
94 |
+
dropout=dropout,
|
95 |
+
bias=attention_bias,
|
96 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
97 |
+
upcast_attention=upcast_attention,
|
98 |
+
)
|
99 |
+
|
100 |
+
# 2. Cross-Attn
|
101 |
+
if cross_attention_dim is not None or double_self_attention:
|
102 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
103 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
104 |
+
# the second cross attention block.
|
105 |
+
self.norm2 = (
|
106 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
107 |
+
if self.use_ada_layer_norm
|
108 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
109 |
+
)
|
110 |
+
# self.attn2 = CrossAttention(
|
111 |
+
# self.attn2 = SelfAttention(
|
112 |
+
self.attn2 = CrossAttention(
|
113 |
+
query_dim=dim,
|
114 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
115 |
+
heads=num_attention_heads,
|
116 |
+
dim_head=attention_head_dim,
|
117 |
+
dropout=dropout,
|
118 |
+
bias=attention_bias,
|
119 |
+
upcast_attention=upcast_attention,
|
120 |
+
image_prompt_settings=image_prompt_settings,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
self.norm2 = None
|
124 |
+
self.attn2 = None
|
125 |
+
|
126 |
+
# 3. Feed-forward
|
127 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
128 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
129 |
+
|
130 |
+
# let chunk size default to None
|
131 |
+
self._chunk_size = None
|
132 |
+
self._chunk_dim = 0
|
133 |
+
|
134 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
135 |
+
# Sets chunk feed-forward
|
136 |
+
self._chunk_size = chunk_size
|
137 |
+
self._chunk_dim = dim
|
138 |
+
|
139 |
+
def forward(
|
140 |
+
self,
|
141 |
+
hidden_states: torch.FloatTensor,
|
142 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
143 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
144 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
145 |
+
timestep: Optional[torch.LongTensor] = None,
|
146 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
147 |
+
class_labels: Optional[torch.LongTensor] = None,
|
148 |
+
encoder_hidden_states_vision = None,
|
149 |
+
encoder_hidden_states_control = None,
|
150 |
+
vision_guided_mask = None,
|
151 |
+
extra_dict_inputs = {},
|
152 |
+
height = None,
|
153 |
+
width = None,
|
154 |
+
return_self_attn_map = False,
|
155 |
+
):
|
156 |
+
extra_dict_outputs = {}
|
157 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
158 |
+
# 1. Self-Attention
|
159 |
+
if self.use_ada_layer_norm:
|
160 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
161 |
+
elif self.use_ada_layer_norm_zero:
|
162 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
163 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
164 |
+
)
|
165 |
+
else:
|
166 |
+
norm_hidden_states = self.norm1(hidden_states)
|
167 |
+
|
168 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
169 |
+
|
170 |
+
# self attention in XL
|
171 |
+
attn_output = self.attn1(
|
172 |
+
norm_hidden_states,
|
173 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
174 |
+
attention_mask=attention_mask,
|
175 |
+
# vision_guided_mask=vision_guided_mask,
|
176 |
+
# height=height,
|
177 |
+
# width=width,
|
178 |
+
# return_self_attn_map=return_self_attn_map,
|
179 |
+
# **cross_attention_kwargs,
|
180 |
+
)
|
181 |
+
# extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_attn)
|
182 |
+
if self.use_ada_layer_norm_zero:
|
183 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
184 |
+
hidden_states = attn_output + hidden_states
|
185 |
+
|
186 |
+
# 2. Cross-Attention
|
187 |
+
if self.attn2 is not None:
|
188 |
+
norm_hidden_states = (
|
189 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
190 |
+
)
|
191 |
+
|
192 |
+
attn_output, extra_dict_output_attn = self.attn2(
|
193 |
+
norm_hidden_states,
|
194 |
+
encoder_hidden_states=encoder_hidden_states,
|
195 |
+
attention_mask=encoder_attention_mask,
|
196 |
+
encoder_hidden_states_vision=encoder_hidden_states_vision,
|
197 |
+
encoder_hidden_states_control=encoder_hidden_states_control,
|
198 |
+
vision_guided_mask=vision_guided_mask,
|
199 |
+
extra_dict_inputs=extra_dict_inputs,
|
200 |
+
height=height,
|
201 |
+
width=width,
|
202 |
+
**cross_attention_kwargs,
|
203 |
+
)
|
204 |
+
extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_attn)
|
205 |
+
# attn_output = self.attn2(
|
206 |
+
# norm_hidden_states,
|
207 |
+
# encoder_hidden_states=encoder_hidden_states,
|
208 |
+
# attention_mask=encoder_attention_mask,
|
209 |
+
# **cross_attention_kwargs,
|
210 |
+
# )
|
211 |
+
hidden_states = attn_output + hidden_states
|
212 |
+
|
213 |
+
|
214 |
+
# 3. Feed-forward
|
215 |
+
norm_hidden_states = self.norm3(hidden_states)
|
216 |
+
|
217 |
+
if self.use_ada_layer_norm_zero:
|
218 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
219 |
+
|
220 |
+
if self._chunk_size is not None:
|
221 |
+
# "feed_forward_chunk_size" can be used to save memory
|
222 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
223 |
+
raise ValueError(
|
224 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
225 |
+
)
|
226 |
+
|
227 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
228 |
+
ff_output = torch.cat(
|
229 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
230 |
+
dim=self._chunk_dim,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
ff_output = self.ff(norm_hidden_states)
|
234 |
+
|
235 |
+
if self.use_ada_layer_norm_zero:
|
236 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
237 |
+
|
238 |
+
hidden_states = ff_output + hidden_states
|
239 |
+
|
240 |
+
return hidden_states, extra_dict_outputs
|
241 |
+
|
242 |
+
|
243 |
+
class FeedForward(nn.Module):
|
244 |
+
r"""
|
245 |
+
A feed-forward layer.
|
246 |
+
|
247 |
+
Parameters:
|
248 |
+
dim (`int`): The number of channels in the input.
|
249 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
250 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
251 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
252 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
253 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
254 |
+
"""
|
255 |
+
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
dim: int,
|
259 |
+
dim_out: Optional[int] = None,
|
260 |
+
mult: int = 4,
|
261 |
+
dropout: float = 0.0,
|
262 |
+
activation_fn: str = "geglu",
|
263 |
+
final_dropout: bool = False,
|
264 |
+
):
|
265 |
+
super().__init__()
|
266 |
+
inner_dim = int(dim * mult)
|
267 |
+
dim_out = dim_out if dim_out is not None else dim
|
268 |
+
|
269 |
+
if activation_fn == "gelu":
|
270 |
+
act_fn = GELU(dim, inner_dim)
|
271 |
+
if activation_fn == "gelu-approximate":
|
272 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
273 |
+
elif activation_fn == "geglu":
|
274 |
+
act_fn = GEGLU(dim, inner_dim)
|
275 |
+
elif activation_fn == "geglu-approximate":
|
276 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
277 |
+
|
278 |
+
self.net = nn.ModuleList([])
|
279 |
+
# project in
|
280 |
+
self.net.append(act_fn)
|
281 |
+
# project dropout
|
282 |
+
self.net.append(nn.Dropout(dropout))
|
283 |
+
# project out
|
284 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
285 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
286 |
+
if final_dropout:
|
287 |
+
self.net.append(nn.Dropout(dropout))
|
288 |
+
|
289 |
+
def forward(self, hidden_states):
|
290 |
+
for module in self.net:
|
291 |
+
hidden_states = module(hidden_states)
|
292 |
+
return hidden_states
|
293 |
+
|
294 |
+
|
295 |
+
class GELU(nn.Module):
|
296 |
+
r"""
|
297 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
298 |
+
"""
|
299 |
+
|
300 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
301 |
+
super().__init__()
|
302 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
303 |
+
self.approximate = approximate
|
304 |
+
|
305 |
+
def gelu(self, gate):
|
306 |
+
if gate.device.type != "mps":
|
307 |
+
return F.gelu(gate, approximate=self.approximate)
|
308 |
+
# mps: gelu is not implemented for float16
|
309 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
310 |
+
|
311 |
+
def forward(self, hidden_states):
|
312 |
+
hidden_states = self.proj(hidden_states)
|
313 |
+
hidden_states = self.gelu(hidden_states)
|
314 |
+
return hidden_states
|
315 |
+
|
316 |
+
|
317 |
+
class GEGLU(nn.Module):
|
318 |
+
r"""
|
319 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
320 |
+
|
321 |
+
Parameters:
|
322 |
+
dim_in (`int`): The number of channels in the input.
|
323 |
+
dim_out (`int`): The number of channels in the output.
|
324 |
+
"""
|
325 |
+
|
326 |
+
def __init__(self, dim_in: int, dim_out: int):
|
327 |
+
super().__init__()
|
328 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
329 |
+
|
330 |
+
def gelu(self, gate):
|
331 |
+
if gate.device.type != "mps":
|
332 |
+
return F.gelu(gate)
|
333 |
+
# mps: gelu is not implemented for float16
|
334 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
335 |
+
|
336 |
+
def forward(self, hidden_states):
|
337 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
338 |
+
return hidden_states * self.gelu(gate)
|
339 |
+
|
340 |
+
|
341 |
+
class ApproximateGELU(nn.Module):
|
342 |
+
"""
|
343 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
344 |
+
|
345 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
346 |
+
"""
|
347 |
+
|
348 |
+
def __init__(self, dim_in: int, dim_out: int):
|
349 |
+
super().__init__()
|
350 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
351 |
+
|
352 |
+
def forward(self, x):
|
353 |
+
x = self.proj(x)
|
354 |
+
return x * torch.sigmoid(1.702 * x)
|
355 |
+
|
356 |
+
|
357 |
+
class AdaLayerNorm(nn.Module):
|
358 |
+
"""
|
359 |
+
Norm layer modified to incorporate timestep embeddings.
|
360 |
+
"""
|
361 |
+
|
362 |
+
def __init__(self, embedding_dim, num_embeddings):
|
363 |
+
super().__init__()
|
364 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
365 |
+
self.silu = nn.SiLU()
|
366 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
367 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
368 |
+
|
369 |
+
def forward(self, x, timestep):
|
370 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
371 |
+
scale, shift = torch.chunk(emb, 2)
|
372 |
+
x = self.norm(x) * (1 + scale) + shift
|
373 |
+
return x
|
374 |
+
|
375 |
+
|
376 |
+
class AdaLayerNormZero(nn.Module):
|
377 |
+
"""
|
378 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
379 |
+
"""
|
380 |
+
|
381 |
+
def __init__(self, embedding_dim, num_embeddings):
|
382 |
+
super().__init__()
|
383 |
+
|
384 |
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
385 |
+
|
386 |
+
self.silu = nn.SiLU()
|
387 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
388 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
389 |
+
|
390 |
+
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
391 |
+
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
392 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
393 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
394 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
395 |
+
|
396 |
+
|
397 |
+
class AdaGroupNorm(nn.Module):
|
398 |
+
"""
|
399 |
+
GroupNorm layer modified to incorporate timestep embeddings.
|
400 |
+
"""
|
401 |
+
|
402 |
+
def __init__(
|
403 |
+
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
404 |
+
):
|
405 |
+
super().__init__()
|
406 |
+
self.num_groups = num_groups
|
407 |
+
self.eps = eps
|
408 |
+
|
409 |
+
if act_fn is None:
|
410 |
+
self.act = None
|
411 |
+
else:
|
412 |
+
self.act = get_activation(act_fn)
|
413 |
+
|
414 |
+
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
415 |
+
|
416 |
+
def forward(self, x, emb):
|
417 |
+
if self.act:
|
418 |
+
emb = self.act(emb)
|
419 |
+
emb = self.linear(emb)
|
420 |
+
emb = emb[:, :, None, None]
|
421 |
+
scale, shift = emb.chunk(2, dim=1)
|
422 |
+
|
423 |
+
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
424 |
+
x = x * (1 + scale) + shift
|
425 |
+
return x
|
models/attention_processor_custom_cross.py
ADDED
@@ -0,0 +1,1778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Callable, Optional, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
# from diffusers.utils import deprecate, logging, maybe_allow_in_graph
|
23 |
+
from diffusers.utils import deprecate, logging
|
24 |
+
from diffusers.utils.import_utils import is_xformers_available
|
25 |
+
from einops import rearrange
|
26 |
+
import random
|
27 |
+
|
28 |
+
from utils import zero_module
|
29 |
+
# from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31 |
+
|
32 |
+
|
33 |
+
if is_xformers_available():
|
34 |
+
import xformers
|
35 |
+
import xformers.ops
|
36 |
+
else:
|
37 |
+
xformers = None
|
38 |
+
|
39 |
+
# Cross Attention
|
40 |
+
# @maybe_allow_in_graph
|
41 |
+
class Attention(nn.Module):
|
42 |
+
r"""
|
43 |
+
A cross attention layer.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
query_dim (`int`): The number of channels in the query.
|
47 |
+
cross_attention_dim (`int`, *optional*):
|
48 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
49 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
50 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
51 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
52 |
+
bias (`bool`, *optional*, defaults to False):
|
53 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
query_dim: int,
|
59 |
+
cross_attention_dim: Optional[int] = None,
|
60 |
+
heads: int = 8,
|
61 |
+
dim_head: int = 64,
|
62 |
+
dropout: float = 0.0,
|
63 |
+
bias=False,
|
64 |
+
upcast_attention: bool = False,
|
65 |
+
upcast_softmax: bool = False,
|
66 |
+
cross_attention_norm: Optional[str] = None,
|
67 |
+
cross_attention_norm_num_groups: int = 32,
|
68 |
+
added_kv_proj_dim: Optional[int] = None,
|
69 |
+
norm_num_groups: Optional[int] = None,
|
70 |
+
spatial_norm_dim: Optional[int] = None,
|
71 |
+
out_bias: bool = True,
|
72 |
+
scale_qk: bool = True,
|
73 |
+
only_cross_attention: bool = False,
|
74 |
+
eps: float = 1e-5,
|
75 |
+
rescale_output_factor: float = 1.0,
|
76 |
+
residual_connection: bool = False,
|
77 |
+
_from_deprecated_attn_block=False,
|
78 |
+
processor: Optional["AttnProcessor"] = None,
|
79 |
+
image_prompt_settings: dict = {}
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
inner_dim = dim_head * heads
|
84 |
+
self.inner_dim = inner_dim
|
85 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
86 |
+
self.upcast_attention = upcast_attention
|
87 |
+
self.upcast_softmax = upcast_softmax
|
88 |
+
self.rescale_output_factor = rescale_output_factor
|
89 |
+
self.residual_connection = residual_connection
|
90 |
+
self.dropout = dropout
|
91 |
+
|
92 |
+
# we make use of this private variable to know whether this class is loaded
|
93 |
+
# with an deprecated state dict so that we can convert it on the fly
|
94 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
95 |
+
|
96 |
+
self.scale_qk = scale_qk
|
97 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
98 |
+
|
99 |
+
self.heads = heads
|
100 |
+
# for slice_size > 0 the attention score computation
|
101 |
+
# is split across the batch axis to save memory
|
102 |
+
# You can set slice_size with `set_attention_slice`
|
103 |
+
self.sliceable_head_dim = heads
|
104 |
+
|
105 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
106 |
+
self.only_cross_attention = only_cross_attention
|
107 |
+
|
108 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
109 |
+
raise ValueError(
|
110 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
111 |
+
)
|
112 |
+
|
113 |
+
if norm_num_groups is not None:
|
114 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
115 |
+
else:
|
116 |
+
self.group_norm = None
|
117 |
+
|
118 |
+
if spatial_norm_dim is not None:
|
119 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
120 |
+
else:
|
121 |
+
self.spatial_norm = None
|
122 |
+
|
123 |
+
if cross_attention_norm is None:
|
124 |
+
self.norm_cross = None
|
125 |
+
elif cross_attention_norm == "layer_norm":
|
126 |
+
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
127 |
+
elif cross_attention_norm == "group_norm":
|
128 |
+
if self.added_kv_proj_dim is not None:
|
129 |
+
# The given `encoder_hidden_states` are initially of shape
|
130 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
131 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
132 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
133 |
+
# the number of channels for the group norm.
|
134 |
+
norm_cross_num_channels = added_kv_proj_dim
|
135 |
+
else:
|
136 |
+
norm_cross_num_channels = cross_attention_dim
|
137 |
+
|
138 |
+
self.norm_cross = nn.GroupNorm(
|
139 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
raise ValueError(
|
143 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
144 |
+
)
|
145 |
+
|
146 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
147 |
+
|
148 |
+
if not self.only_cross_attention:
|
149 |
+
# only relevant for the `AddedKVProcessor` classes
|
150 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
151 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
152 |
+
else:
|
153 |
+
self.to_k = None
|
154 |
+
self.to_v = None
|
155 |
+
|
156 |
+
if self.added_kv_proj_dim is not None:
|
157 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
158 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
159 |
+
|
160 |
+
# NOTE for custom dual branch settings
|
161 |
+
self.cross_attention_id = image_prompt_settings.get("cross_attention_id", 0)
|
162 |
+
self.use_cross_attention_id = image_prompt_settings.get("use_cross_attention_id", False)
|
163 |
+
self.image_prompt_mode = image_prompt_settings.get("image_prompt_mode", "none")
|
164 |
+
|
165 |
+
if self.image_prompt_mode == "naive": # only used in cross-attention, NOT self-attention
|
166 |
+
self.to_k_vision = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
167 |
+
self.to_v_vision = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
168 |
+
else:
|
169 |
+
if self.image_prompt_mode != "none":
|
170 |
+
print("Warning .... unknown self.image_prompt_mode")
|
171 |
+
|
172 |
+
self.to_out = nn.ModuleList([])
|
173 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
174 |
+
self.to_out.append(nn.Dropout(dropout))
|
175 |
+
|
176 |
+
self.to_out = nn.ModuleList([])
|
177 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
178 |
+
self.to_out.append(nn.Dropout(dropout))
|
179 |
+
|
180 |
+
# set attention processor
|
181 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
182 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
183 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
184 |
+
# <note> processor default to be None
|
185 |
+
if processor is None:
|
186 |
+
assert hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
187 |
+
processor = AttnProcessor2_0_image_prompt()
|
188 |
+
|
189 |
+
self.set_processor(processor)
|
190 |
+
|
191 |
+
def set_use_memory_efficient_attention_xformers(
|
192 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
193 |
+
):
|
194 |
+
is_lora = hasattr(self, "processor") and isinstance(
|
195 |
+
self.processor,
|
196 |
+
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
|
197 |
+
)
|
198 |
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
199 |
+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
200 |
+
)
|
201 |
+
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
202 |
+
self.processor,
|
203 |
+
(
|
204 |
+
AttnAddedKVProcessor,
|
205 |
+
AttnAddedKVProcessor2_0,
|
206 |
+
SlicedAttnAddedKVProcessor,
|
207 |
+
XFormersAttnAddedKVProcessor,
|
208 |
+
LoRAAttnAddedKVProcessor,
|
209 |
+
),
|
210 |
+
)
|
211 |
+
|
212 |
+
if use_memory_efficient_attention_xformers:
|
213 |
+
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
214 |
+
raise NotImplementedError(
|
215 |
+
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
|
216 |
+
)
|
217 |
+
if not is_xformers_available():
|
218 |
+
raise ModuleNotFoundError(
|
219 |
+
(
|
220 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
221 |
+
" xformers"
|
222 |
+
),
|
223 |
+
name="xformers",
|
224 |
+
)
|
225 |
+
elif not torch.cuda.is_available():
|
226 |
+
raise ValueError(
|
227 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
228 |
+
" only available for GPU "
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
try:
|
232 |
+
# Make sure we can run the memory efficient attention
|
233 |
+
_ = xformers.ops.memory_efficient_attention(
|
234 |
+
torch.randn((1, 2, 40), device="cuda"),
|
235 |
+
torch.randn((1, 2, 40), device="cuda"),
|
236 |
+
torch.randn((1, 2, 40), device="cuda"),
|
237 |
+
)
|
238 |
+
except Exception as e:
|
239 |
+
raise e
|
240 |
+
|
241 |
+
if is_lora:
|
242 |
+
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
243 |
+
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
244 |
+
processor = LoRAXFormersAttnProcessor(
|
245 |
+
hidden_size=self.processor.hidden_size,
|
246 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
247 |
+
rank=self.processor.rank,
|
248 |
+
attention_op=attention_op,
|
249 |
+
)
|
250 |
+
processor.load_state_dict(self.processor.state_dict())
|
251 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
252 |
+
elif is_custom_diffusion:
|
253 |
+
processor = CustomDiffusionXFormersAttnProcessor(
|
254 |
+
train_kv=self.processor.train_kv,
|
255 |
+
train_q_out=self.processor.train_q_out,
|
256 |
+
hidden_size=self.processor.hidden_size,
|
257 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
258 |
+
attention_op=attention_op,
|
259 |
+
)
|
260 |
+
processor.load_state_dict(self.processor.state_dict())
|
261 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
262 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
263 |
+
elif is_added_kv_processor:
|
264 |
+
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
265 |
+
# which uses this type of cross attention ONLY because the attention mask of format
|
266 |
+
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
267 |
+
# throw warning
|
268 |
+
logger.info(
|
269 |
+
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
270 |
+
)
|
271 |
+
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
272 |
+
else:
|
273 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
274 |
+
else:
|
275 |
+
if is_lora:
|
276 |
+
attn_processor_class = (
|
277 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
278 |
+
)
|
279 |
+
processor = attn_processor_class(
|
280 |
+
hidden_size=self.processor.hidden_size,
|
281 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
282 |
+
rank=self.processor.rank,
|
283 |
+
)
|
284 |
+
processor.load_state_dict(self.processor.state_dict())
|
285 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
286 |
+
elif is_custom_diffusion:
|
287 |
+
processor = CustomDiffusionAttnProcessor(
|
288 |
+
train_kv=self.processor.train_kv,
|
289 |
+
train_q_out=self.processor.train_q_out,
|
290 |
+
hidden_size=self.processor.hidden_size,
|
291 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
292 |
+
)
|
293 |
+
processor.load_state_dict(self.processor.state_dict())
|
294 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
295 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
296 |
+
else:
|
297 |
+
# set attention processor
|
298 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
299 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
300 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
301 |
+
processor = (
|
302 |
+
AttnProcessor2_0()
|
303 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
304 |
+
else AttnProcessor()
|
305 |
+
)
|
306 |
+
|
307 |
+
self.set_processor(processor)
|
308 |
+
|
309 |
+
def set_attention_slice(self, slice_size):
|
310 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
311 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
312 |
+
|
313 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
314 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
315 |
+
elif slice_size is not None:
|
316 |
+
processor = SlicedAttnProcessor(slice_size)
|
317 |
+
elif self.added_kv_proj_dim is not None:
|
318 |
+
processor = AttnAddedKVProcessor()
|
319 |
+
else:
|
320 |
+
# set attention processor
|
321 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
322 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
323 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
324 |
+
processor = (
|
325 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
326 |
+
)
|
327 |
+
|
328 |
+
self.set_processor(processor)
|
329 |
+
|
330 |
+
def set_processor(self, processor: "AttnProcessor"):
|
331 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
332 |
+
# pop `processor` from `self._modules`
|
333 |
+
if (
|
334 |
+
hasattr(self, "processor")
|
335 |
+
and isinstance(self.processor, torch.nn.Module)
|
336 |
+
and not isinstance(processor, torch.nn.Module)
|
337 |
+
):
|
338 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
339 |
+
self._modules.pop("processor")
|
340 |
+
|
341 |
+
self.processor = processor
|
342 |
+
|
343 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
344 |
+
encoder_hidden_states_vision=None, encoder_hidden_states_control=None,
|
345 |
+
vision_guided_mask=None, extra_dict_inputs={}, height=None, width=None,
|
346 |
+
**cross_attention_kwargs):
|
347 |
+
# The `Attention` class can call different attention processors / attention functions
|
348 |
+
# here we simply pass along all tensors to the selected processor class
|
349 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
350 |
+
return self.processor(
|
351 |
+
self,
|
352 |
+
hidden_states,
|
353 |
+
encoder_hidden_states=encoder_hidden_states,
|
354 |
+
attention_mask=attention_mask,
|
355 |
+
|
356 |
+
encoder_hidden_states_vision=encoder_hidden_states_vision,
|
357 |
+
encoder_hidden_states_control=encoder_hidden_states_control,
|
358 |
+
vision_guided_mask=vision_guided_mask,
|
359 |
+
extra_dict_inputs=extra_dict_inputs,
|
360 |
+
image_prompt_mode=self.image_prompt_mode,
|
361 |
+
|
362 |
+
height=height, width=width,
|
363 |
+
|
364 |
+
**cross_attention_kwargs,
|
365 |
+
)
|
366 |
+
|
367 |
+
def batch_to_head_dim(self, tensor):
|
368 |
+
head_size = self.heads
|
369 |
+
batch_size, seq_len, dim = tensor.shape
|
370 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
371 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
372 |
+
return tensor
|
373 |
+
|
374 |
+
def head_to_batch_dim(self, tensor, out_dim=3):
|
375 |
+
head_size = self.heads
|
376 |
+
batch_size, seq_len, dim = tensor.shape
|
377 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
378 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
379 |
+
|
380 |
+
if out_dim == 3:
|
381 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
382 |
+
|
383 |
+
return tensor
|
384 |
+
|
385 |
+
def get_attention_scores(self, query, key, attention_mask=None):
|
386 |
+
dtype = query.dtype
|
387 |
+
if self.upcast_attention:
|
388 |
+
query = query.float()
|
389 |
+
key = key.float()
|
390 |
+
|
391 |
+
if attention_mask is None:
|
392 |
+
baddbmm_input = torch.empty(
|
393 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
394 |
+
)
|
395 |
+
beta = 0
|
396 |
+
else:
|
397 |
+
baddbmm_input = attention_mask
|
398 |
+
beta = 1
|
399 |
+
|
400 |
+
attention_scores = torch.baddbmm(
|
401 |
+
baddbmm_input,
|
402 |
+
query,
|
403 |
+
key.transpose(-1, -2),
|
404 |
+
beta=beta,
|
405 |
+
alpha=self.scale,
|
406 |
+
)
|
407 |
+
del baddbmm_input
|
408 |
+
|
409 |
+
if self.upcast_softmax:
|
410 |
+
attention_scores = attention_scores.float()
|
411 |
+
|
412 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
413 |
+
del attention_scores
|
414 |
+
|
415 |
+
attention_probs = attention_probs.to(dtype)
|
416 |
+
|
417 |
+
return attention_probs
|
418 |
+
|
419 |
+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
420 |
+
if batch_size is None:
|
421 |
+
deprecate(
|
422 |
+
"batch_size=None",
|
423 |
+
"0.0.15",
|
424 |
+
(
|
425 |
+
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
426 |
+
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
427 |
+
" `prepare_attention_mask` when preparing the attention_mask."
|
428 |
+
),
|
429 |
+
)
|
430 |
+
batch_size = 1
|
431 |
+
|
432 |
+
head_size = self.heads
|
433 |
+
if attention_mask is None:
|
434 |
+
return attention_mask
|
435 |
+
|
436 |
+
current_length: int = attention_mask.shape[-1]
|
437 |
+
if current_length != target_length:
|
438 |
+
if attention_mask.device.type == "mps":
|
439 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
440 |
+
# Instead, we can manually construct the padding tensor.
|
441 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
442 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
443 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
444 |
+
else:
|
445 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
446 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
447 |
+
# remaining_length: int = target_length - current_length
|
448 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
449 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
450 |
+
|
451 |
+
if out_dim == 3:
|
452 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
453 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
454 |
+
elif out_dim == 4:
|
455 |
+
attention_mask = attention_mask.unsqueeze(1)
|
456 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
457 |
+
|
458 |
+
return attention_mask
|
459 |
+
|
460 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
461 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
462 |
+
|
463 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
464 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
465 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
466 |
+
# Group norm norms along the channels dimension and expects
|
467 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
468 |
+
# to norm along the hidden dimension, so we need to move
|
469 |
+
# (batch_size, sequence_length, hidden_size) ->
|
470 |
+
# (batch_size, hidden_size, sequence_length)
|
471 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
472 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
473 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
474 |
+
else:
|
475 |
+
assert False
|
476 |
+
|
477 |
+
return encoder_hidden_states
|
478 |
+
|
479 |
+
|
480 |
+
class AttnProcessor2_0_image_prompt:
|
481 |
+
r"""
|
482 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
483 |
+
"""
|
484 |
+
|
485 |
+
def __init__(self):
|
486 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
487 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
488 |
+
|
489 |
+
def __call__(
|
490 |
+
self,
|
491 |
+
attn: Attention,
|
492 |
+
hidden_states,
|
493 |
+
encoder_hidden_states=None,
|
494 |
+
attention_mask=None,
|
495 |
+
temb=None,
|
496 |
+
|
497 |
+
encoder_hidden_states_vision=None,
|
498 |
+
encoder_hidden_states_control=None,
|
499 |
+
vision_guided_mask=None,
|
500 |
+
extra_dict_inputs={},
|
501 |
+
image_prompt_mode="none",
|
502 |
+
|
503 |
+
height=None, width=None,
|
504 |
+
):
|
505 |
+
if "multiple_reference_image" in extra_dict_inputs.keys():
|
506 |
+
multiple_reference_image = extra_dict_inputs["multiple_reference_image"]
|
507 |
+
else:
|
508 |
+
multiple_reference_image = False
|
509 |
+
|
510 |
+
resampled_token = None
|
511 |
+
if encoder_hidden_states_vision is not None:
|
512 |
+
if attn.use_cross_attention_id:
|
513 |
+
encoder_hidden_states_vision = encoder_hidden_states_vision[:, attn.cross_attention_id, :, :]
|
514 |
+
|
515 |
+
extra_dict_outputs = {}
|
516 |
+
residual = hidden_states
|
517 |
+
|
518 |
+
if attn.spatial_norm is not None:
|
519 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
520 |
+
|
521 |
+
input_ndim = hidden_states.ndim
|
522 |
+
|
523 |
+
if input_ndim == 4:
|
524 |
+
batch_size, channel, height, width = hidden_states.shape
|
525 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
526 |
+
|
527 |
+
batch_size, sequence_length, _ = (
|
528 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
529 |
+
)
|
530 |
+
inner_dim = hidden_states.shape[-1]
|
531 |
+
|
532 |
+
if attention_mask is not None:
|
533 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
534 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
535 |
+
# (batch, heads, source_length, target_length)
|
536 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
537 |
+
if attn.group_norm is not None:
|
538 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
539 |
+
|
540 |
+
query = attn.to_q(hidden_states)
|
541 |
+
|
542 |
+
if encoder_hidden_states is None:
|
543 |
+
encoder_hidden_states = hidden_states
|
544 |
+
elif attn.norm_cross: # attn.norm_cross: None
|
545 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
546 |
+
|
547 |
+
# text-imgae cross attention
|
548 |
+
key = attn.to_k(encoder_hidden_states)
|
549 |
+
value = attn.to_v(encoder_hidden_states)
|
550 |
+
|
551 |
+
head_dim = inner_dim // attn.heads
|
552 |
+
|
553 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
554 |
+
|
555 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
556 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
557 |
+
|
558 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
559 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
560 |
+
if attn.training:
|
561 |
+
# with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
562 |
+
# hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
563 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
564 |
+
# hidden_states = flash_attn_func(query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False)
|
565 |
+
else: # use vanilla attention during inference
|
566 |
+
with torch.autocast(enabled=True, device_type = 'cuda'):
|
567 |
+
q, k, v = query.float(), key.float(), value.float()
|
568 |
+
sim = (q @ k.transpose(-2, -1) * attn.scale)
|
569 |
+
if attention_mask is not None: # no mask in SDXL?
|
570 |
+
attention_mask = 1 + (attention_mask / -10000.0)
|
571 |
+
attention_mask = attention_mask.bool()
|
572 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
573 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
574 |
+
sim = sim.softmax(dim=-1)
|
575 |
+
hidden_states = torch.einsum('b h i j, b h j d -> b h i d', sim, v)
|
576 |
+
extra_dict_outputs["text2image_crossmap_2d"] = rearrange(sim, "b head (h w) n -> b head n h w", h=height, w=width)
|
577 |
+
|
578 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
579 |
+
hidden_states = hidden_states.to(query.dtype)
|
580 |
+
|
581 |
+
if encoder_hidden_states_vision is not None and not multiple_reference_image: # single image
|
582 |
+
if image_prompt_mode == "naive":
|
583 |
+
key_vision = attn.to_k_vision(encoder_hidden_states_vision)
|
584 |
+
value_vision = attn.to_v_vision(encoder_hidden_states_vision)
|
585 |
+
key_vision = key_vision.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
586 |
+
value_vision = value_vision.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
587 |
+
hidden_states_vision = F.scaled_dot_product_attention(query, key_vision, value_vision, dropout_p=0.0, is_causal=False)
|
588 |
+
hidden_states_vision = hidden_states_vision.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
589 |
+
else:
|
590 |
+
hidden_states_vision = torch.zeros_like(hidden_states).to(hidden_states.device)
|
591 |
+
|
592 |
+
if vision_guided_mask is not None:
|
593 |
+
if vision_guided_mask.dim() == 4: # 所有层共用相同的mask
|
594 |
+
target_h, target_w = vision_guided_mask.size(-2), vision_guided_mask.size(-1)
|
595 |
+
vision_guided_mask = F.interpolate(vision_guided_mask.float(), scale_factor=height/target_h, mode='bilinear')
|
596 |
+
vision_guided_mask_1d = rearrange(vision_guided_mask, "b c h w -> b (h w) c")
|
597 |
+
hidden_states_vision = hidden_states_vision * vision_guided_mask_1d
|
598 |
+
else: # according to different self.cross_attention_id, 每一层用单独的mask
|
599 |
+
vision_guided_mask = vision_guided_mask[:, :, attn.cross_attention_id, :, :]
|
600 |
+
target_h, target_w = vision_guided_mask.size(-2), vision_guided_mask.size(-1)
|
601 |
+
vision_guided_mask = F.interpolate(vision_guided_mask, scale_factor=height/target_h, mode='bilinear')
|
602 |
+
vision_guided_mask_1d = rearrange(vision_guided_mask, "b c h w -> b (h w) c")
|
603 |
+
hidden_states_vision = hidden_states_vision * vision_guided_mask_1d
|
604 |
+
|
605 |
+
elif encoder_hidden_states_vision is not None and multiple_reference_image: # multiple image
|
606 |
+
if image_prompt_mode == "naive":
|
607 |
+
image_num = encoder_hidden_states_vision.size(1)
|
608 |
+
encoder_hidden_states_vision_list = encoder_hidden_states_vision.chunk(image_num, dim=1)
|
609 |
+
if vision_guided_mask is not None:
|
610 |
+
vision_guided_mask_list = vision_guided_mask.chunk(image_num, dim=1)
|
611 |
+
else:
|
612 |
+
vision_guided_mask_list = [None] * image_num
|
613 |
+
hidden_states_vision_results = []
|
614 |
+
for encoder_hidden_states_vision_i, vision_guided_mask_i in zip(encoder_hidden_states_vision_list, vision_guided_mask_list):
|
615 |
+
encoder_hidden_states_vision_i = encoder_hidden_states_vision_i.squeeze(1)
|
616 |
+
|
617 |
+
key_vision = attn.to_k_vision(encoder_hidden_states_vision_i)
|
618 |
+
value_vision = attn.to_v_vision(encoder_hidden_states_vision_i)
|
619 |
+
key_vision = key_vision.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
620 |
+
value_vision = value_vision.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
621 |
+
hidden_states_vision = F.scaled_dot_product_attention(query, key_vision, value_vision, dropout_p=0.0, is_causal=False)
|
622 |
+
hidden_states_vision = hidden_states_vision.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
623 |
+
|
624 |
+
if vision_guided_mask_i is not None:
|
625 |
+
target_h, target_w = vision_guided_mask_i.size(-2), vision_guided_mask_i.size(-1)
|
626 |
+
vision_guided_mask_i = F.interpolate(vision_guided_mask_i, scale_factor=height/target_h, mode='bilinear')
|
627 |
+
vision_guided_mask_1d_i = rearrange(vision_guided_mask_i, "b c h w -> b (h w) c")
|
628 |
+
hidden_states_vision = hidden_states_vision * vision_guided_mask_1d_i
|
629 |
+
|
630 |
+
hidden_states_vision_results.append(hidden_states_vision.unsqueeze(1))
|
631 |
+
|
632 |
+
hidden_states_vision = torch.cat(hidden_states_vision_results, dim=1).sum(dim=1)
|
633 |
+
|
634 |
+
else:
|
635 |
+
hidden_states_vision = torch.zeros_like(hidden_states).to(hidden_states.device)
|
636 |
+
else:
|
637 |
+
hidden_states_vision = torch.zeros_like(hidden_states).to(hidden_states.device)
|
638 |
+
|
639 |
+
hidden_states = hidden_states + hidden_states_vision
|
640 |
+
|
641 |
+
# linear proj
|
642 |
+
hidden_states = attn.to_out[0](hidden_states)
|
643 |
+
# dropout
|
644 |
+
hidden_states = attn.to_out[1](hidden_states)
|
645 |
+
|
646 |
+
if input_ndim == 4:
|
647 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
648 |
+
|
649 |
+
if attn.residual_connection:
|
650 |
+
hidden_states = hidden_states + residual
|
651 |
+
|
652 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
653 |
+
|
654 |
+
return hidden_states, extra_dict_outputs
|
655 |
+
|
656 |
+
|
657 |
+
class AttnProcessor:
|
658 |
+
r"""
|
659 |
+
Default processor for performing attention-related computations.
|
660 |
+
"""
|
661 |
+
|
662 |
+
def __call__(
|
663 |
+
self,
|
664 |
+
attn: Attention,
|
665 |
+
hidden_states,
|
666 |
+
encoder_hidden_states=None,
|
667 |
+
attention_mask=None,
|
668 |
+
temb=None,
|
669 |
+
):
|
670 |
+
residual = hidden_states
|
671 |
+
|
672 |
+
if attn.spatial_norm is not None:
|
673 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
674 |
+
|
675 |
+
input_ndim = hidden_states.ndim
|
676 |
+
|
677 |
+
if input_ndim == 4:
|
678 |
+
batch_size, channel, height, width = hidden_states.shape
|
679 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
680 |
+
|
681 |
+
batch_size, sequence_length, _ = (
|
682 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
683 |
+
)
|
684 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
685 |
+
|
686 |
+
if attn.group_norm is not None:
|
687 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
688 |
+
|
689 |
+
query = attn.to_q(hidden_states)
|
690 |
+
|
691 |
+
if encoder_hidden_states is None:
|
692 |
+
encoder_hidden_states = hidden_states
|
693 |
+
elif attn.norm_cross:
|
694 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
695 |
+
|
696 |
+
key = attn.to_k(encoder_hidden_states)
|
697 |
+
value = attn.to_v(encoder_hidden_states)
|
698 |
+
|
699 |
+
query = attn.head_to_batch_dim(query)
|
700 |
+
key = attn.head_to_batch_dim(key)
|
701 |
+
value = attn.head_to_batch_dim(value)
|
702 |
+
|
703 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
704 |
+
hidden_states = torch.bmm(attention_probs, value)
|
705 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
706 |
+
|
707 |
+
# linear proj
|
708 |
+
hidden_states = attn.to_out[0](hidden_states)
|
709 |
+
# dropout
|
710 |
+
hidden_states = attn.to_out[1](hidden_states)
|
711 |
+
|
712 |
+
if input_ndim == 4:
|
713 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
714 |
+
|
715 |
+
if attn.residual_connection:
|
716 |
+
hidden_states = hidden_states + residual
|
717 |
+
|
718 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
719 |
+
|
720 |
+
return hidden_states
|
721 |
+
|
722 |
+
|
723 |
+
class LoRALinearLayer(nn.Module):
|
724 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
725 |
+
super().__init__()
|
726 |
+
|
727 |
+
if rank > min(in_features, out_features):
|
728 |
+
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
729 |
+
|
730 |
+
self.down = nn.Linear(in_features, rank, bias=False)
|
731 |
+
self.up = nn.Linear(rank, out_features, bias=False)
|
732 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
733 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
734 |
+
self.network_alpha = network_alpha
|
735 |
+
self.rank = rank
|
736 |
+
|
737 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
738 |
+
nn.init.zeros_(self.up.weight)
|
739 |
+
|
740 |
+
def forward(self, hidden_states):
|
741 |
+
orig_dtype = hidden_states.dtype
|
742 |
+
dtype = self.down.weight.dtype
|
743 |
+
|
744 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
745 |
+
up_hidden_states = self.up(down_hidden_states)
|
746 |
+
|
747 |
+
if self.network_alpha is not None:
|
748 |
+
up_hidden_states *= self.network_alpha / self.rank
|
749 |
+
|
750 |
+
return up_hidden_states.to(orig_dtype)
|
751 |
+
|
752 |
+
|
753 |
+
class LoRAAttnProcessor(nn.Module):
|
754 |
+
r"""
|
755 |
+
Processor for implementing the LoRA attention mechanism.
|
756 |
+
|
757 |
+
Args:
|
758 |
+
hidden_size (`int`, *optional*):
|
759 |
+
The hidden size of the attention layer.
|
760 |
+
cross_attention_dim (`int`, *optional*):
|
761 |
+
The number of channels in the `encoder_hidden_states`.
|
762 |
+
rank (`int`, defaults to 4):
|
763 |
+
The dimension of the LoRA update matrices.
|
764 |
+
network_alpha (`int`, *optional*):
|
765 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
766 |
+
"""
|
767 |
+
|
768 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
769 |
+
super().__init__()
|
770 |
+
|
771 |
+
self.hidden_size = hidden_size
|
772 |
+
self.cross_attention_dim = cross_attention_dim
|
773 |
+
self.rank = rank
|
774 |
+
|
775 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
776 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
777 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
778 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
779 |
+
|
780 |
+
def __call__(
|
781 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
782 |
+
):
|
783 |
+
residual = hidden_states
|
784 |
+
|
785 |
+
if attn.spatial_norm is not None:
|
786 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
787 |
+
|
788 |
+
input_ndim = hidden_states.ndim
|
789 |
+
|
790 |
+
if input_ndim == 4:
|
791 |
+
batch_size, channel, height, width = hidden_states.shape
|
792 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
793 |
+
|
794 |
+
batch_size, sequence_length, _ = (
|
795 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
796 |
+
)
|
797 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
798 |
+
|
799 |
+
if attn.group_norm is not None:
|
800 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
801 |
+
|
802 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
803 |
+
query = attn.head_to_batch_dim(query)
|
804 |
+
|
805 |
+
if encoder_hidden_states is None:
|
806 |
+
encoder_hidden_states = hidden_states
|
807 |
+
elif attn.norm_cross:
|
808 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
809 |
+
|
810 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
811 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
812 |
+
|
813 |
+
key = attn.head_to_batch_dim(key)
|
814 |
+
value = attn.head_to_batch_dim(value)
|
815 |
+
|
816 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
817 |
+
hidden_states = torch.bmm(attention_probs, value)
|
818 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
819 |
+
|
820 |
+
# linear proj
|
821 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
822 |
+
# dropout
|
823 |
+
hidden_states = attn.to_out[1](hidden_states)
|
824 |
+
|
825 |
+
if input_ndim == 4:
|
826 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
827 |
+
|
828 |
+
if attn.residual_connection:
|
829 |
+
hidden_states = hidden_states + residual
|
830 |
+
|
831 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
832 |
+
|
833 |
+
return hidden_states
|
834 |
+
|
835 |
+
|
836 |
+
class CustomDiffusionAttnProcessor(nn.Module):
|
837 |
+
r"""
|
838 |
+
Processor for implementing attention for the Custom Diffusion method.
|
839 |
+
|
840 |
+
Args:
|
841 |
+
train_kv (`bool`, defaults to `True`):
|
842 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
843 |
+
train_q_out (`bool`, defaults to `True`):
|
844 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
845 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
846 |
+
The hidden size of the attention layer.
|
847 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
848 |
+
The number of channels in the `encoder_hidden_states`.
|
849 |
+
out_bias (`bool`, defaults to `True`):
|
850 |
+
Whether to include the bias parameter in `train_q_out`.
|
851 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
852 |
+
The dropout probability to use.
|
853 |
+
"""
|
854 |
+
|
855 |
+
def __init__(
|
856 |
+
self,
|
857 |
+
train_kv=True,
|
858 |
+
train_q_out=True,
|
859 |
+
hidden_size=None,
|
860 |
+
cross_attention_dim=None,
|
861 |
+
out_bias=True,
|
862 |
+
dropout=0.0,
|
863 |
+
):
|
864 |
+
super().__init__()
|
865 |
+
self.train_kv = train_kv
|
866 |
+
self.train_q_out = train_q_out
|
867 |
+
|
868 |
+
self.hidden_size = hidden_size
|
869 |
+
self.cross_attention_dim = cross_attention_dim
|
870 |
+
|
871 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
872 |
+
if self.train_kv:
|
873 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
874 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
875 |
+
if self.train_q_out:
|
876 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
877 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
878 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
879 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
880 |
+
|
881 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
882 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
883 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
884 |
+
if self.train_q_out:
|
885 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
886 |
+
else:
|
887 |
+
query = attn.to_q(hidden_states)
|
888 |
+
|
889 |
+
if encoder_hidden_states is None:
|
890 |
+
crossattn = False
|
891 |
+
encoder_hidden_states = hidden_states
|
892 |
+
else:
|
893 |
+
crossattn = True
|
894 |
+
if attn.norm_cross:
|
895 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
896 |
+
|
897 |
+
if self.train_kv:
|
898 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
899 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
900 |
+
else:
|
901 |
+
key = attn.to_k(encoder_hidden_states)
|
902 |
+
value = attn.to_v(encoder_hidden_states)
|
903 |
+
|
904 |
+
if crossattn:
|
905 |
+
detach = torch.ones_like(key)
|
906 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
907 |
+
key = detach * key + (1 - detach) * key.detach()
|
908 |
+
value = detach * value + (1 - detach) * value.detach()
|
909 |
+
|
910 |
+
query = attn.head_to_batch_dim(query)
|
911 |
+
key = attn.head_to_batch_dim(key)
|
912 |
+
value = attn.head_to_batch_dim(value)
|
913 |
+
|
914 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
915 |
+
hidden_states = torch.bmm(attention_probs, value)
|
916 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
917 |
+
|
918 |
+
if self.train_q_out:
|
919 |
+
# linear proj
|
920 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
921 |
+
# dropout
|
922 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
923 |
+
else:
|
924 |
+
# linear proj
|
925 |
+
hidden_states = attn.to_out[0](hidden_states)
|
926 |
+
# dropout
|
927 |
+
hidden_states = attn.to_out[1](hidden_states)
|
928 |
+
|
929 |
+
return hidden_states
|
930 |
+
|
931 |
+
|
932 |
+
class AttnAddedKVProcessor:
|
933 |
+
r"""
|
934 |
+
Processor for performing attention-related computations with extra learnable key and value matrices for the text
|
935 |
+
encoder.
|
936 |
+
"""
|
937 |
+
|
938 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
939 |
+
residual = hidden_states
|
940 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
941 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
942 |
+
|
943 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
944 |
+
|
945 |
+
if encoder_hidden_states is None:
|
946 |
+
encoder_hidden_states = hidden_states
|
947 |
+
elif attn.norm_cross:
|
948 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
949 |
+
|
950 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
951 |
+
|
952 |
+
query = attn.to_q(hidden_states)
|
953 |
+
query = attn.head_to_batch_dim(query)
|
954 |
+
|
955 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
956 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
957 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
958 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
959 |
+
|
960 |
+
if not attn.only_cross_attention:
|
961 |
+
key = attn.to_k(hidden_states)
|
962 |
+
value = attn.to_v(hidden_states)
|
963 |
+
key = attn.head_to_batch_dim(key)
|
964 |
+
value = attn.head_to_batch_dim(value)
|
965 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
966 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
967 |
+
else:
|
968 |
+
key = encoder_hidden_states_key_proj
|
969 |
+
value = encoder_hidden_states_value_proj
|
970 |
+
|
971 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
972 |
+
hidden_states = torch.bmm(attention_probs, value)
|
973 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
974 |
+
|
975 |
+
# linear proj
|
976 |
+
hidden_states = attn.to_out[0](hidden_states)
|
977 |
+
# dropout
|
978 |
+
hidden_states = attn.to_out[1](hidden_states)
|
979 |
+
|
980 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
981 |
+
hidden_states = hidden_states + residual
|
982 |
+
|
983 |
+
return hidden_states
|
984 |
+
|
985 |
+
|
986 |
+
class AttnAddedKVProcessor2_0:
|
987 |
+
r"""
|
988 |
+
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
|
989 |
+
learnable key and value matrices for the text encoder.
|
990 |
+
"""
|
991 |
+
|
992 |
+
def __init__(self):
|
993 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
994 |
+
raise ImportError(
|
995 |
+
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
996 |
+
)
|
997 |
+
|
998 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
999 |
+
residual = hidden_states
|
1000 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1001 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1002 |
+
|
1003 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
1004 |
+
|
1005 |
+
if encoder_hidden_states is None:
|
1006 |
+
encoder_hidden_states = hidden_states
|
1007 |
+
elif attn.norm_cross:
|
1008 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1009 |
+
|
1010 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1011 |
+
|
1012 |
+
query = attn.to_q(hidden_states)
|
1013 |
+
query = attn.head_to_batch_dim(query, out_dim=4)
|
1014 |
+
|
1015 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1016 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1017 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
1018 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
1019 |
+
|
1020 |
+
if not attn.only_cross_attention:
|
1021 |
+
key = attn.to_k(hidden_states)
|
1022 |
+
value = attn.to_v(hidden_states)
|
1023 |
+
key = attn.head_to_batch_dim(key, out_dim=4)
|
1024 |
+
value = attn.head_to_batch_dim(value, out_dim=4)
|
1025 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
1026 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
1027 |
+
else:
|
1028 |
+
key = encoder_hidden_states_key_proj
|
1029 |
+
value = encoder_hidden_states_value_proj
|
1030 |
+
|
1031 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1032 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1033 |
+
hidden_states = F.scaled_dot_product_attention(
|
1034 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1035 |
+
)
|
1036 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
1037 |
+
|
1038 |
+
# linear proj
|
1039 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1040 |
+
# dropout
|
1041 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1042 |
+
|
1043 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1044 |
+
hidden_states = hidden_states + residual
|
1045 |
+
|
1046 |
+
return hidden_states
|
1047 |
+
|
1048 |
+
|
1049 |
+
class LoRAAttnAddedKVProcessor(nn.Module):
|
1050 |
+
r"""
|
1051 |
+
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
1052 |
+
encoder.
|
1053 |
+
|
1054 |
+
Args:
|
1055 |
+
hidden_size (`int`, *optional*):
|
1056 |
+
The hidden size of the attention layer.
|
1057 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1058 |
+
The number of channels in the `encoder_hidden_states`.
|
1059 |
+
rank (`int`, defaults to 4):
|
1060 |
+
The dimension of the LoRA update matrices.
|
1061 |
+
|
1062 |
+
"""
|
1063 |
+
|
1064 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
1065 |
+
super().__init__()
|
1066 |
+
|
1067 |
+
self.hidden_size = hidden_size
|
1068 |
+
self.cross_attention_dim = cross_attention_dim
|
1069 |
+
self.rank = rank
|
1070 |
+
|
1071 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1072 |
+
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1073 |
+
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1074 |
+
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1075 |
+
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1076 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1077 |
+
|
1078 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
1079 |
+
residual = hidden_states
|
1080 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1081 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1082 |
+
|
1083 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1084 |
+
|
1085 |
+
if encoder_hidden_states is None:
|
1086 |
+
encoder_hidden_states = hidden_states
|
1087 |
+
elif attn.norm_cross:
|
1088 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1089 |
+
|
1090 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1091 |
+
|
1092 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1093 |
+
query = attn.head_to_batch_dim(query)
|
1094 |
+
|
1095 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
|
1096 |
+
encoder_hidden_states
|
1097 |
+
)
|
1098 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
|
1099 |
+
encoder_hidden_states
|
1100 |
+
)
|
1101 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1102 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1103 |
+
|
1104 |
+
if not attn.only_cross_attention:
|
1105 |
+
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
|
1106 |
+
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
|
1107 |
+
key = attn.head_to_batch_dim(key)
|
1108 |
+
value = attn.head_to_batch_dim(value)
|
1109 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1110 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1111 |
+
else:
|
1112 |
+
key = encoder_hidden_states_key_proj
|
1113 |
+
value = encoder_hidden_states_value_proj
|
1114 |
+
|
1115 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
1116 |
+
hidden_states = torch.bmm(attention_probs, value)
|
1117 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1118 |
+
|
1119 |
+
# linear proj
|
1120 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1121 |
+
# dropout
|
1122 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1123 |
+
|
1124 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1125 |
+
hidden_states = hidden_states + residual
|
1126 |
+
|
1127 |
+
return hidden_states
|
1128 |
+
|
1129 |
+
|
1130 |
+
class XFormersAttnAddedKVProcessor:
|
1131 |
+
r"""
|
1132 |
+
Processor for implementing memory efficient attention using xFormers.
|
1133 |
+
|
1134 |
+
Args:
|
1135 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1136 |
+
The base
|
1137 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1138 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1139 |
+
operator.
|
1140 |
+
"""
|
1141 |
+
|
1142 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1143 |
+
self.attention_op = attention_op
|
1144 |
+
|
1145 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1146 |
+
residual = hidden_states
|
1147 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1148 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1149 |
+
|
1150 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1151 |
+
|
1152 |
+
if encoder_hidden_states is None:
|
1153 |
+
encoder_hidden_states = hidden_states
|
1154 |
+
elif attn.norm_cross:
|
1155 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1156 |
+
|
1157 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1158 |
+
|
1159 |
+
query = attn.to_q(hidden_states)
|
1160 |
+
query = attn.head_to_batch_dim(query)
|
1161 |
+
|
1162 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1163 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1164 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1165 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1166 |
+
|
1167 |
+
if not attn.only_cross_attention:
|
1168 |
+
key = attn.to_k(hidden_states)
|
1169 |
+
value = attn.to_v(hidden_states)
|
1170 |
+
key = attn.head_to_batch_dim(key)
|
1171 |
+
value = attn.head_to_batch_dim(value)
|
1172 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1173 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1174 |
+
else:
|
1175 |
+
key = encoder_hidden_states_key_proj
|
1176 |
+
value = encoder_hidden_states_value_proj
|
1177 |
+
|
1178 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1179 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1180 |
+
)
|
1181 |
+
hidden_states = hidden_states.to(query.dtype)
|
1182 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1183 |
+
|
1184 |
+
# linear proj
|
1185 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1186 |
+
# dropout
|
1187 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1188 |
+
|
1189 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1190 |
+
hidden_states = hidden_states + residual
|
1191 |
+
|
1192 |
+
return hidden_states
|
1193 |
+
|
1194 |
+
|
1195 |
+
class XFormersAttnProcessor:
|
1196 |
+
r"""
|
1197 |
+
Processor for implementing memory efficient attention using xFormers.
|
1198 |
+
|
1199 |
+
Args:
|
1200 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1201 |
+
The base
|
1202 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1203 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1204 |
+
operator.
|
1205 |
+
"""
|
1206 |
+
|
1207 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1208 |
+
self.attention_op = attention_op
|
1209 |
+
|
1210 |
+
def __call__(
|
1211 |
+
self,
|
1212 |
+
attn: Attention,
|
1213 |
+
hidden_states: torch.FloatTensor,
|
1214 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1215 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1216 |
+
temb: Optional[torch.FloatTensor] = None,
|
1217 |
+
):
|
1218 |
+
residual = hidden_states
|
1219 |
+
|
1220 |
+
if attn.spatial_norm is not None:
|
1221 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1222 |
+
|
1223 |
+
input_ndim = hidden_states.ndim
|
1224 |
+
|
1225 |
+
if input_ndim == 4:
|
1226 |
+
batch_size, channel, height, width = hidden_states.shape
|
1227 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1228 |
+
|
1229 |
+
batch_size, key_tokens, _ = (
|
1230 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
1234 |
+
if attention_mask is not None:
|
1235 |
+
# expand our mask's singleton query_tokens dimension:
|
1236 |
+
# [batch*heads, 1, key_tokens] ->
|
1237 |
+
# [batch*heads, query_tokens, key_tokens]
|
1238 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
1239 |
+
# [batch*heads, query_tokens, key_tokens]
|
1240 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
1241 |
+
_, query_tokens, _ = hidden_states.shape
|
1242 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
1243 |
+
|
1244 |
+
if attn.group_norm is not None:
|
1245 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1246 |
+
|
1247 |
+
query = attn.to_q(hidden_states)
|
1248 |
+
|
1249 |
+
if encoder_hidden_states is None:
|
1250 |
+
encoder_hidden_states = hidden_states
|
1251 |
+
elif attn.norm_cross:
|
1252 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1253 |
+
|
1254 |
+
key = attn.to_k(encoder_hidden_states)
|
1255 |
+
value = attn.to_v(encoder_hidden_states)
|
1256 |
+
|
1257 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1258 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1259 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1260 |
+
|
1261 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1262 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1263 |
+
)
|
1264 |
+
hidden_states = hidden_states.to(query.dtype)
|
1265 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1266 |
+
|
1267 |
+
# linear proj
|
1268 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1269 |
+
# dropout
|
1270 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1271 |
+
|
1272 |
+
if input_ndim == 4:
|
1273 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1274 |
+
|
1275 |
+
if attn.residual_connection:
|
1276 |
+
hidden_states = hidden_states + residual
|
1277 |
+
|
1278 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1279 |
+
|
1280 |
+
return hidden_states
|
1281 |
+
|
1282 |
+
|
1283 |
+
class LoRAXFormersAttnProcessor(nn.Module):
|
1284 |
+
r"""
|
1285 |
+
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
1286 |
+
|
1287 |
+
Args:
|
1288 |
+
hidden_size (`int`, *optional*):
|
1289 |
+
The hidden size of the attention layer.
|
1290 |
+
cross_attention_dim (`int`, *optional*):
|
1291 |
+
The number of channels in the `encoder_hidden_states`.
|
1292 |
+
rank (`int`, defaults to 4):
|
1293 |
+
The dimension of the LoRA update matrices.
|
1294 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1295 |
+
The base
|
1296 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1297 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1298 |
+
operator.
|
1299 |
+
network_alpha (`int`, *optional*):
|
1300 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1301 |
+
|
1302 |
+
"""
|
1303 |
+
|
1304 |
+
def __init__(
|
1305 |
+
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
|
1306 |
+
):
|
1307 |
+
super().__init__()
|
1308 |
+
|
1309 |
+
self.hidden_size = hidden_size
|
1310 |
+
self.cross_attention_dim = cross_attention_dim
|
1311 |
+
self.rank = rank
|
1312 |
+
self.attention_op = attention_op
|
1313 |
+
|
1314 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1315 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1316 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1317 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1318 |
+
|
1319 |
+
def __call__(
|
1320 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
1321 |
+
):
|
1322 |
+
residual = hidden_states
|
1323 |
+
|
1324 |
+
if attn.spatial_norm is not None:
|
1325 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1326 |
+
|
1327 |
+
input_ndim = hidden_states.ndim
|
1328 |
+
|
1329 |
+
if input_ndim == 4:
|
1330 |
+
batch_size, channel, height, width = hidden_states.shape
|
1331 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1332 |
+
|
1333 |
+
batch_size, sequence_length, _ = (
|
1334 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1335 |
+
)
|
1336 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1337 |
+
|
1338 |
+
if attn.group_norm is not None:
|
1339 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1340 |
+
|
1341 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1342 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1343 |
+
|
1344 |
+
if encoder_hidden_states is None:
|
1345 |
+
encoder_hidden_states = hidden_states
|
1346 |
+
elif attn.norm_cross:
|
1347 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1348 |
+
|
1349 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1350 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1351 |
+
|
1352 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1353 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1354 |
+
|
1355 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1356 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1357 |
+
)
|
1358 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1359 |
+
|
1360 |
+
# linear proj
|
1361 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1362 |
+
# dropout
|
1363 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1364 |
+
|
1365 |
+
if input_ndim == 4:
|
1366 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1367 |
+
|
1368 |
+
if attn.residual_connection:
|
1369 |
+
hidden_states = hidden_states + residual
|
1370 |
+
|
1371 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1372 |
+
|
1373 |
+
return hidden_states
|
1374 |
+
|
1375 |
+
|
1376 |
+
class LoRAAttnProcessor2_0(nn.Module):
|
1377 |
+
r"""
|
1378 |
+
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
1379 |
+
attention.
|
1380 |
+
|
1381 |
+
Args:
|
1382 |
+
hidden_size (`int`):
|
1383 |
+
The hidden size of the attention layer.
|
1384 |
+
cross_attention_dim (`int`, *optional*):
|
1385 |
+
The number of channels in the `encoder_hidden_states`.
|
1386 |
+
rank (`int`, defaults to 4):
|
1387 |
+
The dimension of the LoRA update matrices.
|
1388 |
+
network_alpha (`int`, *optional*):
|
1389 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1390 |
+
"""
|
1391 |
+
|
1392 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
1393 |
+
super().__init__()
|
1394 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1395 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1396 |
+
|
1397 |
+
self.hidden_size = hidden_size
|
1398 |
+
self.cross_attention_dim = cross_attention_dim
|
1399 |
+
self.rank = rank
|
1400 |
+
|
1401 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1402 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1403 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1404 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1405 |
+
|
1406 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
1407 |
+
residual = hidden_states
|
1408 |
+
|
1409 |
+
input_ndim = hidden_states.ndim
|
1410 |
+
|
1411 |
+
if input_ndim == 4:
|
1412 |
+
batch_size, channel, height, width = hidden_states.shape
|
1413 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1414 |
+
|
1415 |
+
batch_size, sequence_length, _ = (
|
1416 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1417 |
+
)
|
1418 |
+
inner_dim = hidden_states.shape[-1]
|
1419 |
+
|
1420 |
+
if attention_mask is not None:
|
1421 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1422 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1423 |
+
# (batch, heads, source_length, target_length)
|
1424 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1425 |
+
|
1426 |
+
if attn.group_norm is not None:
|
1427 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1428 |
+
|
1429 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1430 |
+
|
1431 |
+
if encoder_hidden_states is None:
|
1432 |
+
encoder_hidden_states = hidden_states
|
1433 |
+
elif attn.norm_cross:
|
1434 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1435 |
+
|
1436 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1437 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1438 |
+
|
1439 |
+
head_dim = inner_dim // attn.heads
|
1440 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1441 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1442 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1443 |
+
|
1444 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1445 |
+
hidden_states = F.scaled_dot_product_attention(
|
1446 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1447 |
+
)
|
1448 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1449 |
+
hidden_states = hidden_states.to(query.dtype)
|
1450 |
+
|
1451 |
+
# linear proj
|
1452 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1453 |
+
# dropout
|
1454 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1455 |
+
|
1456 |
+
if input_ndim == 4:
|
1457 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1458 |
+
|
1459 |
+
if attn.residual_connection:
|
1460 |
+
hidden_states = hidden_states + residual
|
1461 |
+
|
1462 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1463 |
+
|
1464 |
+
return hidden_states
|
1465 |
+
|
1466 |
+
|
1467 |
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
1468 |
+
r"""
|
1469 |
+
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
1470 |
+
|
1471 |
+
Args:
|
1472 |
+
train_kv (`bool`, defaults to `True`):
|
1473 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
1474 |
+
train_q_out (`bool`, defaults to `True`):
|
1475 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
1476 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
1477 |
+
The hidden size of the attention layer.
|
1478 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1479 |
+
The number of channels in the `encoder_hidden_states`.
|
1480 |
+
out_bias (`bool`, defaults to `True`):
|
1481 |
+
Whether to include the bias parameter in `train_q_out`.
|
1482 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
1483 |
+
The dropout probability to use.
|
1484 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1485 |
+
The base
|
1486 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
|
1487 |
+
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
|
1488 |
+
"""
|
1489 |
+
|
1490 |
+
def __init__(
|
1491 |
+
self,
|
1492 |
+
train_kv=True,
|
1493 |
+
train_q_out=False,
|
1494 |
+
hidden_size=None,
|
1495 |
+
cross_attention_dim=None,
|
1496 |
+
out_bias=True,
|
1497 |
+
dropout=0.0,
|
1498 |
+
attention_op: Optional[Callable] = None,
|
1499 |
+
):
|
1500 |
+
super().__init__()
|
1501 |
+
self.train_kv = train_kv
|
1502 |
+
self.train_q_out = train_q_out
|
1503 |
+
|
1504 |
+
self.hidden_size = hidden_size
|
1505 |
+
self.cross_attention_dim = cross_attention_dim
|
1506 |
+
self.attention_op = attention_op
|
1507 |
+
|
1508 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
1509 |
+
if self.train_kv:
|
1510 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1511 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1512 |
+
if self.train_q_out:
|
1513 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
1514 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
1515 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
1516 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
1517 |
+
|
1518 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1519 |
+
batch_size, sequence_length, _ = (
|
1520 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1521 |
+
)
|
1522 |
+
|
1523 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1524 |
+
|
1525 |
+
if self.train_q_out:
|
1526 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
1527 |
+
else:
|
1528 |
+
query = attn.to_q(hidden_states)
|
1529 |
+
|
1530 |
+
if encoder_hidden_states is None:
|
1531 |
+
crossattn = False
|
1532 |
+
encoder_hidden_states = hidden_states
|
1533 |
+
else:
|
1534 |
+
crossattn = True
|
1535 |
+
if attn.norm_cross:
|
1536 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1537 |
+
|
1538 |
+
if self.train_kv:
|
1539 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
1540 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
1541 |
+
else:
|
1542 |
+
key = attn.to_k(encoder_hidden_states)
|
1543 |
+
value = attn.to_v(encoder_hidden_states)
|
1544 |
+
|
1545 |
+
if crossattn:
|
1546 |
+
detach = torch.ones_like(key)
|
1547 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
1548 |
+
key = detach * key + (1 - detach) * key.detach()
|
1549 |
+
value = detach * value + (1 - detach) * value.detach()
|
1550 |
+
|
1551 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1552 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1553 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1554 |
+
|
1555 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1556 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1557 |
+
)
|
1558 |
+
hidden_states = hidden_states.to(query.dtype)
|
1559 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1560 |
+
|
1561 |
+
if self.train_q_out:
|
1562 |
+
# linear proj
|
1563 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
1564 |
+
# dropout
|
1565 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
1566 |
+
else:
|
1567 |
+
# linear proj
|
1568 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1569 |
+
# dropout
|
1570 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1571 |
+
return hidden_states
|
1572 |
+
|
1573 |
+
|
1574 |
+
class SlicedAttnProcessor:
|
1575 |
+
r"""
|
1576 |
+
Processor for implementing sliced attention.
|
1577 |
+
|
1578 |
+
Args:
|
1579 |
+
slice_size (`int`, *optional*):
|
1580 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1581 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1582 |
+
"""
|
1583 |
+
|
1584 |
+
def __init__(self, slice_size):
|
1585 |
+
self.slice_size = slice_size
|
1586 |
+
|
1587 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1588 |
+
residual = hidden_states
|
1589 |
+
|
1590 |
+
input_ndim = hidden_states.ndim
|
1591 |
+
|
1592 |
+
if input_ndim == 4:
|
1593 |
+
batch_size, channel, height, width = hidden_states.shape
|
1594 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1595 |
+
|
1596 |
+
batch_size, sequence_length, _ = (
|
1597 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1598 |
+
)
|
1599 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1600 |
+
|
1601 |
+
if attn.group_norm is not None:
|
1602 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1603 |
+
|
1604 |
+
query = attn.to_q(hidden_states)
|
1605 |
+
dim = query.shape[-1]
|
1606 |
+
query = attn.head_to_batch_dim(query)
|
1607 |
+
|
1608 |
+
if encoder_hidden_states is None:
|
1609 |
+
encoder_hidden_states = hidden_states
|
1610 |
+
elif attn.norm_cross:
|
1611 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1612 |
+
|
1613 |
+
key = attn.to_k(encoder_hidden_states)
|
1614 |
+
value = attn.to_v(encoder_hidden_states)
|
1615 |
+
key = attn.head_to_batch_dim(key)
|
1616 |
+
value = attn.head_to_batch_dim(value)
|
1617 |
+
|
1618 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1619 |
+
hidden_states = torch.zeros(
|
1620 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1621 |
+
)
|
1622 |
+
|
1623 |
+
for i in range(batch_size_attention // self.slice_size):
|
1624 |
+
start_idx = i * self.slice_size
|
1625 |
+
end_idx = (i + 1) * self.slice_size
|
1626 |
+
|
1627 |
+
query_slice = query[start_idx:end_idx]
|
1628 |
+
key_slice = key[start_idx:end_idx]
|
1629 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1630 |
+
|
1631 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1632 |
+
|
1633 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1634 |
+
|
1635 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1636 |
+
|
1637 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1638 |
+
|
1639 |
+
# linear proj
|
1640 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1641 |
+
# dropout
|
1642 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1643 |
+
|
1644 |
+
if input_ndim == 4:
|
1645 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1646 |
+
|
1647 |
+
if attn.residual_connection:
|
1648 |
+
hidden_states = hidden_states + residual
|
1649 |
+
|
1650 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1651 |
+
|
1652 |
+
return hidden_states
|
1653 |
+
|
1654 |
+
|
1655 |
+
class SlicedAttnAddedKVProcessor:
|
1656 |
+
r"""
|
1657 |
+
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
|
1658 |
+
|
1659 |
+
Args:
|
1660 |
+
slice_size (`int`, *optional*):
|
1661 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1662 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1663 |
+
"""
|
1664 |
+
|
1665 |
+
def __init__(self, slice_size):
|
1666 |
+
self.slice_size = slice_size
|
1667 |
+
|
1668 |
+
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
1669 |
+
residual = hidden_states
|
1670 |
+
|
1671 |
+
if attn.spatial_norm is not None:
|
1672 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1673 |
+
|
1674 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1675 |
+
|
1676 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1677 |
+
|
1678 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1679 |
+
|
1680 |
+
if encoder_hidden_states is None:
|
1681 |
+
encoder_hidden_states = hidden_states
|
1682 |
+
elif attn.norm_cross:
|
1683 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1684 |
+
|
1685 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1686 |
+
|
1687 |
+
query = attn.to_q(hidden_states)
|
1688 |
+
dim = query.shape[-1]
|
1689 |
+
query = attn.head_to_batch_dim(query)
|
1690 |
+
|
1691 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1692 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1693 |
+
|
1694 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1695 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1696 |
+
|
1697 |
+
if not attn.only_cross_attention:
|
1698 |
+
key = attn.to_k(hidden_states)
|
1699 |
+
value = attn.to_v(hidden_states)
|
1700 |
+
key = attn.head_to_batch_dim(key)
|
1701 |
+
value = attn.head_to_batch_dim(value)
|
1702 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1703 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1704 |
+
else:
|
1705 |
+
key = encoder_hidden_states_key_proj
|
1706 |
+
value = encoder_hidden_states_value_proj
|
1707 |
+
|
1708 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1709 |
+
hidden_states = torch.zeros(
|
1710 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1711 |
+
)
|
1712 |
+
|
1713 |
+
for i in range(batch_size_attention // self.slice_size):
|
1714 |
+
start_idx = i * self.slice_size
|
1715 |
+
end_idx = (i + 1) * self.slice_size
|
1716 |
+
|
1717 |
+
query_slice = query[start_idx:end_idx]
|
1718 |
+
key_slice = key[start_idx:end_idx]
|
1719 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1720 |
+
|
1721 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1722 |
+
|
1723 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1724 |
+
|
1725 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1726 |
+
|
1727 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1728 |
+
|
1729 |
+
# linear proj
|
1730 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1731 |
+
# dropout
|
1732 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1733 |
+
|
1734 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1735 |
+
hidden_states = hidden_states + residual
|
1736 |
+
|
1737 |
+
return hidden_states
|
1738 |
+
|
1739 |
+
|
1740 |
+
AttentionProcessor = Union[
|
1741 |
+
AttnProcessor,
|
1742 |
+
AttnProcessor2_0_image_prompt,
|
1743 |
+
XFormersAttnProcessor,
|
1744 |
+
SlicedAttnProcessor,
|
1745 |
+
AttnAddedKVProcessor,
|
1746 |
+
SlicedAttnAddedKVProcessor,
|
1747 |
+
AttnAddedKVProcessor2_0,
|
1748 |
+
XFormersAttnAddedKVProcessor,
|
1749 |
+
LoRAAttnProcessor,
|
1750 |
+
LoRAXFormersAttnProcessor,
|
1751 |
+
LoRAAttnProcessor2_0,
|
1752 |
+
LoRAAttnAddedKVProcessor,
|
1753 |
+
CustomDiffusionAttnProcessor,
|
1754 |
+
CustomDiffusionXFormersAttnProcessor,
|
1755 |
+
]
|
1756 |
+
|
1757 |
+
|
1758 |
+
class SpatialNorm(nn.Module):
|
1759 |
+
"""
|
1760 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
|
1761 |
+
"""
|
1762 |
+
|
1763 |
+
def __init__(
|
1764 |
+
self,
|
1765 |
+
f_channels,
|
1766 |
+
zq_channels,
|
1767 |
+
):
|
1768 |
+
super().__init__()
|
1769 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
1770 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1771 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1772 |
+
|
1773 |
+
def forward(self, f, zq):
|
1774 |
+
f_size = f.shape[-2:]
|
1775 |
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
1776 |
+
norm_f = self.norm_layer(f)
|
1777 |
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
1778 |
+
return new_f
|
models/base_vision.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
base_vision.py
|
17 |
+
|
18 |
+
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
|
19 |
+
functions, and initialization logic.
|
20 |
+
|
21 |
+
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
|
22 |
+
Transformer model for feature extraction.
|
23 |
+
"""
|
24 |
+
from abc import ABC, abstractmethod
|
25 |
+
from dataclasses import dataclass
|
26 |
+
from functools import partial
|
27 |
+
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
28 |
+
|
29 |
+
import timm
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
import torchvision.transforms.functional as TVF
|
33 |
+
from PIL.Image import Image
|
34 |
+
from timm.models.vision_transformer import Block, VisionTransformer
|
35 |
+
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
36 |
+
from torchvision.transforms import Compose, Resize
|
37 |
+
|
38 |
+
|
39 |
+
# === Utility Functions for Monkey-Patching ===
|
40 |
+
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
41 |
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
42 |
+
result = fn(*args, **kwargs)
|
43 |
+
return result[0] if (isinstance(result, tuple) or isinstance(result, list)) else result
|
44 |
+
|
45 |
+
return wrapper
|
46 |
+
|
47 |
+
def return_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
48 |
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
49 |
+
result = fn(*args, **kwargs)
|
50 |
+
return result
|
51 |
+
|
52 |
+
return wrapper
|
53 |
+
|
54 |
+
|
55 |
+
# === Interface for an Image Transform ===
|
56 |
+
class ImageTransform(Protocol):
|
57 |
+
def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ...
|
58 |
+
|
59 |
+
|
60 |
+
# === Custom Torchvision Image Transforms ===
|
61 |
+
@dataclass
|
62 |
+
class LetterboxPad:
|
63 |
+
padding_fill_value: Tuple[int, int, int]
|
64 |
+
|
65 |
+
def __call__(self, image: Image) -> Image:
|
66 |
+
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
|
67 |
+
(w, h), max_wh = image.size, max(image.size)
|
68 |
+
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
|
69 |
+
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
70 |
+
return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant")
|
71 |
+
|
72 |
+
|
73 |
+
# === Abstract Base Class for arbitrary Vision Backbones ===
|
74 |
+
class VisionBackbone(nn.Module, ABC):
|
75 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
76 |
+
super().__init__()
|
77 |
+
self.identifier: str = vision_backbone_id
|
78 |
+
self.image_resize_strategy: str = image_resize_strategy
|
79 |
+
self.default_image_size: int = default_image_size
|
80 |
+
|
81 |
+
# Instance attributes for a Vision Backbone
|
82 |
+
self.featurizer: nn.Module = None
|
83 |
+
self.image_transform: ImageTransform = None
|
84 |
+
|
85 |
+
def get_image_transform(self) -> ImageTransform:
|
86 |
+
return self.image_transform
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def get_fsdp_wrapping_policy(self) -> Callable: ...
|
90 |
+
|
91 |
+
@abstractmethod
|
92 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
93 |
+
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
|
94 |
+
raise NotImplementedError
|
95 |
+
|
96 |
+
@property
|
97 |
+
@abstractmethod
|
98 |
+
def default_image_resolution(self) -> Tuple[int, int, int]: ...
|
99 |
+
|
100 |
+
@property
|
101 |
+
@abstractmethod
|
102 |
+
def embed_dim(self) -> int: ...
|
103 |
+
|
104 |
+
@property
|
105 |
+
@abstractmethod
|
106 |
+
def num_patches(self) -> int: ...
|
107 |
+
|
108 |
+
@property
|
109 |
+
@abstractmethod
|
110 |
+
def half_precision_dtype(self) -> torch.dtype: ...
|
111 |
+
|
112 |
+
|
113 |
+
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
|
114 |
+
class TimmViTBackbone(VisionBackbone, ABC):
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
vision_backbone_id: str,
|
118 |
+
timm_path_or_url: str,
|
119 |
+
image_resize_strategy: str,
|
120 |
+
default_image_size: int = 224,
|
121 |
+
override_act_layer: Optional[str] = None,
|
122 |
+
) -> None:
|
123 |
+
super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
|
124 |
+
self.timm_path_or_url = timm_path_or_url
|
125 |
+
self.override_act_layer = override_act_layer
|
126 |
+
self.dtype = torch.bfloat16
|
127 |
+
|
128 |
+
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
|
129 |
+
if self.override_act_layer is None:
|
130 |
+
self.featurizer: VisionTransformer = timm.create_model(
|
131 |
+
self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size,
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
self.featurizer: VisionTransformer = timm.create_model(
|
135 |
+
self.timm_path_or_url,
|
136 |
+
pretrained=True,
|
137 |
+
num_classes=0,
|
138 |
+
img_size=self.default_image_size,
|
139 |
+
act_layer=self.override_act_layer,
|
140 |
+
)
|
141 |
+
self.featurizer.eval()
|
142 |
+
|
143 |
+
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
|
144 |
+
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
145 |
+
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
146 |
+
self.featurizer.forward = unpack_tuple(
|
147 |
+
partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
|
148 |
+
)
|
149 |
+
|
150 |
+
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
|
151 |
+
assert isinstance(self.featurizer, VisionTransformer), (
|
152 |
+
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
|
153 |
+
"file an issue or implement the requisite logic (see `cobra/models/backbones/vision/base_vision.py`)!"
|
154 |
+
)
|
155 |
+
|
156 |
+
# Get Config =>> Note :: Override default image size to ensure correct image transform
|
157 |
+
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
|
158 |
+
self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
|
159 |
+
|
160 |
+
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
|
161 |
+
default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False)
|
162 |
+
|
163 |
+
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
|
164 |
+
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
|
165 |
+
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
|
166 |
+
assert isinstance(resize_transform := default_image_transform.transforms[0], Resize)
|
167 |
+
default_image_transform = Compose(
|
168 |
+
[
|
169 |
+
Resize(self.default_image_size, interpolation=resize_transform.interpolation),
|
170 |
+
*default_image_transform.transforms[1:],
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
# Switch on `image_resize_strategy`
|
175 |
+
if self.image_resize_strategy == "resize-naive":
|
176 |
+
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
|
177 |
+
assert isinstance(resize_transform := default_image_transform.transforms[0], Resize)
|
178 |
+
|
179 |
+
target_size = (self.default_image_size, self.default_image_size)
|
180 |
+
self.image_transform = Compose(
|
181 |
+
[
|
182 |
+
Resize(target_size, interpolation=resize_transform.interpolation),
|
183 |
+
*default_image_transform.transforms[1:],
|
184 |
+
]
|
185 |
+
)
|
186 |
+
|
187 |
+
elif self.image_resize_strategy == "resize-crop":
|
188 |
+
self.image_transform = default_image_transform
|
189 |
+
|
190 |
+
elif self.image_resize_strategy == "letterbox":
|
191 |
+
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
|
192 |
+
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
|
193 |
+
|
194 |
+
# Compute Padding Fill Value (rescaled normalization mean if applicable)
|
195 |
+
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
|
196 |
+
|
197 |
+
# Build New Transform
|
198 |
+
self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms])
|
199 |
+
|
200 |
+
else:
|
201 |
+
raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
|
202 |
+
|
203 |
+
def get_fsdp_wrapping_policy(self) -> Callable:
|
204 |
+
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
|
205 |
+
vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
|
206 |
+
transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
|
207 |
+
return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
|
208 |
+
|
209 |
+
def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
|
210 |
+
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
|
211 |
+
return self.featurizer(pixel_values)
|
212 |
+
|
213 |
+
@property
|
214 |
+
def default_image_resolution(self) -> Tuple[int, int, int]:
|
215 |
+
return self.data_cfg["input_size"]
|
216 |
+
|
217 |
+
@property
|
218 |
+
def embed_dim(self) -> int:
|
219 |
+
return self.featurizer.embed_dim
|
220 |
+
|
221 |
+
@property
|
222 |
+
def num_patches(self) -> int:
|
223 |
+
return self.featurizer.patch_embed.num_patches
|
224 |
+
|
225 |
+
@property
|
226 |
+
def half_precision_dtype(self) -> torch.dtype:
|
227 |
+
return self.dtype
|
models/dino.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
dinosiglip_vit.py
|
17 |
+
|
18 |
+
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
|
19 |
+
"""
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from functools import partial
|
22 |
+
from typing import Callable, Dict, Tuple
|
23 |
+
import os
|
24 |
+
import timm
|
25 |
+
import torch
|
26 |
+
from PIL import Image
|
27 |
+
from einops import rearrange
|
28 |
+
from timm.models.vision_transformer import Block, VisionTransformer
|
29 |
+
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
30 |
+
from torchvision.transforms import Compose, Resize
|
31 |
+
|
32 |
+
from models.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple, return_tuple
|
33 |
+
|
34 |
+
import torch.nn as nn
|
35 |
+
import torchvision
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class DinoSigLIPImageTransform:
|
39 |
+
dino_image_transform: ImageTransform
|
40 |
+
siglip_image_transform: ImageTransform
|
41 |
+
is_cobra: bool = True
|
42 |
+
|
43 |
+
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
|
44 |
+
return {"dino": self.dino_image_transform(img, **kwargs).unsqueeze(0), "siglip": self.siglip_image_transform(img, **kwargs).unsqueeze(0)}
|
45 |
+
|
46 |
+
|
47 |
+
class DinoViTBackbone(VisionBackbone):
|
48 |
+
def __init__(self, backbone_name_or_path, image_resize_strategy: str, default_image_size: int = 224, last_n = 2, feature_index = 22) -> None:
|
49 |
+
super().__init__(backbone_name_or_path, image_resize_strategy, default_image_size=default_image_size)
|
50 |
+
# load from local paths
|
51 |
+
dino_pretrained_cfg = timm.models.create_model(backbone_name_or_path).default_cfg
|
52 |
+
dino_pretrained_cfg['file'] = 'ckpts/vit_large_patch14_reg4_dinov2.lvd142m/pytorch_model.bin'
|
53 |
+
|
54 |
+
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
|
55 |
+
self.dino_featurizer: VisionTransformer = timm.create_model(
|
56 |
+
backbone_name_or_path, pretrained=True, num_classes=0, img_size=self.default_image_size,
|
57 |
+
pretrained_cfg=dino_pretrained_cfg
|
58 |
+
)
|
59 |
+
self.dino_featurizer.eval()
|
60 |
+
|
61 |
+
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
|
62 |
+
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
63 |
+
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
64 |
+
# return the output tokens from the `n` last blocks
|
65 |
+
print("dino has {} layer intermediate features. ".format(len(self.dino_featurizer.blocks))) # 24
|
66 |
+
# self.dino_featurizer.forward = unpack_tuple(
|
67 |
+
# partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - last_n})
|
68 |
+
# )
|
69 |
+
if isinstance(feature_index, tuple) or isinstance(feature_index, list):
|
70 |
+
feature_index = set(feature_index)
|
71 |
+
else:
|
72 |
+
feature_index = {feature_index}
|
73 |
+
self.dino_featurizer.forward = return_tuple(
|
74 |
+
partial(self.dino_featurizer.get_intermediate_layers, n=feature_index)
|
75 |
+
)
|
76 |
+
|
77 |
+
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
|
78 |
+
self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer)
|
79 |
+
self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
|
80 |
+
|
81 |
+
# Initialize *both* Transforms
|
82 |
+
default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False)
|
83 |
+
|
84 |
+
if self.image_resize_strategy == "resize-naive":
|
85 |
+
assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!"
|
86 |
+
assert isinstance(dino_resize_transform := default_dino_transform.transforms[0], Resize)
|
87 |
+
|
88 |
+
target_size = (self.default_image_size, self.default_image_size)
|
89 |
+
dino_transform = Compose(
|
90 |
+
[
|
91 |
+
Resize(target_size, interpolation=dino_resize_transform.interpolation),
|
92 |
+
*default_dino_transform.transforms[1:],
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
self.dino_transform = dino_transform
|
97 |
+
else:
|
98 |
+
raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
|
99 |
+
|
100 |
+
def get_fsdp_wrapping_policy(self) -> Callable:
|
101 |
+
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
|
102 |
+
vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
|
103 |
+
transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
|
104 |
+
return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
|
105 |
+
|
106 |
+
def forward(self, pixel_values, device="cpu", input_dtype_new=None) -> torch.Tensor:
|
107 |
+
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
|
108 |
+
# b, c , h , w : 0-1
|
109 |
+
t_tensors = []
|
110 |
+
for pixel_value in pixel_values:
|
111 |
+
t_tensors.append(self.dino_transform(pixel_value).unsqueeze(0))
|
112 |
+
t_tensors = torch.cat(t_tensors, dim=0).to(device)
|
113 |
+
if input_dtype_new is not None:
|
114 |
+
t_tensors = t_tensors.to(input_dtype_new)
|
115 |
+
|
116 |
+
t_tensors_list = self.dino_featurizer(t_tensors)
|
117 |
+
return t_tensors_list
|
118 |
+
|
119 |
+
@property
|
120 |
+
def default_image_resolution(self) -> Tuple[int, int, int]:
|
121 |
+
return self.dino_data_cfg["input_size"]
|
122 |
+
|
123 |
+
@property
|
124 |
+
def embed_dim(self) -> int:
|
125 |
+
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
|
126 |
+
|
127 |
+
@property
|
128 |
+
def num_patches(self) -> int:
|
129 |
+
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
|
130 |
+
return self.dino_featurizer.patch_embed.num_patches
|
131 |
+
|
132 |
+
@property
|
133 |
+
def half_precision_dtype(self) -> torch.dtype:
|
134 |
+
return torch.bfloat16
|
135 |
+
|
136 |
+
|
137 |
+
class DinoEncoder(nn.Module):
|
138 |
+
def __init__(self, backbone_name_or_path, image_resize_strategy: str, default_image_size: int = 224, feature_index = 22) -> None:
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
self.image_encoder = DinoViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
|
142 |
+
self.to_pil = torchvision.transforms.ToPILImage()
|
143 |
+
|
144 |
+
def forward(self, image_tensor, device="cpu", input_dtype_new=torch.float32): # input image size = 768
|
145 |
+
pixel_values = []
|
146 |
+
|
147 |
+
for image_tensor_i in image_tensor:
|
148 |
+
pixel_values.append(self.to_pil(image_tensor_i))
|
149 |
+
|
150 |
+
embeddings_dino_list = self.image_encoder(pixel_values, device, input_dtype_new)
|
151 |
+
if len(embeddings_dino_list) == 1:
|
152 |
+
embeddings_dino_list = embeddings_dino_list[0]
|
153 |
+
return embeddings_dino_list
|
154 |
+
|
155 |
+
class DinoEncoderV2(nn.Module):
|
156 |
+
def __init__(self, backbone_name_or_path, image_resize_strategy: str, default_image_size: int = 224, feature_index = 22) -> None:
|
157 |
+
super().__init__()
|
158 |
+
|
159 |
+
self.image_encoder = DinoViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
|
160 |
+
self.to_pil = torchvision.transforms.ToPILImage()
|
161 |
+
|
162 |
+
def get_fsdp_wrapping_policy(self):
|
163 |
+
return self.image_encoder.get_fsdp_wrapping_policy()
|
164 |
+
|
165 |
+
def forward(self, image_tensor_dict, device="cpu", input_dtype_new=torch.float32):
|
166 |
+
image_tensor = image_tensor_dict["images_ref"]
|
167 |
+
|
168 |
+
output_dict = {}
|
169 |
+
pixel_values = []
|
170 |
+
|
171 |
+
for image_tensor_i in image_tensor:
|
172 |
+
pixel_values.append(self.to_pil(image_tensor_i))
|
173 |
+
|
174 |
+
embeddings_dino_list = self.image_encoder(pixel_values, device, input_dtype_new)
|
175 |
+
if len(embeddings_dino_list) == 1:
|
176 |
+
embeddings_dino_list = embeddings_dino_list[0]
|
177 |
+
output_dict["img_patch_features"] = embeddings_dino_list
|
178 |
+
return output_dict
|
179 |
+
|
180 |
+
class DinoEncoderV2_Canny(nn.Module):
|
181 |
+
def __init__(self, backbone_name_or_path, image_resize_strategy: str, default_image_size: int = 224, feature_index = 22) -> None:
|
182 |
+
super().__init__()
|
183 |
+
|
184 |
+
self.image_encoder = DinoViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
|
185 |
+
self.to_pil = torchvision.transforms.ToPILImage()
|
186 |
+
|
187 |
+
def get_fsdp_wrapping_policy(self):
|
188 |
+
return self.image_encoder.get_fsdp_wrapping_policy()
|
189 |
+
|
190 |
+
def forward(self, image_tensor_dict, device="cpu", input_dtype_new=torch.float32):
|
191 |
+
image_canny = image_tensor_dict["images_canny"]
|
192 |
+
|
193 |
+
output_dict = {}
|
194 |
+
pixel_values = []
|
195 |
+
|
196 |
+
for image_tensor_i in image_canny:
|
197 |
+
pixel_values.append(self.to_pil(image_tensor_i))
|
198 |
+
|
199 |
+
embeddings_dino_list = self.image_encoder(pixel_values, device, input_dtype_new)
|
200 |
+
if len(embeddings_dino_list) == 1:
|
201 |
+
embeddings_dino_list = embeddings_dino_list[0]
|
202 |
+
output_dict["img_patch_features"] = embeddings_dino_list
|
203 |
+
return output_dict
|
models/image_encoder_siglipdino_shallowdeep.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torchvision
|
17 |
+
import torch.nn as nn
|
18 |
+
from einops import rearrange
|
19 |
+
|
20 |
+
from models.sigclip import SigLIPViTBackbone
|
21 |
+
from models.dino import DinoViTBackbone
|
22 |
+
|
23 |
+
class ShallowDeepSiglipDinoEncoder(nn.Module):
|
24 |
+
def __init__(self, siglip_config={}, dino_config={}):
|
25 |
+
super().__init__()
|
26 |
+
self.to_pil = torchvision.transforms.ToPILImage()
|
27 |
+
self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config)
|
28 |
+
self.image_encoder_dino = DinoViTBackbone(**dino_config)
|
29 |
+
|
30 |
+
def forward(self, image_tensor, device="cpu"):
|
31 |
+
bs = image_tensor.size(0)
|
32 |
+
# tensor 转 PIL
|
33 |
+
pixel_values = []
|
34 |
+
for image_tensor_i in image_tensor:
|
35 |
+
pixel_values.append(self.to_pil(image_tensor_i))
|
36 |
+
|
37 |
+
embeddings = []
|
38 |
+
embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device)
|
39 |
+
embeddings_dino_list = self.image_encoder_dino(pixel_values, device)
|
40 |
+
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
|
41 |
+
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
|
42 |
+
embeddings.append(embeddings_i)
|
43 |
+
|
44 |
+
return embeddings
|
45 |
+
|
46 |
+
# The default is to use double the image size, i.e., 768x768.
|
47 |
+
class ShallowDeepPatchfySiglipDinoEncoder(nn.Module):
|
48 |
+
def __init__(self, siglip_config={}, dino_config={}, patchfy_scale=2, default_image_size=384):
|
49 |
+
super().__init__()
|
50 |
+
self.to_pil = torchvision.transforms.ToPILImage()
|
51 |
+
self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config)
|
52 |
+
self.image_encoder_dino = DinoViTBackbone(**dino_config)
|
53 |
+
|
54 |
+
self.patchfy = (patchfy_scale > 1)
|
55 |
+
self.patchfy_scale = patchfy_scale
|
56 |
+
self.default_image_size = default_image_size
|
57 |
+
|
58 |
+
def forward(self, image_tensor, device="cpu", **kwargs): # input image size = 768
|
59 |
+
image_tensor = image_tensor["image_ref"] # this is a dict
|
60 |
+
bs = image_tensor.size(0)
|
61 |
+
|
62 |
+
if self.patchfy:
|
63 |
+
image_local = rearrange(image_tensor, "b c (h hp) (w wp) -> (b hp wp) c h w", hp=self.patchfy_scale, wp=self.patchfy_scale)
|
64 |
+
image_global = torch.nn.functional.interpolate(image_tensor, size=(self.default_image_size, self.default_image_size), mode='bilinear', align_corners=True)
|
65 |
+
|
66 |
+
# tensor 转 PIL
|
67 |
+
pixel_values_local, pixel_values_global = [], []
|
68 |
+
for image_tensor_i in image_local:
|
69 |
+
pixel_values_local.append(self.to_pil(image_tensor_i.to(torch.float)))
|
70 |
+
for image_tensor_i in image_global:
|
71 |
+
pixel_values_global.append(self.to_pil(image_tensor_i.to(torch.float)))
|
72 |
+
|
73 |
+
embeddings = []
|
74 |
+
embeddings_siglip_list = self.image_encoder_siglip(pixel_values_global, device)
|
75 |
+
embeddings_dino_list = self.image_encoder_dino(pixel_values_global, device)
|
76 |
+
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
|
77 |
+
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
|
78 |
+
embeddings.append(embeddings_i)
|
79 |
+
|
80 |
+
embeddings_local_siglip_deep = self.image_encoder_siglip(pixel_values_local, device)[-1]
|
81 |
+
embeddings_local_dino_deep = self.image_encoder_dino(pixel_values_local, device)[-1]
|
82 |
+
embeddings_local_deep = torch.cat([embeddings_local_siglip_deep, embeddings_local_dino_deep], dim=-1)
|
83 |
+
|
84 |
+
embeddings_local_deep = rearrange(embeddings_local_deep, "(b hp wp) l c -> b (l hp wp) c", hp=self.patchfy_scale, wp=self.patchfy_scale)
|
85 |
+
|
86 |
+
embeddings.append(embeddings_local_deep)
|
87 |
+
|
88 |
+
else:
|
89 |
+
# tensor 转 PIL
|
90 |
+
pixel_values = []
|
91 |
+
for image_tensor_i in image_tensor:
|
92 |
+
pixel_values.append(self.to_pil(image_tensor_i))
|
93 |
+
|
94 |
+
embeddings = []
|
95 |
+
embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device)
|
96 |
+
embeddings_dino_list = self.image_encoder_dino(pixel_values, device)
|
97 |
+
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
|
98 |
+
# 逐层concat的方式
|
99 |
+
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
|
100 |
+
embeddings.append(embeddings_i)
|
101 |
+
|
102 |
+
if len(embeddings) == 1:
|
103 |
+
embeddings = embeddings[0]
|
104 |
+
return embeddings
|
105 |
+
|
106 |
+
|
107 |
+
class ShallowDeepPatchfySiglipDinoEncoder_v2(nn.Module):
|
108 |
+
def __init__(self, siglip_config={}, dino_config={}, patchfy_scale=2, default_image_size=384):
|
109 |
+
super().__init__()
|
110 |
+
self.to_pil = torchvision.transforms.ToPILImage()
|
111 |
+
self.image_encoder_siglip = SigLIPViTBackbone(**siglip_config)
|
112 |
+
self.image_encoder_dino = DinoViTBackbone(**dino_config)
|
113 |
+
|
114 |
+
self.patchfy = (patchfy_scale > 1)
|
115 |
+
self.patchfy_scale = patchfy_scale
|
116 |
+
self.default_image_size = default_image_size
|
117 |
+
|
118 |
+
def forward(self, image_tensor_dict, device="cpu", **kwargs): # input image size = 768
|
119 |
+
image_tensor = image_tensor_dict["image_ref"]
|
120 |
+
bs = image_tensor.size(0)
|
121 |
+
|
122 |
+
if self.patchfy:
|
123 |
+
image_local = rearrange(image_tensor, "b c (h hp) (w wp) -> (b hp wp) c h w", hp=self.patchfy_scale, wp=self.patchfy_scale)
|
124 |
+
image_global = torch.nn.functional.interpolate(image_tensor, size=(self.default_image_size, self.default_image_size), mode='bilinear', align_corners=True)
|
125 |
+
|
126 |
+
pixel_values_local, pixel_values_global = [], []
|
127 |
+
for image_tensor_i in image_local:
|
128 |
+
pixel_values_local.append(self.to_pil(image_tensor_i.to(torch.float32)))
|
129 |
+
for image_tensor_i in image_global:
|
130 |
+
pixel_values_global.append(self.to_pil(image_tensor_i.to(torch.float32)))
|
131 |
+
|
132 |
+
embeddings = []
|
133 |
+
embeddings_siglip_list = self.image_encoder_siglip(pixel_values_global, device)
|
134 |
+
embeddings_dino_list = self.image_encoder_dino(pixel_values_global, device)
|
135 |
+
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
|
136 |
+
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
|
137 |
+
embeddings.append(embeddings_i)
|
138 |
+
|
139 |
+
embeddings_local_siglip_deep = self.image_encoder_siglip(pixel_values_local, device)[-1]
|
140 |
+
embeddings_local_dino_deep = self.image_encoder_dino(pixel_values_local, device)[-1]
|
141 |
+
embeddings_local_deep = torch.cat([embeddings_local_siglip_deep, embeddings_local_dino_deep], dim=-1)
|
142 |
+
|
143 |
+
embeddings_local_deep = rearrange(embeddings_local_deep, "(b hp wp) l c -> b (l hp wp) c", hp=self.patchfy_scale, wp=self.patchfy_scale)
|
144 |
+
|
145 |
+
embeddings.append(embeddings_local_deep)
|
146 |
+
|
147 |
+
else:
|
148 |
+
# tensor 转 PIL
|
149 |
+
pixel_values = []
|
150 |
+
for image_tensor_i in image_tensor:
|
151 |
+
pixel_values.append(self.to_pil(image_tensor_i))
|
152 |
+
|
153 |
+
embeddings = []
|
154 |
+
embeddings_siglip_list = self.image_encoder_siglip(pixel_values, device)
|
155 |
+
embeddings_dino_list = self.image_encoder_dino(pixel_values, device)
|
156 |
+
for embeddings_siglip_i, embeddings_dino_i in zip(embeddings_siglip_list, embeddings_dino_list):
|
157 |
+
embeddings_i = torch.cat([embeddings_siglip_i, embeddings_dino_i], dim=-1) # channel concat
|
158 |
+
embeddings.append(embeddings_i)
|
159 |
+
|
160 |
+
if len(embeddings) == 1:
|
161 |
+
embeddings = embeddings[0]
|
162 |
+
return embeddings
|
models/projectors.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
class MinAttention(nn.Module):
|
20 |
+
def __init__(self, q_dim: int, kv_dim: int, dim_head=64, heads=8):
|
21 |
+
super().__init__()
|
22 |
+
self.dim_head = dim_head
|
23 |
+
self.heads = heads
|
24 |
+
inner_dim = dim_head * heads
|
25 |
+
|
26 |
+
self.norm1 = nn.LayerNorm(q_dim)
|
27 |
+
self.norm2 = nn.LayerNorm(kv_dim)
|
28 |
+
|
29 |
+
self.to_q = nn.Linear(q_dim, inner_dim, bias=False)
|
30 |
+
self.to_k = nn.Linear(kv_dim, inner_dim, bias=False)
|
31 |
+
self.to_v = nn.Linear(kv_dim, inner_dim, bias=False)
|
32 |
+
|
33 |
+
def forward(self, local_fea, global_fea):
|
34 |
+
global_fea = self.norm1(global_fea)
|
35 |
+
local_fea = self.norm2(local_fea)
|
36 |
+
b, l, _ = global_fea.shape
|
37 |
+
|
38 |
+
q = self.to_q(global_fea)
|
39 |
+
k = self.to_k(local_fea)
|
40 |
+
v = self.to_v(local_fea)
|
41 |
+
|
42 |
+
q = q.view(b, -1, self.heads, self.dim_head).transpose(1, 2)
|
43 |
+
k = k.view(b, -1, self.heads, self.dim_head).transpose(1, 2)
|
44 |
+
v = v.view(b, -1, self.heads, self.dim_head).transpose(1, 2)
|
45 |
+
hidden_states = F.scaled_dot_product_attention(
|
46 |
+
q,k,v, dropout_p=0.0, is_causal=False
|
47 |
+
)
|
48 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(b, -1, self.heads*self.dim_head)
|
49 |
+
hidden_states = hidden_states.to(q.dtype)
|
50 |
+
return hidden_states
|
51 |
+
|
52 |
+
class CustomParameter(nn.Module):
|
53 |
+
def __init__(self, init_value):
|
54 |
+
super().__init__()
|
55 |
+
self.init_value = init_value
|
56 |
+
self.value = nn.Parameter(torch.tensor(init_value))
|
57 |
+
|
58 |
+
def forward(self):
|
59 |
+
return self.value
|
60 |
+
|
61 |
+
|
62 |
+
class ProjectorHighResMinAttn(nn.Module):
|
63 |
+
def __init__(self, vision_dim, out_dim, dim_head=64, adaptive_scale=False, scale_value=1.0, **kwargs):
|
64 |
+
super().__init__()
|
65 |
+
self.initial_projection_dim = vision_dim * 4
|
66 |
+
heads = vision_dim // dim_head
|
67 |
+
|
68 |
+
self.min_attention = MinAttention(q_dim=vision_dim, kv_dim=vision_dim, dim_head=dim_head, heads=heads)
|
69 |
+
self.projector = nn.Sequential(
|
70 |
+
nn.Linear(vision_dim, self.initial_projection_dim, bias=True),
|
71 |
+
nn.GELU(),
|
72 |
+
nn.Linear(self.initial_projection_dim, out_dim, bias=True),
|
73 |
+
nn.GELU(),
|
74 |
+
nn.Linear(out_dim, out_dim, bias=True),
|
75 |
+
nn.LayerNorm(out_dim)
|
76 |
+
)
|
77 |
+
self.projector_base = nn.Linear(vision_dim, out_dim, bias=True)
|
78 |
+
|
79 |
+
self.adaptive_scale = adaptive_scale
|
80 |
+
if self.adaptive_scale:
|
81 |
+
self.scale_value = CustomParameter(scale_value)
|
82 |
+
|
83 |
+
def forward(self, vision_input_dict, time_emb=None, **kwargs):
|
84 |
+
"""
|
85 |
+
vision_input_dict: here, this is not a dict, just for the unity of naming
|
86 |
+
"""
|
87 |
+
img_patch_features = vision_input_dict
|
88 |
+
deep_features, deep_features_local = img_patch_features
|
89 |
+
|
90 |
+
fused_img_features = self.min_attention(deep_features_local, deep_features)
|
91 |
+
fused_img_features = self.projector(fused_img_features)
|
92 |
+
|
93 |
+
deep_img_features = self.projector_base(deep_features)
|
94 |
+
|
95 |
+
if self.adaptive_scale:
|
96 |
+
output = deep_img_features + fused_img_features * self.scale_value()
|
97 |
+
else:
|
98 |
+
output = deep_img_features + fused_img_features
|
99 |
+
return output
|
100 |
+
|
101 |
+
|
102 |
+
class ProjectorHighResShallowMinAttnV1(nn.Module):
|
103 |
+
def __init__(self, vision_dim, out_dim, dim_head=64, **kwargs):
|
104 |
+
super().__init__()
|
105 |
+
self.initial_projection_dim = vision_dim * 4
|
106 |
+
heads = vision_dim // dim_head
|
107 |
+
|
108 |
+
self.min_attention = MinAttention(q_dim=vision_dim, kv_dim=vision_dim, dim_head=dim_head, heads=heads)
|
109 |
+
self.projector = nn.Sequential(
|
110 |
+
nn.Linear(vision_dim, self.initial_projection_dim, bias=True),
|
111 |
+
nn.GELU(),
|
112 |
+
nn.Linear(self.initial_projection_dim, out_dim, bias=True),
|
113 |
+
nn.GELU(),
|
114 |
+
nn.Linear(out_dim, out_dim, bias=True),
|
115 |
+
nn.LayerNorm(out_dim)
|
116 |
+
)
|
117 |
+
self.projector_base = nn.Linear(vision_dim, out_dim, bias=True)
|
118 |
+
|
119 |
+
self.min_attention2 = MinAttention(q_dim=vision_dim, kv_dim=vision_dim, dim_head=dim_head, heads=heads)
|
120 |
+
self.projector2 = nn.Sequential(
|
121 |
+
nn.Linear(vision_dim, self.initial_projection_dim, bias=True),
|
122 |
+
nn.GELU(),
|
123 |
+
nn.Linear(self.initial_projection_dim, out_dim, bias=True),
|
124 |
+
nn.GELU(),
|
125 |
+
nn.Linear(out_dim, out_dim, bias=True),
|
126 |
+
nn.LayerNorm(out_dim)
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, vision_input_dict, time_emb=None, **kwargs):
|
130 |
+
"""
|
131 |
+
vision_input_dict: here, this is not a dict, just for the unity of naming
|
132 |
+
"""
|
133 |
+
img_patch_features = vision_input_dict
|
134 |
+
shallow_features1, shallow_features2, shallow_features3, deep_features, deep_features_local = img_patch_features
|
135 |
+
shallow_features = torch.cat([shallow_features1, shallow_features2, shallow_features3], dim=1) # token concat
|
136 |
+
|
137 |
+
# original code
|
138 |
+
fused_img_features = self.min_attention(deep_features_local, deep_features)
|
139 |
+
fused_img_features = self.projector(fused_img_features)
|
140 |
+
|
141 |
+
deep_img_features = self.projector_base(deep_features)
|
142 |
+
|
143 |
+
output = deep_img_features + fused_img_features
|
144 |
+
|
145 |
+
# new code part
|
146 |
+
fused_img_features2 = self.min_attention2(shallow_features, deep_features)
|
147 |
+
fused_img_features2 = self.projector2(fused_img_features2)
|
148 |
+
|
149 |
+
output = torch.cat([deep_img_features, fused_img_features2], dim=1)
|
150 |
+
return output
|
models/sigclip.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
dinosiglip_vit.py
|
17 |
+
|
18 |
+
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
|
19 |
+
"""
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from functools import partial
|
22 |
+
from typing import Callable, Dict, Tuple
|
23 |
+
import os
|
24 |
+
import timm
|
25 |
+
import torch
|
26 |
+
from PIL import Image
|
27 |
+
from timm.models.vision_transformer import Block, VisionTransformer
|
28 |
+
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
29 |
+
from torchvision.transforms import Compose, Resize
|
30 |
+
|
31 |
+
from models.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple, return_tuple
|
32 |
+
|
33 |
+
import torchvision
|
34 |
+
import torch.nn as nn
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class DinoSigLIPImageTransform:
|
38 |
+
dino_image_transform: ImageTransform
|
39 |
+
siglip_image_transform: ImageTransform
|
40 |
+
is_cobra: bool = True
|
41 |
+
|
42 |
+
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
|
43 |
+
return {"dino": self.dino_image_transform(img, **kwargs).unsqueeze(0), "siglip": self.siglip_image_transform(img, **kwargs).unsqueeze(0)}
|
44 |
+
|
45 |
+
|
46 |
+
class SigLIPViTBackbone(VisionBackbone):
|
47 |
+
def __init__(self, backbone_name_or_path: str, image_resize_strategy: str, default_image_size: int = 224, last_n = 2, feature_index = 25) -> None:
|
48 |
+
super().__init__(backbone_name_or_path, image_resize_strategy, default_image_size=default_image_size)
|
49 |
+
# load from local paths
|
50 |
+
sigclip_pretrained_cfg = timm.models.create_model(backbone_name_or_path).default_cfg
|
51 |
+
sigclip_pretrained_cfg['file'] = 'ckpts/vit_so400m_patch14_siglip_384/open_clip_pytorch_model.bin'
|
52 |
+
|
53 |
+
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
|
54 |
+
self.siglip_featurizer: VisionTransformer = timm.create_model(
|
55 |
+
backbone_name_or_path, pretrained=True, num_classes=0, img_size=self.default_image_size,
|
56 |
+
pretrained_cfg=sigclip_pretrained_cfg
|
57 |
+
)
|
58 |
+
self.siglip_featurizer.eval()
|
59 |
+
|
60 |
+
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
|
61 |
+
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
62 |
+
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
63 |
+
# return the output tokens from the `n` last blocks
|
64 |
+
print("siglip has {} layer intermediate features. ".format(len(self.siglip_featurizer.blocks))) # 27
|
65 |
+
# self.siglip_featurizer.forward = unpack_tuple(
|
66 |
+
# partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - last_n})
|
67 |
+
# )
|
68 |
+
if isinstance(feature_index, tuple) or isinstance(feature_index, list):
|
69 |
+
feature_index = set(feature_index)
|
70 |
+
else:
|
71 |
+
feature_index = {feature_index}
|
72 |
+
self.siglip_featurizer.forward = return_tuple(
|
73 |
+
partial(self.siglip_featurizer.get_intermediate_layers, n=feature_index)
|
74 |
+
)
|
75 |
+
|
76 |
+
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
|
77 |
+
|
78 |
+
self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer)
|
79 |
+
self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
|
80 |
+
|
81 |
+
# Initialize *both* Transforms
|
82 |
+
default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False)
|
83 |
+
|
84 |
+
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
|
85 |
+
assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!"
|
86 |
+
assert isinstance(sl_resize_transform := default_siglip_transform.transforms[0], Resize)
|
87 |
+
default_siglip_transform = Compose(
|
88 |
+
[
|
89 |
+
Resize(self.default_image_size, interpolation=sl_resize_transform.interpolation),
|
90 |
+
*default_siglip_transform.transforms[1:],
|
91 |
+
]
|
92 |
+
)
|
93 |
+
|
94 |
+
if self.image_resize_strategy == "resize-naive":
|
95 |
+
assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!"
|
96 |
+
assert isinstance(siglip_resize_transform := default_siglip_transform.transforms[0], Resize)
|
97 |
+
|
98 |
+
target_size = (self.default_image_size, self.default_image_size)
|
99 |
+
siglip_transform = Compose(
|
100 |
+
[
|
101 |
+
Resize(target_size, interpolation=siglip_resize_transform.interpolation),
|
102 |
+
*default_siglip_transform.transforms[1:],
|
103 |
+
]
|
104 |
+
)
|
105 |
+
|
106 |
+
self.siglip_transform = siglip_transform
|
107 |
+
else:
|
108 |
+
raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
|
109 |
+
|
110 |
+
def get_fsdp_wrapping_policy(self) -> Callable:
|
111 |
+
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
|
112 |
+
vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
|
113 |
+
transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
|
114 |
+
return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
|
115 |
+
|
116 |
+
def forward(self, pixel_values, device="cpu") -> torch.Tensor:
|
117 |
+
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
|
118 |
+
# b, c , h , w : 0-1
|
119 |
+
t_tensors = []
|
120 |
+
for pixel_value in pixel_values:
|
121 |
+
t_tensors.append(self.siglip_transform(pixel_value).unsqueeze(0))
|
122 |
+
t_tensors = torch.cat(t_tensors, dim=0).to(device)
|
123 |
+
|
124 |
+
t_tensors_list = self.siglip_featurizer(t_tensors)
|
125 |
+
return t_tensors_list
|
126 |
+
|
127 |
+
@property
|
128 |
+
def default_image_resolution(self) -> Tuple[int, int, int]:
|
129 |
+
return self.dino_data_cfg["input_size"]
|
130 |
+
|
131 |
+
@property
|
132 |
+
def embed_dim(self) -> int:
|
133 |
+
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
|
134 |
+
|
135 |
+
@property
|
136 |
+
def num_patches(self) -> int:
|
137 |
+
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
|
138 |
+
return self.dino_featurizer.patch_embed.num_patches
|
139 |
+
|
140 |
+
@property
|
141 |
+
def half_precision_dtype(self) -> torch.dtype:
|
142 |
+
return torch.bfloat16
|
143 |
+
|
144 |
+
|
145 |
+
class SigLIPEncoder(nn.Module):
|
146 |
+
def __init__(self, backbone_name_or_path: str, image_resize_strategy: str, default_image_size: int = 224, feature_index = 25):
|
147 |
+
super().__init__()
|
148 |
+
self.image_encoder = SigLIPViTBackbone(backbone_name_or_path, image_resize_strategy, default_image_size, feature_index)
|
149 |
+
self.to_pil = torchvision.transforms.ToPILImage()
|
150 |
+
|
151 |
+
def forward(self, image_tensor, device="cpu"): # input image size = 768
|
152 |
+
pixel_values = []
|
153 |
+
for image_tensor_i in image_tensor:
|
154 |
+
pixel_values.append(self.to_pil(image_tensor_i))
|
155 |
+
|
156 |
+
embeddings_dino_list = self.image_encoder(pixel_values, device)
|
157 |
+
if len(embeddings_dino_list) == 1:
|
158 |
+
embeddings_dino_list = embeddings_dino_list[0]
|
159 |
+
return embeddings_dino_list
|
models/text.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from torch import nn
|
18 |
+
from transformers import AutoTokenizer, CLIPTokenizerFast, CLIPTextModel, T5EncoderModel
|
19 |
+
from typing import List
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class TextModelOutput:
|
23 |
+
embeddings: torch.Tensor
|
24 |
+
masks: torch.Tensor
|
25 |
+
pooled: List
|
26 |
+
|
27 |
+
|
28 |
+
class TextModel(nn.Module):
|
29 |
+
available_modes = [
|
30 |
+
"last", # If present, use last layer.
|
31 |
+
"penultimate", # If present, use penultimate layer.
|
32 |
+
"penultimate_nonorm", # If present, use penultimate layer without final norm.
|
33 |
+
"token_cat", # If present, concat in token dimension, default concat in channel dimension.
|
34 |
+
"pad0", # If present, use 0 padding, default use EOT padding.
|
35 |
+
"masked", # If present, pass attention mask to encoder.
|
36 |
+
]
|
37 |
+
|
38 |
+
def __init__(self, variant: List[str], mode: List[str]):
|
39 |
+
super().__init__()
|
40 |
+
self.mode = set(mode)
|
41 |
+
self.tokenizers = []
|
42 |
+
self.models = nn.ModuleList([])
|
43 |
+
|
44 |
+
for v in variant:
|
45 |
+
if "clip" in v.lower():
|
46 |
+
self.tokenizers.append(CLIPTokenizerFast.from_pretrained(v, model_max_length=77))
|
47 |
+
self.models.append(CLIPTextModel.from_pretrained(v))
|
48 |
+
elif "t5" in v.lower() or "ul2" in v.lower():
|
49 |
+
self.tokenizers.append(AutoTokenizer.from_pretrained(v, model_max_length=77))
|
50 |
+
self.models.append(T5EncoderModel.from_pretrained(v, torch_dtype=torch.bfloat16))
|
51 |
+
else:
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
def get_vaild_token_length(self, text): # Return the length of the BPE encoding of the text, excluding `<sos>` and `<eos>`.
|
55 |
+
lengths = []
|
56 |
+
for tokenizer, model in zip(self.tokenizers, self.models):
|
57 |
+
|
58 |
+
tokens = tokenizer(
|
59 |
+
text=text,
|
60 |
+
truncation=True,
|
61 |
+
padding="max_length",
|
62 |
+
return_tensors="pt"
|
63 |
+
).to(model.device)
|
64 |
+
token_length = tokens["attention_mask"].sum() - 2 # In the attention mask, both the SOS and EOS (first PAD) have a value of 1.
|
65 |
+
lengths.append(token_length.item())
|
66 |
+
length = int(sum(lengths) / len(lengths))
|
67 |
+
return length
|
68 |
+
|
69 |
+
def forward(self, text: List[str]) -> TextModelOutput:
|
70 |
+
embeddings = []
|
71 |
+
masks = []
|
72 |
+
pooled = []
|
73 |
+
|
74 |
+
for tokenizer, model in zip(self.tokenizers, self.models):
|
75 |
+
|
76 |
+
tokens = tokenizer(
|
77 |
+
text=text,
|
78 |
+
truncation=True,
|
79 |
+
padding="max_length",
|
80 |
+
return_tensors="pt"
|
81 |
+
).to(model.device)
|
82 |
+
|
83 |
+
if "pad0" in self.mode:
|
84 |
+
tokens.input_ids *= tokens.attention_mask
|
85 |
+
|
86 |
+
output = model(
|
87 |
+
input_ids=tokens.input_ids,
|
88 |
+
attention_mask=tokens.attention_mask if "masked" in self.mode else None,
|
89 |
+
output_hidden_states=True
|
90 |
+
)
|
91 |
+
|
92 |
+
if "last" in self.mode:
|
93 |
+
embeddings.append(output.last_hidden_state)
|
94 |
+
if "penultimate" in self.mode:
|
95 |
+
embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
|
96 |
+
if "penultimate_nonorm" in self.mode:
|
97 |
+
embeddings.append(output.hidden_states[-2])
|
98 |
+
masks.append(tokens.attention_mask)
|
99 |
+
if hasattr(output, "pooler_output"):
|
100 |
+
pooled.append(output.pooler_output)
|
101 |
+
|
102 |
+
if "token_cat" in self.mode:
|
103 |
+
return TextModelOutput(
|
104 |
+
embeddings=torch.cat(embeddings, dim=1),
|
105 |
+
masks=torch.cat(masks, dim=1),
|
106 |
+
pooled=pooled
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
return TextModelOutput(
|
110 |
+
embeddings=torch.cat(embeddings, dim=2),
|
111 |
+
masks=torch.stack(masks, dim=2).sum(2).clamp_max(1),
|
112 |
+
pooled=pooled
|
113 |
+
)
|
models/transformer_2d_custom.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
25 |
+
from diffusers.utils import BaseOutput, deprecate
|
26 |
+
# from diffusers.models.attention import BasicTransformerBlock
|
27 |
+
from models.attention_custom import BasicTransformerBlock
|
28 |
+
from diffusers.models.embeddings import PatchEmbed
|
29 |
+
from diffusers.models.modeling_utils import ModelMixin
|
30 |
+
|
31 |
+
from utils import update_dict
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class Transformer2DModelOutput(BaseOutput):
|
35 |
+
"""
|
36 |
+
The output of [`Transformer2DModel`].
|
37 |
+
|
38 |
+
Args:
|
39 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
40 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
41 |
+
distributions for the unnoised latent pixels.
|
42 |
+
"""
|
43 |
+
|
44 |
+
sample: torch.FloatTensor
|
45 |
+
|
46 |
+
# Transformer2DModel
|
47 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
48 |
+
"""
|
49 |
+
A 2D Transformer model for image-like data.
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
53 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
54 |
+
in_channels (`int`, *optional*):
|
55 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
56 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
57 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
58 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
59 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
60 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
61 |
+
num_vector_embeds (`int`, *optional*):
|
62 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
63 |
+
Includes the class for the masked latent pixel.
|
64 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
65 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
66 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
67 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
68 |
+
added to the hidden states.
|
69 |
+
|
70 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
71 |
+
attention_bias (`bool`, *optional*):
|
72 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
73 |
+
"""
|
74 |
+
|
75 |
+
@register_to_config
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
num_attention_heads: int = 16,
|
79 |
+
attention_head_dim: int = 88,
|
80 |
+
in_channels: Optional[int] = None,
|
81 |
+
out_channels: Optional[int] = None,
|
82 |
+
num_layers: int = 1,
|
83 |
+
dropout: float = 0.0,
|
84 |
+
norm_num_groups: int = 32,
|
85 |
+
cross_attention_dim: Optional[int] = None,
|
86 |
+
attention_bias: bool = False,
|
87 |
+
sample_size: Optional[int] = None,
|
88 |
+
num_vector_embeds: Optional[int] = None,
|
89 |
+
patch_size: Optional[int] = None,
|
90 |
+
activation_fn: str = "geglu",
|
91 |
+
num_embeds_ada_norm: Optional[int] = None,
|
92 |
+
use_linear_projection: bool = False,
|
93 |
+
only_cross_attention: bool = False,
|
94 |
+
upcast_attention: bool = False,
|
95 |
+
norm_type: str = "layer_norm",
|
96 |
+
norm_elementwise_affine: bool = True,
|
97 |
+
image_prompt_settings = {},
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
self.use_linear_projection = use_linear_projection
|
101 |
+
self.num_attention_heads = num_attention_heads
|
102 |
+
self.attention_head_dim = attention_head_dim
|
103 |
+
inner_dim = num_attention_heads * attention_head_dim
|
104 |
+
|
105 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
106 |
+
# Define whether input is continuous or discrete depending on configuration
|
107 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
108 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
109 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
110 |
+
|
111 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
112 |
+
deprecation_message = (
|
113 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
114 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
115 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
116 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
117 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
118 |
+
)
|
119 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
120 |
+
norm_type = "ada_norm"
|
121 |
+
|
122 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
123 |
+
raise ValueError(
|
124 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
125 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
126 |
+
)
|
127 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
128 |
+
raise ValueError(
|
129 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
130 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
131 |
+
)
|
132 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
133 |
+
raise ValueError(
|
134 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
135 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
136 |
+
)
|
137 |
+
|
138 |
+
# 2. Define input layers
|
139 |
+
if self.is_input_continuous:
|
140 |
+
self.in_channels = in_channels
|
141 |
+
|
142 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
143 |
+
if use_linear_projection:
|
144 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
145 |
+
else:
|
146 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
147 |
+
elif self.is_input_vectorized:
|
148 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
149 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
150 |
+
|
151 |
+
self.height = sample_size
|
152 |
+
self.width = sample_size
|
153 |
+
self.num_vector_embeds = num_vector_embeds
|
154 |
+
self.num_latent_pixels = self.height * self.width
|
155 |
+
|
156 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
157 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
158 |
+
)
|
159 |
+
elif self.is_input_patches:
|
160 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
161 |
+
|
162 |
+
self.height = sample_size
|
163 |
+
self.width = sample_size
|
164 |
+
|
165 |
+
self.patch_size = patch_size
|
166 |
+
self.pos_embed = PatchEmbed(
|
167 |
+
height=sample_size,
|
168 |
+
width=sample_size,
|
169 |
+
patch_size=patch_size,
|
170 |
+
in_channels=in_channels,
|
171 |
+
embed_dim=inner_dim,
|
172 |
+
)
|
173 |
+
|
174 |
+
# 3. Define transformers blocks, NOTE: we change the format
|
175 |
+
self.transformer_blocks = []
|
176 |
+
for d in range(num_layers):
|
177 |
+
self.transformer_blocks.append(
|
178 |
+
BasicTransformerBlock(
|
179 |
+
inner_dim,
|
180 |
+
num_attention_heads,
|
181 |
+
attention_head_dim,
|
182 |
+
dropout=dropout,
|
183 |
+
cross_attention_dim=cross_attention_dim,
|
184 |
+
activation_fn=activation_fn,
|
185 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
186 |
+
attention_bias=attention_bias,
|
187 |
+
only_cross_attention=only_cross_attention,
|
188 |
+
upcast_attention=upcast_attention,
|
189 |
+
norm_type=norm_type,
|
190 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
191 |
+
image_prompt_settings=image_prompt_settings,
|
192 |
+
)
|
193 |
+
)
|
194 |
+
image_prompt_settings["cross_attention_id"] += 1
|
195 |
+
self.transformer_blocks = nn.ModuleList(self.transformer_blocks)
|
196 |
+
|
197 |
+
# self.transformer_blocks = nn.ModuleList(
|
198 |
+
# [
|
199 |
+
# BasicTransformerBlock(
|
200 |
+
# inner_dim,
|
201 |
+
# num_attention_heads,
|
202 |
+
# attention_head_dim,
|
203 |
+
# dropout=dropout,
|
204 |
+
# cross_attention_dim=cross_attention_dim,
|
205 |
+
# activation_fn=activation_fn,
|
206 |
+
# num_embeds_ada_norm=num_embeds_ada_norm,
|
207 |
+
# attention_bias=attention_bias,
|
208 |
+
# only_cross_attention=only_cross_attention,
|
209 |
+
# upcast_attention=upcast_attention,
|
210 |
+
# norm_type=norm_type,
|
211 |
+
# norm_elementwise_affine=norm_elementwise_affine,
|
212 |
+
# image_prompt_settings=image_prompt_settings,
|
213 |
+
# )
|
214 |
+
# for d in range(num_layers)
|
215 |
+
# ]
|
216 |
+
# )
|
217 |
+
|
218 |
+
# 4. Define output layers
|
219 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
220 |
+
if self.is_input_continuous:
|
221 |
+
# TODO: should use out_channels for continuous projections
|
222 |
+
if use_linear_projection:
|
223 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
224 |
+
else:
|
225 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
226 |
+
elif self.is_input_vectorized:
|
227 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
228 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
229 |
+
elif self.is_input_patches:
|
230 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
231 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
232 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
233 |
+
|
234 |
+
def forward(
|
235 |
+
self,
|
236 |
+
hidden_states: torch.Tensor,
|
237 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
238 |
+
timestep: Optional[torch.LongTensor] = None,
|
239 |
+
class_labels: Optional[torch.LongTensor] = None,
|
240 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
241 |
+
attention_mask: Optional[torch.Tensor] = None,
|
242 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
243 |
+
return_dict: bool = True,
|
244 |
+
encoder_hidden_states_vision = None,
|
245 |
+
encoder_hidden_states_control = None,
|
246 |
+
vision_guided_mask = None,
|
247 |
+
extra_dict_inputs = {},
|
248 |
+
return_self_attn_map = False,
|
249 |
+
):
|
250 |
+
"""
|
251 |
+
The [`Transformer2DModel`] forward method.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
255 |
+
Input `hidden_states`.
|
256 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
257 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
258 |
+
self-attention.
|
259 |
+
timestep ( `torch.LongTensor`, *optional*):
|
260 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
261 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
262 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
263 |
+
`AdaLayerZeroNorm`.
|
264 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
265 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
266 |
+
|
267 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
268 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
269 |
+
|
270 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
271 |
+
above. This bias will be added to the cross-attention scores.
|
272 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
273 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
274 |
+
tuple.
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
278 |
+
`tuple` where the first element is the sample tensor.
|
279 |
+
"""
|
280 |
+
# <notice>
|
281 |
+
extra_dict_outputs = {}
|
282 |
+
height, width = hidden_states.size(-2), hidden_states.size(-1)
|
283 |
+
|
284 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
285 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
286 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
287 |
+
# expects mask of shape:
|
288 |
+
# [batch, key_tokens]
|
289 |
+
# adds singleton query_tokens dimension:
|
290 |
+
# [batch, 1, key_tokens]
|
291 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
292 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
293 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
294 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
295 |
+
# assume that mask is expressed as:
|
296 |
+
# (1 = keep, 0 = discard)
|
297 |
+
# convert mask into a bias that can be added to attention scores:
|
298 |
+
# (keep = +0, discard = -10000.0)
|
299 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
300 |
+
attention_mask = attention_mask.unsqueeze(1)
|
301 |
+
|
302 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
303 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
304 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
305 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
306 |
+
|
307 |
+
# 1. Input
|
308 |
+
if self.is_input_continuous:
|
309 |
+
batch, _, height, width = hidden_states.shape
|
310 |
+
residual = hidden_states
|
311 |
+
|
312 |
+
hidden_states = self.norm(hidden_states)
|
313 |
+
if not self.use_linear_projection:
|
314 |
+
hidden_states = self.proj_in(hidden_states)
|
315 |
+
inner_dim = hidden_states.shape[1]
|
316 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
317 |
+
else:
|
318 |
+
inner_dim = hidden_states.shape[1]
|
319 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
320 |
+
hidden_states = self.proj_in(hidden_states)
|
321 |
+
elif self.is_input_vectorized:
|
322 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
323 |
+
elif self.is_input_patches:
|
324 |
+
hidden_states = self.pos_embed(hidden_states)
|
325 |
+
|
326 |
+
# 2. Blocks
|
327 |
+
for block in self.transformer_blocks:
|
328 |
+
hidden_states, extra_dict_output_transformer = block(
|
329 |
+
hidden_states,
|
330 |
+
attention_mask=attention_mask,
|
331 |
+
encoder_hidden_states=encoder_hidden_states,
|
332 |
+
encoder_attention_mask=encoder_attention_mask,
|
333 |
+
timestep=timestep,
|
334 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
335 |
+
class_labels=class_labels,
|
336 |
+
encoder_hidden_states_vision=encoder_hidden_states_vision,
|
337 |
+
encoder_hidden_states_control=encoder_hidden_states_control,
|
338 |
+
vision_guided_mask=vision_guided_mask,
|
339 |
+
extra_dict_inputs=extra_dict_inputs,
|
340 |
+
height=height,
|
341 |
+
width=width,
|
342 |
+
return_self_attn_map=return_self_attn_map
|
343 |
+
)
|
344 |
+
extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_transformer)
|
345 |
+
|
346 |
+
# 3. Output
|
347 |
+
if self.is_input_continuous:
|
348 |
+
if not self.use_linear_projection:
|
349 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
350 |
+
hidden_states = self.proj_out(hidden_states)
|
351 |
+
else:
|
352 |
+
hidden_states = self.proj_out(hidden_states)
|
353 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
354 |
+
|
355 |
+
output = hidden_states + residual
|
356 |
+
elif self.is_input_vectorized:
|
357 |
+
hidden_states = self.norm_out(hidden_states)
|
358 |
+
logits = self.out(hidden_states)
|
359 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
360 |
+
logits = logits.permute(0, 2, 1)
|
361 |
+
|
362 |
+
# log(p(x_0))
|
363 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
364 |
+
elif self.is_input_patches:
|
365 |
+
# TODO: cleanup!
|
366 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
367 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
368 |
+
)
|
369 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
370 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
371 |
+
hidden_states = self.proj_out_2(hidden_states)
|
372 |
+
|
373 |
+
# unpatchify
|
374 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
375 |
+
hidden_states = hidden_states.reshape(
|
376 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
377 |
+
)
|
378 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
379 |
+
output = hidden_states.reshape(
|
380 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
381 |
+
)
|
382 |
+
|
383 |
+
if not return_dict: # return_dict=False
|
384 |
+
return output, extra_dict_outputs
|
385 |
+
# return (output,)
|
386 |
+
|
387 |
+
# return Transformer2DModelOutput(sample=output)
|
388 |
+
return Transformer2DModelOutput(sample=output)[0], extra_dict_outputs
|
models/unet_2d_blocks_custom.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/unet_2d_condition_custom.py
ADDED
@@ -0,0 +1,1059 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.utils.checkpoint
|
22 |
+
|
23 |
+
from einops import rearrange
|
24 |
+
|
25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
26 |
+
# from diffusers.models.unet_2d_blocks import (
|
27 |
+
from models.unet_2d_blocks_custom import (
|
28 |
+
CrossAttnDownBlock2D,
|
29 |
+
CrossAttnUpBlock2D,
|
30 |
+
DownBlock2D,
|
31 |
+
UNetMidBlock2DCrossAttn,
|
32 |
+
UNetMidBlock2DSimpleCrossAttn,
|
33 |
+
UpBlock2D,
|
34 |
+
get_down_block,
|
35 |
+
get_up_block,
|
36 |
+
)
|
37 |
+
|
38 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
39 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
40 |
+
from diffusers.utils import BaseOutput, logging
|
41 |
+
from diffusers.models.activations import get_activation
|
42 |
+
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
43 |
+
from diffusers.models.embeddings import (
|
44 |
+
GaussianFourierProjection,
|
45 |
+
ImageHintTimeEmbedding,
|
46 |
+
ImageProjection,
|
47 |
+
ImageTimeEmbedding,
|
48 |
+
TextImageProjection,
|
49 |
+
TextImageTimeEmbedding,
|
50 |
+
TextTimeEmbedding,
|
51 |
+
TimestepEmbedding,
|
52 |
+
Timesteps,
|
53 |
+
)
|
54 |
+
|
55 |
+
from utils import update_dict
|
56 |
+
|
57 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class UNet2DConditionOutput(BaseOutput):
|
62 |
+
"""
|
63 |
+
The output of [`UNet2DConditionModel`].
|
64 |
+
|
65 |
+
Args:
|
66 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
67 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
68 |
+
"""
|
69 |
+
|
70 |
+
sample: torch.FloatTensor = None
|
71 |
+
|
72 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
73 |
+
r"""
|
74 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
75 |
+
shaped output.
|
76 |
+
|
77 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
78 |
+
for all models (such as downloading or saving).
|
79 |
+
|
80 |
+
Parameters:her
|
81 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
82 |
+
Height and width of input/output sample.
|
83 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
84 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
85 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
86 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
87 |
+
Whether to flip the sin to cos in the time embedding.
|
88 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
89 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
90 |
+
The tuple of downsample blocks to use.
|
91 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
92 |
+
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
|
93 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
94 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
95 |
+
The tuple of upsample blocks to use.
|
96 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
97 |
+
Whether to include self-attention in the basic transformer blocks, see
|
98 |
+
[`~models.attention.BasicTransformerBlock`].
|
99 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
100 |
+
The tuple of output channels for each block.
|
101 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
102 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
103 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
104 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
105 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
106 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
107 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
108 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
109 |
+
The dimension of the cross attention features.
|
110 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
111 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
112 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
113 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
114 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
115 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
116 |
+
dimension to `cross_attention_dim`.
|
117 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
118 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
119 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
120 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
121 |
+
num_attention_heads (`int`, *optional*):
|
122 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
123 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
124 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
125 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
126 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
127 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
128 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
129 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
130 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
131 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
132 |
+
Dimension for the timestep embeddings.
|
133 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
134 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
135 |
+
class conditioning with `class_embed_type` equal to `None`.
|
136 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
137 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
138 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
139 |
+
An optional override for the dimension of the projected time embedding.
|
140 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
141 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
142 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
143 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
144 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
145 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
146 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
147 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
148 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
149 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
150 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
151 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
152 |
+
embeddings with the class embeddings.
|
153 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
154 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
155 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
156 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
157 |
+
otherwise.
|
158 |
+
"""
|
159 |
+
|
160 |
+
_supports_gradient_checkpointing = True
|
161 |
+
|
162 |
+
@register_to_config
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
sample_size: Optional[int] = None,
|
166 |
+
in_channels: int = 4,
|
167 |
+
out_channels: int = 4,
|
168 |
+
center_input_sample: bool = False,
|
169 |
+
flip_sin_to_cos: bool = True,
|
170 |
+
freq_shift: int = 0,
|
171 |
+
down_block_types: Tuple[str] = (
|
172 |
+
"CrossAttnDownBlock2D",
|
173 |
+
"CrossAttnDownBlock2D",
|
174 |
+
"CrossAttnDownBlock2D",
|
175 |
+
"DownBlock2D",
|
176 |
+
),
|
177 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
178 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
179 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
180 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
181 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
182 |
+
downsample_padding: int = 1,
|
183 |
+
mid_block_scale_factor: float = 1,
|
184 |
+
act_fn: str = "silu",
|
185 |
+
norm_num_groups: Optional[int] = 32,
|
186 |
+
norm_eps: float = 1e-5,
|
187 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
188 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
189 |
+
encoder_hid_dim: Optional[int] = None,
|
190 |
+
encoder_hid_dim_type: Optional[str] = None,
|
191 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
192 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
193 |
+
dual_cross_attention: bool = False,
|
194 |
+
use_linear_projection: bool = False,
|
195 |
+
class_embed_type: Optional[str] = None,
|
196 |
+
addition_embed_type: Optional[str] = None,
|
197 |
+
addition_time_embed_dim: Optional[int] = None,
|
198 |
+
num_class_embeds: Optional[int] = None,
|
199 |
+
upcast_attention: bool = False,
|
200 |
+
resnet_time_scale_shift: str = "default",
|
201 |
+
resnet_skip_time_act: bool = False,
|
202 |
+
resnet_out_scale_factor: int = 1.0,
|
203 |
+
time_embedding_type: str = "positional",
|
204 |
+
time_embedding_dim: Optional[int] = None,
|
205 |
+
time_embedding_act_fn: Optional[str] = None,
|
206 |
+
timestep_post_act: Optional[str] = None,
|
207 |
+
time_cond_proj_dim: Optional[int] = None,
|
208 |
+
conv_in_kernel: int = 3,
|
209 |
+
conv_out_kernel: int = 3,
|
210 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
211 |
+
class_embeddings_concat: bool = False,
|
212 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
213 |
+
cross_attention_norm: Optional[str] = None,
|
214 |
+
addition_embed_type_num_heads=64,
|
215 |
+
|
216 |
+
image_prompt_settings = {"dualbranch_mode": "none"},
|
217 |
+
**ignore_kwargs,
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
|
221 |
+
self.sample_size = sample_size
|
222 |
+
|
223 |
+
if num_attention_heads is not None:
|
224 |
+
raise ValueError(
|
225 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
226 |
+
)
|
227 |
+
|
228 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
229 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
230 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
231 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
232 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
233 |
+
# which is why we correct for the naming here.
|
234 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
235 |
+
|
236 |
+
# Check inputs
|
237 |
+
if len(down_block_types) != len(up_block_types):
|
238 |
+
raise ValueError(
|
239 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
240 |
+
)
|
241 |
+
|
242 |
+
if len(block_out_channels) != len(down_block_types):
|
243 |
+
raise ValueError(
|
244 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
245 |
+
)
|
246 |
+
|
247 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
248 |
+
raise ValueError(
|
249 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
250 |
+
)
|
251 |
+
|
252 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
253 |
+
raise ValueError(
|
254 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
255 |
+
)
|
256 |
+
|
257 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
258 |
+
raise ValueError(
|
259 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
260 |
+
)
|
261 |
+
|
262 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
263 |
+
raise ValueError(
|
264 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
265 |
+
)
|
266 |
+
|
267 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
268 |
+
raise ValueError(
|
269 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
270 |
+
)
|
271 |
+
|
272 |
+
# input
|
273 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
274 |
+
self.conv_in = nn.Conv2d(
|
275 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
276 |
+
)
|
277 |
+
|
278 |
+
# time
|
279 |
+
if time_embedding_type == "fourier":
|
280 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
281 |
+
if time_embed_dim % 2 != 0:
|
282 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
283 |
+
self.time_proj = GaussianFourierProjection(
|
284 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
285 |
+
)
|
286 |
+
timestep_input_dim = time_embed_dim
|
287 |
+
elif time_embedding_type == "positional":
|
288 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
289 |
+
|
290 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
291 |
+
timestep_input_dim = block_out_channels[0]
|
292 |
+
else:
|
293 |
+
raise ValueError(
|
294 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
295 |
+
)
|
296 |
+
|
297 |
+
self.time_embedding = TimestepEmbedding(
|
298 |
+
timestep_input_dim,
|
299 |
+
time_embed_dim,
|
300 |
+
act_fn=act_fn,
|
301 |
+
post_act_fn=timestep_post_act,
|
302 |
+
cond_proj_dim=time_cond_proj_dim,
|
303 |
+
)
|
304 |
+
|
305 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
306 |
+
encoder_hid_dim_type = "text_proj"
|
307 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
308 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
309 |
+
|
310 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
311 |
+
raise ValueError(
|
312 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
313 |
+
)
|
314 |
+
|
315 |
+
if encoder_hid_dim_type == "text_proj":
|
316 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
317 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
318 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
319 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
320 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
321 |
+
self.encoder_hid_proj = TextImageProjection(
|
322 |
+
text_embed_dim=encoder_hid_dim,
|
323 |
+
image_embed_dim=cross_attention_dim,
|
324 |
+
cross_attention_dim=cross_attention_dim,
|
325 |
+
)
|
326 |
+
elif encoder_hid_dim_type == "image_proj":
|
327 |
+
# Kandinsky 2.2
|
328 |
+
self.encoder_hid_proj = ImageProjection(
|
329 |
+
image_embed_dim=encoder_hid_dim,
|
330 |
+
cross_attention_dim=cross_attention_dim,
|
331 |
+
)
|
332 |
+
elif encoder_hid_dim_type is not None:
|
333 |
+
raise ValueError(
|
334 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
self.encoder_hid_proj = None
|
338 |
+
|
339 |
+
# class embedding
|
340 |
+
if class_embed_type is None and num_class_embeds is not None:
|
341 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
342 |
+
elif class_embed_type == "timestep":
|
343 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
344 |
+
elif class_embed_type == "identity":
|
345 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
346 |
+
elif class_embed_type == "projection":
|
347 |
+
if projection_class_embeddings_input_dim is None:
|
348 |
+
raise ValueError(
|
349 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
350 |
+
)
|
351 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
352 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
353 |
+
# 2. it projects from an arbitrary input dimension.
|
354 |
+
#
|
355 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
356 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
357 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
358 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
359 |
+
elif class_embed_type == "simple_projection":
|
360 |
+
if projection_class_embeddings_input_dim is None:
|
361 |
+
raise ValueError(
|
362 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
363 |
+
)
|
364 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
365 |
+
else:
|
366 |
+
self.class_embedding = None
|
367 |
+
|
368 |
+
if addition_embed_type == "text":
|
369 |
+
if encoder_hid_dim is not None:
|
370 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
371 |
+
else:
|
372 |
+
text_time_embedding_from_dim = cross_attention_dim
|
373 |
+
|
374 |
+
self.add_embedding = TextTimeEmbedding(
|
375 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
376 |
+
)
|
377 |
+
elif addition_embed_type == "text_image":
|
378 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
379 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
380 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
381 |
+
self.add_embedding = TextImageTimeEmbedding(
|
382 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
383 |
+
)
|
384 |
+
elif addition_embed_type == "text_time":
|
385 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
386 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
387 |
+
elif addition_embed_type == "image":
|
388 |
+
# Kandinsky 2.2
|
389 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
390 |
+
elif addition_embed_type == "image_hint":
|
391 |
+
# Kandinsky 2.2 ControlNet
|
392 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
393 |
+
elif addition_embed_type is not None:
|
394 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
395 |
+
|
396 |
+
if time_embedding_act_fn is None:
|
397 |
+
self.time_embed_act = None
|
398 |
+
else:
|
399 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
400 |
+
|
401 |
+
self.down_blocks = nn.ModuleList([])
|
402 |
+
self.up_blocks = nn.ModuleList([])
|
403 |
+
|
404 |
+
if isinstance(only_cross_attention, bool):
|
405 |
+
if mid_block_only_cross_attention is None:
|
406 |
+
mid_block_only_cross_attention = only_cross_attention
|
407 |
+
|
408 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
409 |
+
|
410 |
+
if mid_block_only_cross_attention is None:
|
411 |
+
mid_block_only_cross_attention = False
|
412 |
+
|
413 |
+
if isinstance(num_attention_heads, int):
|
414 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
415 |
+
|
416 |
+
if isinstance(attention_head_dim, int):
|
417 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
418 |
+
|
419 |
+
if isinstance(cross_attention_dim, int):
|
420 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
421 |
+
|
422 |
+
if isinstance(layers_per_block, int):
|
423 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
424 |
+
|
425 |
+
if isinstance(transformer_layers_per_block, int):
|
426 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
427 |
+
|
428 |
+
if class_embeddings_concat:
|
429 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
430 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
431 |
+
# regular time embeddings
|
432 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
433 |
+
else:
|
434 |
+
blocks_time_embed_dim = time_embed_dim
|
435 |
+
|
436 |
+
# NOTE: we need to mark each cross attention id
|
437 |
+
image_prompt_settings["cross_attention_id"] = 0
|
438 |
+
# down
|
439 |
+
output_channel = block_out_channels[0]
|
440 |
+
# XL: ['DownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D']
|
441 |
+
for i, down_block_type in enumerate(down_block_types):
|
442 |
+
input_channel = output_channel
|
443 |
+
output_channel = block_out_channels[i]
|
444 |
+
is_final_block = i == len(block_out_channels) - 1
|
445 |
+
|
446 |
+
down_block = get_down_block(
|
447 |
+
down_block_type,
|
448 |
+
num_layers=layers_per_block[i],
|
449 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
450 |
+
in_channels=input_channel,
|
451 |
+
out_channels=output_channel,
|
452 |
+
temb_channels=blocks_time_embed_dim,
|
453 |
+
add_downsample=not is_final_block,
|
454 |
+
resnet_eps=norm_eps,
|
455 |
+
resnet_act_fn=act_fn,
|
456 |
+
resnet_groups=norm_num_groups,
|
457 |
+
cross_attention_dim=cross_attention_dim[i],
|
458 |
+
num_attention_heads=num_attention_heads[i],
|
459 |
+
downsample_padding=downsample_padding,
|
460 |
+
dual_cross_attention=dual_cross_attention,
|
461 |
+
use_linear_projection=use_linear_projection,
|
462 |
+
only_cross_attention=only_cross_attention[i],
|
463 |
+
upcast_attention=upcast_attention,
|
464 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
465 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
466 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
467 |
+
cross_attention_norm=cross_attention_norm,
|
468 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
469 |
+
image_prompt_settings=image_prompt_settings,
|
470 |
+
)
|
471 |
+
self.down_blocks.append(down_block)
|
472 |
+
|
473 |
+
# mid, XL: UNetMidBlock2DCrossAttn
|
474 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
475 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
476 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
477 |
+
in_channels=block_out_channels[-1],
|
478 |
+
temb_channels=blocks_time_embed_dim,
|
479 |
+
resnet_eps=norm_eps,
|
480 |
+
resnet_act_fn=act_fn,
|
481 |
+
output_scale_factor=mid_block_scale_factor,
|
482 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
483 |
+
cross_attention_dim=cross_attention_dim[-1],
|
484 |
+
num_attention_heads=num_attention_heads[-1],
|
485 |
+
resnet_groups=norm_num_groups,
|
486 |
+
dual_cross_attention=dual_cross_attention,
|
487 |
+
use_linear_projection=use_linear_projection,
|
488 |
+
upcast_attention=upcast_attention,
|
489 |
+
image_prompt_settings=image_prompt_settings,
|
490 |
+
)
|
491 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
492 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
493 |
+
in_channels=block_out_channels[-1],
|
494 |
+
temb_channels=blocks_time_embed_dim,
|
495 |
+
resnet_eps=norm_eps,
|
496 |
+
resnet_act_fn=act_fn,
|
497 |
+
output_scale_factor=mid_block_scale_factor,
|
498 |
+
cross_attention_dim=cross_attention_dim[-1],
|
499 |
+
attention_head_dim=attention_head_dim[-1],
|
500 |
+
resnet_groups=norm_num_groups,
|
501 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
502 |
+
skip_time_act=resnet_skip_time_act,
|
503 |
+
only_cross_attention=mid_block_only_cross_attention,
|
504 |
+
cross_attention_norm=cross_attention_norm,
|
505 |
+
)
|
506 |
+
elif mid_block_type is None:
|
507 |
+
self.mid_block = None
|
508 |
+
else:
|
509 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
510 |
+
|
511 |
+
# count how many layers upsample the images
|
512 |
+
self.num_upsamplers = 0
|
513 |
+
|
514 |
+
# up, XL: ['CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'UpBlock2D']
|
515 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
516 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
517 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
518 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
519 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
520 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
521 |
+
|
522 |
+
output_channel = reversed_block_out_channels[0]
|
523 |
+
for i, up_block_type in enumerate(up_block_types):
|
524 |
+
is_final_block = i == len(block_out_channels) - 1
|
525 |
+
|
526 |
+
prev_output_channel = output_channel
|
527 |
+
output_channel = reversed_block_out_channels[i]
|
528 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
529 |
+
|
530 |
+
# add upsample block for all BUT final layer
|
531 |
+
if not is_final_block:
|
532 |
+
add_upsample = True
|
533 |
+
self.num_upsamplers += 1
|
534 |
+
else:
|
535 |
+
add_upsample = False
|
536 |
+
|
537 |
+
up_block = get_up_block(
|
538 |
+
up_block_type,
|
539 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
540 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
541 |
+
in_channels=input_channel,
|
542 |
+
out_channels=output_channel,
|
543 |
+
prev_output_channel=prev_output_channel,
|
544 |
+
temb_channels=blocks_time_embed_dim,
|
545 |
+
add_upsample=add_upsample,
|
546 |
+
resnet_eps=norm_eps,
|
547 |
+
resnet_act_fn=act_fn,
|
548 |
+
resnet_groups=norm_num_groups,
|
549 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
550 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
551 |
+
dual_cross_attention=dual_cross_attention,
|
552 |
+
use_linear_projection=use_linear_projection,
|
553 |
+
only_cross_attention=only_cross_attention[i],
|
554 |
+
upcast_attention=upcast_attention,
|
555 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
556 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
557 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
558 |
+
cross_attention_norm=cross_attention_norm,
|
559 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
560 |
+
image_prompt_settings=image_prompt_settings
|
561 |
+
)
|
562 |
+
self.up_blocks.append(up_block)
|
563 |
+
prev_output_channel = output_channel
|
564 |
+
|
565 |
+
# out
|
566 |
+
if norm_num_groups is not None:
|
567 |
+
self.conv_norm_out = nn.GroupNorm(
|
568 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
569 |
+
)
|
570 |
+
|
571 |
+
self.conv_act = get_activation(act_fn)
|
572 |
+
|
573 |
+
else:
|
574 |
+
self.conv_norm_out = None
|
575 |
+
self.conv_act = None
|
576 |
+
|
577 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
578 |
+
self.conv_out = nn.Conv2d(
|
579 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
580 |
+
)
|
581 |
+
|
582 |
+
# NOTE: settings for IP-consistent generation
|
583 |
+
from utils import instantiate_from_config
|
584 |
+
self.vision_projection_type = image_prompt_settings.get("vision_projection_type", "none")
|
585 |
+
self.cross_attention_dim = cross_attention_dim[0]
|
586 |
+
if self.vision_projection_type != "none":
|
587 |
+
self.encoder_hidden_states_vision_projection = instantiate_from_config(image_prompt_settings["vision_projection_config"])
|
588 |
+
|
589 |
+
@property
|
590 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
591 |
+
r"""
|
592 |
+
Returns:
|
593 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
594 |
+
indexed by its weight name.
|
595 |
+
"""
|
596 |
+
# set recursively
|
597 |
+
processors = {}
|
598 |
+
|
599 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
600 |
+
if hasattr(module, "set_processor"):
|
601 |
+
processors[f"{name}.processor"] = module.processor
|
602 |
+
|
603 |
+
for sub_name, child in module.named_children():
|
604 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
605 |
+
|
606 |
+
return processors
|
607 |
+
|
608 |
+
for name, module in self.named_children():
|
609 |
+
fn_recursive_add_processors(name, module, processors)
|
610 |
+
|
611 |
+
return processors
|
612 |
+
|
613 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
614 |
+
r"""
|
615 |
+
Sets the attention processor to use to compute attention.
|
616 |
+
|
617 |
+
Parameters:
|
618 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
619 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
620 |
+
for **all** `Attention` layers.
|
621 |
+
|
622 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
623 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
624 |
+
|
625 |
+
"""
|
626 |
+
count = len(self.attn_processors.keys())
|
627 |
+
|
628 |
+
if isinstance(processor, dict) and len(processor) != count:
|
629 |
+
raise ValueError(
|
630 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
631 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
632 |
+
)
|
633 |
+
|
634 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
635 |
+
if hasattr(module, "set_processor"):
|
636 |
+
if not isinstance(processor, dict):
|
637 |
+
module.set_processor(processor)
|
638 |
+
else:
|
639 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
640 |
+
|
641 |
+
for sub_name, child in module.named_children():
|
642 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
643 |
+
|
644 |
+
for name, module in self.named_children():
|
645 |
+
fn_recursive_attn_processor(name, module, processor)
|
646 |
+
|
647 |
+
def set_default_attn_processor(self):
|
648 |
+
"""
|
649 |
+
Disables custom attention processors and sets the default attention implementation.
|
650 |
+
"""
|
651 |
+
self.set_attn_processor(AttnProcessor())
|
652 |
+
|
653 |
+
def set_attention_slice(self, slice_size):
|
654 |
+
r"""
|
655 |
+
Enable sliced attention computation.
|
656 |
+
|
657 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
658 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
659 |
+
|
660 |
+
Args:
|
661 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
662 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
663 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
664 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
665 |
+
must be a multiple of `slice_size`.
|
666 |
+
"""
|
667 |
+
sliceable_head_dims = []
|
668 |
+
|
669 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
670 |
+
if hasattr(module, "set_attention_slice"):
|
671 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
672 |
+
|
673 |
+
for child in module.children():
|
674 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
675 |
+
|
676 |
+
# retrieve number of attention layers
|
677 |
+
for module in self.children():
|
678 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
679 |
+
|
680 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
681 |
+
|
682 |
+
if slice_size == "auto":
|
683 |
+
# half the attention head size is usually a good trade-off between
|
684 |
+
# speed and memory
|
685 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
686 |
+
elif slice_size == "max":
|
687 |
+
# make smallest slice possible
|
688 |
+
slice_size = num_sliceable_layers * [1]
|
689 |
+
|
690 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
691 |
+
|
692 |
+
if len(slice_size) != len(sliceable_head_dims):
|
693 |
+
raise ValueError(
|
694 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
695 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
696 |
+
)
|
697 |
+
|
698 |
+
for i in range(len(slice_size)):
|
699 |
+
size = slice_size[i]
|
700 |
+
dim = sliceable_head_dims[i]
|
701 |
+
if size is not None and size > dim:
|
702 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
703 |
+
|
704 |
+
# Recursively walk through all the children.
|
705 |
+
# Any children which exposes the set_attention_slice method
|
706 |
+
# gets the message
|
707 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
708 |
+
if hasattr(module, "set_attention_slice"):
|
709 |
+
module.set_attention_slice(slice_size.pop())
|
710 |
+
|
711 |
+
for child in module.children():
|
712 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
713 |
+
|
714 |
+
reversed_slice_size = list(reversed(slice_size))
|
715 |
+
for module in self.children():
|
716 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
717 |
+
|
718 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
719 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
720 |
+
module.gradient_checkpointing = value
|
721 |
+
|
722 |
+
def forward(
|
723 |
+
self,
|
724 |
+
sample: torch.FloatTensor,
|
725 |
+
timestep: Union[torch.Tensor, float, int],
|
726 |
+
encoder_hidden_states: torch.Tensor,
|
727 |
+
class_labels: Optional[torch.Tensor] = None,
|
728 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
729 |
+
attention_mask: Optional[torch.Tensor] = None,
|
730 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
731 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
732 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
733 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
734 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
735 |
+
return_dict: bool = True,
|
736 |
+
|
737 |
+
vision_input_dict = None,
|
738 |
+
vision_guided_mask: Optional[torch.Tensor] = None,
|
739 |
+
return_text2image_mask: bool = False,
|
740 |
+
return_as_origin: bool = True,
|
741 |
+
return_self_attn_map = False,
|
742 |
+
multiple_reference_image = False,
|
743 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
744 |
+
r"""
|
745 |
+
The [`UNet2DConditionModel`] forward method.
|
746 |
+
|
747 |
+
Args:
|
748 |
+
sample (`torch.FloatTensor`):
|
749 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
750 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
751 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
752 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
753 |
+
encoder_attention_mask (`torch.Tensor`):
|
754 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
755 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
756 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
757 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
758 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
759 |
+
tuple.
|
760 |
+
cross_attention_kwargs (`dict`, *optional*):
|
761 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
762 |
+
added_cond_kwargs: (`dict`, *optional*):
|
763 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
764 |
+
are passed along to the UNet blocks.
|
765 |
+
|
766 |
+
Returns:
|
767 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
768 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
769 |
+
a `tuple` is returned where the first element is the sample tensor.
|
770 |
+
"""
|
771 |
+
extra_dict_outputs = {}
|
772 |
+
extra_dict_inputs = {}
|
773 |
+
extra_dict_inputs["multiple_reference_image"] = multiple_reference_image
|
774 |
+
|
775 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
776 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
777 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
778 |
+
# on the fly if necessary.
|
779 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
780 |
+
|
781 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
782 |
+
forward_upsample_size = False
|
783 |
+
upsample_size = None
|
784 |
+
|
785 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
786 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
787 |
+
forward_upsample_size = True
|
788 |
+
|
789 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
790 |
+
# expects mask of shape:
|
791 |
+
# [batch, key_tokens]
|
792 |
+
# adds singleton query_tokens dimension:
|
793 |
+
# [batch, 1, key_tokens]
|
794 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
795 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
796 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
797 |
+
if attention_mask is not None:
|
798 |
+
# assume that mask is expressed as:
|
799 |
+
# (1 = keep, 0 = discard)
|
800 |
+
# convert mask into a bias that can be added to attention scores:
|
801 |
+
# (keep = +0, discard = -10000.0)
|
802 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
803 |
+
attention_mask = attention_mask.unsqueeze(1)
|
804 |
+
|
805 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
806 |
+
if encoder_attention_mask is not None:
|
807 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
808 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
809 |
+
|
810 |
+
# 0. center input if necessary
|
811 |
+
if self.config.center_input_sample:
|
812 |
+
sample = 2 * sample - 1.0
|
813 |
+
|
814 |
+
# 1. time
|
815 |
+
timesteps = timestep
|
816 |
+
if not torch.is_tensor(timesteps):
|
817 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
818 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
819 |
+
is_mps = sample.device.type == "mps"
|
820 |
+
if isinstance(timestep, float):
|
821 |
+
dtype = torch.float32 if is_mps else torch.float64
|
822 |
+
else:
|
823 |
+
dtype = torch.int32 if is_mps else torch.int64
|
824 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
825 |
+
elif len(timesteps.shape) == 0:
|
826 |
+
timesteps = timesteps[None].to(sample.device)
|
827 |
+
|
828 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
829 |
+
timesteps = timesteps.expand(sample.shape[0])
|
830 |
+
|
831 |
+
t_emb = self.time_proj(timesteps)
|
832 |
+
|
833 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
834 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
835 |
+
# there might be better ways to encapsulate this.
|
836 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
837 |
+
|
838 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
839 |
+
aug_emb = None
|
840 |
+
|
841 |
+
if self.class_embedding is not None:
|
842 |
+
if class_labels is None:
|
843 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
844 |
+
|
845 |
+
if self.config.class_embed_type == "timestep":
|
846 |
+
class_labels = self.time_proj(class_labels)
|
847 |
+
|
848 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
849 |
+
# there might be better ways to encapsulate this.
|
850 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
851 |
+
|
852 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
853 |
+
|
854 |
+
if self.config.class_embeddings_concat:
|
855 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
856 |
+
else:
|
857 |
+
emb = emb + class_emb
|
858 |
+
|
859 |
+
if self.config.addition_embed_type == "text":
|
860 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
861 |
+
elif self.config.addition_embed_type == "text_image":
|
862 |
+
# Kandinsky 2.1 - style
|
863 |
+
if "image_embeds" not in added_cond_kwargs:
|
864 |
+
raise ValueError(
|
865 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
866 |
+
)
|
867 |
+
|
868 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
869 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
870 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
871 |
+
elif self.config.addition_embed_type == "text_time":
|
872 |
+
if "text_embeds" not in added_cond_kwargs:
|
873 |
+
raise ValueError(
|
874 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
875 |
+
)
|
876 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
877 |
+
if "time_ids" not in added_cond_kwargs:
|
878 |
+
raise ValueError(
|
879 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
880 |
+
)
|
881 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
882 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
883 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
884 |
+
|
885 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
886 |
+
add_embeds = add_embeds.to(emb.dtype)
|
887 |
+
aug_emb = self.add_embedding(add_embeds)
|
888 |
+
elif self.config.addition_embed_type == "image":
|
889 |
+
# Kandinsky 2.2 - style
|
890 |
+
if "image_embeds" not in added_cond_kwargs:
|
891 |
+
raise ValueError(
|
892 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
893 |
+
)
|
894 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
895 |
+
aug_emb = self.add_embedding(image_embs)
|
896 |
+
elif self.config.addition_embed_type == "image_hint":
|
897 |
+
# Kandinsky 2.2 - style
|
898 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
899 |
+
raise ValueError(
|
900 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
901 |
+
)
|
902 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
903 |
+
hint = added_cond_kwargs.get("hint")
|
904 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
905 |
+
sample = torch.cat([sample, hint], dim=1)
|
906 |
+
|
907 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
908 |
+
|
909 |
+
if self.time_embed_act is not None:
|
910 |
+
emb = self.time_embed_act(emb)
|
911 |
+
|
912 |
+
# 2. pre-process
|
913 |
+
sample = self.conv_in(sample)
|
914 |
+
|
915 |
+
# NOTE: image condition
|
916 |
+
encoder_hidden_states_vision = None
|
917 |
+
if vision_input_dict is not None and self.vision_projection_type != "none":
|
918 |
+
if multiple_reference_image:
|
919 |
+
encoder_hidden_states_vision = []
|
920 |
+
for encoder_hidden_states_vision_i in encoder_hidden_states_vision:
|
921 |
+
encoder_hidden_states_vision_i = self.encoder_hidden_states_vision_projection(vision_input_dict=vision_input_dict, time_emb=emb, image_latent=sample)
|
922 |
+
encoder_hidden_states_vision.append(encoder_hidden_states_vision_i)
|
923 |
+
else:
|
924 |
+
encoder_hidden_states_vision = self.encoder_hidden_states_vision_projection(vision_input_dict=vision_input_dict, time_emb=emb, image_latent=sample)
|
925 |
+
|
926 |
+
if type(encoder_hidden_states_vision) == dict:
|
927 |
+
if "l_disen" in encoder_hidden_states_vision.keys():
|
928 |
+
extra_dict_outputs["l_disen"] = encoder_hidden_states_vision["l_disen"]
|
929 |
+
|
930 |
+
|
931 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
932 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
933 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
934 |
+
# Kadinsky 2.1 - style
|
935 |
+
if "image_embeds" not in added_cond_kwargs:
|
936 |
+
raise ValueError(
|
937 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
938 |
+
)
|
939 |
+
|
940 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
941 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
942 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
943 |
+
# Kandinsky 2.2 - style
|
944 |
+
if "image_embeds" not in added_cond_kwargs:
|
945 |
+
raise ValueError(
|
946 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
947 |
+
)
|
948 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
949 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
950 |
+
|
951 |
+
# 3. down
|
952 |
+
down_block_res_samples = (sample,)
|
953 |
+
additional_residuals = {}
|
954 |
+
for downsample_block in self.down_blocks:
|
955 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
956 |
+
sample, res_samples, extra_dict_output_down = downsample_block(
|
957 |
+
hidden_states=sample,
|
958 |
+
temb=emb,
|
959 |
+
encoder_hidden_states=encoder_hidden_states,
|
960 |
+
attention_mask=attention_mask,
|
961 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
962 |
+
encoder_attention_mask=encoder_attention_mask,
|
963 |
+
encoder_hidden_states_vision=encoder_hidden_states_vision,
|
964 |
+
encoder_hidden_states_control=None,
|
965 |
+
vision_guided_mask=vision_guided_mask,
|
966 |
+
extra_dict_inputs=extra_dict_inputs,
|
967 |
+
return_self_attn_map=return_self_attn_map,
|
968 |
+
**additional_residuals,
|
969 |
+
)
|
970 |
+
else:
|
971 |
+
sample, res_samples, extra_dict_output_down = downsample_block(hidden_states=sample, temb=emb)
|
972 |
+
extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_down)
|
973 |
+
|
974 |
+
down_block_res_samples += res_samples
|
975 |
+
|
976 |
+
if down_block_additional_residuals is not None:
|
977 |
+
new_down_block_res_samples = ()
|
978 |
+
|
979 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
980 |
+
down_block_res_samples, down_block_additional_residuals
|
981 |
+
):
|
982 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
983 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
984 |
+
|
985 |
+
down_block_res_samples = new_down_block_res_samples
|
986 |
+
|
987 |
+
# 4. mid
|
988 |
+
if self.mid_block is not None:
|
989 |
+
sample, extra_dict_output_middel = self.mid_block(
|
990 |
+
sample,
|
991 |
+
emb,
|
992 |
+
encoder_hidden_states=encoder_hidden_states,
|
993 |
+
attention_mask=attention_mask,
|
994 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
995 |
+
encoder_attention_mask=encoder_attention_mask,
|
996 |
+
encoder_hidden_states_vision=encoder_hidden_states_vision,
|
997 |
+
encoder_hidden_states_control=None,
|
998 |
+
vision_guided_mask=vision_guided_mask,
|
999 |
+
extra_dict_inputs=extra_dict_inputs,
|
1000 |
+
return_self_attn_map=return_self_attn_map,
|
1001 |
+
)
|
1002 |
+
extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_middel)
|
1003 |
+
|
1004 |
+
|
1005 |
+
if mid_block_additional_residual is not None:
|
1006 |
+
sample = sample + mid_block_additional_residual
|
1007 |
+
|
1008 |
+
# 5. up
|
1009 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1010 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1011 |
+
|
1012 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1013 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
1014 |
+
|
1015 |
+
# if we have not reached the final block and need to forward the
|
1016 |
+
# upsample size, we do it here
|
1017 |
+
if not is_final_block and forward_upsample_size:
|
1018 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1019 |
+
|
1020 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
1021 |
+
sample, extra_dict_output_up = upsample_block(
|
1022 |
+
hidden_states=sample,
|
1023 |
+
temb=emb,
|
1024 |
+
res_hidden_states_tuple=res_samples,
|
1025 |
+
encoder_hidden_states=encoder_hidden_states,
|
1026 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1027 |
+
upsample_size=upsample_size,
|
1028 |
+
attention_mask=attention_mask,
|
1029 |
+
encoder_attention_mask=encoder_attention_mask,
|
1030 |
+
encoder_hidden_states_vision=encoder_hidden_states_vision,
|
1031 |
+
encoder_hidden_states_control=None,
|
1032 |
+
vision_guided_mask=vision_guided_mask,
|
1033 |
+
extra_dict_inputs=extra_dict_inputs,
|
1034 |
+
return_self_attn_map=return_self_attn_map
|
1035 |
+
)
|
1036 |
+
else:
|
1037 |
+
sample, extra_dict_output_up = upsample_block(
|
1038 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
1039 |
+
)
|
1040 |
+
extra_dict_outputs = update_dict(extra_dict_outputs, extra_dict_output_up)
|
1041 |
+
|
1042 |
+
# 6. post-process
|
1043 |
+
if self.conv_norm_out:
|
1044 |
+
sample = self.conv_norm_out(sample)
|
1045 |
+
sample = self.conv_act(sample)
|
1046 |
+
sample = self.conv_out(sample)
|
1047 |
+
|
1048 |
+
if not return_dict:
|
1049 |
+
return (sample,)
|
1050 |
+
|
1051 |
+
if return_as_origin:
|
1052 |
+
return UNet2DConditionOutput(sample=sample)
|
1053 |
+
else:
|
1054 |
+
if return_text2image_mask:
|
1055 |
+
return UNet2DConditionOutput(sample=sample).sample, extra_dict_outputs
|
1056 |
+
else:
|
1057 |
+
return UNet2DConditionOutput(sample=sample).sample
|
1058 |
+
|
1059 |
+
|
models/vae.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import diffusers
|
17 |
+
import torch
|
18 |
+
|
19 |
+
class AutoencoderKL(diffusers.AutoencoderKL):
|
20 |
+
"""
|
21 |
+
We simply inherit the model code from diffusers
|
22 |
+
"""
|
23 |
+
def __init__(self, attention=True, *args, **kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
# A hacky way to remove attention.
|
26 |
+
if not attention:
|
27 |
+
self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
|
28 |
+
self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
|
29 |
+
|
30 |
+
def load_state_dict(self, state_dict, strict=True):
|
31 |
+
# Newer version of diffusers changed the model keys, causing incompatibility with old checkpoints.
|
32 |
+
# They provided a method for conversion. We call conversion before loading state_dict.
|
33 |
+
convert_deprecated_attention_blocks = getattr(self, "_convert_deprecated_attention_blocks", None)
|
34 |
+
if callable(convert_deprecated_attention_blocks):
|
35 |
+
convert_deprecated_attention_blocks(state_dict)
|
36 |
+
return super().load_state_dict(state_dict, strict)
|
prompts/validation_negative.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
(((naked))), deformed, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, malformed hands, blurry, ((((mutated hands and fingers)))), watermark, watermarked, oversaturated, censored, distorted hands, amputation, missing hands, obese, doubled face, double hands
|
requirements.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
|
2 |
+
|
3 |
+
numpy
|
4 |
+
ftfy
|
5 |
+
|
6 |
+
# Training
|
7 |
+
bs4==0.0.1 # Needed for text cleaning
|
8 |
+
bson==0.5.10
|
9 |
+
diffusers==0.19.3 # diffusers[torch]==0.19.3 in control
|
10 |
+
einops==0.6.0
|
11 |
+
ftfy==6.1.1 # Needed for text cleaning
|
12 |
+
kornia==0.6.12
|
13 |
+
lpips==0.1.4
|
14 |
+
sentencepiece==0.1.99 # Needed for T5 tokenizer
|
15 |
+
transformers==4.36.2
|
16 |
+
tqdm==4.64.1
|
17 |
+
torchgeometry # Needed for ssim loss
|
18 |
+
expecttest # Needed for compile
|
19 |
+
accelerate==0.24.1 # model saving bugs when accelerate==0.25.0
|
20 |
+
|
21 |
+
# Inference
|
22 |
+
av==10.0.0
|
23 |
+
pims==0.6.1
|
24 |
+
opencv-python-headless==4.6.0.66
|
25 |
+
|
26 |
+
gradio==3.42.0
|
27 |
+
httpx==0.23.3
|
28 |
+
opencv-python
|
29 |
+
open_clip_torch
|
30 |
+
protobuf==3.20.0
|
31 |
+
huggingface_hub==0.25.0
|
32 |
+
|
33 |
+
open_clip_torch
|
34 |
+
git+https://github.com/openai/CLIP.git
|
schedulers/__pycache__/base.cpython-310.pyc
ADDED
Binary file (3.63 kB). View file
|
|
schedulers/__pycache__/ddim.cpython-310.pyc
ADDED
Binary file (1.98 kB). View file
|
|
schedulers/__pycache__/dpm_s.cpython-310.pyc
ADDED
Binary file (6.09 kB). View file
|
|
schedulers/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.8 kB). View file
|
|
schedulers/base.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from abc import ABC
|
18 |
+
from typing import Optional, Union, List
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class SchedulerConversionOutput:
|
23 |
+
pred_epsilon: torch.Tensor
|
24 |
+
pred_original_sample: torch.Tensor
|
25 |
+
pred_velocity: torch.Tensor
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class SchedulerStepOutput:
|
30 |
+
prev_sample: torch.Tensor
|
31 |
+
pred_original_sample: Optional[torch.Tensor] = None
|
32 |
+
|
33 |
+
|
34 |
+
class Scheduler(ABC):
|
35 |
+
prediction_types = ["epsilon", "sample", "v_prediction"]
|
36 |
+
timesteps_types = ["leading", "linspace", "trailing"]
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
num_train_timesteps: int,
|
41 |
+
num_inference_timesteps: int,
|
42 |
+
betas: torch.Tensor,
|
43 |
+
inference_timesteps: Union[str, List[int]] = "trailing",
|
44 |
+
set_alpha_to_one: bool = True,
|
45 |
+
device: Optional[Union[str, torch.device]] = None,
|
46 |
+
dtype: torch.dtype = torch.float32
|
47 |
+
):
|
48 |
+
assert num_train_timesteps > 0
|
49 |
+
assert num_train_timesteps >= num_inference_timesteps
|
50 |
+
assert num_train_timesteps == betas.size(0)
|
51 |
+
assert betas.ndim == 1
|
52 |
+
|
53 |
+
self.device = device or betas.device
|
54 |
+
self.dtype = dtype
|
55 |
+
|
56 |
+
self.num_train_timesteps = num_train_timesteps
|
57 |
+
self.num_inference_timesteps = num_inference_timesteps
|
58 |
+
|
59 |
+
self.betas = betas.to(device=device, dtype=dtype)
|
60 |
+
self.alphas = 1.0 - self.betas
|
61 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
62 |
+
self.final_alpha_cumprod = torch.tensor(1.0, device=self.device, dtype=self.dtype) if set_alpha_to_one else self.alphas_cumprod[0]
|
63 |
+
|
64 |
+
if isinstance(inference_timesteps, list):
|
65 |
+
# If user defines a custom inference timestep, directly assign it.
|
66 |
+
assert len(inference_timesteps) == num_inference_timesteps
|
67 |
+
self.timesteps = torch.tensor(inference_timesteps, device=self.device, dtype=torch.int)
|
68 |
+
elif inference_timesteps == "trailing":
|
69 |
+
# Example 20 steps: [999, 949, 899, 849, 799, 749, 699, 649, 599, 549, 499, 449, 399, 349, 299, 249, 199, 149, 99, 49]
|
70 |
+
self.timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / num_inference_timesteps, device=self.device).round().int()
|
71 |
+
elif inference_timesteps == "linspace":
|
72 |
+
# Example 20 steps: [999, 946, 894, 841, 789, 736, 684, 631, 578, 526, 473, 421, 368, 315, 263, 210, 158, 105, 53, 0]
|
73 |
+
self.timesteps = torch.linspace(0, num_train_timesteps - 1, num_inference_timesteps, device=self.device).round().int().flip(0)
|
74 |
+
elif inference_timesteps == "leading":
|
75 |
+
# Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
|
76 |
+
# The inference timestep does not start from 999.
|
77 |
+
# Example 20 steps: [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0]
|
78 |
+
self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
|
79 |
+
else:
|
80 |
+
raise NotImplementedError
|
81 |
+
|
82 |
+
def reset(self):
|
83 |
+
pass
|
84 |
+
|
85 |
+
def add_noise(
|
86 |
+
self,
|
87 |
+
original_samples: torch.Tensor,
|
88 |
+
noise: torch.Tensor,
|
89 |
+
timesteps: Union[torch.Tensor, int],
|
90 |
+
) -> torch.Tensor:
|
91 |
+
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (original_samples.ndim - 1)))
|
92 |
+
return alpha_prod_t ** (0.5) * original_samples + (1 - alpha_prod_t) ** (0.5) * noise
|
93 |
+
|
94 |
+
def convert_output(
|
95 |
+
self,
|
96 |
+
model_output: torch.Tensor,
|
97 |
+
model_output_type: str,
|
98 |
+
sample: torch.Tensor,
|
99 |
+
timesteps: Union[torch.Tensor, int]
|
100 |
+
) -> SchedulerConversionOutput:
|
101 |
+
assert model_output_type in self.prediction_types
|
102 |
+
|
103 |
+
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
|
104 |
+
beta_prod_t = 1 - alpha_prod_t
|
105 |
+
|
106 |
+
if model_output_type == "epsilon":
|
107 |
+
pred_epsilon = model_output
|
108 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5)
|
109 |
+
pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
|
110 |
+
elif model_output_type == "sample":
|
111 |
+
pred_original_sample = model_output
|
112 |
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
113 |
+
pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
|
114 |
+
elif model_output_type == "v_prediction":
|
115 |
+
pred_velocity = model_output
|
116 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
117 |
+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
118 |
+
else:
|
119 |
+
raise ValueError("Unknown prediction type")
|
120 |
+
|
121 |
+
return SchedulerConversionOutput(
|
122 |
+
pred_epsilon=pred_epsilon,
|
123 |
+
pred_original_sample=pred_original_sample,
|
124 |
+
pred_velocity=pred_velocity)
|
125 |
+
|
126 |
+
def get_velocity(
|
127 |
+
self,
|
128 |
+
sample: torch.Tensor,
|
129 |
+
noise: torch.Tensor,
|
130 |
+
timesteps: torch.Tensor
|
131 |
+
) -> torch.FloatTensor:
|
132 |
+
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
|
133 |
+
return alpha_prod_t ** (0.5) * noise - (1 - alpha_prod_t) ** (0.5) * sample
|
schedulers/ddim.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .base import *
|
16 |
+
|
17 |
+
|
18 |
+
class DDIMScheduler(Scheduler):
|
19 |
+
def step(
|
20 |
+
self,
|
21 |
+
model_output: torch.Tensor,
|
22 |
+
model_output_type: str,
|
23 |
+
timestep: Union[torch.Tensor, int],
|
24 |
+
sample: torch.Tensor,
|
25 |
+
eta: float = 0.0,
|
26 |
+
clip_sample: bool = False,
|
27 |
+
dynamic_threshold: Optional[float] = None,
|
28 |
+
variance_noise: Optional[torch.Tensor] = None,
|
29 |
+
) -> SchedulerStepOutput:
|
30 |
+
# 1. get previous step value (t-1)
|
31 |
+
if not isinstance(timestep, torch.Tensor):
|
32 |
+
timestep = torch.tensor(timestep, device=self.device, dtype=torch.int)
|
33 |
+
|
34 |
+
idx = timestep.reshape(-1, 1).eq(self.timesteps.reshape(1, -1)).nonzero()[:, 1]
|
35 |
+
prev_timestep = self.timesteps[idx.add(1).clamp_max(self.num_inference_timesteps - 1)]
|
36 |
+
|
37 |
+
# 2. compute alphas, betas
|
38 |
+
alpha_prod_t = self.alphas_cumprod[timestep].reshape(-1, *([1] * (sample.ndim - 1)))
|
39 |
+
alpha_prod_t_prev = torch.where(idx < self.num_inference_timesteps - 1, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod).reshape(-1, *([1] * (sample.ndim - 1)))
|
40 |
+
beta_prod_t = 1 - alpha_prod_t
|
41 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
42 |
+
|
43 |
+
# 3. compute predicted original sample from predicted noise also called
|
44 |
+
model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep)
|
45 |
+
pred_original_sample = model_output_conversion.pred_original_sample
|
46 |
+
pred_epsilon = model_output_conversion.pred_epsilon
|
47 |
+
|
48 |
+
# 4. Clip or threshold "predicted x_0"
|
49 |
+
if clip_sample:
|
50 |
+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
51 |
+
pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
|
52 |
+
|
53 |
+
if dynamic_threshold is not None:
|
54 |
+
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
55 |
+
dynamic_max_val = pred_original_sample \
|
56 |
+
.flatten(1) \
|
57 |
+
.abs() \
|
58 |
+
.float() \
|
59 |
+
.quantile(dynamic_threshold, dim=1) \
|
60 |
+
.type_as(pred_original_sample) \
|
61 |
+
.clamp_min(1) \
|
62 |
+
.view(-1, *([1] * (pred_original_sample.ndim - 1)))
|
63 |
+
pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
64 |
+
pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
|
65 |
+
|
66 |
+
# 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf
|
67 |
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
68 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
69 |
+
std_dev_t = eta * variance ** (0.5)
|
70 |
+
|
71 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
72 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
73 |
+
|
74 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
75 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
76 |
+
|
77 |
+
# 8. add "random noise" if needed.
|
78 |
+
if eta > 0:
|
79 |
+
if variance_noise is None:
|
80 |
+
variance_noise = torch.randn_like(model_output)
|
81 |
+
prev_sample = prev_sample + std_dev_t * variance_noise
|
82 |
+
|
83 |
+
return SchedulerStepOutput(
|
84 |
+
prev_sample=prev_sample,
|
85 |
+
pred_original_sample=pred_original_sample)
|
schedulers/dpm_m.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
17 |
+
|
18 |
+
from typing import List, Optional, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from .base import *
|
24 |
+
|
25 |
+
|
26 |
+
class DPMSolverMultistepScheduler(Scheduler):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
# Generic scheduler settings
|
30 |
+
num_inference_timesteps: int,
|
31 |
+
betas: torch.Tensor,
|
32 |
+
num_train_timesteps: int = 1000,
|
33 |
+
inference_timesteps: Union[str, List[str]] = "trailing",
|
34 |
+
set_alpha_to_one: bool = True,
|
35 |
+
device: Optional[torch.device] = None,
|
36 |
+
dtype: torch.dtype = torch.float32,
|
37 |
+
# DPM scheduler settings
|
38 |
+
solver_order: int = 2,
|
39 |
+
algorithm_type: str = "dpmsolver++",
|
40 |
+
solver_type: str = "midpoint",
|
41 |
+
lower_order_final: bool = True,
|
42 |
+
use_karras_sigmas: bool = False,
|
43 |
+
):
|
44 |
+
super().__init__(
|
45 |
+
num_train_timesteps=num_train_timesteps,
|
46 |
+
num_inference_timesteps=num_inference_timesteps,
|
47 |
+
betas=betas,
|
48 |
+
inference_timesteps=inference_timesteps,
|
49 |
+
set_alpha_to_one=set_alpha_to_one,
|
50 |
+
device=device,
|
51 |
+
dtype=dtype,
|
52 |
+
)
|
53 |
+
|
54 |
+
self.solver_order = solver_order
|
55 |
+
self.solver_type = solver_type
|
56 |
+
self.lower_order_final = lower_order_final
|
57 |
+
self.algorithm_type = algorithm_type
|
58 |
+
|
59 |
+
# Currently we only support VP-type noise schedule
|
60 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
61 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
62 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
63 |
+
|
64 |
+
sigmas = torch.sqrt((1 - self.alphas_cumprod) / self.alphas_cumprod)
|
65 |
+
if use_karras_sigmas:
|
66 |
+
log_sigmas = torch.log(sigmas)
|
67 |
+
sigmas = self._convert_to_karras(
|
68 |
+
in_sigmas=sigmas, num_inference_timesteps=num_inference_timesteps)
|
69 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas)
|
70 |
+
for sigma in sigmas]).round()
|
71 |
+
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
72 |
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
73 |
+
sigmas = torch.from_numpy(sigmas).to(device)
|
74 |
+
self.sigmas = sigmas
|
75 |
+
|
76 |
+
# standard deviation of the initial noise distribution
|
77 |
+
self.init_noise_sigma = 1.0
|
78 |
+
|
79 |
+
# settings for DPM-Solver
|
80 |
+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++", "deis"]:
|
81 |
+
raise NotImplementedError(
|
82 |
+
f"{algorithm_type} does is not implemented for {self.__class__}")
|
83 |
+
|
84 |
+
if solver_type not in ["midpoint", "heun", "logrho", "bh1", "bh2"]:
|
85 |
+
raise NotImplementedError(
|
86 |
+
f"{solver_type} does is not implemented for {self.__class__}")
|
87 |
+
|
88 |
+
# setable values
|
89 |
+
self.reset()
|
90 |
+
|
91 |
+
def reset(self):
|
92 |
+
self.model_outputs = [None] * self.solver_order
|
93 |
+
self.lower_order_nums = 0
|
94 |
+
|
95 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
96 |
+
# get log sigma
|
97 |
+
log_sigma = np.log(sigma)
|
98 |
+
|
99 |
+
# get distribution
|
100 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
101 |
+
|
102 |
+
# get sigmas range
|
103 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(
|
104 |
+
axis=0).clip(max=log_sigmas.shape[0] - 2)
|
105 |
+
high_idx = low_idx + 1
|
106 |
+
|
107 |
+
low = log_sigmas[low_idx]
|
108 |
+
high = log_sigmas[high_idx]
|
109 |
+
|
110 |
+
# interpolate sigmas
|
111 |
+
w = (low - log_sigma) / (low - high)
|
112 |
+
w = np.clip(w, 0, 1)
|
113 |
+
|
114 |
+
# transform interpolation to time range
|
115 |
+
t = (1 - w) * low_idx + w * high_idx
|
116 |
+
t = t.reshape(sigma.shape)
|
117 |
+
return t
|
118 |
+
|
119 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
120 |
+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_timesteps) -> torch.FloatTensor:
|
121 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
122 |
+
|
123 |
+
sigma_min: float = in_sigmas[-1].item()
|
124 |
+
sigma_max: float = in_sigmas[0].item()
|
125 |
+
|
126 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
127 |
+
ramp = np.linspace(0, 1, num_inference_timesteps)
|
128 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
129 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
130 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
131 |
+
return sigmas
|
132 |
+
|
133 |
+
def dpm_solver_first_order_update(
|
134 |
+
self,
|
135 |
+
model_output: torch.FloatTensor,
|
136 |
+
timestep: int,
|
137 |
+
prev_timestep: int,
|
138 |
+
sample: torch.FloatTensor,
|
139 |
+
noise: Optional[torch.FloatTensor] = None,
|
140 |
+
) -> torch.FloatTensor:
|
141 |
+
"""
|
142 |
+
One step for the first-order DPM-Solver (equivalent to DDIM).
|
143 |
+
|
144 |
+
See https://arxiv.org/abs/2206.00927 for the detailed derivation.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
148 |
+
timestep (`int`): current discrete timestep in the diffusion chain.
|
149 |
+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
150 |
+
sample (`torch.FloatTensor`):
|
151 |
+
current instance of sample being created by diffusion process.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
155 |
+
"""
|
156 |
+
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
157 |
+
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
158 |
+
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
159 |
+
h = lambda_t - lambda_s
|
160 |
+
if self.algorithm_type == "dpmsolver++":
|
161 |
+
x_t = (sigma_t / sigma_s) * sample - \
|
162 |
+
(alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
163 |
+
elif self.algorithm_type == "dpmsolver":
|
164 |
+
x_t = (alpha_t / alpha_s) * sample - \
|
165 |
+
(sigma_t * (torch.exp(h) - 1.0)) * model_output
|
166 |
+
elif self.algorithm_type == "sde-dpmsolver++":
|
167 |
+
assert noise is not None
|
168 |
+
x_t = (
|
169 |
+
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
170 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
171 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
172 |
+
)
|
173 |
+
elif self.algorithm_type == "sde-dpmsolver":
|
174 |
+
assert noise is not None
|
175 |
+
x_t = (
|
176 |
+
(alpha_t / alpha_s) * sample
|
177 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
178 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
179 |
+
)
|
180 |
+
return x_t
|
181 |
+
|
182 |
+
def multistep_dpm_solver_second_order_update(
|
183 |
+
self,
|
184 |
+
model_output_list: List[torch.FloatTensor],
|
185 |
+
timestep_list: List[int],
|
186 |
+
prev_timestep: int,
|
187 |
+
sample: torch.FloatTensor,
|
188 |
+
noise: Optional[torch.FloatTensor] = None,
|
189 |
+
) -> torch.FloatTensor:
|
190 |
+
"""
|
191 |
+
One step for the second-order multistep DPM-Solver.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
model_output_list (`List[torch.FloatTensor]`):
|
195 |
+
direct outputs from learned diffusion model at current and latter timesteps.
|
196 |
+
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
197 |
+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
198 |
+
sample (`torch.FloatTensor`):
|
199 |
+
current instance of sample being created by diffusion process.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
203 |
+
"""
|
204 |
+
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
205 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
206 |
+
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
207 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
208 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
209 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
210 |
+
r0 = h_0 / h
|
211 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
212 |
+
if self.algorithm_type == "dpmsolver++":
|
213 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
214 |
+
if self.solver_type == "midpoint":
|
215 |
+
x_t = (
|
216 |
+
(sigma_t / sigma_s0) * sample
|
217 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
218 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
219 |
+
)
|
220 |
+
elif self.solver_type == "heun":
|
221 |
+
x_t = (
|
222 |
+
(sigma_t / sigma_s0) * sample
|
223 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
224 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
225 |
+
)
|
226 |
+
elif self.algorithm_type == "dpmsolver":
|
227 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
228 |
+
if self.solver_type == "midpoint":
|
229 |
+
x_t = (
|
230 |
+
(alpha_t / alpha_s0) * sample
|
231 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
232 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
233 |
+
)
|
234 |
+
elif self.solver_type == "heun":
|
235 |
+
x_t = (
|
236 |
+
(alpha_t / alpha_s0) * sample
|
237 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
238 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
239 |
+
)
|
240 |
+
elif self.algorithm_type == "sde-dpmsolver++":
|
241 |
+
assert noise is not None
|
242 |
+
if self.solver_type == "midpoint":
|
243 |
+
x_t = (
|
244 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
245 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
246 |
+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
247 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
248 |
+
)
|
249 |
+
elif self.solver_type == "heun":
|
250 |
+
x_t = (
|
251 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
252 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
253 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
254 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
255 |
+
)
|
256 |
+
elif self.algorithm_type == "sde-dpmsolver":
|
257 |
+
assert noise is not None
|
258 |
+
if self.solver_type == "midpoint":
|
259 |
+
x_t = (
|
260 |
+
(alpha_t / alpha_s0) * sample
|
261 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
262 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
263 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
264 |
+
)
|
265 |
+
elif self.solver_type == "heun":
|
266 |
+
x_t = (
|
267 |
+
(alpha_t / alpha_s0) * sample
|
268 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
269 |
+
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
270 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
271 |
+
)
|
272 |
+
return x_t
|
273 |
+
|
274 |
+
def multistep_dpm_solver_third_order_update(
|
275 |
+
self,
|
276 |
+
model_output_list: List[torch.FloatTensor],
|
277 |
+
timestep_list: List[int],
|
278 |
+
prev_timestep: int,
|
279 |
+
sample: torch.FloatTensor,
|
280 |
+
) -> torch.FloatTensor:
|
281 |
+
"""
|
282 |
+
One step for the third-order multistep DPM-Solver.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
model_output_list (`List[torch.FloatTensor]`):
|
286 |
+
direct outputs from learned diffusion model at current and latter timesteps.
|
287 |
+
timestep (`int`): current and latter discrete timestep in the diffusion chain.
|
288 |
+
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
|
289 |
+
sample (`torch.FloatTensor`):
|
290 |
+
current instance of sample being created by diffusion process.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
`torch.FloatTensor`: the sample tensor at the previous timestep.
|
294 |
+
"""
|
295 |
+
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
296 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
297 |
+
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
298 |
+
self.lambda_t[t],
|
299 |
+
self.lambda_t[s0],
|
300 |
+
self.lambda_t[s1],
|
301 |
+
self.lambda_t[s2],
|
302 |
+
)
|
303 |
+
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
304 |
+
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
305 |
+
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
306 |
+
r0, r1 = h_0 / h, h_1 / h
|
307 |
+
D0 = m0
|
308 |
+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
309 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
310 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
311 |
+
if self.algorithm_type == "dpmsolver++":
|
312 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
313 |
+
x_t = (
|
314 |
+
(sigma_t / sigma_s0) * sample
|
315 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
316 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
317 |
+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
318 |
+
)
|
319 |
+
elif self.algorithm_type == "dpmsolver":
|
320 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
321 |
+
x_t = (
|
322 |
+
(alpha_t / alpha_s0) * sample
|
323 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
324 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
325 |
+
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
326 |
+
)
|
327 |
+
return x_t
|
328 |
+
|
329 |
+
def step(
|
330 |
+
self,
|
331 |
+
model_output: torch.FloatTensor,
|
332 |
+
model_output_type: str,
|
333 |
+
timestep: int,
|
334 |
+
sample: torch.FloatTensor,
|
335 |
+
) -> SchedulerStepOutput:
|
336 |
+
"""
|
337 |
+
Step function propagating the sample with the multistep DPM-Solver.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
341 |
+
timestep (`int`): current discrete timestep in the diffusion chain.
|
342 |
+
sample (`torch.FloatTensor`):
|
343 |
+
current instance of sample being created by diffusion process.
|
344 |
+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
|
348 |
+
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
349 |
+
|
350 |
+
"""
|
351 |
+
if self.num_inference_timesteps is None:
|
352 |
+
raise ValueError(
|
353 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
354 |
+
)
|
355 |
+
|
356 |
+
if isinstance(timestep, torch.Tensor):
|
357 |
+
timestep = timestep.to(self.timesteps.device)
|
358 |
+
step_index = (self.timesteps == timestep).nonzero()
|
359 |
+
if len(step_index) == 0:
|
360 |
+
step_index = len(self.timesteps) - 1
|
361 |
+
else:
|
362 |
+
step_index = step_index.item()
|
363 |
+
prev_timestep = 0 if step_index == len(
|
364 |
+
self.timesteps) - 1 else self.timesteps[step_index + 1]
|
365 |
+
lower_order_final = (
|
366 |
+
(step_index == len(self.timesteps) -
|
367 |
+
1) and self.lower_order_final and len(self.timesteps) < 15
|
368 |
+
)
|
369 |
+
lower_order_second = (
|
370 |
+
(step_index == len(self.timesteps) -
|
371 |
+
2) and self.lower_order_final and len(self.timesteps) < 15
|
372 |
+
)
|
373 |
+
|
374 |
+
model_output_convert = self.convert_output(
|
375 |
+
model_output, model_output_type=model_output_type, sample=sample, timesteps=timestep)
|
376 |
+
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
377 |
+
if self.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
378 |
+
model_output = model_output_convert.pred_original_sample
|
379 |
+
# DPM-Solver needs to solve an integral of the noise prediction model.
|
380 |
+
elif self.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
381 |
+
model_output = model_output_convert.pred_epsilon
|
382 |
+
|
383 |
+
for i in range(self.solver_order - 1):
|
384 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
385 |
+
self.model_outputs[-1] = model_output
|
386 |
+
|
387 |
+
if self.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
388 |
+
noise = torch.randn_like(
|
389 |
+
model_output, device=model_output.device, dtype=model_output.dtype)
|
390 |
+
else:
|
391 |
+
noise = None
|
392 |
+
|
393 |
+
if self.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
394 |
+
prev_sample = self.dpm_solver_first_order_update(
|
395 |
+
model_output, timestep, prev_timestep, sample, noise=noise
|
396 |
+
)
|
397 |
+
elif self.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
398 |
+
timestep_list = [self.timesteps[step_index - 1], timestep]
|
399 |
+
prev_sample = self.multistep_dpm_solver_second_order_update(
|
400 |
+
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
401 |
+
)
|
402 |
+
else:
|
403 |
+
timestep_list = [self.timesteps[step_index - 2],
|
404 |
+
self.timesteps[step_index - 1], timestep]
|
405 |
+
prev_sample = self.multistep_dpm_solver_third_order_update(
|
406 |
+
self.model_outputs, timestep_list, prev_timestep, sample
|
407 |
+
)
|
408 |
+
|
409 |
+
if self.lower_order_nums < self.solver_order:
|
410 |
+
self.lower_order_nums += 1
|
411 |
+
|
412 |
+
return SchedulerStepOutput(prev_sample=prev_sample, pred_original_sample=model_output_convert.pred_original_sample)
|
schedulers/dpm_s.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .base import *
|
16 |
+
|
17 |
+
|
18 |
+
class DPMSolverSingleStepScheduler(Scheduler):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
# Generic scheduler settings
|
22 |
+
num_train_timesteps: int,
|
23 |
+
num_inference_timesteps: int,
|
24 |
+
betas: torch.Tensor,
|
25 |
+
inference_timesteps: Union[str, List[int]] = "trailing",
|
26 |
+
set_alpha_to_one: bool = True,
|
27 |
+
device: Optional[Union[str, torch.device]] = None,
|
28 |
+
dtype: torch.dtype = torch.float32,
|
29 |
+
# DPM scheduler settings
|
30 |
+
algorithm_type: str = "dpmsolver++",
|
31 |
+
solver_type: str = "midpoint",
|
32 |
+
solver_order: int = 2,
|
33 |
+
lower_order_final: bool = True,
|
34 |
+
):
|
35 |
+
super().__init__(
|
36 |
+
num_train_timesteps=num_train_timesteps,
|
37 |
+
num_inference_timesteps=num_inference_timesteps,
|
38 |
+
betas=betas,
|
39 |
+
inference_timesteps=inference_timesteps,
|
40 |
+
set_alpha_to_one=set_alpha_to_one,
|
41 |
+
device=device,
|
42 |
+
dtype=dtype,
|
43 |
+
)
|
44 |
+
|
45 |
+
self.solver_order = solver_order
|
46 |
+
self.solver_type = solver_type
|
47 |
+
self.lower_order_final = lower_order_final
|
48 |
+
self.algorithm_type = algorithm_type
|
49 |
+
|
50 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
51 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
52 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
53 |
+
|
54 |
+
self.reset()
|
55 |
+
|
56 |
+
def reset(self):
|
57 |
+
self.model_outputs = [None] * self.solver_order
|
58 |
+
self.sample = None
|
59 |
+
self.order_list = self.get_order_list()
|
60 |
+
self.last_step_index = None
|
61 |
+
|
62 |
+
def get_order_list(self) -> List[int]:
|
63 |
+
steps = self.num_inference_timesteps
|
64 |
+
order = self.solver_order
|
65 |
+
# First step must be order 1
|
66 |
+
# Second step must be order 1 in case of terminal zero SNR
|
67 |
+
orders = [1] + [(i % order) + 1 for i in range(steps - 1)] + [1]
|
68 |
+
# Last step should be order 1 for better quality.
|
69 |
+
if self.lower_order_final:
|
70 |
+
orders[-1] = 1
|
71 |
+
return orders
|
72 |
+
|
73 |
+
def dpm_solver_first_order_update(
|
74 |
+
self,
|
75 |
+
model_output: torch.FloatTensor,
|
76 |
+
timestep: int,
|
77 |
+
prev_timestep: int,
|
78 |
+
sample: torch.FloatTensor,
|
79 |
+
) -> torch.FloatTensor:
|
80 |
+
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
81 |
+
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
82 |
+
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
83 |
+
h = lambda_t - lambda_s
|
84 |
+
if self.algorithm_type == "dpmsolver++":
|
85 |
+
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
86 |
+
elif self.algorithm_type == "dpmsolver":
|
87 |
+
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
88 |
+
return x_t
|
89 |
+
|
90 |
+
def singlestep_dpm_solver_second_order_update(
|
91 |
+
self,
|
92 |
+
model_output_list: List[torch.FloatTensor],
|
93 |
+
timestep_list: List[int],
|
94 |
+
prev_timestep: int,
|
95 |
+
sample: torch.FloatTensor,
|
96 |
+
) -> torch.FloatTensor:
|
97 |
+
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
98 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
99 |
+
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
100 |
+
alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1]
|
101 |
+
sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1]
|
102 |
+
h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
|
103 |
+
r0 = h_0 / h
|
104 |
+
D0, D1 = m1, (1.0 / r0) * (m0 - m1)
|
105 |
+
if self.algorithm_type == "dpmsolver++":
|
106 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
107 |
+
if self.solver_type == "midpoint":
|
108 |
+
x_t = (
|
109 |
+
(sigma_t / sigma_s1) * sample
|
110 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
111 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
112 |
+
)
|
113 |
+
elif self.solver_type == "heun":
|
114 |
+
x_t = (
|
115 |
+
(sigma_t / sigma_s1) * sample
|
116 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
117 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
118 |
+
)
|
119 |
+
elif self.algorithm_type == "dpmsolver":
|
120 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
121 |
+
if self.solver_type == "midpoint":
|
122 |
+
x_t = (
|
123 |
+
(alpha_t / alpha_s1) * sample
|
124 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
125 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
126 |
+
)
|
127 |
+
elif self.solver_type == "heun":
|
128 |
+
x_t = (
|
129 |
+
(alpha_t / alpha_s1) * sample
|
130 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
131 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
132 |
+
)
|
133 |
+
return x_t
|
134 |
+
|
135 |
+
def singlestep_dpm_solver_third_order_update(
|
136 |
+
self,
|
137 |
+
model_output_list: List[torch.FloatTensor],
|
138 |
+
timestep_list: List[int],
|
139 |
+
prev_timestep: int,
|
140 |
+
sample: torch.FloatTensor,
|
141 |
+
) -> torch.FloatTensor:
|
142 |
+
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
|
143 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
144 |
+
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
|
145 |
+
self.lambda_t[t],
|
146 |
+
self.lambda_t[s0],
|
147 |
+
self.lambda_t[s1],
|
148 |
+
self.lambda_t[s2],
|
149 |
+
)
|
150 |
+
alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2]
|
151 |
+
sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2]
|
152 |
+
h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
|
153 |
+
r0, r1 = h_0 / h, h_1 / h
|
154 |
+
D0 = m2
|
155 |
+
D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2)
|
156 |
+
D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1)
|
157 |
+
D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1)
|
158 |
+
if self.algorithm_type == "dpmsolver++":
|
159 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
160 |
+
if self.solver_type == "midpoint":
|
161 |
+
x_t = (
|
162 |
+
(sigma_t / sigma_s2) * sample
|
163 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
164 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1
|
165 |
+
)
|
166 |
+
elif self.solver_type == "heun":
|
167 |
+
x_t = (
|
168 |
+
(sigma_t / sigma_s2) * sample
|
169 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
170 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
171 |
+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
172 |
+
)
|
173 |
+
elif self.algorithm_type == "dpmsolver":
|
174 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
175 |
+
if self.solver_type == "midpoint":
|
176 |
+
x_t = (
|
177 |
+
(alpha_t / alpha_s2) * sample
|
178 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
179 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1
|
180 |
+
)
|
181 |
+
elif self.solver_type == "heun":
|
182 |
+
x_t = (
|
183 |
+
(alpha_t / alpha_s2) * sample
|
184 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
185 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
186 |
+
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
187 |
+
)
|
188 |
+
return x_t
|
189 |
+
|
190 |
+
def step(
|
191 |
+
self,
|
192 |
+
model_output: torch.FloatTensor,
|
193 |
+
model_output_type: str,
|
194 |
+
timestep: int,
|
195 |
+
sample: torch.FloatTensor,
|
196 |
+
) -> SchedulerStepOutput:
|
197 |
+
|
198 |
+
step_index = (self.timesteps == timestep).nonzero().item()
|
199 |
+
|
200 |
+
# Check if this step is the follow-up of the previous step.
|
201 |
+
# If not, then we reset and treat it as a new run.
|
202 |
+
if self.last_step_index and self.last_step_index != step_index - 1:
|
203 |
+
self.reset()
|
204 |
+
self.last_step_index = step_index
|
205 |
+
|
206 |
+
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
207 |
+
model_output_convert = self.convert_output(model_output, model_output_type, sample, timestep)
|
208 |
+
|
209 |
+
if self.algorithm_type == "dpmsolver++":
|
210 |
+
model_output = model_output_convert.pred_original_sample
|
211 |
+
else:
|
212 |
+
model_output = model_output_convert.pred_epsilon
|
213 |
+
|
214 |
+
for i in range(self.solver_order - 1):
|
215 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
216 |
+
self.model_outputs[-1] = model_output
|
217 |
+
|
218 |
+
order = self.order_list[step_index]
|
219 |
+
|
220 |
+
# For img2img denoising might start with order>1 which is not possible
|
221 |
+
# In this case make sure that the first two steps are both order=1
|
222 |
+
while self.model_outputs[-order] is None:
|
223 |
+
order -= 1
|
224 |
+
|
225 |
+
# For single-step solvers, we use the initial value at each time with order = 1.
|
226 |
+
if order == 1:
|
227 |
+
self.sample = sample
|
228 |
+
|
229 |
+
timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep]
|
230 |
+
|
231 |
+
if order == 1:
|
232 |
+
prev_sample = self.dpm_solver_first_order_update(self.model_outputs[-1], timestep_list[-1], prev_timestep, self.sample)
|
233 |
+
elif order == 2:
|
234 |
+
prev_sample = self.singlestep_dpm_solver_second_order_update(self.model_outputs, timestep_list, prev_timestep, self.sample)
|
235 |
+
elif order == 3:
|
236 |
+
prev_sample = self.singlestep_dpm_solver_third_order_update(self.model_outputs, timestep_list, prev_timestep, self.sample)
|
237 |
+
else:
|
238 |
+
raise NotImplementedError
|
239 |
+
|
240 |
+
return SchedulerStepOutput(
|
241 |
+
prev_sample=prev_sample,
|
242 |
+
pred_original_sample=model_output_convert.pred_original_sample
|
243 |
+
)
|
schedulers/utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
|
21 |
+
def get_betas(name: str, num_steps: int = 1000, shift_snr: float = 1, terminal_pure_noise: bool = False):
|
22 |
+
# Get betas
|
23 |
+
max_beta = 1 if terminal_pure_noise else 0.999
|
24 |
+
if name == "squared_linear":
|
25 |
+
betas = torch.linspace(0.00085**0.5, 0.012**0.5, num_steps) ** 2
|
26 |
+
elif name == "cosine":
|
27 |
+
betas = get_cosine_betas(num_steps, max_beta=max_beta)
|
28 |
+
elif name == "alphas_cumprod_linear":
|
29 |
+
betas = get_alphas_cumprod_linear_betas(num_steps, max_beta=max_beta)
|
30 |
+
elif name == "sigmoid":
|
31 |
+
betas = get_sigmoid_betas(num_steps, max_beta=max_beta, square=True, slop=0.7)
|
32 |
+
else:
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
+
# Shift snr
|
36 |
+
betas = shift_betas_by_snr_factor(betas, shift_snr)
|
37 |
+
|
38 |
+
# Ensure terminal pure noise
|
39 |
+
# Only non-cosine schedule uses rescale, cosine schedule can directly set max_beta=1 to ensure temrinal pure noise.
|
40 |
+
if name == "squared_linear" and terminal_pure_noise:
|
41 |
+
betas = rescale_betas_to_ensure_terminal_pure_noise(betas)
|
42 |
+
|
43 |
+
return betas
|
44 |
+
|
45 |
+
|
46 |
+
def validate_betas(betas: List[float]) -> bool:
|
47 |
+
"""
|
48 |
+
Validate betas is monotonic and within 0 to 1 range, i.e. 0 < beta_{t-1} < beta_{t} <= 1
|
49 |
+
|
50 |
+
Args:
|
51 |
+
betas (List[float]): betas
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
bool: True if betas is correct
|
55 |
+
"""
|
56 |
+
return all(b1 < b2 for b1, b2 in zip(betas, betas[1:])) and betas[0] > 0 and betas[-1] <= 1
|
57 |
+
|
58 |
+
|
59 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta=0.999):
|
60 |
+
betas = []
|
61 |
+
for i in range(num_diffusion_timesteps):
|
62 |
+
t1 = i / num_diffusion_timesteps
|
63 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
64 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
65 |
+
if not validate_betas(betas):
|
66 |
+
import logging
|
67 |
+
logging.warning("No feasible betas for given alpha bar")
|
68 |
+
return torch.tensor(betas, dtype=torch.float32)
|
69 |
+
|
70 |
+
|
71 |
+
def get_cosine_betas(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
72 |
+
def alpha_bar_fn(time_step):
|
73 |
+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
74 |
+
return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta)
|
75 |
+
|
76 |
+
|
77 |
+
def get_sigmoid_betas(num_diffusion_timesteps, max_beta, square=False, slop=1):
|
78 |
+
def alpha_bar_fn(t):
|
79 |
+
def sigmoid(x):
|
80 |
+
return 1 / (1 + math.exp(-x * slop))
|
81 |
+
s = 6 # (-6, 6) from geodiff
|
82 |
+
vb = sigmoid(-s)
|
83 |
+
ve = sigmoid(s)
|
84 |
+
return ((sigmoid(s - t * 2 * s) - vb) / (ve - vb))**(1 if not square else 2)
|
85 |
+
return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta)
|
86 |
+
|
87 |
+
|
88 |
+
def get_alphas_cumprod_linear_betas(num_diffusion_timesteps, max_beta):
|
89 |
+
def alpha_bar_fn(t):
|
90 |
+
return 1 - t
|
91 |
+
return betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar_fn, max_beta=max_beta)
|
92 |
+
|
93 |
+
|
94 |
+
def shift_betas_by_snr_factor(betas: torch.Tensor, factor: float) -> torch.Tensor:
|
95 |
+
if factor == 1.0:
|
96 |
+
return betas
|
97 |
+
# Convert betas to snr
|
98 |
+
alphas = 1 - betas
|
99 |
+
alphas_cumprod = alphas.cumprod(dim=0)
|
100 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
101 |
+
# Shift snr
|
102 |
+
snr *= factor
|
103 |
+
# Convert snr to betas
|
104 |
+
alphas_cumprod = snr / (1 + snr)
|
105 |
+
alphas = torch.cat(
|
106 |
+
[alphas_cumprod[0:1], alphas_cumprod[1:] / alphas_cumprod[:-1]])
|
107 |
+
betas = 1 - alphas
|
108 |
+
return betas
|
109 |
+
|
110 |
+
|
111 |
+
def rescale_betas_to_ensure_terminal_pure_noise(betas: torch.Tensor) -> torch.Tensor:
|
112 |
+
# Convert betas to alphas_cumprod_sqrt
|
113 |
+
alphas = 1 - betas
|
114 |
+
alphas_cumprod = alphas.cumprod(0)
|
115 |
+
alphas_cumprod_sqrt = alphas_cumprod.sqrt()
|
116 |
+
# Rescale alphas_cumprod_sqrt such that alphas_cumprod_sqrt[0] remains unchanged but alphas_cumprod_sqrt[-1] = 0
|
117 |
+
alphas_cumprod_sqrt = (alphas_cumprod_sqrt - alphas_cumprod_sqrt[-1]) / (
|
118 |
+
alphas_cumprod_sqrt[0] - alphas_cumprod_sqrt[-1]) * alphas_cumprod_sqrt[0]
|
119 |
+
# Convert alphas_cumprod_sqrt to betas
|
120 |
+
alphas_cumprod = alphas_cumprod_sqrt ** 2
|
121 |
+
alphas = torch.cat(
|
122 |
+
[alphas_cumprod[0:1], alphas_cumprod[1:] / alphas_cumprod[:-1]])
|
123 |
+
betas = 1 - alphas
|
124 |
+
return betas
|
utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch.nn as nn
|
16 |
+
import importlib
|
17 |
+
|
18 |
+
def zero_module(module):
|
19 |
+
if isinstance(module, nn.Linear):
|
20 |
+
module.weight.data.zero_()
|
21 |
+
if module.bias is not None:
|
22 |
+
module.bias.data.zero_()
|
23 |
+
return module
|
24 |
+
|
25 |
+
def get_obj_from_str(string, reload=False):
|
26 |
+
module, cls = string.rsplit(".", 1)
|
27 |
+
if reload:
|
28 |
+
module_imp = importlib.import_module(module)
|
29 |
+
importlib.reload(module_imp)
|
30 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
31 |
+
|
32 |
+
def instantiate_from_config(config):
|
33 |
+
if not "target" in config:
|
34 |
+
if config == '__is_first_stage__':
|
35 |
+
return None
|
36 |
+
elif config == "__is_unconditional__":
|
37 |
+
return None
|
38 |
+
raise KeyError("Expected key `target` to instantiate.")
|
39 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
40 |
+
|
41 |
+
def update_dict(old_dict, new_dict):
|
42 |
+
old_keys = old_dict.keys()
|
43 |
+
for new_key in new_dict.keys():
|
44 |
+
if new_key in old_keys:
|
45 |
+
if type(old_dict[new_key]) == list:
|
46 |
+
if type(new_dict[new_key]) == list:
|
47 |
+
old_dict[new_key].extend(new_dict[new_key])
|
48 |
+
else:
|
49 |
+
old_dict[new_key].append(new_dict[new_key])
|
50 |
+
else:
|
51 |
+
old_dict[new_key] = [old_dict[new_key]]
|
52 |
+
old_dict[new_key].append(new_dict[new_key])
|
53 |
+
else:
|
54 |
+
old_dict[new_key] = new_dict[new_key]
|
55 |
+
return old_dict
|