Spaces:
Running
on
Zero
Running
on
Zero
Upload 35 files
Browse files- .gitattributes +11 -0
- app.py +246 -0
- assets/image/100.png +3 -0
- assets/image/34933195-9c2c-4271-8d31-a28bc5348b7a.png +0 -0
- assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png +3 -0
- assets/image/579584fb-8d1c-4312-a3f0-f7a81bd16493.png +0 -0
- assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png +3 -0
- assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png +3 -0
- assets/image/e799e6b4-3b47-40e0-befb-b156af8758ad.png +0 -0
- assets/model/100.glb +3 -0
- assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb +3 -0
- assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb +3 -0
- assets/model/579584fb-8d1c-4312-a3f0-f7a81bd16493.glb +3 -0
- assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb +3 -0
- assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb +3 -0
- assets/model/e799e6b4-3b47-40e0-befb-b156af8758ad.glb +3 -0
- detailgen3d/__init__.py +0 -0
- detailgen3d/inference_utils.py +17 -0
- detailgen3d/models/attention_processor.py +576 -0
- detailgen3d/models/autoencoders/__init__.py +1 -0
- detailgen3d/models/autoencoders/autoencoder_kl_triposg.py +536 -0
- detailgen3d/models/autoencoders/vae.py +69 -0
- detailgen3d/models/embeddings.py +96 -0
- detailgen3d/models/transformers/__init__.py +61 -0
- detailgen3d/models/transformers/detailgen3d_transformers.py +771 -0
- detailgen3d/models/transformers/modeling_outputs.py +8 -0
- detailgen3d/models/transformers/triposg_transformer.py +726 -0
- detailgen3d/pipelines/__init__.py +1 -0
- detailgen3d/pipelines/pipeline_detailgen3d.py +322 -0
- detailgen3d/pipelines/pipeline_detailgen3d_output.py +13 -0
- detailgen3d/pipelines/pipeline_utils.py +96 -0
- detailgen3d/schedulers/__init__.py +5 -0
- detailgen3d/schedulers/scheduling_rectified_flow.py +327 -0
- detailgen3d/utils/__init__.py +2 -0
- detailgen3d/utils/typing.py +64 -0
- scripts/inference_detailgen3d.py +70 -0
.gitattributes
CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/image/100.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/model/100.glb filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/model/579584fb-8d1c-4312-a3f0-f7a81bd16493.glb filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/model/e799e6b4-3b47-40e0-befb-b156af8758ad.glb filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import tempfile
|
4 |
+
from typing import Any, List, Union
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from gradio_image_prompter import ImagePrompter
|
10 |
+
from gradio_litmodel3d import LitModel3D
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
from PIL import Image
|
13 |
+
import trimesh
|
14 |
+
from skimage import measure
|
15 |
+
|
16 |
+
from detailgen3d.pipelines.pipeline_detailgen3d import DetailGen3DPipeline
|
17 |
+
from detailgen3d.inference_utils import generate_dense_grid_points
|
18 |
+
|
19 |
+
# Constants
|
20 |
+
MAX_SEED = np.iinfo(np.int32).max
|
21 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
|
22 |
+
DTYPE = torch.bfloat16
|
23 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
REPO_ID = "" # 似乎还没有
|
25 |
+
|
26 |
+
MARKDOWN = """
|
27 |
+
## Generating geometry details guided by reference image with [DetailGen3D](https://detailgen3d.github.io/DetailGen3D/)
|
28 |
+
1. Upload a detailed image of the frontal view and a coarse model. Then clik "Run " to generate the refined result.
|
29 |
+
2. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
|
30 |
+
3. If you want the refine result to be more consistent with the image, please manually increase the CFG strength.
|
31 |
+
"""
|
32 |
+
EXAMPLES = [
|
33 |
+
[
|
34 |
+
{
|
35 |
+
"image": "assets/image/100.png",
|
36 |
+
},
|
37 |
+
"assets/model/100.glb",
|
38 |
+
42,
|
39 |
+
False,
|
40 |
+
],
|
41 |
+
[
|
42 |
+
{
|
43 |
+
"image": "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png",
|
44 |
+
},
|
45 |
+
"assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb",
|
46 |
+
42,
|
47 |
+
False,
|
48 |
+
],
|
49 |
+
[
|
50 |
+
{
|
51 |
+
"image": "assets/image/34933195-9c2c-4271-8d31-a28bc5348b7a.png",
|
52 |
+
},
|
53 |
+
"assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb",
|
54 |
+
42,
|
55 |
+
False,
|
56 |
+
],
|
57 |
+
[
|
58 |
+
{
|
59 |
+
"image": "assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png",
|
60 |
+
},
|
61 |
+
"assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb",
|
62 |
+
42,
|
63 |
+
False,
|
64 |
+
],
|
65 |
+
[
|
66 |
+
{
|
67 |
+
"image": "assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png",
|
68 |
+
},
|
69 |
+
"assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb",
|
70 |
+
42,
|
71 |
+
False,
|
72 |
+
],
|
73 |
+
[
|
74 |
+
{
|
75 |
+
"image": "assets/image/e799e6b4-3b47-40e0-befb-b156af8758ad.png",
|
76 |
+
},
|
77 |
+
"assets/model/instant3d/e799e6b4-3b47-40e0-befb-b156af8758ad.glb",
|
78 |
+
42,
|
79 |
+
False,
|
80 |
+
],
|
81 |
+
]
|
82 |
+
|
83 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
84 |
+
|
85 |
+
device = "cuda"
|
86 |
+
dtype = torch.float16
|
87 |
+
|
88 |
+
pipeline = DetailGen3DPipeline.from_pretrained(
|
89 |
+
"VAST-AI/DetailGen3D",
|
90 |
+
low_cpu_mem_usage=False
|
91 |
+
).to(device, dtype=dtype)
|
92 |
+
|
93 |
+
|
94 |
+
def load_mesh(mesh_path, num_pc=20480):
|
95 |
+
mesh = trimesh.load(mesh_path,force="mesh")
|
96 |
+
|
97 |
+
center = mesh.bounding_box.centroid
|
98 |
+
mesh.apply_translation(-center)
|
99 |
+
scale = max(mesh.bounding_box.extents)
|
100 |
+
mesh.apply_scale(1.9 / scale)
|
101 |
+
|
102 |
+
surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000,)
|
103 |
+
normal = mesh.face_normals[face_indices]
|
104 |
+
|
105 |
+
rng = np.random.default_rng()
|
106 |
+
ind = rng.choice(surface.shape[0], num_pc, replace=False)
|
107 |
+
surface = torch.FloatTensor(surface[ind])
|
108 |
+
normal = torch.FloatTensor(normal[ind])
|
109 |
+
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
|
110 |
+
|
111 |
+
return surface
|
112 |
+
|
113 |
+
@torch.no_grad()
|
114 |
+
@torch.autocast(device_type=DEVICE)
|
115 |
+
def run_detailgen3d(
|
116 |
+
pipeline,
|
117 |
+
image,
|
118 |
+
mesh,
|
119 |
+
seed,
|
120 |
+
num_inference_steps,
|
121 |
+
guidance_scale,
|
122 |
+
):
|
123 |
+
surface = load_mesh(mesh)
|
124 |
+
|
125 |
+
batch_size = 1
|
126 |
+
|
127 |
+
# sample query points for decoding
|
128 |
+
box_min = np.array([-1.005, -1.005, -1.005])
|
129 |
+
box_max = np.array([1.005, 1.005, 1.005])
|
130 |
+
sampled_points, grid_size, bbox_size = generate_dense_grid_points(
|
131 |
+
bbox_min=box_min, bbox_max=box_max, octree_depth=8, indexing="ij"
|
132 |
+
)
|
133 |
+
sampled_points = torch.FloatTensor(sampled_points).to(device, dtype=dtype)
|
134 |
+
sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
|
135 |
+
|
136 |
+
# inference pipeline
|
137 |
+
sample = pipeline.vae.encode(surface).latent_dist.sample()
|
138 |
+
occ = pipeline(image, latents=sample, sampled_points=sampled_points, guidance_scale=guidance_scale, noise_aug_level=0, num_inference_steps=num_inference_steps).samples[0]
|
139 |
+
|
140 |
+
# marching cubes
|
141 |
+
grid_logits = occ.view(grid_size).cpu().numpy()
|
142 |
+
vertices, faces, normals, _ = measure.marching_cubes(
|
143 |
+
grid_logits, 0, method="lewiner"
|
144 |
+
)
|
145 |
+
vertices = vertices / grid_size * bbox_size + box_min
|
146 |
+
mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
|
147 |
+
return mesh
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
@torch.autocast(device_type=DEVICE)
|
151 |
+
def run_refinement(
|
152 |
+
rgb_image: Any,
|
153 |
+
mesh: Any,
|
154 |
+
seed: int,
|
155 |
+
randomize_seed: bool = False,
|
156 |
+
num_inference_steps: int = 50,
|
157 |
+
guidance_scale: float = 4.0,
|
158 |
+
):
|
159 |
+
if randomize_seed:
|
160 |
+
seed = random.randint(0, MAX_SEED)
|
161 |
+
|
162 |
+
# print("rgb_image", rgb_image)
|
163 |
+
# print("mesh", rgb_image)
|
164 |
+
|
165 |
+
if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
|
166 |
+
rgb_image = rgb_image["image"]
|
167 |
+
|
168 |
+
scene = run_detailgen3d(
|
169 |
+
pipeline,
|
170 |
+
rgb_image,
|
171 |
+
mesh,
|
172 |
+
seed,
|
173 |
+
num_inference_steps,
|
174 |
+
guidance_scale,
|
175 |
+
)
|
176 |
+
|
177 |
+
_, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="detailgen3d_", dir=TMP_DIR)
|
178 |
+
scene.export(tmp_path)
|
179 |
+
|
180 |
+
torch.cuda.empty_cache()
|
181 |
+
|
182 |
+
return tmp_path, tmp_path, seed
|
183 |
+
|
184 |
+
# Demo
|
185 |
+
with gr.Blocks() as demo:
|
186 |
+
gr.Markdown(MARKDOWN)
|
187 |
+
|
188 |
+
with gr.Row():
|
189 |
+
with gr.Column():
|
190 |
+
with gr.Row():
|
191 |
+
image_prompts = ImagePrompter(label="Input Image", type="pil")
|
192 |
+
mesh = gr.Model3D(label="Input Coarse Model",camera_position=(90,90,3))
|
193 |
+
|
194 |
+
with gr.Accordion("Generation Settings", open=False):
|
195 |
+
seed = gr.Slider(
|
196 |
+
label="Seed",
|
197 |
+
minimum=0,
|
198 |
+
maximum=MAX_SEED,
|
199 |
+
step=1,
|
200 |
+
value=0,
|
201 |
+
)
|
202 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
203 |
+
num_inference_steps = gr.Slider(
|
204 |
+
label="Number of inference steps",
|
205 |
+
minimum=1,
|
206 |
+
maximum=50,
|
207 |
+
step=1,
|
208 |
+
value=50,
|
209 |
+
)
|
210 |
+
guidance_scale = gr.Slider(
|
211 |
+
label="CFG scale",
|
212 |
+
minimum=0.0,
|
213 |
+
maximum=50.0,
|
214 |
+
step=0.1,
|
215 |
+
value=4.0,
|
216 |
+
)
|
217 |
+
gen_button = gr.Button("Run Refinement", variant="primary")
|
218 |
+
|
219 |
+
with gr.Column():
|
220 |
+
model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500,camera_position=(90,90,3))
|
221 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
222 |
+
|
223 |
+
with gr.Row():
|
224 |
+
gr.Examples(
|
225 |
+
examples=EXAMPLES,
|
226 |
+
fn=run_refinement,
|
227 |
+
inputs=[image_prompts, mesh, seed, randomize_seed],
|
228 |
+
outputs=[model_output, download_glb, seed],
|
229 |
+
cache_examples=False,
|
230 |
+
)
|
231 |
+
|
232 |
+
gen_button.click(
|
233 |
+
run_refinement,
|
234 |
+
inputs=[
|
235 |
+
image_prompts,
|
236 |
+
mesh,
|
237 |
+
seed,
|
238 |
+
randomize_seed,
|
239 |
+
num_inference_steps,
|
240 |
+
guidance_scale,
|
241 |
+
],
|
242 |
+
outputs=[model_output, download_glb, seed],
|
243 |
+
).then(lambda: gr.Button(interactive=True), outputs=[download_glb])
|
244 |
+
|
245 |
+
|
246 |
+
demo.launch()
|
assets/image/100.png
ADDED
![]() |
Git LFS Details
|
assets/image/34933195-9c2c-4271-8d31-a28bc5348b7a.png
ADDED
![]() |
assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png
ADDED
![]() |
Git LFS Details
|
assets/image/579584fb-8d1c-4312-a3f0-f7a81bd16493.png
ADDED
![]() |
assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png
ADDED
![]() |
Git LFS Details
|
assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png
ADDED
![]() |
Git LFS Details
|
assets/image/e799e6b4-3b47-40e0-befb-b156af8758ad.png
ADDED
![]() |
assets/model/100.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:353a255228aa9a95a0607a9da07decc4f9fa72378b58773540029b62c56b0680
|
3 |
+
size 650964
|
assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f81d1f33696ad3bae8a81df4754d1a582ac819a32d5837cfd27ad3f9419f830e
|
3 |
+
size 969996
|
assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1dfd2100d0997a3ee423ee40c0a7ce40c04f99a0bb7147962444c5ab5ae8550
|
3 |
+
size 961580
|
assets/model/579584fb-8d1c-4312-a3f0-f7a81bd16493.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:30c6f40b2fb6be7e887ba554e81da39ee7ee690238712eb8fb7dde6691c131ba
|
3 |
+
size 1886840
|
assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da2b13e6939c0cef2d5c092a1814905abc948814552d8d3ff71163a2cc9e25d5
|
3 |
+
size 958896
|
assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ff09758624653fec32bda11b47c04a44c1c79f327220a953f0ba4633f7ac871
|
3 |
+
size 951492
|
assets/model/e799e6b4-3b47-40e0-befb-b156af8758ad.glb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d2843167e957e71b9a144e104dea8e41c9eba50de912aa8fd48b27e642d8983
|
3 |
+
size 944340
|
detailgen3d/__init__.py
ADDED
File without changes
|
detailgen3d/inference_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def generate_dense_grid_points(
|
5 |
+
bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij"
|
6 |
+
):
|
7 |
+
length = bbox_max - bbox_min
|
8 |
+
num_cells = np.exp2(octree_depth)
|
9 |
+
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
10 |
+
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
11 |
+
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
12 |
+
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
13 |
+
xyz = np.stack((xs, ys, zs), axis=-1)
|
14 |
+
xyz = xyz.reshape(-1, 3)
|
15 |
+
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
16 |
+
|
17 |
+
return xyz, grid_size, length
|
detailgen3d/models/attention_processor.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from diffusers.models.attention_processor import Attention
|
6 |
+
from diffusers.utils import logging
|
7 |
+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
8 |
+
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
9 |
+
from einops import rearrange
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
13 |
+
|
14 |
+
class FlashTripo2AttnProcessor2_0:
|
15 |
+
r"""
|
16 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
17 |
+
used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, topk=True):
|
21 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
22 |
+
raise ImportError(
|
23 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
24 |
+
)
|
25 |
+
self.topk = topk
|
26 |
+
|
27 |
+
def qkv(self, attn, q, k, v, attn_mask, dropout_p, is_causal):
|
28 |
+
if k.shape[-2] == 3072:
|
29 |
+
topk = 1024
|
30 |
+
elif k.shape[-2] == 512:
|
31 |
+
topk = 256
|
32 |
+
else:
|
33 |
+
topk = k.shape[-2] // 3
|
34 |
+
|
35 |
+
if self.topk is True:
|
36 |
+
q1 = q[:, :, ::100, :]
|
37 |
+
sim = q1 @ k.transpose(-1, -2)
|
38 |
+
sim = torch.mean(sim, -2)
|
39 |
+
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
40 |
+
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
41 |
+
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
42 |
+
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
43 |
+
out = F.scaled_dot_product_attention(q, k0, v0)
|
44 |
+
elif self.topk is False:
|
45 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
46 |
+
else:
|
47 |
+
idx, counts = self.topk
|
48 |
+
start = 0
|
49 |
+
outs = []
|
50 |
+
for grid_coord, count in zip(idx, counts):
|
51 |
+
end = start + count
|
52 |
+
q_chunk = q[:, :, start:end, :]
|
53 |
+
q1 = q_chunk[:, :, ::50, :]
|
54 |
+
sim = q1 @ k.transpose(-1, -2)
|
55 |
+
sim = torch.mean(sim, -2)
|
56 |
+
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
57 |
+
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
58 |
+
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
59 |
+
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
60 |
+
out = F.scaled_dot_product_attention(q_chunk, k0, v0)
|
61 |
+
outs.append(out)
|
62 |
+
start += count
|
63 |
+
out = torch.cat(outs, dim=-2)
|
64 |
+
self.topk = False
|
65 |
+
return out
|
66 |
+
|
67 |
+
def __call__(
|
68 |
+
self,
|
69 |
+
attn: Attention,
|
70 |
+
hidden_states: torch.Tensor,
|
71 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
72 |
+
attention_mask: Optional[torch.Tensor] = None,
|
73 |
+
temb: Optional[torch.Tensor] = None,
|
74 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
75 |
+
) -> torch.Tensor:
|
76 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
77 |
+
|
78 |
+
residual = hidden_states
|
79 |
+
if attn.spatial_norm is not None:
|
80 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
81 |
+
|
82 |
+
input_ndim = hidden_states.ndim
|
83 |
+
|
84 |
+
if input_ndim == 4:
|
85 |
+
batch_size, channel, height, width = hidden_states.shape
|
86 |
+
hidden_states = hidden_states.view(
|
87 |
+
batch_size, channel, height * width
|
88 |
+
).transpose(1, 2)
|
89 |
+
|
90 |
+
batch_size, sequence_length, _ = (
|
91 |
+
hidden_states.shape
|
92 |
+
if encoder_hidden_states is None
|
93 |
+
else encoder_hidden_states.shape
|
94 |
+
)
|
95 |
+
|
96 |
+
if attention_mask is not None:
|
97 |
+
attention_mask = attn.prepare_attention_mask(
|
98 |
+
attention_mask, sequence_length, batch_size
|
99 |
+
)
|
100 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
101 |
+
# (batch, heads, source_length, target_length)
|
102 |
+
attention_mask = attention_mask.view(
|
103 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
104 |
+
)
|
105 |
+
|
106 |
+
if attn.group_norm is not None:
|
107 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
108 |
+
1, 2
|
109 |
+
)
|
110 |
+
|
111 |
+
query = attn.to_q(hidden_states)
|
112 |
+
|
113 |
+
if encoder_hidden_states is None:
|
114 |
+
encoder_hidden_states = hidden_states
|
115 |
+
elif attn.norm_cross:
|
116 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
117 |
+
encoder_hidden_states
|
118 |
+
)
|
119 |
+
|
120 |
+
key = attn.to_k(encoder_hidden_states)
|
121 |
+
value = attn.to_v(encoder_hidden_states)
|
122 |
+
|
123 |
+
# NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
|
124 |
+
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
|
125 |
+
if not attn.is_cross_attention:
|
126 |
+
qkv = torch.cat((query, key, value), dim=-1)
|
127 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
128 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
129 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
130 |
+
else:
|
131 |
+
kv = torch.cat((key, value), dim=-1)
|
132 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
133 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
134 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
135 |
+
|
136 |
+
head_dim = key.shape[-1]
|
137 |
+
|
138 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
139 |
+
|
140 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
141 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
142 |
+
|
143 |
+
if attn.norm_q is not None:
|
144 |
+
query = attn.norm_q(query)
|
145 |
+
if attn.norm_k is not None:
|
146 |
+
key = attn.norm_k(key)
|
147 |
+
|
148 |
+
# Apply RoPE if needed
|
149 |
+
if image_rotary_emb is not None:
|
150 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
151 |
+
if not attn.is_cross_attention:
|
152 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
153 |
+
|
154 |
+
# flashvdm topk
|
155 |
+
hidden_states = self.qkv(attn, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
156 |
+
|
157 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
158 |
+
batch_size, -1, attn.heads * head_dim
|
159 |
+
)
|
160 |
+
hidden_states = hidden_states.to(query.dtype)
|
161 |
+
|
162 |
+
# linear proj
|
163 |
+
hidden_states = attn.to_out[0](hidden_states)
|
164 |
+
# dropout
|
165 |
+
hidden_states = attn.to_out[1](hidden_states)
|
166 |
+
|
167 |
+
if input_ndim == 4:
|
168 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
169 |
+
batch_size, channel, height, width
|
170 |
+
)
|
171 |
+
|
172 |
+
if attn.residual_connection:
|
173 |
+
hidden_states = hidden_states + residual
|
174 |
+
|
175 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
176 |
+
|
177 |
+
return hidden_states
|
178 |
+
|
179 |
+
class TripoSGAttnProcessor2_0:
|
180 |
+
r"""
|
181 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
182 |
+
used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self):
|
186 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
187 |
+
raise ImportError(
|
188 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
189 |
+
)
|
190 |
+
|
191 |
+
def __call__(
|
192 |
+
self,
|
193 |
+
attn: Attention,
|
194 |
+
hidden_states: torch.Tensor,
|
195 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
196 |
+
attention_mask: Optional[torch.Tensor] = None,
|
197 |
+
temb: Optional[torch.Tensor] = None,
|
198 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
199 |
+
) -> torch.Tensor:
|
200 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
201 |
+
|
202 |
+
residual = hidden_states
|
203 |
+
if attn.spatial_norm is not None:
|
204 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
205 |
+
|
206 |
+
input_ndim = hidden_states.ndim
|
207 |
+
|
208 |
+
if input_ndim == 4:
|
209 |
+
batch_size, channel, height, width = hidden_states.shape
|
210 |
+
hidden_states = hidden_states.view(
|
211 |
+
batch_size, channel, height * width
|
212 |
+
).transpose(1, 2)
|
213 |
+
|
214 |
+
batch_size, sequence_length, _ = (
|
215 |
+
hidden_states.shape
|
216 |
+
if encoder_hidden_states is None
|
217 |
+
else encoder_hidden_states.shape
|
218 |
+
)
|
219 |
+
|
220 |
+
if attention_mask is not None:
|
221 |
+
attention_mask = attn.prepare_attention_mask(
|
222 |
+
attention_mask, sequence_length, batch_size
|
223 |
+
)
|
224 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
225 |
+
# (batch, heads, source_length, target_length)
|
226 |
+
attention_mask = attention_mask.view(
|
227 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
228 |
+
)
|
229 |
+
|
230 |
+
if attn.group_norm is not None:
|
231 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
232 |
+
1, 2
|
233 |
+
)
|
234 |
+
|
235 |
+
query = attn.to_q(hidden_states)
|
236 |
+
|
237 |
+
if encoder_hidden_states is None:
|
238 |
+
encoder_hidden_states = hidden_states
|
239 |
+
elif attn.norm_cross:
|
240 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
241 |
+
encoder_hidden_states
|
242 |
+
)
|
243 |
+
|
244 |
+
key = attn.to_k(encoder_hidden_states)
|
245 |
+
value = attn.to_v(encoder_hidden_states)
|
246 |
+
|
247 |
+
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
|
248 |
+
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
|
249 |
+
if not attn.is_cross_attention:
|
250 |
+
qkv = torch.cat((query, key, value), dim=-1)
|
251 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
252 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
253 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
254 |
+
else:
|
255 |
+
kv = torch.cat((key, value), dim=-1)
|
256 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
257 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
258 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
259 |
+
|
260 |
+
head_dim = key.shape[-1]
|
261 |
+
|
262 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
263 |
+
|
264 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
265 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
266 |
+
|
267 |
+
if attn.norm_q is not None:
|
268 |
+
query = attn.norm_q(query)
|
269 |
+
if attn.norm_k is not None:
|
270 |
+
key = attn.norm_k(key)
|
271 |
+
|
272 |
+
# Apply RoPE if needed
|
273 |
+
if image_rotary_emb is not None:
|
274 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
275 |
+
if not attn.is_cross_attention:
|
276 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
277 |
+
|
278 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
279 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
280 |
+
hidden_states = F.scaled_dot_product_attention(
|
281 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
282 |
+
)
|
283 |
+
|
284 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
285 |
+
batch_size, -1, attn.heads * head_dim
|
286 |
+
)
|
287 |
+
hidden_states = hidden_states.to(query.dtype)
|
288 |
+
|
289 |
+
# linear proj
|
290 |
+
hidden_states = attn.to_out[0](hidden_states)
|
291 |
+
# dropout
|
292 |
+
hidden_states = attn.to_out[1](hidden_states)
|
293 |
+
|
294 |
+
if input_ndim == 4:
|
295 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
296 |
+
batch_size, channel, height, width
|
297 |
+
)
|
298 |
+
|
299 |
+
if attn.residual_connection:
|
300 |
+
hidden_states = hidden_states + residual
|
301 |
+
|
302 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
303 |
+
|
304 |
+
return hidden_states
|
305 |
+
|
306 |
+
|
307 |
+
class FusedTripoSGAttnProcessor2_0:
|
308 |
+
r"""
|
309 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
|
310 |
+
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
|
311 |
+
query and key vector.
|
312 |
+
"""
|
313 |
+
|
314 |
+
def __init__(self):
|
315 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
316 |
+
raise ImportError(
|
317 |
+
"FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
318 |
+
)
|
319 |
+
|
320 |
+
def __call__(
|
321 |
+
self,
|
322 |
+
attn: Attention,
|
323 |
+
hidden_states: torch.Tensor,
|
324 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
325 |
+
attention_mask: Optional[torch.Tensor] = None,
|
326 |
+
temb: Optional[torch.Tensor] = None,
|
327 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
328 |
+
) -> torch.Tensor:
|
329 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
330 |
+
|
331 |
+
residual = hidden_states
|
332 |
+
if attn.spatial_norm is not None:
|
333 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
334 |
+
|
335 |
+
input_ndim = hidden_states.ndim
|
336 |
+
|
337 |
+
if input_ndim == 4:
|
338 |
+
batch_size, channel, height, width = hidden_states.shape
|
339 |
+
hidden_states = hidden_states.view(
|
340 |
+
batch_size, channel, height * width
|
341 |
+
).transpose(1, 2)
|
342 |
+
|
343 |
+
batch_size, sequence_length, _ = (
|
344 |
+
hidden_states.shape
|
345 |
+
if encoder_hidden_states is None
|
346 |
+
else encoder_hidden_states.shape
|
347 |
+
)
|
348 |
+
|
349 |
+
if attention_mask is not None:
|
350 |
+
attention_mask = attn.prepare_attention_mask(
|
351 |
+
attention_mask, sequence_length, batch_size
|
352 |
+
)
|
353 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
354 |
+
# (batch, heads, source_length, target_length)
|
355 |
+
attention_mask = attention_mask.view(
|
356 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
357 |
+
)
|
358 |
+
|
359 |
+
if attn.group_norm is not None:
|
360 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
361 |
+
1, 2
|
362 |
+
)
|
363 |
+
|
364 |
+
# NOTE that pre-trained split heads first, then split qkv
|
365 |
+
if encoder_hidden_states is None:
|
366 |
+
qkv = attn.to_qkv(hidden_states)
|
367 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
368 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
369 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
370 |
+
else:
|
371 |
+
if attn.norm_cross:
|
372 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
373 |
+
encoder_hidden_states
|
374 |
+
)
|
375 |
+
query = attn.to_q(hidden_states)
|
376 |
+
|
377 |
+
kv = attn.to_kv(encoder_hidden_states)
|
378 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
379 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
380 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
381 |
+
|
382 |
+
head_dim = key.shape[-1]
|
383 |
+
|
384 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
385 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
386 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
387 |
+
|
388 |
+
if attn.norm_q is not None:
|
389 |
+
query = attn.norm_q(query)
|
390 |
+
if attn.norm_k is not None:
|
391 |
+
key = attn.norm_k(key)
|
392 |
+
|
393 |
+
# Apply RoPE if needed
|
394 |
+
if image_rotary_emb is not None:
|
395 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
396 |
+
if not attn.is_cross_attention:
|
397 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
398 |
+
|
399 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
400 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
401 |
+
hidden_states = F.scaled_dot_product_attention(
|
402 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
403 |
+
)
|
404 |
+
|
405 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
406 |
+
batch_size, -1, attn.heads * head_dim
|
407 |
+
)
|
408 |
+
hidden_states = hidden_states.to(query.dtype)
|
409 |
+
|
410 |
+
# linear proj
|
411 |
+
hidden_states = attn.to_out[0](hidden_states)
|
412 |
+
# dropout
|
413 |
+
hidden_states = attn.to_out[1](hidden_states)
|
414 |
+
|
415 |
+
if input_ndim == 4:
|
416 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
417 |
+
batch_size, channel, height, width
|
418 |
+
)
|
419 |
+
|
420 |
+
if attn.residual_connection:
|
421 |
+
hidden_states = hidden_states + residual
|
422 |
+
|
423 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
424 |
+
|
425 |
+
return hidden_states
|
426 |
+
|
427 |
+
|
428 |
+
class MIAttnProcessor2_0:
|
429 |
+
r"""
|
430 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
431 |
+
used in the TripoSG model. It applies a normalization layer and rotary embedding on query and key vector.
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self, use_mi: bool = True):
|
435 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
436 |
+
raise ImportError(
|
437 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
438 |
+
)
|
439 |
+
|
440 |
+
self.use_mi = use_mi
|
441 |
+
|
442 |
+
def __call__(
|
443 |
+
self,
|
444 |
+
attn: Attention,
|
445 |
+
hidden_states: torch.Tensor,
|
446 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
447 |
+
attention_mask: Optional[torch.Tensor] = None,
|
448 |
+
temb: Optional[torch.Tensor] = None,
|
449 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
450 |
+
num_instances: Optional[torch.IntTensor] = None,
|
451 |
+
) -> torch.Tensor:
|
452 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
453 |
+
|
454 |
+
residual = hidden_states
|
455 |
+
if attn.spatial_norm is not None:
|
456 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
457 |
+
|
458 |
+
input_ndim = hidden_states.ndim
|
459 |
+
|
460 |
+
if input_ndim == 4:
|
461 |
+
batch_size, channel, height, width = hidden_states.shape
|
462 |
+
hidden_states = hidden_states.view(
|
463 |
+
batch_size, channel, height * width
|
464 |
+
).transpose(1, 2)
|
465 |
+
|
466 |
+
batch_size, sequence_length, _ = (
|
467 |
+
hidden_states.shape
|
468 |
+
if encoder_hidden_states is None
|
469 |
+
else encoder_hidden_states.shape
|
470 |
+
)
|
471 |
+
|
472 |
+
if attention_mask is not None:
|
473 |
+
attention_mask = attn.prepare_attention_mask(
|
474 |
+
attention_mask, sequence_length, batch_size
|
475 |
+
)
|
476 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
477 |
+
# (batch, heads, source_length, target_length)
|
478 |
+
attention_mask = attention_mask.view(
|
479 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
480 |
+
)
|
481 |
+
|
482 |
+
if attn.group_norm is not None:
|
483 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
484 |
+
1, 2
|
485 |
+
)
|
486 |
+
|
487 |
+
query = attn.to_q(hidden_states)
|
488 |
+
|
489 |
+
if encoder_hidden_states is None:
|
490 |
+
encoder_hidden_states = hidden_states
|
491 |
+
elif attn.norm_cross:
|
492 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
493 |
+
encoder_hidden_states
|
494 |
+
)
|
495 |
+
|
496 |
+
key = attn.to_k(encoder_hidden_states)
|
497 |
+
value = attn.to_v(encoder_hidden_states)
|
498 |
+
|
499 |
+
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
|
500 |
+
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
|
501 |
+
if not attn.is_cross_attention:
|
502 |
+
qkv = torch.cat((query, key, value), dim=-1)
|
503 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
504 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
505 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
506 |
+
else:
|
507 |
+
kv = torch.cat((key, value), dim=-1)
|
508 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
509 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
510 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
511 |
+
|
512 |
+
head_dim = key.shape[-1]
|
513 |
+
|
514 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
515 |
+
|
516 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
517 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
518 |
+
|
519 |
+
if attn.norm_q is not None:
|
520 |
+
query = attn.norm_q(query)
|
521 |
+
if attn.norm_k is not None:
|
522 |
+
key = attn.norm_k(key)
|
523 |
+
|
524 |
+
# Apply RoPE if needed
|
525 |
+
if image_rotary_emb is not None:
|
526 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
527 |
+
if not attn.is_cross_attention:
|
528 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
529 |
+
|
530 |
+
if self.use_mi and num_instances is not None:
|
531 |
+
key = rearrange(
|
532 |
+
key, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
|
533 |
+
).repeat_interleave(num_instances, dim=0)
|
534 |
+
value = rearrange(
|
535 |
+
value, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
|
536 |
+
).repeat_interleave(num_instances, dim=0)
|
537 |
+
|
538 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
539 |
+
hidden_states = F.scaled_dot_product_attention(
|
540 |
+
query,
|
541 |
+
key,
|
542 |
+
value,
|
543 |
+
dropout_p=0.0,
|
544 |
+
is_causal=False,
|
545 |
+
)
|
546 |
+
else:
|
547 |
+
hidden_states = F.scaled_dot_product_attention(
|
548 |
+
query,
|
549 |
+
key,
|
550 |
+
value,
|
551 |
+
attn_mask=attention_mask,
|
552 |
+
dropout_p=0.0,
|
553 |
+
is_causal=False,
|
554 |
+
)
|
555 |
+
|
556 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
557 |
+
batch_size, -1, attn.heads * head_dim
|
558 |
+
)
|
559 |
+
hidden_states = hidden_states.to(query.dtype)
|
560 |
+
|
561 |
+
# linear proj
|
562 |
+
hidden_states = attn.to_out[0](hidden_states)
|
563 |
+
# dropout
|
564 |
+
hidden_states = attn.to_out[1](hidden_states)
|
565 |
+
|
566 |
+
if input_ndim == 4:
|
567 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
568 |
+
batch_size, channel, height, width
|
569 |
+
)
|
570 |
+
|
571 |
+
if attn.residual_connection:
|
572 |
+
hidden_states = hidden_states + residual
|
573 |
+
|
574 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
575 |
+
|
576 |
+
return hidden_states
|
detailgen3d/models/autoencoders/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .autoencoder_kl_triposg import TripoSGVAEModel
|
detailgen3d/models/autoencoders/autoencoder_kl_triposg.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
7 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
8 |
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
9 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
10 |
+
from diffusers.models.modeling_utils import ModelMixin
|
11 |
+
from diffusers.models.normalization import FP32LayerNorm, LayerNorm
|
12 |
+
from diffusers.utils import logging
|
13 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
14 |
+
from einops import repeat
|
15 |
+
from torch_cluster import fps
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0, FlashTripo2AttnProcessor2_0
|
19 |
+
from ..embeddings import FrequencyPositionalEmbedding
|
20 |
+
from ..transformers.triposg_transformer import DiTBlock
|
21 |
+
from .vae import DiagonalGaussianDistribution
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
24 |
+
|
25 |
+
|
26 |
+
class TripoSGEncoder(nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
in_channels: int = 3,
|
30 |
+
dim: int = 512,
|
31 |
+
num_attention_heads: int = 8,
|
32 |
+
num_layers: int = 8,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.proj_in = nn.Linear(in_channels, dim, bias=True)
|
37 |
+
|
38 |
+
self.blocks = nn.ModuleList(
|
39 |
+
[
|
40 |
+
DiTBlock(
|
41 |
+
dim=dim,
|
42 |
+
num_attention_heads=num_attention_heads,
|
43 |
+
use_self_attention=False,
|
44 |
+
use_cross_attention=True,
|
45 |
+
cross_attention_dim=dim,
|
46 |
+
cross_attention_norm_type="layer_norm",
|
47 |
+
activation_fn="gelu",
|
48 |
+
norm_type="fp32_layer_norm",
|
49 |
+
norm_eps=1e-5,
|
50 |
+
qk_norm=False,
|
51 |
+
qkv_bias=False,
|
52 |
+
) # cross attention
|
53 |
+
]
|
54 |
+
+ [
|
55 |
+
DiTBlock(
|
56 |
+
dim=dim,
|
57 |
+
num_attention_heads=num_attention_heads,
|
58 |
+
use_self_attention=True,
|
59 |
+
self_attention_norm_type="fp32_layer_norm",
|
60 |
+
use_cross_attention=False,
|
61 |
+
activation_fn="gelu",
|
62 |
+
norm_type="fp32_layer_norm",
|
63 |
+
norm_eps=1e-5,
|
64 |
+
qk_norm=False,
|
65 |
+
qkv_bias=False,
|
66 |
+
)
|
67 |
+
for _ in range(num_layers) # self attention
|
68 |
+
]
|
69 |
+
)
|
70 |
+
|
71 |
+
self.norm_out = LayerNorm(dim)
|
72 |
+
|
73 |
+
def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor):
|
74 |
+
hidden_states = self.proj_in(sample_1)
|
75 |
+
encoder_hidden_states = self.proj_in(sample_2)
|
76 |
+
|
77 |
+
for layer, block in enumerate(self.blocks):
|
78 |
+
if layer == 0:
|
79 |
+
hidden_states = block(
|
80 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
hidden_states = block(hidden_states)
|
84 |
+
|
85 |
+
hidden_states = self.norm_out(hidden_states)
|
86 |
+
|
87 |
+
return hidden_states
|
88 |
+
|
89 |
+
|
90 |
+
class TripoSGDecoder(nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
in_channels: int = 3,
|
94 |
+
out_channels: int = 1,
|
95 |
+
dim: int = 512,
|
96 |
+
num_attention_heads: int = 8,
|
97 |
+
num_layers: int = 16,
|
98 |
+
grad_type: str = "analytical",
|
99 |
+
grad_interval: float = 0.001,
|
100 |
+
):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
if grad_type not in ["numerical", "analytical"]:
|
104 |
+
raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
|
105 |
+
self.grad_type = grad_type
|
106 |
+
self.grad_interval = grad_interval
|
107 |
+
|
108 |
+
self.blocks = nn.ModuleList(
|
109 |
+
[
|
110 |
+
DiTBlock(
|
111 |
+
dim=dim,
|
112 |
+
num_attention_heads=num_attention_heads,
|
113 |
+
use_self_attention=True,
|
114 |
+
self_attention_norm_type="fp32_layer_norm",
|
115 |
+
use_cross_attention=False,
|
116 |
+
activation_fn="gelu",
|
117 |
+
norm_type="fp32_layer_norm",
|
118 |
+
norm_eps=1e-5,
|
119 |
+
qk_norm=False,
|
120 |
+
qkv_bias=False,
|
121 |
+
)
|
122 |
+
for _ in range(num_layers) # self attention
|
123 |
+
]
|
124 |
+
+ [
|
125 |
+
DiTBlock(
|
126 |
+
dim=dim,
|
127 |
+
num_attention_heads=num_attention_heads,
|
128 |
+
use_self_attention=False,
|
129 |
+
use_cross_attention=True,
|
130 |
+
cross_attention_dim=dim,
|
131 |
+
cross_attention_norm_type="layer_norm",
|
132 |
+
activation_fn="gelu",
|
133 |
+
norm_type="fp32_layer_norm",
|
134 |
+
norm_eps=1e-5,
|
135 |
+
qk_norm=False,
|
136 |
+
qkv_bias=False,
|
137 |
+
) # cross attention
|
138 |
+
]
|
139 |
+
)
|
140 |
+
|
141 |
+
self.proj_query = nn.Linear(in_channels, dim, bias=True)
|
142 |
+
|
143 |
+
self.norm_out = LayerNorm(dim)
|
144 |
+
self.proj_out = nn.Linear(dim, out_channels, bias=True)
|
145 |
+
|
146 |
+
def set_topk(self, topk):
|
147 |
+
self.blocks[-1].set_topk(topk)
|
148 |
+
|
149 |
+
def set_flash_processor(self, processor):
|
150 |
+
self.blocks[-1].set_flash_processor(processor)
|
151 |
+
|
152 |
+
def query_geometry(
|
153 |
+
self,
|
154 |
+
model_fn: callable,
|
155 |
+
queries: torch.Tensor,
|
156 |
+
sample: torch.Tensor,
|
157 |
+
grad: bool = False,
|
158 |
+
):
|
159 |
+
logits = model_fn(queries, sample)
|
160 |
+
if grad:
|
161 |
+
with torch.autocast(device_type="cuda", dtype=torch.float32):
|
162 |
+
if self.grad_type == "numerical":
|
163 |
+
interval = self.grad_interval
|
164 |
+
grad_value = []
|
165 |
+
for offset in [
|
166 |
+
(interval, 0, 0),
|
167 |
+
(0, interval, 0),
|
168 |
+
(0, 0, interval),
|
169 |
+
]:
|
170 |
+
offset_tensor = torch.tensor(offset, device=queries.device)[
|
171 |
+
None, :
|
172 |
+
]
|
173 |
+
res_p = model_fn(queries + offset_tensor, sample)[..., 0]
|
174 |
+
res_n = model_fn(queries - offset_tensor, sample)[..., 0]
|
175 |
+
grad_value.append((res_p - res_n) / (2 * interval))
|
176 |
+
grad_value = torch.stack(grad_value, dim=-1)
|
177 |
+
else:
|
178 |
+
queries_d = torch.clone(queries)
|
179 |
+
queries_d.requires_grad = True
|
180 |
+
with torch.enable_grad():
|
181 |
+
res_d = model_fn(queries_d, sample)
|
182 |
+
grad_value = torch.autograd.grad(
|
183 |
+
res_d,
|
184 |
+
[queries_d],
|
185 |
+
grad_outputs=torch.ones_like(res_d),
|
186 |
+
create_graph=self.training,
|
187 |
+
)[0]
|
188 |
+
else:
|
189 |
+
grad_value = None
|
190 |
+
|
191 |
+
return logits, grad_value
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
sample: torch.Tensor,
|
196 |
+
queries: torch.Tensor,
|
197 |
+
kv_cache: Optional[torch.Tensor] = None,
|
198 |
+
):
|
199 |
+
if kv_cache is None:
|
200 |
+
hidden_states = sample
|
201 |
+
for _, block in enumerate(self.blocks[:-1]):
|
202 |
+
hidden_states = block(hidden_states)
|
203 |
+
kv_cache = hidden_states
|
204 |
+
|
205 |
+
# query grid logits by cross attention
|
206 |
+
def query_fn(q, kv):
|
207 |
+
q = self.proj_query(q)
|
208 |
+
l = self.blocks[-1](q, encoder_hidden_states=kv)
|
209 |
+
return self.proj_out(self.norm_out(l))
|
210 |
+
|
211 |
+
logits, grad = self.query_geometry(
|
212 |
+
query_fn, queries, kv_cache, grad=self.training
|
213 |
+
)
|
214 |
+
logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1
|
215 |
+
|
216 |
+
return logits, kv_cache
|
217 |
+
|
218 |
+
|
219 |
+
class TripoSGVAEModel(ModelMixin, ConfigMixin):
|
220 |
+
@register_to_config
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
in_channels: int = 3, # NOTE xyz instead of feature dim
|
224 |
+
latent_channels: int = 64,
|
225 |
+
num_attention_heads: int = 8,
|
226 |
+
width_encoder: int = 512,
|
227 |
+
width_decoder: int = 1024,
|
228 |
+
num_layers_encoder: int = 8,
|
229 |
+
num_layers_decoder: int = 16,
|
230 |
+
embedding_type: str = "frequency",
|
231 |
+
embed_frequency: int = 8,
|
232 |
+
embed_include_pi: bool = False,
|
233 |
+
):
|
234 |
+
super().__init__()
|
235 |
+
|
236 |
+
self.out_channels = 1
|
237 |
+
|
238 |
+
if embedding_type == "frequency":
|
239 |
+
self.embedder = FrequencyPositionalEmbedding(
|
240 |
+
num_freqs=embed_frequency,
|
241 |
+
logspace=True,
|
242 |
+
input_dim=in_channels,
|
243 |
+
include_pi=embed_include_pi,
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
raise NotImplementedError(
|
247 |
+
f"Embedding type {embedding_type} is not supported."
|
248 |
+
)
|
249 |
+
|
250 |
+
self.encoder = TripoSGEncoder(
|
251 |
+
in_channels=in_channels + self.embedder.out_dim,
|
252 |
+
dim=width_encoder,
|
253 |
+
num_attention_heads=num_attention_heads,
|
254 |
+
num_layers=num_layers_encoder,
|
255 |
+
)
|
256 |
+
self.decoder = TripoSGDecoder(
|
257 |
+
in_channels=self.embedder.out_dim,
|
258 |
+
out_channels=self.out_channels,
|
259 |
+
dim=width_decoder,
|
260 |
+
num_attention_heads=num_attention_heads,
|
261 |
+
num_layers=num_layers_decoder,
|
262 |
+
)
|
263 |
+
|
264 |
+
self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True)
|
265 |
+
self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
|
266 |
+
|
267 |
+
self.use_slicing = False
|
268 |
+
self.slicing_length = 1
|
269 |
+
|
270 |
+
def set_flash_decoder(self):
|
271 |
+
self.decoder.set_flash_processor(FlashTripo2AttnProcessor2_0())
|
272 |
+
|
273 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
|
274 |
+
def fuse_qkv_projections(self):
|
275 |
+
"""
|
276 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
277 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
278 |
+
|
279 |
+
<Tip warning={true}>
|
280 |
+
|
281 |
+
This API is 🧪 experimental.
|
282 |
+
|
283 |
+
</Tip>
|
284 |
+
"""
|
285 |
+
self.original_attn_processors = None
|
286 |
+
|
287 |
+
for _, attn_processor in self.attn_processors.items():
|
288 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
289 |
+
raise ValueError(
|
290 |
+
"`fuse_qkv_projections()` is not supported for models having added KV projections."
|
291 |
+
)
|
292 |
+
|
293 |
+
self.original_attn_processors = self.attn_processors
|
294 |
+
|
295 |
+
for module in self.modules():
|
296 |
+
if isinstance(module, Attention):
|
297 |
+
module.fuse_projections(fuse=True)
|
298 |
+
|
299 |
+
self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
|
300 |
+
|
301 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
302 |
+
def unfuse_qkv_projections(self):
|
303 |
+
"""Disables the fused QKV projection if enabled.
|
304 |
+
|
305 |
+
<Tip warning={true}>
|
306 |
+
|
307 |
+
This API is 🧪 experimental.
|
308 |
+
|
309 |
+
</Tip>
|
310 |
+
|
311 |
+
"""
|
312 |
+
if self.original_attn_processors is not None:
|
313 |
+
self.set_attn_processor(self.original_attn_processors)
|
314 |
+
|
315 |
+
@property
|
316 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
317 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
318 |
+
r"""
|
319 |
+
Returns:
|
320 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
321 |
+
indexed by its weight name.
|
322 |
+
"""
|
323 |
+
# set recursively
|
324 |
+
processors = {}
|
325 |
+
|
326 |
+
def fn_recursive_add_processors(
|
327 |
+
name: str,
|
328 |
+
module: torch.nn.Module,
|
329 |
+
processors: Dict[str, AttentionProcessor],
|
330 |
+
):
|
331 |
+
if hasattr(module, "get_processor"):
|
332 |
+
processors[f"{name}.processor"] = module.get_processor()
|
333 |
+
|
334 |
+
for sub_name, child in module.named_children():
|
335 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
336 |
+
|
337 |
+
return processors
|
338 |
+
|
339 |
+
for name, module in self.named_children():
|
340 |
+
fn_recursive_add_processors(name, module, processors)
|
341 |
+
|
342 |
+
return processors
|
343 |
+
|
344 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
345 |
+
def set_attn_processor(
|
346 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
347 |
+
):
|
348 |
+
r"""
|
349 |
+
Sets the attention processor to use to compute attention.
|
350 |
+
|
351 |
+
Parameters:
|
352 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
353 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
354 |
+
for **all** `Attention` layers.
|
355 |
+
|
356 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
357 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
358 |
+
|
359 |
+
"""
|
360 |
+
count = len(self.attn_processors.keys())
|
361 |
+
|
362 |
+
if isinstance(processor, dict) and len(processor) != count:
|
363 |
+
raise ValueError(
|
364 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
365 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
366 |
+
)
|
367 |
+
|
368 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
369 |
+
if hasattr(module, "set_processor"):
|
370 |
+
if not isinstance(processor, dict):
|
371 |
+
module.set_processor(processor)
|
372 |
+
else:
|
373 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
374 |
+
|
375 |
+
for sub_name, child in module.named_children():
|
376 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
377 |
+
|
378 |
+
for name, module in self.named_children():
|
379 |
+
fn_recursive_attn_processor(name, module, processor)
|
380 |
+
|
381 |
+
def set_default_attn_processor(self):
|
382 |
+
"""
|
383 |
+
Disables custom attention processors and sets the default attention implementation.
|
384 |
+
"""
|
385 |
+
self.set_attn_processor(TripoSGAttnProcessor2_0())
|
386 |
+
|
387 |
+
def enable_slicing(self, slicing_length: int = 1) -> None:
|
388 |
+
r"""
|
389 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
390 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
391 |
+
"""
|
392 |
+
self.use_slicing = True
|
393 |
+
self.slicing_length = slicing_length
|
394 |
+
|
395 |
+
def disable_slicing(self) -> None:
|
396 |
+
r"""
|
397 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
398 |
+
decoding in one step.
|
399 |
+
"""
|
400 |
+
self.use_slicing = False
|
401 |
+
|
402 |
+
def _sample_features(
|
403 |
+
self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
|
404 |
+
):
|
405 |
+
"""
|
406 |
+
Sample points from features of the input point cloud.
|
407 |
+
|
408 |
+
Args:
|
409 |
+
x (torch.Tensor): The input point cloud. shape: (B, N, C)
|
410 |
+
num_tokens (int, optional): The number of points to sample. Defaults to 2048.
|
411 |
+
seed (Optional[int], optional): The random seed. Defaults to None.
|
412 |
+
"""
|
413 |
+
rng = np.random.default_rng(seed)
|
414 |
+
indices = rng.choice(
|
415 |
+
x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
|
416 |
+
)
|
417 |
+
selected_points = x[:, indices]
|
418 |
+
|
419 |
+
batch_size, num_points, num_channels = selected_points.shape
|
420 |
+
flattened_points = selected_points.view(batch_size * num_points, num_channels)
|
421 |
+
batch_indices = (
|
422 |
+
torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
|
423 |
+
)
|
424 |
+
|
425 |
+
# fps sampling
|
426 |
+
sampling_ratio = 1.0 / 4
|
427 |
+
sampled_indices = fps(
|
428 |
+
flattened_points[:, :3],
|
429 |
+
batch_indices,
|
430 |
+
ratio=sampling_ratio,
|
431 |
+
random_start=self.training,
|
432 |
+
)
|
433 |
+
sampled_points = flattened_points[sampled_indices].view(
|
434 |
+
batch_size, -1, num_channels
|
435 |
+
)
|
436 |
+
|
437 |
+
return sampled_points
|
438 |
+
|
439 |
+
def _encode(
|
440 |
+
self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
|
441 |
+
):
|
442 |
+
position_channels = self.config.in_channels
|
443 |
+
positions, features = x[..., :position_channels], x[..., position_channels:]
|
444 |
+
x_kv = torch.cat([self.embedder(positions), features], dim=-1)
|
445 |
+
|
446 |
+
sampled_x = self._sample_features(x, num_tokens, seed)
|
447 |
+
positions, features = (
|
448 |
+
sampled_x[..., :position_channels],
|
449 |
+
sampled_x[..., position_channels:],
|
450 |
+
)
|
451 |
+
x_q = torch.cat([self.embedder(positions), features], dim=-1)
|
452 |
+
|
453 |
+
x = self.encoder(x_q, x_kv)
|
454 |
+
|
455 |
+
x = self.quant(x)
|
456 |
+
|
457 |
+
return x
|
458 |
+
|
459 |
+
@apply_forward_hook
|
460 |
+
def encode(
|
461 |
+
self, x: torch.Tensor, return_dict: bool = True, **kwargs
|
462 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
463 |
+
"""
|
464 |
+
Encode a batch of point features into latents.
|
465 |
+
"""
|
466 |
+
if self.use_slicing and x.shape[0] > 1:
|
467 |
+
encoded_slices = [
|
468 |
+
self._encode(x_slice, **kwargs)
|
469 |
+
for x_slice in x.split(self.slicing_length)
|
470 |
+
]
|
471 |
+
h = torch.cat(encoded_slices)
|
472 |
+
else:
|
473 |
+
h = self._encode(x, **kwargs)
|
474 |
+
|
475 |
+
posterior = DiagonalGaussianDistribution(h, feature_dim=-1)
|
476 |
+
|
477 |
+
if not return_dict:
|
478 |
+
return (posterior,)
|
479 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
480 |
+
|
481 |
+
def _decode(
|
482 |
+
self,
|
483 |
+
z: torch.Tensor,
|
484 |
+
sampled_points: torch.Tensor,
|
485 |
+
num_chunks: int = 50000,
|
486 |
+
to_cpu: bool = False,
|
487 |
+
return_dict: bool = True,
|
488 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
489 |
+
xyz_samples = sampled_points
|
490 |
+
|
491 |
+
z = self.post_quant(z)
|
492 |
+
|
493 |
+
num_points = xyz_samples.shape[1]
|
494 |
+
kv_cache = None
|
495 |
+
dec = []
|
496 |
+
|
497 |
+
for i in range(0, num_points, num_chunks):
|
498 |
+
queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype)
|
499 |
+
queries = self.embedder(queries)
|
500 |
+
|
501 |
+
z_, kv_cache = self.decoder(z, queries, kv_cache)
|
502 |
+
dec.append(z_ if not to_cpu else z_.cpu())
|
503 |
+
|
504 |
+
z = torch.cat(dec, dim=1)
|
505 |
+
|
506 |
+
if not return_dict:
|
507 |
+
return (z,)
|
508 |
+
|
509 |
+
return DecoderOutput(sample=z)
|
510 |
+
|
511 |
+
@apply_forward_hook
|
512 |
+
def decode(
|
513 |
+
self,
|
514 |
+
z: torch.Tensor,
|
515 |
+
sampled_points: torch.Tensor,
|
516 |
+
return_dict: bool = True,
|
517 |
+
**kwargs,
|
518 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
519 |
+
if self.use_slicing and z.shape[0] > 1:
|
520 |
+
decoded_slices = [
|
521 |
+
self._decode(z_slice, p_slice, **kwargs).sample
|
522 |
+
for z_slice, p_slice in zip(
|
523 |
+
z.split(self.slicing_length),
|
524 |
+
sampled_points.split(self.slicing_length),
|
525 |
+
)
|
526 |
+
]
|
527 |
+
decoded = torch.cat(decoded_slices)
|
528 |
+
else:
|
529 |
+
decoded = self._decode(z, sampled_points, **kwargs).sample
|
530 |
+
|
531 |
+
if not return_dict:
|
532 |
+
return (decoded,)
|
533 |
+
return DecoderOutput(sample=decoded)
|
534 |
+
|
535 |
+
def forward(self, x: torch.Tensor):
|
536 |
+
pass
|
detailgen3d/models/autoencoders/vae.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from diffusers.utils.torch_utils import randn_tensor
|
6 |
+
|
7 |
+
|
8 |
+
class DiagonalGaussianDistribution(object):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
parameters: torch.Tensor,
|
12 |
+
deterministic: bool = False,
|
13 |
+
feature_dim: int = 1,
|
14 |
+
):
|
15 |
+
self.parameters = parameters
|
16 |
+
self.feature_dim = feature_dim
|
17 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
|
18 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
19 |
+
self.deterministic = deterministic
|
20 |
+
self.std = torch.exp(0.5 * self.logvar)
|
21 |
+
self.var = torch.exp(self.logvar)
|
22 |
+
if self.deterministic:
|
23 |
+
self.var = self.std = torch.zeros_like(
|
24 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
25 |
+
)
|
26 |
+
|
27 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
28 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
29 |
+
sample = randn_tensor(
|
30 |
+
self.mean.shape,
|
31 |
+
generator=generator,
|
32 |
+
device=self.parameters.device,
|
33 |
+
dtype=self.parameters.dtype,
|
34 |
+
)
|
35 |
+
x = self.mean + self.std * sample
|
36 |
+
return x
|
37 |
+
|
38 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
39 |
+
if self.deterministic:
|
40 |
+
return torch.Tensor([0.0])
|
41 |
+
else:
|
42 |
+
if other is None:
|
43 |
+
return 0.5 * torch.sum(
|
44 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
45 |
+
dim=[1, 2, 3],
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
return 0.5 * torch.sum(
|
49 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
50 |
+
+ self.var / other.var
|
51 |
+
- 1.0
|
52 |
+
- self.logvar
|
53 |
+
+ other.logvar,
|
54 |
+
dim=[1, 2, 3],
|
55 |
+
)
|
56 |
+
|
57 |
+
def nll(
|
58 |
+
self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
|
59 |
+
) -> torch.Tensor:
|
60 |
+
if self.deterministic:
|
61 |
+
return torch.Tensor([0.0])
|
62 |
+
logtwopi = np.log(2.0 * np.pi)
|
63 |
+
return 0.5 * torch.sum(
|
64 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
65 |
+
dim=dims,
|
66 |
+
)
|
67 |
+
|
68 |
+
def mode(self) -> torch.Tensor:
|
69 |
+
return self.mean
|
detailgen3d/models/embeddings.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class FrequencyPositionalEmbedding(nn.Module):
|
6 |
+
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
7 |
+
each feature dimension of `x[..., i]` into:
|
8 |
+
[
|
9 |
+
sin(x[..., i]),
|
10 |
+
sin(f_1*x[..., i]),
|
11 |
+
sin(f_2*x[..., i]),
|
12 |
+
...
|
13 |
+
sin(f_N * x[..., i]),
|
14 |
+
cos(x[..., i]),
|
15 |
+
cos(f_1*x[..., i]),
|
16 |
+
cos(f_2*x[..., i]),
|
17 |
+
...
|
18 |
+
cos(f_N * x[..., i]),
|
19 |
+
x[..., i] # only present if include_input is True.
|
20 |
+
], here f_i is the frequency.
|
21 |
+
|
22 |
+
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
23 |
+
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
24 |
+
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
25 |
+
|
26 |
+
Args:
|
27 |
+
num_freqs (int): the number of frequencies, default is 6;
|
28 |
+
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
29 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
30 |
+
input_dim (int): the input dimension, default is 3;
|
31 |
+
include_input (bool): include the input tensor or not, default is True.
|
32 |
+
|
33 |
+
Attributes:
|
34 |
+
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
35 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
36 |
+
|
37 |
+
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
38 |
+
otherwise, it is input_dim * num_freqs * 2.
|
39 |
+
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
num_freqs: int = 6,
|
45 |
+
logspace: bool = True,
|
46 |
+
input_dim: int = 3,
|
47 |
+
include_input: bool = True,
|
48 |
+
include_pi: bool = True,
|
49 |
+
) -> None:
|
50 |
+
"""The initialization"""
|
51 |
+
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
if logspace:
|
55 |
+
frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
|
56 |
+
else:
|
57 |
+
frequencies = torch.linspace(
|
58 |
+
1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
|
59 |
+
)
|
60 |
+
|
61 |
+
if include_pi:
|
62 |
+
frequencies *= torch.pi
|
63 |
+
|
64 |
+
self.register_buffer("frequencies", frequencies, persistent=False)
|
65 |
+
self.include_input = include_input
|
66 |
+
self.num_freqs = num_freqs
|
67 |
+
|
68 |
+
self.out_dim = self.get_dims(input_dim)
|
69 |
+
|
70 |
+
def get_dims(self, input_dim):
|
71 |
+
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
72 |
+
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
73 |
+
|
74 |
+
return out_dim
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77 |
+
"""Forward process.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
x: tensor of shape [..., dim]
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
84 |
+
where temp is 1 if include_input is True and 0 otherwise.
|
85 |
+
"""
|
86 |
+
|
87 |
+
if self.num_freqs > 0:
|
88 |
+
embed = (x[..., None].contiguous() * self.frequencies).view(
|
89 |
+
*x.shape[:-1], -1
|
90 |
+
)
|
91 |
+
if self.include_input:
|
92 |
+
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
93 |
+
else:
|
94 |
+
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
95 |
+
else:
|
96 |
+
return x
|
detailgen3d/models/transformers/__init__.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional
|
2 |
+
|
3 |
+
from .detailgen3d_transformers import DetailGen3DDiTModel
|
4 |
+
|
5 |
+
|
6 |
+
def default_set_attn_proc_func(
|
7 |
+
name: str,
|
8 |
+
hidden_size: int,
|
9 |
+
cross_attention_dim: Optional[int],
|
10 |
+
ori_attn_proc: object,
|
11 |
+
) -> object:
|
12 |
+
return ori_attn_proc
|
13 |
+
|
14 |
+
|
15 |
+
def set_transformer_attn_processor(
|
16 |
+
transformer: DetailGen3DDiTModel,
|
17 |
+
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
|
18 |
+
set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func,
|
19 |
+
set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func,
|
20 |
+
set_self_attn_module_names: Optional[list[str]] = None,
|
21 |
+
set_cross_attn_1_module_names: Optional[list[str]] = None,
|
22 |
+
set_cross_attn_2_module_names: Optional[list[str]] = None,
|
23 |
+
) -> None:
|
24 |
+
do_set_processor = lambda name, module_names: (
|
25 |
+
any([name.startswith(module_name) for module_name in module_names])
|
26 |
+
if module_names is not None
|
27 |
+
else True
|
28 |
+
) # prefix match
|
29 |
+
|
30 |
+
attn_procs = {}
|
31 |
+
for name, attn_processor in transformer.attn_processors.items():
|
32 |
+
hidden_size = transformer.config.width
|
33 |
+
if name.endswith("attn1.processor"):
|
34 |
+
# self attention
|
35 |
+
attn_procs[name] = (
|
36 |
+
set_self_attn_proc_func(name, hidden_size, None, attn_processor)
|
37 |
+
if do_set_processor(name, set_self_attn_module_names)
|
38 |
+
else attn_processor
|
39 |
+
)
|
40 |
+
elif name.endswith("attn2.processor"):
|
41 |
+
# cross attention
|
42 |
+
cross_attention_dim = transformer.config.cross_attention_dim
|
43 |
+
attn_procs[name] = (
|
44 |
+
set_cross_attn_1_proc_func(
|
45 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
46 |
+
)
|
47 |
+
if do_set_processor(name, set_cross_attn_1_module_names)
|
48 |
+
else attn_processor
|
49 |
+
)
|
50 |
+
elif name.endswith("attn2_2.processor"):
|
51 |
+
# cross attention 2
|
52 |
+
cross_attention_dim = transformer.config.cross_attention_2_dim
|
53 |
+
attn_procs[name] = (
|
54 |
+
set_cross_attn_2_proc_func(
|
55 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
56 |
+
)
|
57 |
+
if do_set_processor(name, set_cross_attn_2_module_names)
|
58 |
+
else attn_processor
|
59 |
+
)
|
60 |
+
|
61 |
+
transformer.set_attn_processor(attn_procs)
|
detailgen3d/models/transformers/detailgen3d_transformers.py
ADDED
@@ -0,0 +1,771 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 VAST-AI-Research and contributors
|
2 |
+
|
3 |
+
# This code is based on Tencent HunyuanDiT (https://huggingface.co/Tencent-Hunyuan/HunyuanDiT),
|
4 |
+
# which is licensed under the Tencent Hunyuan Community License Agreement.
|
5 |
+
# Portions of this code are copied or adapted from HunyuanDiT.
|
6 |
+
# See the original license below:
|
7 |
+
|
8 |
+
# ---- Start of Tencent Hunyuan Community License Agreement ----
|
9 |
+
|
10 |
+
# TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
11 |
+
# Tencent Hunyuan DiT Release Date: 14 May 2024
|
12 |
+
# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
13 |
+
# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
14 |
+
# 1. DEFINITIONS.
|
15 |
+
# a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
16 |
+
# b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
17 |
+
# c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
18 |
+
# d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
19 |
+
# e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
20 |
+
# f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
21 |
+
# g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
22 |
+
# h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
23 |
+
# i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
|
24 |
+
# j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan DiT released at https://huggingface.co/Tencent-Hunyuan/HunyuanDiT.
|
25 |
+
# k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
26 |
+
# l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
|
27 |
+
# m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
28 |
+
# n. “including” shall mean including but not limited to.
|
29 |
+
# 2. GRANT OF RIGHTS.
|
30 |
+
# We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
31 |
+
# 3. DISTRIBUTION.
|
32 |
+
# You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
33 |
+
# a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
34 |
+
# b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
35 |
+
# c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
36 |
+
# d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
37 |
+
# You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
38 |
+
# 4. ADDITIONAL COMMERCIAL TERMS.
|
39 |
+
# If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
40 |
+
# 5. RULES OF USE.
|
41 |
+
# a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
42 |
+
# b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof).
|
43 |
+
# c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
44 |
+
# 6. INTELLECTUAL PROPERTY.
|
45 |
+
# a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
46 |
+
# b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
47 |
+
# c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
48 |
+
# d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
49 |
+
# 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
50 |
+
# a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
51 |
+
# b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
52 |
+
# c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
53 |
+
# 8. SURVIVAL AND TERMINATION.
|
54 |
+
# a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
55 |
+
# b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
56 |
+
# 9. GOVERNING LAW AND JURISDICTION.
|
57 |
+
# a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
58 |
+
# b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
59 |
+
#
|
60 |
+
# EXHIBIT A
|
61 |
+
# ACCEPTABLE USE POLICY
|
62 |
+
|
63 |
+
# Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
64 |
+
# Last modified: [insert date]
|
65 |
+
|
66 |
+
# Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
67 |
+
# 1. Outside the Territory;
|
68 |
+
# 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
69 |
+
# 3. To harm Yourself or others;
|
70 |
+
# 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
71 |
+
# 5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
72 |
+
# 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
73 |
+
# 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
74 |
+
# 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
75 |
+
# 9. To intentionally defame, disparage or otherwise harass others;
|
76 |
+
# 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
77 |
+
# 11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
78 |
+
# 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
79 |
+
# 13. To impersonate another individual without consent, authorization, or legal right;
|
80 |
+
# 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
81 |
+
# 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
82 |
+
# 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
83 |
+
# 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
84 |
+
# 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
85 |
+
# 19. For military purposes;
|
86 |
+
# 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
87 |
+
|
88 |
+
# ---- End of Tencent Hunyuan Community License Agreement ----
|
89 |
+
|
90 |
+
# Please note that the use of this code is subject to the terms and conditions
|
91 |
+
# of the Tencent Hunyuan Community License Agreement, including the Acceptable Use Policy.
|
92 |
+
|
93 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
94 |
+
|
95 |
+
import torch
|
96 |
+
import torch.utils.checkpoint
|
97 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
98 |
+
from diffusers.models.attention import FeedForward
|
99 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
100 |
+
from diffusers.models.embeddings import (
|
101 |
+
GaussianFourierProjection,
|
102 |
+
TimestepEmbedding,
|
103 |
+
Timesteps,
|
104 |
+
)
|
105 |
+
from diffusers.models.modeling_utils import ModelMixin
|
106 |
+
from diffusers.models.normalization import (
|
107 |
+
AdaLayerNormContinuous,
|
108 |
+
FP32LayerNorm,
|
109 |
+
LayerNorm,
|
110 |
+
)
|
111 |
+
from diffusers.utils import (
|
112 |
+
USE_PEFT_BACKEND,
|
113 |
+
is_torch_version,
|
114 |
+
logging,
|
115 |
+
scale_lora_layers,
|
116 |
+
unscale_lora_layers,
|
117 |
+
)
|
118 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
119 |
+
from torch import nn
|
120 |
+
|
121 |
+
from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
|
122 |
+
from .modeling_outputs import Transformer1DModelOutput
|
123 |
+
|
124 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
125 |
+
|
126 |
+
|
127 |
+
@maybe_allow_in_graph
|
128 |
+
class DiTBlock(nn.Module):
|
129 |
+
r"""
|
130 |
+
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
|
131 |
+
QKNorm
|
132 |
+
|
133 |
+
Parameters:
|
134 |
+
dim (`int`):
|
135 |
+
The number of channels in the input and output.
|
136 |
+
num_attention_heads (`int`):
|
137 |
+
The number of headsto use for multi-head attention.
|
138 |
+
cross_attention_dim (`int`,*optional*):
|
139 |
+
The size of the encoder_hidden_states vector for cross attention.
|
140 |
+
dropout(`float`, *optional*, defaults to 0.0):
|
141 |
+
The dropout probability to use.
|
142 |
+
activation_fn (`str`,*optional*, defaults to `"geglu"`):
|
143 |
+
Activation function to be used in feed-forward. .
|
144 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
145 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
146 |
+
norm_eps (`float`, *optional*, defaults to 1e-6):
|
147 |
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
148 |
+
final_dropout (`bool` *optional*, defaults to False):
|
149 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
150 |
+
ff_inner_dim (`int`, *optional*):
|
151 |
+
The size of the hidden layer in the feed-forward block. Defaults to `None`.
|
152 |
+
ff_bias (`bool`, *optional*, defaults to `True`):
|
153 |
+
Whether to use bias in the feed-forward block.
|
154 |
+
skip (`bool`, *optional*, defaults to `False`):
|
155 |
+
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
|
156 |
+
qk_norm (`bool`, *optional*, defaults to `True`):
|
157 |
+
Whether to use normalization in QK calculation. Defaults to `True`.
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
dim: int,
|
163 |
+
num_attention_heads: int,
|
164 |
+
use_self_attention: bool = True,
|
165 |
+
use_cross_attention: bool = False,
|
166 |
+
self_attention_norm_type: Optional[str] = None, # ada layer norm
|
167 |
+
cross_attention_dim: Optional[int] = None,
|
168 |
+
cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
|
169 |
+
# parallel second cross attention
|
170 |
+
use_cross_attention_2: bool = False,
|
171 |
+
cross_attention_2_dim: Optional[int] = None,
|
172 |
+
cross_attention_2_norm_type: Optional[str] = None,
|
173 |
+
dropout=0.0,
|
174 |
+
activation_fn: str = "gelu",
|
175 |
+
norm_type: str = "fp32_layer_norm", # TODO
|
176 |
+
norm_elementwise_affine: bool = True,
|
177 |
+
norm_eps: float = 1e-5,
|
178 |
+
final_dropout: bool = False,
|
179 |
+
ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
|
180 |
+
ff_bias: bool = True,
|
181 |
+
skip: bool = False,
|
182 |
+
skip_concat_front: bool = False, # [x, skip] or [skip, x]
|
183 |
+
skip_norm_last: bool = False, # this is an error
|
184 |
+
qk_norm: bool = True,
|
185 |
+
qkv_bias: bool = True,
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
self.use_self_attention = use_self_attention
|
190 |
+
self.use_cross_attention = use_cross_attention
|
191 |
+
self.use_cross_attention_2 = use_cross_attention_2
|
192 |
+
self.skip_concat_front = skip_concat_front
|
193 |
+
self.skip_norm_last = skip_norm_last
|
194 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
195 |
+
# NOTE: when new version comes, check norm2 and norm 3
|
196 |
+
# 1. Self-Attn
|
197 |
+
if use_self_attention:
|
198 |
+
if (
|
199 |
+
self_attention_norm_type == "fp32_layer_norm"
|
200 |
+
or self_attention_norm_type is None
|
201 |
+
):
|
202 |
+
self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
203 |
+
else:
|
204 |
+
raise NotImplementedError
|
205 |
+
|
206 |
+
self.attn1 = Attention(
|
207 |
+
query_dim=dim,
|
208 |
+
cross_attention_dim=None,
|
209 |
+
dim_head=dim // num_attention_heads,
|
210 |
+
heads=num_attention_heads,
|
211 |
+
qk_norm="rms_norm" if qk_norm else None,
|
212 |
+
eps=1e-6,
|
213 |
+
bias=qkv_bias,
|
214 |
+
processor=TripoSGAttnProcessor2_0(),
|
215 |
+
)
|
216 |
+
|
217 |
+
# 2. Cross-Attn
|
218 |
+
if use_cross_attention:
|
219 |
+
assert cross_attention_dim is not None
|
220 |
+
|
221 |
+
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
222 |
+
|
223 |
+
self.attn2 = Attention(
|
224 |
+
query_dim=dim,
|
225 |
+
cross_attention_dim=cross_attention_dim,
|
226 |
+
dim_head=dim // num_attention_heads,
|
227 |
+
heads=num_attention_heads,
|
228 |
+
qk_norm="rms_norm" if qk_norm else None,
|
229 |
+
cross_attention_norm=cross_attention_norm_type,
|
230 |
+
eps=1e-6,
|
231 |
+
bias=qkv_bias,
|
232 |
+
processor=TripoSGAttnProcessor2_0(),
|
233 |
+
)
|
234 |
+
|
235 |
+
# 2'. Parallel Second Cross-Attn
|
236 |
+
if use_cross_attention_2:
|
237 |
+
assert cross_attention_2_dim is not None
|
238 |
+
|
239 |
+
self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
240 |
+
|
241 |
+
self.attn2_2 = Attention(
|
242 |
+
query_dim=dim,
|
243 |
+
cross_attention_dim=cross_attention_2_dim,
|
244 |
+
dim_head=dim // num_attention_heads,
|
245 |
+
heads=num_attention_heads,
|
246 |
+
qk_norm="rms_norm" if qk_norm else None,
|
247 |
+
cross_attention_norm=cross_attention_2_norm_type,
|
248 |
+
eps=1e-6,
|
249 |
+
bias=qkv_bias,
|
250 |
+
processor=TripoSGAttnProcessor2_0(),
|
251 |
+
)
|
252 |
+
|
253 |
+
# 3. Feed-forward
|
254 |
+
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
255 |
+
|
256 |
+
self.ff = FeedForward(
|
257 |
+
dim,
|
258 |
+
dropout=dropout, ### 0.0
|
259 |
+
activation_fn=activation_fn, ### approx GeLU
|
260 |
+
final_dropout=final_dropout, ### 0.0
|
261 |
+
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
|
262 |
+
bias=ff_bias,
|
263 |
+
)
|
264 |
+
|
265 |
+
# 4. Skip Connection
|
266 |
+
if skip:
|
267 |
+
self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
|
268 |
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
269 |
+
else:
|
270 |
+
self.skip_linear = None
|
271 |
+
|
272 |
+
# 5. adaLN time embedding
|
273 |
+
self.adaln_modulation = nn.Sequential(
|
274 |
+
nn.SiLU(),
|
275 |
+
nn.Linear(dim, 9 * dim, bias=True)
|
276 |
+
)
|
277 |
+
|
278 |
+
# let chunk size default to None
|
279 |
+
self._chunk_size = None
|
280 |
+
self._chunk_dim = 0
|
281 |
+
|
282 |
+
|
283 |
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
284 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
285 |
+
# Sets chunk feed-forward
|
286 |
+
self._chunk_size = chunk_size
|
287 |
+
self._chunk_dim = dim
|
288 |
+
|
289 |
+
def forward(
|
290 |
+
self,
|
291 |
+
hidden_states: torch.Tensor,
|
292 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
293 |
+
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
294 |
+
temb: Optional[torch.Tensor] = None,
|
295 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
296 |
+
skip: Optional[torch.Tensor] = None,
|
297 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
298 |
+
) -> torch.Tensor:
|
299 |
+
# Prepare attention kwargs
|
300 |
+
attention_kwargs = attention_kwargs or {}
|
301 |
+
|
302 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
303 |
+
# 0. Long Skip Connection
|
304 |
+
if self.skip_linear is not None:
|
305 |
+
cat = torch.cat(
|
306 |
+
(
|
307 |
+
[skip, hidden_states]
|
308 |
+
if self.skip_concat_front
|
309 |
+
else [hidden_states, skip]
|
310 |
+
),
|
311 |
+
dim=-1,
|
312 |
+
)
|
313 |
+
if self.skip_norm_last:
|
314 |
+
# don't do this
|
315 |
+
hidden_states = self.skip_linear(cat)
|
316 |
+
hidden_states = self.skip_norm(hidden_states)
|
317 |
+
else:
|
318 |
+
cat = self.skip_norm(cat)
|
319 |
+
hidden_states = self.skip_linear(cat)
|
320 |
+
|
321 |
+
# 0. adaLN time embedding
|
322 |
+
shift_msa, scale_msa, gate_msa, shift_mca, scale_mca, gate_mca, shift_mlp, scale_mlp, gate_mlp = self.adaln_modulation(
|
323 |
+
temb
|
324 |
+
).chunk(9, dim=-1)
|
325 |
+
|
326 |
+
# 1. Self-Attention
|
327 |
+
if self.use_self_attention:
|
328 |
+
norm_hidden_states = self.norm1(hidden_states) * (1 + scale_msa) + shift_msa
|
329 |
+
attn_output = self.attn1(
|
330 |
+
norm_hidden_states,
|
331 |
+
image_rotary_emb=image_rotary_emb,
|
332 |
+
**attention_kwargs,
|
333 |
+
)
|
334 |
+
hidden_states = hidden_states + gate_msa * attn_output
|
335 |
+
|
336 |
+
# 2. Cross-Attention
|
337 |
+
if self.use_cross_attention:
|
338 |
+
if self.use_cross_attention_2:
|
339 |
+
hidden_states = (
|
340 |
+
hidden_states
|
341 |
+
+ self.attn2(
|
342 |
+
self.norm2(hidden_states),
|
343 |
+
encoder_hidden_states=encoder_hidden_states,
|
344 |
+
image_rotary_emb=image_rotary_emb,
|
345 |
+
**attention_kwargs,
|
346 |
+
)
|
347 |
+
+ self.attn2_2(
|
348 |
+
self.norm2_2(hidden_states),
|
349 |
+
encoder_hidden_states=encoder_hidden_states_2,
|
350 |
+
image_rotary_emb=image_rotary_emb,
|
351 |
+
**attention_kwargs,
|
352 |
+
)
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
hidden_states = hidden_states + gate_mca * self.attn2(
|
356 |
+
self.norm2(hidden_states) * (1 + scale_mca) + shift_mca,
|
357 |
+
encoder_hidden_states=encoder_hidden_states,
|
358 |
+
image_rotary_emb=image_rotary_emb,
|
359 |
+
**attention_kwargs,
|
360 |
+
)
|
361 |
+
|
362 |
+
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
|
363 |
+
mlp_inputs = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
|
364 |
+
hidden_states = hidden_states + gate_mlp * self.ff(mlp_inputs)
|
365 |
+
|
366 |
+
return hidden_states
|
367 |
+
|
368 |
+
|
369 |
+
class DetailGen3DDiTModel(ModelMixin, ConfigMixin):
|
370 |
+
"""
|
371 |
+
DetailGen3DDiT: Diffusion model with a Transformer backbone.
|
372 |
+
|
373 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
374 |
+
|
375 |
+
Parameters:
|
376 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
377 |
+
The number of heads to use for multi-head attention.
|
378 |
+
attention_head_dim (`int`, *optional*, defaults to 88):
|
379 |
+
The number of channels in each head.
|
380 |
+
in_channels (`int`, *optional*):
|
381 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
382 |
+
patch_size (`int`, *optional*):
|
383 |
+
The size of the patch to use for the input.
|
384 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
385 |
+
Activation function to use in feed-forward.
|
386 |
+
sample_size (`int`, *optional*):
|
387 |
+
The width of the latent images. This is fixed during training since it is used to learn a number of
|
388 |
+
position embeddings.
|
389 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
390 |
+
The dropout probability to use.
|
391 |
+
cross_attention_dim (`int`, *optional*):
|
392 |
+
The number of dimension in the clip text embedding.
|
393 |
+
hidden_size (`int`, *optional*):
|
394 |
+
The size of hidden layer in the conditioning embedding layers.
|
395 |
+
num_layers (`int`, *optional*, defaults to 1):
|
396 |
+
The number of layers of Transformer blocks to use.
|
397 |
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
398 |
+
The ratio of the hidden layer size to the input size.
|
399 |
+
learn_sigma (`bool`, *optional*, defaults to `True`):
|
400 |
+
Whether to predict variance.
|
401 |
+
cross_attention_dim_t5 (`int`, *optional*):
|
402 |
+
The number dimensions in t5 text embedding.
|
403 |
+
pooled_projection_dim (`int`, *optional*):
|
404 |
+
The size of the pooled projection.
|
405 |
+
text_len (`int`, *optional*):
|
406 |
+
The length of the clip text embedding.
|
407 |
+
text_len_t5 (`int`, *optional*):
|
408 |
+
The length of the T5 text embedding.
|
409 |
+
use_style_cond_and_image_meta_size (`bool`, *optional*):
|
410 |
+
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
|
411 |
+
"""
|
412 |
+
|
413 |
+
_supports_gradient_checkpointing = True
|
414 |
+
|
415 |
+
@register_to_config
|
416 |
+
def __init__(
|
417 |
+
self,
|
418 |
+
num_attention_heads: int = 12,
|
419 |
+
width: int = 768,
|
420 |
+
in_channels: int = 64,
|
421 |
+
num_layers: int = 24,
|
422 |
+
cross_attention_dim: int = 1024,
|
423 |
+
):
|
424 |
+
super().__init__()
|
425 |
+
self.out_channels = in_channels
|
426 |
+
self.num_heads = num_attention_heads
|
427 |
+
self.inner_dim = width
|
428 |
+
self.mlp_ratio = 4.0
|
429 |
+
|
430 |
+
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
431 |
+
"positional",
|
432 |
+
inner_dim=self.inner_dim,
|
433 |
+
flip_sin_to_cos=False,
|
434 |
+
freq_shift=0,
|
435 |
+
time_embedding_dim=None,
|
436 |
+
)
|
437 |
+
self.time_proj = TimestepEmbedding(
|
438 |
+
timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
|
439 |
+
)
|
440 |
+
self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
|
441 |
+
|
442 |
+
self.blocks = nn.ModuleList(
|
443 |
+
[
|
444 |
+
DiTBlock(
|
445 |
+
dim=self.inner_dim,
|
446 |
+
num_attention_heads=self.config.num_attention_heads,
|
447 |
+
use_self_attention=True,
|
448 |
+
use_cross_attention=True,
|
449 |
+
self_attention_norm_type="fp32_layer_norm",
|
450 |
+
cross_attention_dim=self.config.cross_attention_dim,
|
451 |
+
cross_attention_norm_type=None,
|
452 |
+
use_cross_attention_2=False,
|
453 |
+
cross_attention_2_norm_type=None,
|
454 |
+
activation_fn="gelu",
|
455 |
+
norm_type="fp32_layer_norm", # TODO
|
456 |
+
norm_eps=1e-5,
|
457 |
+
ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
|
458 |
+
qk_norm=False, # See http://arxiv.org/abs/2302.05442 for details.
|
459 |
+
qkv_bias=False,
|
460 |
+
)
|
461 |
+
for layer in range(num_layers)
|
462 |
+
]
|
463 |
+
)
|
464 |
+
|
465 |
+
self.norm_out = LayerNorm(self.inner_dim)
|
466 |
+
self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
|
467 |
+
|
468 |
+
self.gradient_checkpointing = False
|
469 |
+
|
470 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
471 |
+
self.gradient_checkpointing = value
|
472 |
+
|
473 |
+
def _set_time_proj(
|
474 |
+
self,
|
475 |
+
time_embedding_type: str,
|
476 |
+
inner_dim: int,
|
477 |
+
flip_sin_to_cos: bool,
|
478 |
+
freq_shift: float,
|
479 |
+
time_embedding_dim: int,
|
480 |
+
) -> Tuple[int, int]:
|
481 |
+
if time_embedding_type == "fourier":
|
482 |
+
time_embed_dim = time_embedding_dim or inner_dim * 2
|
483 |
+
if time_embed_dim % 2 != 0:
|
484 |
+
raise ValueError(
|
485 |
+
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
|
486 |
+
)
|
487 |
+
self.time_embed = GaussianFourierProjection(
|
488 |
+
time_embed_dim // 2,
|
489 |
+
set_W_to_weight=False,
|
490 |
+
log=False,
|
491 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
492 |
+
)
|
493 |
+
timestep_input_dim = time_embed_dim
|
494 |
+
elif time_embedding_type == "positional":
|
495 |
+
time_embed_dim = time_embedding_dim or inner_dim * 4
|
496 |
+
|
497 |
+
self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
498 |
+
timestep_input_dim = inner_dim
|
499 |
+
else:
|
500 |
+
raise ValueError(
|
501 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
502 |
+
)
|
503 |
+
|
504 |
+
return time_embed_dim, timestep_input_dim
|
505 |
+
|
506 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
|
507 |
+
def fuse_qkv_projections(self):
|
508 |
+
"""
|
509 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
510 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
511 |
+
|
512 |
+
<Tip warning={true}>
|
513 |
+
|
514 |
+
This API is 🧪 experimental.
|
515 |
+
|
516 |
+
</Tip>
|
517 |
+
"""
|
518 |
+
self.original_attn_processors = None
|
519 |
+
|
520 |
+
for _, attn_processor in self.attn_processors.items():
|
521 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
522 |
+
raise ValueError(
|
523 |
+
"`fuse_qkv_projections()` is not supported for models having added KV projections."
|
524 |
+
)
|
525 |
+
|
526 |
+
self.original_attn_processors = self.attn_processors
|
527 |
+
|
528 |
+
for module in self.modules():
|
529 |
+
if isinstance(module, Attention):
|
530 |
+
module.fuse_projections(fuse=True)
|
531 |
+
|
532 |
+
self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
|
533 |
+
|
534 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
535 |
+
def unfuse_qkv_projections(self):
|
536 |
+
"""Disables the fused QKV projection if enabled.
|
537 |
+
|
538 |
+
<Tip warning={true}>
|
539 |
+
|
540 |
+
This API is 🧪 experimental.
|
541 |
+
|
542 |
+
</Tip>
|
543 |
+
|
544 |
+
"""
|
545 |
+
if self.original_attn_processors is not None:
|
546 |
+
self.set_attn_processor(self.original_attn_processors)
|
547 |
+
|
548 |
+
@property
|
549 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
550 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
551 |
+
r"""
|
552 |
+
Returns:
|
553 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
554 |
+
indexed by its weight name.
|
555 |
+
"""
|
556 |
+
# set recursively
|
557 |
+
processors = {}
|
558 |
+
|
559 |
+
def fn_recursive_add_processors(
|
560 |
+
name: str,
|
561 |
+
module: torch.nn.Module,
|
562 |
+
processors: Dict[str, AttentionProcessor],
|
563 |
+
):
|
564 |
+
if hasattr(module, "get_processor"):
|
565 |
+
processors[f"{name}.processor"] = module.get_processor()
|
566 |
+
|
567 |
+
for sub_name, child in module.named_children():
|
568 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
569 |
+
|
570 |
+
return processors
|
571 |
+
|
572 |
+
for name, module in self.named_children():
|
573 |
+
fn_recursive_add_processors(name, module, processors)
|
574 |
+
|
575 |
+
return processors
|
576 |
+
|
577 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
578 |
+
def set_attn_processor(
|
579 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
580 |
+
):
|
581 |
+
r"""
|
582 |
+
Sets the attention processor to use to compute attention.
|
583 |
+
|
584 |
+
Parameters:
|
585 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
586 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
587 |
+
for **all** `Attention` layers.
|
588 |
+
|
589 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
590 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
591 |
+
|
592 |
+
"""
|
593 |
+
count = len(self.attn_processors.keys())
|
594 |
+
|
595 |
+
if isinstance(processor, dict) and len(processor) != count:
|
596 |
+
raise ValueError(
|
597 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
598 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
599 |
+
)
|
600 |
+
|
601 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
602 |
+
if hasattr(module, "set_processor"):
|
603 |
+
if not isinstance(processor, dict):
|
604 |
+
module.set_processor(processor)
|
605 |
+
else:
|
606 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
607 |
+
|
608 |
+
for sub_name, child in module.named_children():
|
609 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
610 |
+
|
611 |
+
for name, module in self.named_children():
|
612 |
+
fn_recursive_attn_processor(name, module, processor)
|
613 |
+
|
614 |
+
def set_default_attn_processor(self):
|
615 |
+
"""
|
616 |
+
Disables custom attention processors and sets the default attention implementation.
|
617 |
+
"""
|
618 |
+
self.set_attn_processor(TripoSGAttnProcessor2_0())
|
619 |
+
|
620 |
+
def forward(
|
621 |
+
self,
|
622 |
+
hidden_states: Optional[torch.Tensor],
|
623 |
+
timestep: Union[int, float, torch.LongTensor],
|
624 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
625 |
+
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
626 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
627 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
628 |
+
return_dict: bool = True,
|
629 |
+
):
|
630 |
+
"""
|
631 |
+
The [`HunyuanDiT2DModel`] forward method.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
635 |
+
The input tensor.
|
636 |
+
timestep ( `torch.LongTensor`, *optional*):
|
637 |
+
Used to indicate denoising step.
|
638 |
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
639 |
+
Conditional embeddings for cross attention layer.
|
640 |
+
encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
641 |
+
Conditional embeddings for cross attention layer.
|
642 |
+
return_dict: bool
|
643 |
+
Whether to return a dictionary.
|
644 |
+
"""
|
645 |
+
|
646 |
+
if attention_kwargs is not None:
|
647 |
+
attention_kwargs = attention_kwargs.copy()
|
648 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
649 |
+
else:
|
650 |
+
lora_scale = 1.0
|
651 |
+
|
652 |
+
if USE_PEFT_BACKEND:
|
653 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
654 |
+
scale_lora_layers(self, lora_scale)
|
655 |
+
else:
|
656 |
+
if (
|
657 |
+
attention_kwargs is not None
|
658 |
+
and attention_kwargs.get("scale", None) is not None
|
659 |
+
):
|
660 |
+
logger.warning(
|
661 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
662 |
+
)
|
663 |
+
|
664 |
+
_, N, _ = hidden_states.shape
|
665 |
+
|
666 |
+
temb = self.time_embed(timestep).to(hidden_states.dtype)
|
667 |
+
temb = self.time_proj(temb)
|
668 |
+
temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
|
669 |
+
|
670 |
+
hidden_states = self.proj_in(hidden_states)
|
671 |
+
|
672 |
+
skips = []
|
673 |
+
for layer, block in enumerate(self.blocks):
|
674 |
+
skip = None if layer <= self.config.num_layers // 2 else skips.pop()
|
675 |
+
|
676 |
+
if self.training and self.gradient_checkpointing:
|
677 |
+
|
678 |
+
def create_custom_forward(module):
|
679 |
+
def custom_forward(*inputs):
|
680 |
+
return module(*inputs)
|
681 |
+
|
682 |
+
return custom_forward
|
683 |
+
|
684 |
+
ckpt_kwargs: Dict[str, Any] = (
|
685 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
686 |
+
)
|
687 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
688 |
+
create_custom_forward(block),
|
689 |
+
hidden_states,
|
690 |
+
encoder_hidden_states,
|
691 |
+
encoder_hidden_states_2,
|
692 |
+
temb,
|
693 |
+
image_rotary_emb,
|
694 |
+
skip,
|
695 |
+
attention_kwargs,
|
696 |
+
**ckpt_kwargs,
|
697 |
+
)
|
698 |
+
else:
|
699 |
+
hidden_states = block(
|
700 |
+
hidden_states,
|
701 |
+
encoder_hidden_states=encoder_hidden_states,
|
702 |
+
encoder_hidden_states_2=encoder_hidden_states_2,
|
703 |
+
temb=temb,
|
704 |
+
image_rotary_emb=image_rotary_emb,
|
705 |
+
skip=skip,
|
706 |
+
attention_kwargs=attention_kwargs,
|
707 |
+
) # (N, L, D)
|
708 |
+
|
709 |
+
if layer < self.config.num_layers // 2:
|
710 |
+
skips.append(hidden_states)
|
711 |
+
|
712 |
+
# final layer
|
713 |
+
hidden_states = self.norm_out(hidden_states)
|
714 |
+
hidden_states = self.proj_out(hidden_states)
|
715 |
+
|
716 |
+
if USE_PEFT_BACKEND:
|
717 |
+
# remove `lora_scale` from each PEFT layer
|
718 |
+
unscale_lora_layers(self, lora_scale)
|
719 |
+
|
720 |
+
if not return_dict:
|
721 |
+
return (hidden_states,)
|
722 |
+
|
723 |
+
return Transformer1DModelOutput(sample=hidden_states)
|
724 |
+
|
725 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
726 |
+
def enable_forward_chunking(
|
727 |
+
self, chunk_size: Optional[int] = None, dim: int = 0
|
728 |
+
) -> None:
|
729 |
+
"""
|
730 |
+
Sets the attention processor to use [feed forward
|
731 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
732 |
+
|
733 |
+
Parameters:
|
734 |
+
chunk_size (`int`, *optional*):
|
735 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
736 |
+
over each tensor of dim=`dim`.
|
737 |
+
dim (`int`, *optional*, defaults to `0`):
|
738 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
739 |
+
or dim=1 (sequence length).
|
740 |
+
"""
|
741 |
+
if dim not in [0, 1]:
|
742 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
743 |
+
|
744 |
+
# By default chunk size is 1
|
745 |
+
chunk_size = chunk_size or 1
|
746 |
+
|
747 |
+
def fn_recursive_feed_forward(
|
748 |
+
module: torch.nn.Module, chunk_size: int, dim: int
|
749 |
+
):
|
750 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
751 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
752 |
+
|
753 |
+
for child in module.children():
|
754 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
755 |
+
|
756 |
+
for module in self.children():
|
757 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
758 |
+
|
759 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
760 |
+
def disable_forward_chunking(self):
|
761 |
+
def fn_recursive_feed_forward(
|
762 |
+
module: torch.nn.Module, chunk_size: int, dim: int
|
763 |
+
):
|
764 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
765 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
766 |
+
|
767 |
+
for child in module.children():
|
768 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
769 |
+
|
770 |
+
for module in self.children():
|
771 |
+
fn_recursive_feed_forward(module, None, 0)
|
detailgen3d/models/transformers/modeling_outputs.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Transformer1DModelOutput:
|
8 |
+
sample: torch.FloatTensor
|
detailgen3d/models/transformers/triposg_transformer.py
ADDED
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 VAST-AI-Research and contributors
|
2 |
+
|
3 |
+
# This code is based on Tencent HunyuanDiT (https://huggingface.co/Tencent-Hunyuan/HunyuanDiT),
|
4 |
+
# which is licensed under the Tencent Hunyuan Community License Agreement.
|
5 |
+
# Portions of this code are copied or adapted from HunyuanDiT.
|
6 |
+
# See the original license below:
|
7 |
+
|
8 |
+
# ---- Start of Tencent Hunyuan Community License Agreement ----
|
9 |
+
|
10 |
+
# TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
11 |
+
# Tencent Hunyuan DiT Release Date: 14 May 2024
|
12 |
+
# THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
13 |
+
# By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
14 |
+
# 1. DEFINITIONS.
|
15 |
+
# a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
16 |
+
# b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
17 |
+
# c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
18 |
+
# d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
19 |
+
# e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
20 |
+
# f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
21 |
+
# g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
22 |
+
# h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
23 |
+
# i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
|
24 |
+
# j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan DiT released at https://huggingface.co/Tencent-Hunyuan/HunyuanDiT.
|
25 |
+
# k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
26 |
+
# l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union.
|
27 |
+
# m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
28 |
+
# n. “including” shall mean including but not limited to.
|
29 |
+
# 2. GRANT OF RIGHTS.
|
30 |
+
# We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
31 |
+
# 3. DISTRIBUTION.
|
32 |
+
# You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
33 |
+
# a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
34 |
+
# b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
35 |
+
# c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
36 |
+
# d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
37 |
+
# You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
38 |
+
# 4. ADDITIONAL COMMERCIAL TERMS.
|
39 |
+
# If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
40 |
+
# 5. RULES OF USE.
|
41 |
+
# a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
42 |
+
# b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof).
|
43 |
+
# c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
44 |
+
# 6. INTELLECTUAL PROPERTY.
|
45 |
+
# a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
46 |
+
# b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
47 |
+
# c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
48 |
+
# d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
49 |
+
# 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
50 |
+
# a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
51 |
+
# b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
52 |
+
# c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
53 |
+
# 8. SURVIVAL AND TERMINATION.
|
54 |
+
# a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
55 |
+
# b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
56 |
+
# 9. GOVERNING LAW AND JURISDICTION.
|
57 |
+
# a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
58 |
+
# b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
59 |
+
#
|
60 |
+
# EXHIBIT A
|
61 |
+
# ACCEPTABLE USE POLICY
|
62 |
+
|
63 |
+
# Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
64 |
+
# Last modified: [insert date]
|
65 |
+
|
66 |
+
# Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
67 |
+
# 1. Outside the Territory;
|
68 |
+
# 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
69 |
+
# 3. To harm Yourself or others;
|
70 |
+
# 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
71 |
+
# 5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
72 |
+
# 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
73 |
+
# 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
74 |
+
# 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
75 |
+
# 9. To intentionally defame, disparage or otherwise harass others;
|
76 |
+
# 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
77 |
+
# 11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
78 |
+
# 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
79 |
+
# 13. To impersonate another individual without consent, authorization, or legal right;
|
80 |
+
# 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
81 |
+
# 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
82 |
+
# 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
83 |
+
# 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
84 |
+
# 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
85 |
+
# 19. For military purposes;
|
86 |
+
# 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
87 |
+
|
88 |
+
# ---- End of Tencent Hunyuan Community License Agreement ----
|
89 |
+
|
90 |
+
# Please note that the use of this code is subject to the terms and conditions
|
91 |
+
# of the Tencent Hunyuan Community License Agreement, including the Acceptable Use Policy.
|
92 |
+
|
93 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
94 |
+
|
95 |
+
import torch
|
96 |
+
import torch.utils.checkpoint
|
97 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
98 |
+
from diffusers.loaders import PeftAdapterMixin
|
99 |
+
from diffusers.models.attention import FeedForward
|
100 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
101 |
+
from diffusers.models.embeddings import (
|
102 |
+
GaussianFourierProjection,
|
103 |
+
TimestepEmbedding,
|
104 |
+
Timesteps,
|
105 |
+
)
|
106 |
+
from diffusers.models.modeling_utils import ModelMixin
|
107 |
+
from diffusers.models.normalization import (
|
108 |
+
AdaLayerNormContinuous,
|
109 |
+
FP32LayerNorm,
|
110 |
+
LayerNorm,
|
111 |
+
)
|
112 |
+
from diffusers.utils import (
|
113 |
+
USE_PEFT_BACKEND,
|
114 |
+
is_torch_version,
|
115 |
+
logging,
|
116 |
+
scale_lora_layers,
|
117 |
+
unscale_lora_layers,
|
118 |
+
)
|
119 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
120 |
+
from torch import nn
|
121 |
+
|
122 |
+
from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
|
123 |
+
from .modeling_outputs import Transformer1DModelOutput
|
124 |
+
|
125 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
126 |
+
|
127 |
+
|
128 |
+
@maybe_allow_in_graph
|
129 |
+
class DiTBlock(nn.Module):
|
130 |
+
r"""
|
131 |
+
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
|
132 |
+
QKNorm
|
133 |
+
|
134 |
+
Parameters:
|
135 |
+
dim (`int`):
|
136 |
+
The number of channels in the input and output.
|
137 |
+
num_attention_heads (`int`):
|
138 |
+
The number of headsto use for multi-head attention.
|
139 |
+
cross_attention_dim (`int`,*optional*):
|
140 |
+
The size of the encoder_hidden_states vector for cross attention.
|
141 |
+
dropout(`float`, *optional*, defaults to 0.0):
|
142 |
+
The dropout probability to use.
|
143 |
+
activation_fn (`str`,*optional*, defaults to `"geglu"`):
|
144 |
+
Activation function to be used in feed-forward. .
|
145 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
146 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
147 |
+
norm_eps (`float`, *optional*, defaults to 1e-6):
|
148 |
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
149 |
+
final_dropout (`bool` *optional*, defaults to False):
|
150 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
151 |
+
ff_inner_dim (`int`, *optional*):
|
152 |
+
The size of the hidden layer in the feed-forward block. Defaults to `None`.
|
153 |
+
ff_bias (`bool`, *optional*, defaults to `True`):
|
154 |
+
Whether to use bias in the feed-forward block.
|
155 |
+
skip (`bool`, *optional*, defaults to `False`):
|
156 |
+
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
|
157 |
+
qk_norm (`bool`, *optional*, defaults to `True`):
|
158 |
+
Whether to use normalization in QK calculation. Defaults to `True`.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
dim: int,
|
164 |
+
num_attention_heads: int,
|
165 |
+
use_self_attention: bool = True,
|
166 |
+
self_attention_norm_type: Optional[str] = None,
|
167 |
+
use_cross_attention: bool = True, # ada layer norm
|
168 |
+
cross_attention_dim: Optional[int] = None,
|
169 |
+
cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
|
170 |
+
dropout=0.0,
|
171 |
+
activation_fn: str = "gelu",
|
172 |
+
norm_type: str = "fp32_layer_norm", # TODO
|
173 |
+
norm_elementwise_affine: bool = True,
|
174 |
+
norm_eps: float = 1e-5,
|
175 |
+
final_dropout: bool = False,
|
176 |
+
ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
|
177 |
+
ff_bias: bool = True,
|
178 |
+
skip: bool = False,
|
179 |
+
skip_concat_front: bool = False, # [x, skip] or [skip, x]
|
180 |
+
skip_norm_last: bool = False, # this is an error
|
181 |
+
qk_norm: bool = True,
|
182 |
+
qkv_bias: bool = True,
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
self.use_self_attention = use_self_attention
|
187 |
+
self.use_cross_attention = use_cross_attention
|
188 |
+
self.skip_concat_front = skip_concat_front
|
189 |
+
self.skip_norm_last = skip_norm_last
|
190 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
191 |
+
# NOTE: when new version comes, check norm2 and norm 3
|
192 |
+
# 1. Self-Attn
|
193 |
+
if use_self_attention:
|
194 |
+
if (
|
195 |
+
self_attention_norm_type == "fp32_layer_norm"
|
196 |
+
or self_attention_norm_type is None
|
197 |
+
):
|
198 |
+
self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
199 |
+
else:
|
200 |
+
raise NotImplementedError
|
201 |
+
|
202 |
+
self.attn1 = Attention(
|
203 |
+
query_dim=dim,
|
204 |
+
cross_attention_dim=None,
|
205 |
+
dim_head=dim // num_attention_heads,
|
206 |
+
heads=num_attention_heads,
|
207 |
+
qk_norm="rms_norm" if qk_norm else None,
|
208 |
+
eps=1e-6,
|
209 |
+
bias=qkv_bias,
|
210 |
+
processor=TripoSGAttnProcessor2_0(),
|
211 |
+
)
|
212 |
+
|
213 |
+
# 2. Cross-Attn
|
214 |
+
if use_cross_attention:
|
215 |
+
assert cross_attention_dim is not None
|
216 |
+
|
217 |
+
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
218 |
+
|
219 |
+
self.attn2 = Attention(
|
220 |
+
query_dim=dim,
|
221 |
+
cross_attention_dim=cross_attention_dim,
|
222 |
+
dim_head=dim // num_attention_heads,
|
223 |
+
heads=num_attention_heads,
|
224 |
+
qk_norm="rms_norm" if qk_norm else None,
|
225 |
+
cross_attention_norm=cross_attention_norm_type,
|
226 |
+
eps=1e-6,
|
227 |
+
bias=qkv_bias,
|
228 |
+
processor=TripoSGAttnProcessor2_0(),
|
229 |
+
)
|
230 |
+
|
231 |
+
# 3. Feed-forward
|
232 |
+
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
233 |
+
|
234 |
+
self.ff = FeedForward(
|
235 |
+
dim,
|
236 |
+
dropout=dropout, ### 0.0
|
237 |
+
activation_fn=activation_fn, ### approx GeLU
|
238 |
+
final_dropout=final_dropout, ### 0.0
|
239 |
+
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
|
240 |
+
bias=ff_bias,
|
241 |
+
)
|
242 |
+
|
243 |
+
# 4. Skip Connection
|
244 |
+
if skip:
|
245 |
+
self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
|
246 |
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
247 |
+
else:
|
248 |
+
self.skip_linear = None
|
249 |
+
|
250 |
+
# let chunk size default to None
|
251 |
+
self._chunk_size = None
|
252 |
+
self._chunk_dim = 0
|
253 |
+
|
254 |
+
def set_topk(self, topk):
|
255 |
+
self.flash_processor.topk = topk
|
256 |
+
|
257 |
+
def set_flash_processor(self, flash_processor):
|
258 |
+
self.flash_processor = flash_processor
|
259 |
+
self.attn2.processor = self.flash_processor
|
260 |
+
|
261 |
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
262 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
263 |
+
# Sets chunk feed-forward
|
264 |
+
self._chunk_size = chunk_size
|
265 |
+
self._chunk_dim = dim
|
266 |
+
|
267 |
+
def forward(
|
268 |
+
self,
|
269 |
+
hidden_states: torch.Tensor,
|
270 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
271 |
+
temb: Optional[torch.Tensor] = None,
|
272 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
273 |
+
skip: Optional[torch.Tensor] = None,
|
274 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
275 |
+
) -> torch.Tensor:
|
276 |
+
# Prepare attention kwargs
|
277 |
+
attention_kwargs = attention_kwargs or {}
|
278 |
+
|
279 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
280 |
+
# 0. Long Skip Connection
|
281 |
+
if self.skip_linear is not None:
|
282 |
+
cat = torch.cat(
|
283 |
+
(
|
284 |
+
[skip, hidden_states]
|
285 |
+
if self.skip_concat_front
|
286 |
+
else [hidden_states, skip]
|
287 |
+
),
|
288 |
+
dim=-1,
|
289 |
+
)
|
290 |
+
if self.skip_norm_last:
|
291 |
+
# don't do this
|
292 |
+
hidden_states = self.skip_linear(cat)
|
293 |
+
hidden_states = self.skip_norm(hidden_states)
|
294 |
+
else:
|
295 |
+
cat = self.skip_norm(cat)
|
296 |
+
hidden_states = self.skip_linear(cat)
|
297 |
+
|
298 |
+
# 1. Self-Attention
|
299 |
+
if self.use_self_attention:
|
300 |
+
norm_hidden_states = self.norm1(hidden_states)
|
301 |
+
attn_output = self.attn1(
|
302 |
+
norm_hidden_states,
|
303 |
+
image_rotary_emb=image_rotary_emb,
|
304 |
+
**attention_kwargs,
|
305 |
+
)
|
306 |
+
hidden_states = hidden_states + attn_output
|
307 |
+
|
308 |
+
# 2. Cross-Attention
|
309 |
+
if self.use_cross_attention:
|
310 |
+
hidden_states = hidden_states + self.attn2(
|
311 |
+
self.norm2(hidden_states),
|
312 |
+
encoder_hidden_states=encoder_hidden_states,
|
313 |
+
image_rotary_emb=image_rotary_emb,
|
314 |
+
**attention_kwargs,
|
315 |
+
)
|
316 |
+
|
317 |
+
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
|
318 |
+
mlp_inputs = self.norm3(hidden_states)
|
319 |
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
320 |
+
|
321 |
+
return hidden_states
|
322 |
+
|
323 |
+
|
324 |
+
class TripoSGDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
325 |
+
"""
|
326 |
+
TripoSG: Diffusion model with a Transformer backbone.
|
327 |
+
|
328 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
329 |
+
|
330 |
+
Parameters:
|
331 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
332 |
+
The number of heads to use for multi-head attention.
|
333 |
+
attention_head_dim (`int`, *optional*, defaults to 88):
|
334 |
+
The number of channels in each head.
|
335 |
+
in_channels (`int`, *optional*):
|
336 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
337 |
+
patch_size (`int`, *optional*):
|
338 |
+
The size of the patch to use for the input.
|
339 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
340 |
+
Activation function to use in feed-forward.
|
341 |
+
sample_size (`int`, *optional*):
|
342 |
+
The width of the latent images. This is fixed during training since it is used to learn a number of
|
343 |
+
position embeddings.
|
344 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
345 |
+
The dropout probability to use.
|
346 |
+
cross_attention_dim (`int`, *optional*):
|
347 |
+
The number of dimension in the clip text embedding.
|
348 |
+
hidden_size (`int`, *optional*):
|
349 |
+
The size of hidden layer in the conditioning embedding layers.
|
350 |
+
num_layers (`int`, *optional*, defaults to 1):
|
351 |
+
The number of layers of Transformer blocks to use.
|
352 |
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
353 |
+
The ratio of the hidden layer size to the input size.
|
354 |
+
learn_sigma (`bool`, *optional*, defaults to `True`):
|
355 |
+
Whether to predict variance.
|
356 |
+
cross_attention_dim_t5 (`int`, *optional*):
|
357 |
+
The number dimensions in t5 text embedding.
|
358 |
+
pooled_projection_dim (`int`, *optional*):
|
359 |
+
The size of the pooled projection.
|
360 |
+
text_len (`int`, *optional*):
|
361 |
+
The length of the clip text embedding.
|
362 |
+
text_len_t5 (`int`, *optional*):
|
363 |
+
The length of the T5 text embedding.
|
364 |
+
use_style_cond_and_image_meta_size (`bool`, *optional*):
|
365 |
+
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
|
366 |
+
"""
|
367 |
+
|
368 |
+
_supports_gradient_checkpointing = True
|
369 |
+
|
370 |
+
@register_to_config
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
num_attention_heads: int = 16,
|
374 |
+
width: int = 2048,
|
375 |
+
in_channels: int = 64,
|
376 |
+
num_layers: int = 21,
|
377 |
+
cross_attention_dim: int = 1024,
|
378 |
+
):
|
379 |
+
super().__init__()
|
380 |
+
self.out_channels = in_channels
|
381 |
+
self.num_heads = num_attention_heads
|
382 |
+
self.inner_dim = width
|
383 |
+
self.mlp_ratio = 4.0
|
384 |
+
|
385 |
+
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
386 |
+
"positional",
|
387 |
+
inner_dim=self.inner_dim,
|
388 |
+
flip_sin_to_cos=False,
|
389 |
+
freq_shift=0,
|
390 |
+
time_embedding_dim=None,
|
391 |
+
)
|
392 |
+
self.time_proj = TimestepEmbedding(
|
393 |
+
timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
|
394 |
+
)
|
395 |
+
self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
|
396 |
+
|
397 |
+
self.blocks = nn.ModuleList(
|
398 |
+
[
|
399 |
+
DiTBlock(
|
400 |
+
dim=self.inner_dim,
|
401 |
+
num_attention_heads=self.config.num_attention_heads,
|
402 |
+
use_self_attention=True,
|
403 |
+
self_attention_norm_type="fp32_layer_norm",
|
404 |
+
use_cross_attention=True,
|
405 |
+
cross_attention_dim=cross_attention_dim,
|
406 |
+
cross_attention_norm_type=None,
|
407 |
+
activation_fn="gelu",
|
408 |
+
norm_type="fp32_layer_norm", # TODO
|
409 |
+
norm_eps=1e-5,
|
410 |
+
ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
|
411 |
+
skip=layer > num_layers // 2,
|
412 |
+
skip_concat_front=True,
|
413 |
+
skip_norm_last=True, # this is an error
|
414 |
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
415 |
+
qkv_bias=False,
|
416 |
+
)
|
417 |
+
for layer in range(num_layers)
|
418 |
+
]
|
419 |
+
)
|
420 |
+
|
421 |
+
self.norm_out = LayerNorm(self.inner_dim)
|
422 |
+
self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
|
423 |
+
|
424 |
+
self.gradient_checkpointing = False
|
425 |
+
|
426 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
427 |
+
self.gradient_checkpointing = value
|
428 |
+
|
429 |
+
def _set_time_proj(
|
430 |
+
self,
|
431 |
+
time_embedding_type: str,
|
432 |
+
inner_dim: int,
|
433 |
+
flip_sin_to_cos: bool,
|
434 |
+
freq_shift: float,
|
435 |
+
time_embedding_dim: int,
|
436 |
+
) -> Tuple[int, int]:
|
437 |
+
if time_embedding_type == "fourier":
|
438 |
+
time_embed_dim = time_embedding_dim or inner_dim * 2
|
439 |
+
if time_embed_dim % 2 != 0:
|
440 |
+
raise ValueError(
|
441 |
+
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
|
442 |
+
)
|
443 |
+
self.time_embed = GaussianFourierProjection(
|
444 |
+
time_embed_dim // 2,
|
445 |
+
set_W_to_weight=False,
|
446 |
+
log=False,
|
447 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
448 |
+
)
|
449 |
+
timestep_input_dim = time_embed_dim
|
450 |
+
elif time_embedding_type == "positional":
|
451 |
+
time_embed_dim = time_embedding_dim or inner_dim * 4
|
452 |
+
|
453 |
+
self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
454 |
+
timestep_input_dim = inner_dim
|
455 |
+
else:
|
456 |
+
raise ValueError(
|
457 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
458 |
+
)
|
459 |
+
|
460 |
+
return time_embed_dim, timestep_input_dim
|
461 |
+
|
462 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
|
463 |
+
def fuse_qkv_projections(self):
|
464 |
+
"""
|
465 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
466 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
467 |
+
|
468 |
+
<Tip warning={true}>
|
469 |
+
|
470 |
+
This API is 🧪 experimental.
|
471 |
+
|
472 |
+
</Tip>
|
473 |
+
"""
|
474 |
+
self.original_attn_processors = None
|
475 |
+
|
476 |
+
for _, attn_processor in self.attn_processors.items():
|
477 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
478 |
+
raise ValueError(
|
479 |
+
"`fuse_qkv_projections()` is not supported for models having added KV projections."
|
480 |
+
)
|
481 |
+
|
482 |
+
self.original_attn_processors = self.attn_processors
|
483 |
+
|
484 |
+
for module in self.modules():
|
485 |
+
if isinstance(module, Attention):
|
486 |
+
module.fuse_projections(fuse=True)
|
487 |
+
|
488 |
+
self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
|
489 |
+
|
490 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
491 |
+
def unfuse_qkv_projections(self):
|
492 |
+
"""Disables the fused QKV projection if enabled.
|
493 |
+
|
494 |
+
<Tip warning={true}>
|
495 |
+
|
496 |
+
This API is 🧪 experimental.
|
497 |
+
|
498 |
+
</Tip>
|
499 |
+
|
500 |
+
"""
|
501 |
+
if self.original_attn_processors is not None:
|
502 |
+
self.set_attn_processor(self.original_attn_processors)
|
503 |
+
|
504 |
+
@property
|
505 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
506 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
507 |
+
r"""
|
508 |
+
Returns:
|
509 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
510 |
+
indexed by its weight name.
|
511 |
+
"""
|
512 |
+
# set recursively
|
513 |
+
processors = {}
|
514 |
+
|
515 |
+
def fn_recursive_add_processors(
|
516 |
+
name: str,
|
517 |
+
module: torch.nn.Module,
|
518 |
+
processors: Dict[str, AttentionProcessor],
|
519 |
+
):
|
520 |
+
if hasattr(module, "get_processor"):
|
521 |
+
processors[f"{name}.processor"] = module.get_processor()
|
522 |
+
|
523 |
+
for sub_name, child in module.named_children():
|
524 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
525 |
+
|
526 |
+
return processors
|
527 |
+
|
528 |
+
for name, module in self.named_children():
|
529 |
+
fn_recursive_add_processors(name, module, processors)
|
530 |
+
|
531 |
+
return processors
|
532 |
+
|
533 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
534 |
+
def set_attn_processor(
|
535 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
536 |
+
):
|
537 |
+
r"""
|
538 |
+
Sets the attention processor to use to compute attention.
|
539 |
+
|
540 |
+
Parameters:
|
541 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
542 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
543 |
+
for **all** `Attention` layers.
|
544 |
+
|
545 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
546 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
547 |
+
|
548 |
+
"""
|
549 |
+
count = len(self.attn_processors.keys())
|
550 |
+
|
551 |
+
if isinstance(processor, dict) and len(processor) != count:
|
552 |
+
raise ValueError(
|
553 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
554 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
555 |
+
)
|
556 |
+
|
557 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
558 |
+
if hasattr(module, "set_processor"):
|
559 |
+
if not isinstance(processor, dict):
|
560 |
+
module.set_processor(processor)
|
561 |
+
else:
|
562 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
563 |
+
|
564 |
+
for sub_name, child in module.named_children():
|
565 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
566 |
+
|
567 |
+
for name, module in self.named_children():
|
568 |
+
fn_recursive_attn_processor(name, module, processor)
|
569 |
+
|
570 |
+
def set_default_attn_processor(self):
|
571 |
+
"""
|
572 |
+
Disables custom attention processors and sets the default attention implementation.
|
573 |
+
"""
|
574 |
+
self.set_attn_processor(TripoSGAttnProcessor2_0())
|
575 |
+
|
576 |
+
def forward(
|
577 |
+
self,
|
578 |
+
hidden_states: Optional[torch.Tensor],
|
579 |
+
timestep: Union[int, float, torch.LongTensor],
|
580 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
581 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
582 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
583 |
+
return_dict: bool = True,
|
584 |
+
):
|
585 |
+
"""
|
586 |
+
The [`HunyuanDiT2DModel`] forward method.
|
587 |
+
|
588 |
+
Args:
|
589 |
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
590 |
+
The input tensor.
|
591 |
+
timestep ( `torch.LongTensor`, *optional*):
|
592 |
+
Used to indicate denoising step.
|
593 |
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
594 |
+
Conditional embeddings for cross attention layer.
|
595 |
+
return_dict: bool
|
596 |
+
Whether to return a dictionary.
|
597 |
+
"""
|
598 |
+
|
599 |
+
if attention_kwargs is not None:
|
600 |
+
attention_kwargs = attention_kwargs.copy()
|
601 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
602 |
+
else:
|
603 |
+
lora_scale = 1.0
|
604 |
+
|
605 |
+
if USE_PEFT_BACKEND:
|
606 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
607 |
+
scale_lora_layers(self, lora_scale)
|
608 |
+
else:
|
609 |
+
if (
|
610 |
+
attention_kwargs is not None
|
611 |
+
and attention_kwargs.get("scale", None) is not None
|
612 |
+
):
|
613 |
+
logger.warning(
|
614 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
615 |
+
)
|
616 |
+
|
617 |
+
_, N, _ = hidden_states.shape
|
618 |
+
|
619 |
+
temb = self.time_embed(timestep).to(hidden_states.dtype)
|
620 |
+
temb = self.time_proj(temb)
|
621 |
+
temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
|
622 |
+
|
623 |
+
hidden_states = self.proj_in(hidden_states)
|
624 |
+
|
625 |
+
# N + 1 token
|
626 |
+
hidden_states = torch.cat([temb, hidden_states], dim=1)
|
627 |
+
|
628 |
+
skips = []
|
629 |
+
for layer, block in enumerate(self.blocks):
|
630 |
+
skip = None if layer <= self.config.num_layers // 2 else skips.pop()
|
631 |
+
|
632 |
+
if self.training and self.gradient_checkpointing:
|
633 |
+
|
634 |
+
def create_custom_forward(module):
|
635 |
+
def custom_forward(*inputs):
|
636 |
+
return module(*inputs)
|
637 |
+
|
638 |
+
return custom_forward
|
639 |
+
|
640 |
+
ckpt_kwargs: Dict[str, Any] = (
|
641 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
642 |
+
)
|
643 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
644 |
+
create_custom_forward(block),
|
645 |
+
hidden_states,
|
646 |
+
encoder_hidden_states,
|
647 |
+
temb,
|
648 |
+
image_rotary_emb,
|
649 |
+
skip,
|
650 |
+
attention_kwargs,
|
651 |
+
**ckpt_kwargs,
|
652 |
+
)
|
653 |
+
else:
|
654 |
+
hidden_states = block(
|
655 |
+
hidden_states,
|
656 |
+
encoder_hidden_states=encoder_hidden_states,
|
657 |
+
temb=temb,
|
658 |
+
image_rotary_emb=image_rotary_emb,
|
659 |
+
skip=skip,
|
660 |
+
attention_kwargs=attention_kwargs,
|
661 |
+
) # (N, L, D)
|
662 |
+
|
663 |
+
if layer < self.config.num_layers // 2:
|
664 |
+
skips.append(hidden_states)
|
665 |
+
|
666 |
+
# final layer
|
667 |
+
hidden_states = self.norm_out(hidden_states)
|
668 |
+
hidden_states = hidden_states[:, -N:]
|
669 |
+
hidden_states = self.proj_out(hidden_states)
|
670 |
+
|
671 |
+
if USE_PEFT_BACKEND:
|
672 |
+
# remove `lora_scale` from each PEFT layer
|
673 |
+
unscale_lora_layers(self, lora_scale)
|
674 |
+
|
675 |
+
if not return_dict:
|
676 |
+
return (hidden_states,)
|
677 |
+
|
678 |
+
return Transformer1DModelOutput(sample=hidden_states)
|
679 |
+
|
680 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
681 |
+
def enable_forward_chunking(
|
682 |
+
self, chunk_size: Optional[int] = None, dim: int = 0
|
683 |
+
) -> None:
|
684 |
+
"""
|
685 |
+
Sets the attention processor to use [feed forward
|
686 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
687 |
+
|
688 |
+
Parameters:
|
689 |
+
chunk_size (`int`, *optional*):
|
690 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
691 |
+
over each tensor of dim=`dim`.
|
692 |
+
dim (`int`, *optional*, defaults to `0`):
|
693 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
694 |
+
or dim=1 (sequence length).
|
695 |
+
"""
|
696 |
+
if dim not in [0, 1]:
|
697 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
698 |
+
|
699 |
+
# By default chunk size is 1
|
700 |
+
chunk_size = chunk_size or 1
|
701 |
+
|
702 |
+
def fn_recursive_feed_forward(
|
703 |
+
module: torch.nn.Module, chunk_size: int, dim: int
|
704 |
+
):
|
705 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
706 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
707 |
+
|
708 |
+
for child in module.children():
|
709 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
710 |
+
|
711 |
+
for module in self.children():
|
712 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
713 |
+
|
714 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
715 |
+
def disable_forward_chunking(self):
|
716 |
+
def fn_recursive_feed_forward(
|
717 |
+
module: torch.nn.Module, chunk_size: int, dim: int
|
718 |
+
):
|
719 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
720 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
721 |
+
|
722 |
+
for child in module.children():
|
723 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
724 |
+
|
725 |
+
for module in self.children():
|
726 |
+
fn_recursive_feed_forward(module, None, 0)
|
detailgen3d/pipelines/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .pipeline_detailgen3d import DetailGen3DPipeline
|
detailgen3d/pipelines/pipeline_detailgen3d.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import math
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
from diffusers.image_processor import PipelineImageInput
|
10 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
11 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler # not sure
|
12 |
+
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
13 |
+
from diffusers.utils import logging
|
14 |
+
from diffusers.utils.torch_utils import randn_tensor
|
15 |
+
from transformers import (
|
16 |
+
BitImageProcessor,
|
17 |
+
CLIPImageProcessor,
|
18 |
+
CLIPVisionModelWithProjection,
|
19 |
+
Dinov2Model,
|
20 |
+
)
|
21 |
+
|
22 |
+
from ..models.autoencoders import TripoSGVAEModel
|
23 |
+
from ..models.transformers import DetailGen3DDiTModel
|
24 |
+
from .pipeline_detailgen3d_output import DetailGen3DPipelineOutput
|
25 |
+
from .pipeline_utils import TransformerDiffusionMixin
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28 |
+
|
29 |
+
|
30 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
31 |
+
def retrieve_timesteps(
|
32 |
+
scheduler,
|
33 |
+
num_inference_steps: Optional[int] = None,
|
34 |
+
device: Optional[Union[str, torch.device]] = None,
|
35 |
+
timesteps: Optional[List[int]] = None,
|
36 |
+
sigmas: Optional[List[float]] = None,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
41 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
scheduler (`SchedulerMixin`):
|
45 |
+
The scheduler to get timesteps from.
|
46 |
+
num_inference_steps (`int`):
|
47 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
48 |
+
must be `None`.
|
49 |
+
device (`str` or `torch.device`, *optional*):
|
50 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
51 |
+
timesteps (`List[int]`, *optional*):
|
52 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
53 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
54 |
+
sigmas (`List[float]`, *optional*):
|
55 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
56 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
60 |
+
second element is the number of inference steps.
|
61 |
+
"""
|
62 |
+
if timesteps is not None and sigmas is not None:
|
63 |
+
raise ValueError(
|
64 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
65 |
+
)
|
66 |
+
if timesteps is not None:
|
67 |
+
accepts_timesteps = "timesteps" in set(
|
68 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
69 |
+
)
|
70 |
+
if not accepts_timesteps:
|
71 |
+
raise ValueError(
|
72 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
73 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
74 |
+
)
|
75 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
76 |
+
timesteps = scheduler.timesteps
|
77 |
+
num_inference_steps = len(timesteps)
|
78 |
+
elif sigmas is not None:
|
79 |
+
accept_sigmas = "sigmas" in set(
|
80 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
81 |
+
)
|
82 |
+
if not accept_sigmas:
|
83 |
+
raise ValueError(
|
84 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
85 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
86 |
+
)
|
87 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
88 |
+
timesteps = scheduler.timesteps
|
89 |
+
num_inference_steps = len(timesteps)
|
90 |
+
else:
|
91 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
92 |
+
timesteps = scheduler.timesteps
|
93 |
+
return timesteps, num_inference_steps
|
94 |
+
|
95 |
+
|
96 |
+
class DetailGen3DPipeline(
|
97 |
+
DiffusionPipeline, TransformerDiffusionMixin
|
98 |
+
):
|
99 |
+
"""
|
100 |
+
Pipeline for detail generation using DetailGen3D.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
vae: TripoSGVAEModel,
|
106 |
+
transformer: DetailGen3DDiTModel,
|
107 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
108 |
+
noise_scheduler: DDPMScheduler,
|
109 |
+
image_encoder_1: Dinov2Model,
|
110 |
+
feature_extractor_1: BitImageProcessor,
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
self.register_modules(
|
115 |
+
vae=vae,
|
116 |
+
transformer=transformer,
|
117 |
+
scheduler=scheduler,
|
118 |
+
noise_scheduler=noise_scheduler,
|
119 |
+
image_encoder_1=image_encoder_1,
|
120 |
+
feature_extractor_1=feature_extractor_1,
|
121 |
+
)
|
122 |
+
|
123 |
+
@property
|
124 |
+
def guidance_scale(self):
|
125 |
+
return self._guidance_scale
|
126 |
+
|
127 |
+
@property
|
128 |
+
def do_classifier_free_guidance(self):
|
129 |
+
return self._guidance_scale > 1
|
130 |
+
|
131 |
+
@property
|
132 |
+
def num_timesteps(self):
|
133 |
+
return self._num_timesteps
|
134 |
+
|
135 |
+
@property
|
136 |
+
def attention_kwargs(self):
|
137 |
+
return self._attention_kwargs
|
138 |
+
|
139 |
+
@property
|
140 |
+
def interrupt(self):
|
141 |
+
return self._interrupt
|
142 |
+
|
143 |
+
def encode_image_1(self, image, device, num_images_per_prompt):
|
144 |
+
dtype = next(self.image_encoder_1.parameters()).dtype
|
145 |
+
|
146 |
+
if not isinstance(image, torch.Tensor):
|
147 |
+
image = self.feature_extractor_1(image, return_tensors="pt").pixel_values
|
148 |
+
|
149 |
+
image = image.to(device=device, dtype=dtype)
|
150 |
+
image_embeds = self.image_encoder_1(image).last_hidden_state
|
151 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
152 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
153 |
+
|
154 |
+
return image_embeds, uncond_image_embeds
|
155 |
+
|
156 |
+
def prepare_latents(
|
157 |
+
self,
|
158 |
+
batch_size,
|
159 |
+
num_tokens,
|
160 |
+
num_channels_latents,
|
161 |
+
dtype,
|
162 |
+
device,
|
163 |
+
generator,
|
164 |
+
latents: Optional[torch.Tensor] = None,
|
165 |
+
noise_aug_level = 0,
|
166 |
+
):
|
167 |
+
if latents is not None:
|
168 |
+
latents = latents.to(device=device, dtype=dtype)
|
169 |
+
latents = self.noise_scheduler.add_noise(latents, torch.randn_like(latents), torch.tensor(noise_aug_level))
|
170 |
+
return latents
|
171 |
+
|
172 |
+
raise Exception(
|
173 |
+
f"You have to pass latents of geometry you want to refine."
|
174 |
+
)
|
175 |
+
|
176 |
+
@torch.no_grad()
|
177 |
+
def __call__(
|
178 |
+
self,
|
179 |
+
image: PipelineImageInput,
|
180 |
+
image_2: Optional[PipelineImageInput] = None,
|
181 |
+
num_inference_steps: int = 10,
|
182 |
+
timesteps: List[int] = None,
|
183 |
+
guidance_scale: float = 4.0,
|
184 |
+
num_images_per_prompt: int = 1,
|
185 |
+
sampled_points: Optional[torch.Tensor] = None,
|
186 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
187 |
+
latents: Optional[torch.FloatTensor] = None,
|
188 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
189 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
190 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
191 |
+
output_type: Optional[str] = "mesh_vf",
|
192 |
+
return_dict: bool = True,
|
193 |
+
noise_aug_level = 0,
|
194 |
+
):
|
195 |
+
# 1. Check inputs. Raise error if not correct
|
196 |
+
# TODO
|
197 |
+
|
198 |
+
self._guidance_scale = guidance_scale
|
199 |
+
self._attention_kwargs = attention_kwargs
|
200 |
+
self._interrupt = False
|
201 |
+
|
202 |
+
# 2. Define call parameters
|
203 |
+
if isinstance(image, PIL.Image.Image):
|
204 |
+
batch_size = 1
|
205 |
+
elif isinstance(image, list):
|
206 |
+
batch_size = len(image)
|
207 |
+
elif isinstance(image, torch.Tensor):
|
208 |
+
batch_size = image.shape[0]
|
209 |
+
else:
|
210 |
+
raise ValueError("Invalid input type for image")
|
211 |
+
|
212 |
+
device = self._execution_device
|
213 |
+
|
214 |
+
# 3. Encode condition
|
215 |
+
image_embeds_1, negative_image_embeds_1 = self.encode_image_1(
|
216 |
+
image, device, num_images_per_prompt
|
217 |
+
)
|
218 |
+
|
219 |
+
if self.do_classifier_free_guidance:
|
220 |
+
image_embeds_1 = torch.cat([negative_image_embeds_1, image_embeds_1], dim=0)
|
221 |
+
|
222 |
+
# 4. Prepare timesteps
|
223 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
224 |
+
self.scheduler, num_inference_steps, device, timesteps
|
225 |
+
)
|
226 |
+
num_warmup_steps = max(
|
227 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
228 |
+
)
|
229 |
+
self._num_timesteps = len(timesteps)
|
230 |
+
|
231 |
+
# 5. Prepare latent variables
|
232 |
+
num_tokens = self.transformer.config.width
|
233 |
+
num_channels_latents = self.transformer.config.in_channels
|
234 |
+
latents = self.prepare_latents(
|
235 |
+
batch_size * num_images_per_prompt,
|
236 |
+
num_tokens,
|
237 |
+
num_channels_latents,
|
238 |
+
image_embeds_1.dtype,
|
239 |
+
device,
|
240 |
+
generator,
|
241 |
+
latents,
|
242 |
+
noise_aug_level,
|
243 |
+
)
|
244 |
+
|
245 |
+
# 6. Denoising loop
|
246 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
247 |
+
for i, t in enumerate(timesteps):
|
248 |
+
if self.interrupt:
|
249 |
+
continue
|
250 |
+
|
251 |
+
# expand the latents if we are doing classifier free guidance
|
252 |
+
latent_model_input = (
|
253 |
+
torch.cat([latents] * 2)
|
254 |
+
if self.do_classifier_free_guidance
|
255 |
+
else latents
|
256 |
+
)
|
257 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
258 |
+
timestep = t.expand(latent_model_input.shape[0])
|
259 |
+
|
260 |
+
noise_pred = self.transformer(
|
261 |
+
latent_model_input,
|
262 |
+
timestep,
|
263 |
+
encoder_hidden_states=image_embeds_1,
|
264 |
+
attention_kwargs=attention_kwargs,
|
265 |
+
return_dict=False,
|
266 |
+
)[0]
|
267 |
+
|
268 |
+
# perform guidance
|
269 |
+
if self.do_classifier_free_guidance:
|
270 |
+
noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
|
271 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
272 |
+
noise_pred_image - noise_pred_uncond
|
273 |
+
)
|
274 |
+
|
275 |
+
# compute the previous noisy sample x_t -> x_t-1
|
276 |
+
latents_dtype = latents.dtype
|
277 |
+
latents = self.scheduler.step(
|
278 |
+
noise_pred, t, latents, return_dict=False
|
279 |
+
)[0]
|
280 |
+
|
281 |
+
if latents.dtype != latents_dtype:
|
282 |
+
if torch.backends.mps.is_available():
|
283 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
284 |
+
latents = latents.to(latents_dtype)
|
285 |
+
|
286 |
+
if callback_on_step_end is not None:
|
287 |
+
callback_kwargs = {}
|
288 |
+
for k in callback_on_step_end_tensor_inputs:
|
289 |
+
callback_kwargs[k] = locals()[k]
|
290 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
291 |
+
|
292 |
+
latents = callback_outputs.pop("latents", latents)
|
293 |
+
image_embeds_1 = callback_outputs.pop(
|
294 |
+
"image_embeds_1", image_embeds_1
|
295 |
+
)
|
296 |
+
negative_image_embeds_1 = callback_outputs.pop(
|
297 |
+
"negative_image_embeds_1", negative_image_embeds_1
|
298 |
+
)
|
299 |
+
|
300 |
+
# call the callback, if provided
|
301 |
+
if i == len(timesteps) - 1 or (
|
302 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
303 |
+
):
|
304 |
+
progress_bar.update()
|
305 |
+
|
306 |
+
if output_type == "latent":
|
307 |
+
output = latents
|
308 |
+
else:
|
309 |
+
if sampled_points is None:
|
310 |
+
raise ValueError(
|
311 |
+
"sampled_points must be provided when output_type is not 'latent'"
|
312 |
+
)
|
313 |
+
|
314 |
+
output = self.vae.decode(latents, sampled_points=sampled_points).sample
|
315 |
+
|
316 |
+
# Offload all models
|
317 |
+
self.maybe_free_model_hooks()
|
318 |
+
|
319 |
+
if not return_dict:
|
320 |
+
return (output,)
|
321 |
+
|
322 |
+
return DetailGen3DPipelineOutput(samples=output)
|
detailgen3d/pipelines/pipeline_detailgen3d_output.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.utils import BaseOutput
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class DetailGen3DPipelineOutput(BaseOutput):
|
9 |
+
r"""
|
10 |
+
Output class for DetailGen3D pipelines.
|
11 |
+
"""
|
12 |
+
|
13 |
+
samples: torch.Tensor
|
detailgen3d/pipelines/pipeline_utils.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.utils import logging
|
2 |
+
|
3 |
+
logger = logging.get_logger(__name__)
|
4 |
+
|
5 |
+
|
6 |
+
class TransformerDiffusionMixin:
|
7 |
+
r"""
|
8 |
+
Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
|
9 |
+
"""
|
10 |
+
|
11 |
+
def enable_vae_slicing(self):
|
12 |
+
r"""
|
13 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
14 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
15 |
+
"""
|
16 |
+
self.vae.enable_slicing()
|
17 |
+
|
18 |
+
def disable_vae_slicing(self):
|
19 |
+
r"""
|
20 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
21 |
+
computing decoding in one step.
|
22 |
+
"""
|
23 |
+
self.vae.disable_slicing()
|
24 |
+
|
25 |
+
def enable_vae_tiling(self):
|
26 |
+
r"""
|
27 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
28 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
29 |
+
processing larger images.
|
30 |
+
"""
|
31 |
+
self.vae.enable_tiling()
|
32 |
+
|
33 |
+
def disable_vae_tiling(self):
|
34 |
+
r"""
|
35 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
36 |
+
computing decoding in one step.
|
37 |
+
"""
|
38 |
+
self.vae.disable_tiling()
|
39 |
+
|
40 |
+
def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
|
41 |
+
"""
|
42 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
43 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
44 |
+
|
45 |
+
<Tip warning={true}>
|
46 |
+
|
47 |
+
This API is 🧪 experimental.
|
48 |
+
|
49 |
+
</Tip>
|
50 |
+
|
51 |
+
Args:
|
52 |
+
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
|
53 |
+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
54 |
+
"""
|
55 |
+
self.fusing_transformer = False
|
56 |
+
self.fusing_vae = False
|
57 |
+
|
58 |
+
if transformer:
|
59 |
+
self.fusing_transformer = True
|
60 |
+
self.transformer.fuse_qkv_projections()
|
61 |
+
|
62 |
+
if vae:
|
63 |
+
self.fusing_vae = True
|
64 |
+
self.vae.fuse_qkv_projections()
|
65 |
+
|
66 |
+
def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
|
67 |
+
"""Disable QKV projection fusion if enabled.
|
68 |
+
|
69 |
+
<Tip warning={true}>
|
70 |
+
|
71 |
+
This API is 🧪 experimental.
|
72 |
+
|
73 |
+
</Tip>
|
74 |
+
|
75 |
+
Args:
|
76 |
+
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
|
77 |
+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
78 |
+
|
79 |
+
"""
|
80 |
+
if transformer:
|
81 |
+
if not self.fusing_transformer:
|
82 |
+
logger.warning(
|
83 |
+
"The UNet was not initially fused for QKV projections. Doing nothing."
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
self.transformer.unfuse_qkv_projections()
|
87 |
+
self.fusing_transformer = False
|
88 |
+
|
89 |
+
if vae:
|
90 |
+
if not self.fusing_vae:
|
91 |
+
logger.warning(
|
92 |
+
"The VAE was not initially fused for QKV projections. Doing nothing."
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
self.vae.unfuse_qkv_projections()
|
96 |
+
self.fusing_vae = False
|
detailgen3d/schedulers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .scheduling_rectified_flow import (
|
2 |
+
RectifiedFlowScheduler,
|
3 |
+
compute_density_for_timestep_sampling,
|
4 |
+
compute_loss_weighting,
|
5 |
+
)
|
detailgen3d/schedulers/scheduling_rectified_flow.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
12 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
13 |
+
from diffusers.utils import BaseOutput, logging
|
14 |
+
from torch.distributions import LogisticNormal
|
15 |
+
|
16 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
17 |
+
|
18 |
+
|
19 |
+
# TODO: may move to training_utils.py
|
20 |
+
def compute_density_for_timestep_sampling(
|
21 |
+
weighting_scheme: str,
|
22 |
+
batch_size: int,
|
23 |
+
logit_mean: float = 0.0,
|
24 |
+
logit_std: float = 1.0,
|
25 |
+
mode_scale: float = None,
|
26 |
+
):
|
27 |
+
if weighting_scheme == "logit_normal":
|
28 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
29 |
+
u = torch.normal(
|
30 |
+
mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
|
31 |
+
)
|
32 |
+
u = torch.nn.functional.sigmoid(u)
|
33 |
+
elif weighting_scheme == "logit_normal_dist":
|
34 |
+
u = (
|
35 |
+
LogisticNormal(loc=logit_mean, scale=logit_std)
|
36 |
+
.sample((batch_size,))[:, 0]
|
37 |
+
.to("cpu")
|
38 |
+
)
|
39 |
+
elif weighting_scheme == "mode":
|
40 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
41 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
42 |
+
else:
|
43 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
44 |
+
return u
|
45 |
+
|
46 |
+
|
47 |
+
def compute_loss_weighting(weighting_scheme: str, sigmas=None):
|
48 |
+
"""
|
49 |
+
Computes loss weighting scheme for SD3 training.
|
50 |
+
|
51 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
52 |
+
|
53 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
54 |
+
"""
|
55 |
+
if weighting_scheme == "sigma_sqrt":
|
56 |
+
weighting = (sigmas**-2.0).float()
|
57 |
+
elif weighting_scheme == "cosmap":
|
58 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
59 |
+
weighting = 2 / (math.pi * bot)
|
60 |
+
else:
|
61 |
+
weighting = torch.ones_like(sigmas)
|
62 |
+
return weighting
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class RectifiedFlowSchedulerOutput(BaseOutput):
|
67 |
+
"""
|
68 |
+
Output class for the scheduler's `step` function output.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
72 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
73 |
+
denoising loop.
|
74 |
+
"""
|
75 |
+
|
76 |
+
prev_sample: torch.FloatTensor
|
77 |
+
|
78 |
+
|
79 |
+
class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin):
|
80 |
+
"""
|
81 |
+
The rectified flow scheduler is a scheduler that is used to propagate the diffusion process in the rectified flow.
|
82 |
+
|
83 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
84 |
+
methods the library implements for all schedulers such as loading and saving.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
num_train_timesteps (`int`, defaults to 1000):
|
88 |
+
The number of diffusion steps to train the model.
|
89 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
90 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
91 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
92 |
+
shift (`float`, defaults to 1.0):
|
93 |
+
The shift value for the timestep schedule.
|
94 |
+
"""
|
95 |
+
|
96 |
+
_compatibles = []
|
97 |
+
order = 1
|
98 |
+
|
99 |
+
@register_to_config
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
num_train_timesteps: int = 1000,
|
103 |
+
shift: float = 1.0,
|
104 |
+
use_dynamic_shifting: bool = False,
|
105 |
+
):
|
106 |
+
# pre-compute timesteps and sigmas; no use in fact
|
107 |
+
# NOTE that shape diffusion sample timesteps randomly or in a distribution,
|
108 |
+
# instead of sampling from the pre-defined linspace
|
109 |
+
timesteps = np.array(
|
110 |
+
[
|
111 |
+
(1.0 - i / num_train_timesteps) * num_train_timesteps
|
112 |
+
for i in range(num_train_timesteps)
|
113 |
+
]
|
114 |
+
)
|
115 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
116 |
+
|
117 |
+
sigmas = timesteps / num_train_timesteps
|
118 |
+
if not use_dynamic_shifting:
|
119 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
120 |
+
sigmas = self.time_shift(sigmas)
|
121 |
+
|
122 |
+
self.timesteps = sigmas * num_train_timesteps
|
123 |
+
|
124 |
+
self._step_index = None
|
125 |
+
self._begin_index = None
|
126 |
+
|
127 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
128 |
+
|
129 |
+
@property
|
130 |
+
def step_index(self):
|
131 |
+
"""
|
132 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
133 |
+
"""
|
134 |
+
return self._step_index
|
135 |
+
|
136 |
+
@property
|
137 |
+
def begin_index(self):
|
138 |
+
"""
|
139 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
140 |
+
"""
|
141 |
+
return self._begin_index
|
142 |
+
|
143 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
144 |
+
def set_begin_index(self, begin_index: int = 0):
|
145 |
+
"""
|
146 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
begin_index (`int`):
|
150 |
+
The begin index for the scheduler.
|
151 |
+
"""
|
152 |
+
self._begin_index = begin_index
|
153 |
+
|
154 |
+
def _sigma_to_t(self, sigma):
|
155 |
+
return sigma * self.config.num_train_timesteps
|
156 |
+
|
157 |
+
def _t_to_sigma(self, timestep):
|
158 |
+
return timestep / self.config.num_train_timesteps
|
159 |
+
|
160 |
+
def time_shift_dynamic(self, mu: float, sigma: float, t: torch.Tensor):
|
161 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
162 |
+
|
163 |
+
def time_shift(self, t: torch.Tensor):
|
164 |
+
return self.config.shift * t / (1 + (self.config.shift - 1) * t)
|
165 |
+
|
166 |
+
def set_timesteps(
|
167 |
+
self,
|
168 |
+
num_inference_steps: int = None,
|
169 |
+
device: Union[str, torch.device] = None,
|
170 |
+
sigmas: Optional[List[float]] = None,
|
171 |
+
mu: Optional[float] = None,
|
172 |
+
):
|
173 |
+
"""
|
174 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
175 |
+
|
176 |
+
Args:
|
177 |
+
num_inference_steps (`int`):
|
178 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
179 |
+
device (`str` or `torch.device`, *optional*):
|
180 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
181 |
+
"""
|
182 |
+
|
183 |
+
if self.config.use_dynamic_shifting and mu is None:
|
184 |
+
raise ValueError(
|
185 |
+
" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
186 |
+
)
|
187 |
+
|
188 |
+
if sigmas is None:
|
189 |
+
self.num_inference_steps = num_inference_steps
|
190 |
+
timesteps = np.array(
|
191 |
+
[
|
192 |
+
(1.0 - i / num_inference_steps) * self.config.num_train_timesteps
|
193 |
+
for i in range(num_inference_steps)
|
194 |
+
]
|
195 |
+
) # different from the original code in SD3
|
196 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
197 |
+
|
198 |
+
if self.config.use_dynamic_shifting:
|
199 |
+
sigmas = self.time_shift_dynamic(mu, 1.0, sigmas)
|
200 |
+
else:
|
201 |
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
202 |
+
|
203 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
204 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
205 |
+
|
206 |
+
self.timesteps = timesteps.to(device=device)
|
207 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
208 |
+
|
209 |
+
self._step_index = None
|
210 |
+
self._begin_index = None
|
211 |
+
|
212 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
213 |
+
if schedule_timesteps is None:
|
214 |
+
schedule_timesteps = self.timesteps
|
215 |
+
|
216 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
217 |
+
|
218 |
+
# The sigma index that is taken for the **very** first `step`
|
219 |
+
# is always the second index (or the last index if there is only 1)
|
220 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
221 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
222 |
+
pos = 1 if len(indices) > 1 else 0
|
223 |
+
|
224 |
+
return indices[pos].item()
|
225 |
+
|
226 |
+
def _init_step_index(self, timestep):
|
227 |
+
if self.begin_index is None:
|
228 |
+
if isinstance(timestep, torch.Tensor):
|
229 |
+
timestep = timestep.to(self.timesteps.device)
|
230 |
+
self._step_index = self.index_for_timestep(timestep)
|
231 |
+
else:
|
232 |
+
self._step_index = self._begin_index
|
233 |
+
|
234 |
+
def step(
|
235 |
+
self,
|
236 |
+
model_output: torch.FloatTensor,
|
237 |
+
timestep: Union[float, torch.FloatTensor],
|
238 |
+
sample: torch.FloatTensor,
|
239 |
+
s_churn: float = 0.0,
|
240 |
+
s_tmin: float = 0.0,
|
241 |
+
s_tmax: float = float("inf"),
|
242 |
+
s_noise: float = 1.0,
|
243 |
+
generator: Optional[torch.Generator] = None,
|
244 |
+
return_dict: bool = True,
|
245 |
+
) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
|
246 |
+
"""
|
247 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
248 |
+
process from the learned model outputs (most often the predicted noise).
|
249 |
+
|
250 |
+
Args:
|
251 |
+
model_output (`torch.FloatTensor`):
|
252 |
+
The direct output from learned diffusion model.
|
253 |
+
timestep (`float`):
|
254 |
+
The current discrete timestep in the diffusion chain.
|
255 |
+
sample (`torch.FloatTensor`):
|
256 |
+
A current instance of a sample created by the diffusion process.
|
257 |
+
s_churn (`float`):
|
258 |
+
s_tmin (`float`):
|
259 |
+
s_tmax (`float`):
|
260 |
+
s_noise (`float`, defaults to 1.0):
|
261 |
+
Scaling factor for noise added to the sample.
|
262 |
+
generator (`torch.Generator`, *optional*):
|
263 |
+
A random number generator.
|
264 |
+
return_dict (`bool`):
|
265 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
266 |
+
tuple.
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
270 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
271 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
272 |
+
"""
|
273 |
+
|
274 |
+
if (
|
275 |
+
isinstance(timestep, int)
|
276 |
+
or isinstance(timestep, torch.IntTensor)
|
277 |
+
or isinstance(timestep, torch.LongTensor)
|
278 |
+
):
|
279 |
+
raise ValueError(
|
280 |
+
(
|
281 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
282 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
283 |
+
" one of the `scheduler.timesteps` as a timestep."
|
284 |
+
),
|
285 |
+
)
|
286 |
+
|
287 |
+
if self.step_index is None:
|
288 |
+
self._init_step_index(timestep)
|
289 |
+
|
290 |
+
# Upcast to avoid precision issues when computing prev_sample
|
291 |
+
sample = sample.to(torch.float32)
|
292 |
+
|
293 |
+
sigma = self.sigmas[self.step_index]
|
294 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
295 |
+
|
296 |
+
# Here different directions are used for the flow matching
|
297 |
+
prev_sample = sample + (sigma - sigma_next) * model_output
|
298 |
+
|
299 |
+
# Cast sample back to model compatible dtype
|
300 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
301 |
+
|
302 |
+
# upon completion increase step index by one
|
303 |
+
self._step_index += 1
|
304 |
+
|
305 |
+
if not return_dict:
|
306 |
+
return (prev_sample,)
|
307 |
+
|
308 |
+
return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
|
309 |
+
|
310 |
+
def scale_noise(
|
311 |
+
self,
|
312 |
+
original_samples: torch.Tensor,
|
313 |
+
noise: torch.Tensor,
|
314 |
+
timesteps: torch.IntTensor,
|
315 |
+
) -> torch.Tensor:
|
316 |
+
"""
|
317 |
+
Forward function for the noise scaling in the flow matching.
|
318 |
+
"""
|
319 |
+
sigmas = self._t_to_sigma(timesteps.to(dtype=torch.float32))
|
320 |
+
|
321 |
+
while len(sigmas.shape) < len(original_samples.shape):
|
322 |
+
sigmas = sigmas.unsqueeze(-1)
|
323 |
+
|
324 |
+
return (1.0 - sigmas) * original_samples + sigmas * noise
|
325 |
+
|
326 |
+
def __len__(self):
|
327 |
+
return self.config.num_train_timesteps
|
detailgen3d/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .constants import USE_FLASH3_BACKEND, USE_SDPA_BACKEND, disable_flash3
|
2 |
+
from .import_utils import is_flash3_available
|
detailgen3d/utils/typing.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module contains type annotations for the project, using
|
3 |
+
1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
|
4 |
+
2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
|
5 |
+
|
6 |
+
Two types of typing checking can be used:
|
7 |
+
1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
|
8 |
+
2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
|
9 |
+
"""
|
10 |
+
|
11 |
+
# Basic types
|
12 |
+
from typing import (
|
13 |
+
Any,
|
14 |
+
Callable,
|
15 |
+
Dict,
|
16 |
+
Iterable,
|
17 |
+
List,
|
18 |
+
Literal,
|
19 |
+
NamedTuple,
|
20 |
+
NewType,
|
21 |
+
Optional,
|
22 |
+
Sized,
|
23 |
+
Tuple,
|
24 |
+
Type,
|
25 |
+
TypedDict,
|
26 |
+
TypeVar,
|
27 |
+
Union,
|
28 |
+
)
|
29 |
+
|
30 |
+
# Tensor dtype
|
31 |
+
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
32 |
+
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
33 |
+
|
34 |
+
# Config type
|
35 |
+
from omegaconf import DictConfig, ListConfig
|
36 |
+
|
37 |
+
# PyTorch Tensor type
|
38 |
+
from torch import Tensor
|
39 |
+
|
40 |
+
# Runtime type checking decorator
|
41 |
+
from typeguard import typechecked as typechecker
|
42 |
+
|
43 |
+
|
44 |
+
# Custom types
|
45 |
+
class FuncArgs(TypedDict):
|
46 |
+
"""Type for instantiating a function with keyword arguments"""
|
47 |
+
|
48 |
+
name: str
|
49 |
+
kwargs: Dict[str, Any]
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def validate(variable):
|
53 |
+
necessary_keys = ["name", "kwargs"]
|
54 |
+
for key in necessary_keys:
|
55 |
+
assert key in variable, f"Key {key} is missing in {variable}"
|
56 |
+
if not isinstance(variable["name"], str):
|
57 |
+
raise TypeError(
|
58 |
+
f"Key 'name' should be a string, not {type(variable['name'])}"
|
59 |
+
)
|
60 |
+
if not isinstance(variable["kwargs"], dict):
|
61 |
+
raise TypeError(
|
62 |
+
f"Key 'kwargs' should be a dictionary, not {type(variable['kwargs'])}"
|
63 |
+
)
|
64 |
+
return variable
|
scripts/inference_detailgen3d.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import trimesh
|
4 |
+
from PIL import Image
|
5 |
+
from skimage import measure
|
6 |
+
|
7 |
+
from detailgen3d.inference_utils import generate_dense_grid_points
|
8 |
+
from detailgen3d.pipelines.pipeline_detailgen3d import (
|
9 |
+
DetailGen3DPipeline,
|
10 |
+
)
|
11 |
+
|
12 |
+
def load_mesh(mesh_path, num_pc=20480):
|
13 |
+
mesh = trimesh.load(mesh_path,force="mesh")
|
14 |
+
|
15 |
+
center = mesh.bounding_box.centroid
|
16 |
+
mesh.apply_translation(-center)
|
17 |
+
scale = max(mesh.bounding_box.extents)
|
18 |
+
mesh.apply_scale(1.9 / scale)
|
19 |
+
|
20 |
+
surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000,)
|
21 |
+
normal = mesh.face_normals[face_indices]
|
22 |
+
|
23 |
+
rng = np.random.default_rng()
|
24 |
+
ind = rng.choice(surface.shape[0], num_pc, replace=False)
|
25 |
+
surface = torch.FloatTensor(surface[ind])
|
26 |
+
normal = torch.FloatTensor(normal[ind])
|
27 |
+
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
|
28 |
+
|
29 |
+
return surface
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
device = "cuda"
|
33 |
+
dtype = torch.float16
|
34 |
+
|
35 |
+
# prepare pipeline
|
36 |
+
pipeline = DetailGen3DPipeline.from_pretrained(
|
37 |
+
"VAST-AI/DetailGen3D",
|
38 |
+
low_cpu_mem_usage=False
|
39 |
+
).to(device, dtype=dtype)
|
40 |
+
|
41 |
+
# prepare data
|
42 |
+
image_path = "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png"
|
43 |
+
image = Image.open(image_path).convert("RGB")
|
44 |
+
|
45 |
+
mesh_path = "assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb"
|
46 |
+
surface = load_mesh(mesh_path).to(device, dtype=dtype)
|
47 |
+
|
48 |
+
batch_size = 1
|
49 |
+
|
50 |
+
# sample query points for decoding
|
51 |
+
box_min = np.array([-1.005, -1.005, -1.005])
|
52 |
+
box_max = np.array([1.005, 1.005, 1.005])
|
53 |
+
sampled_points, grid_size, bbox_size = generate_dense_grid_points(
|
54 |
+
bbox_min=box_min, bbox_max=box_max, octree_depth=9, indexing="ij"
|
55 |
+
)
|
56 |
+
sampled_points = torch.FloatTensor(sampled_points).to(device, dtype=dtype)
|
57 |
+
sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
|
58 |
+
|
59 |
+
# inference pipeline
|
60 |
+
sample = pipeline.vae.encode(surface).latent_dist.sample()
|
61 |
+
sdf = pipeline(image, latents=sample, sampled_points=sampled_points, noise_aug_level=0).samples[0]
|
62 |
+
|
63 |
+
# marching cubes
|
64 |
+
grid_logits = sdf.view(grid_size).cpu().numpy()
|
65 |
+
vertices, faces, normals, _ = measure.marching_cubes(
|
66 |
+
grid_logits, 0, method="lewiner"
|
67 |
+
)
|
68 |
+
vertices = vertices / grid_size * bbox_size + box_min
|
69 |
+
mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
|
70 |
+
mesh.export("output.glb", file_type="glb")
|