zerchen commited on
Commit
717b269
·
1 Parent(s): fe0ef0e

init test without models

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +139 -0
  3. app.py +202 -0
  4. assets/test1.png +3 -0
  5. assets/test2.png +3 -0
  6. assets/test3.jpg +0 -0
  7. assets/test4.jpeg +0 -0
  8. assets/test5.jpeg +0 -0
  9. hort/models/__init__.py +114 -0
  10. hort/models/network/pointnet.py +36 -0
  11. hort/models/tgs/__init__.py +9 -0
  12. hort/models/tgs/data.py +265 -0
  13. hort/models/tgs/models/__init__.py +0 -0
  14. hort/models/tgs/models/image_feature.py +48 -0
  15. hort/models/tgs/models/networks.py +204 -0
  16. hort/models/tgs/models/pointclouds/LICENSE_POINTNET +21 -0
  17. hort/models/tgs/models/pointclouds/pointnet.py +121 -0
  18. hort/models/tgs/models/pointclouds/simplepoint.py +110 -0
  19. hort/models/tgs/models/renderer.py +427 -0
  20. hort/models/tgs/models/snowflake/LICENSE +21 -0
  21. hort/models/tgs/models/snowflake/SPD.py +68 -0
  22. hort/models/tgs/models/snowflake/SPD_crossattn.py +81 -0
  23. hort/models/tgs/models/snowflake/SPD_pp.py +71 -0
  24. hort/models/tgs/models/snowflake/attention.py +239 -0
  25. hort/models/tgs/models/snowflake/model_spdpp.py +239 -0
  26. hort/models/tgs/models/snowflake/pointnet2.py +126 -0
  27. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py +3 -0
  28. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h +5 -0
  29. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h +41 -0
  30. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h +5 -0
  31. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h +10 -0
  32. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h +6 -0
  33. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h +25 -0
  34. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp +32 -0
  35. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu +54 -0
  36. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp +19 -0
  37. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp +62 -0
  38. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu +75 -0
  39. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp +99 -0
  40. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu +154 -0
  41. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp +87 -0
  42. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu +229 -0
  43. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py +1 -0
  44. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py +209 -0
  45. hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py +391 -0
  46. hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py +41 -0
  47. hort/models/tgs/models/snowflake/skip_transformer.py +69 -0
  48. hort/models/tgs/models/snowflake/utils.py +741 -0
  49. hort/models/tgs/models/tokenizers/dinov2.py +1179 -0
  50. hort/models/tgs/models/tokenizers/image.py +123 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/test1.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/test2.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # pyc
7
+ *.pyc
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # VSCode
135
+ .vscode
136
+
137
+ *.swp
138
+ *.h5
139
+ *.mp4
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
4
+ os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
5
+
6
+ import gradio as gr
7
+ #import spaces
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from ultralytics import YOLO
12
+ from pathlib import Path
13
+ import argparse
14
+ import json
15
+ from torchvision import transforms
16
+ from typing import Dict, Optional
17
+ from PIL import Image, ImageDraw
18
+ from lang_sam import LangSAM
19
+
20
+ from wilor.models import load_wilor
21
+ from wilor.utils import recursive_to
22
+ from wilor.datasets.vitdet_dataset import ViTDetDataset
23
+ from hort.models import load_hort
24
+ from hort.utils.renderer import Renderer, cam_crop_to_new
25
+ from hort.utils.img_utils import process_bbox, generate_patch_image, PerspectiveCamera
26
+ from ultralytics import YOLO
27
+ LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
28
+ STEEL_BLUE=(0.2745098, 0.5098039, 0.7058824)
29
+
30
+ # Download and load checkpoints
31
+ wilor_model, wilor_model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
32
+ hand_detector = YOLO('./pretrained_models/detector.pt')
33
+ # Setup the renderer
34
+ renderer = Renderer(wilor_model_cfg, faces=wilor_model.mano.faces)
35
+ # Setup the SAM model
36
+ sam_model = LangSAM(sam_type="sam2.1_hiera_large")
37
+ # Setup the HORT model
38
+ hort_model = load_hort("./pretrained_models/hort_final.pth.tar")
39
+
40
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
41
+ wilor_model = wilor_model.to(device)
42
+ hand_detector = hand_detector.to(device)
43
+ hort_model = hort_model.to(device)
44
+ wilor_model.eval()
45
+ hort_model.eval()
46
+
47
+ image_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
48
+
49
+ @spaces.GPU()
50
+ def run_model(image, conf, IoU_threshold=0.5):
51
+ img_cv2 = image[..., ::-1]
52
+ img_pil = Image.fromarray(image)
53
+
54
+ pred_obj = sam_model.predict([img_pil], ["manipulated object"])
55
+ pred_hand = sam_model.predict([img_pil], ["hand"])
56
+
57
+ bbox_obj = pred_obj[0]["boxes"][0].reshape((-1, 2))
58
+ mask_obj = pred_obj[0]["masks"][0]
59
+ bbox_hand = pred_hand[0]["boxes"][0].reshape((-1, 2))
60
+ mask_hand = pred_hand[0]["masks"][0]
61
+
62
+ tl = np.min(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
63
+ br = np.max(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
64
+ box_size = br - tl
65
+ bbox = np.concatenate([tl - 10, box_size + 20], axis=0)
66
+ ho_bbox = process_bbox(bbox)
67
+
68
+ detections = hand_detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0]
69
+
70
+ bboxes = []
71
+ is_right = []
72
+ for det in detections:
73
+ Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
74
+ is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
75
+ bboxes.append(Bbox[:4].tolist())
76
+
77
+ if len(bboxes) == 1:
78
+ boxes = np.stack(bboxes)
79
+ right = np.stack(is_right)
80
+ if not right:
81
+ new_x1 = img_cv2.shape[1] - boxes[0][2]
82
+ new_x2 = img_cv2.shape[1] - boxes[0][0]
83
+ boxes[0][0] = new_x1
84
+ boxes[0][2] = new_x2
85
+ ho_bbox[0] = img_cv2.shape[1] - (ho_bbox[0] + ho_bbox[2])
86
+ img_cv2 = cv2.flip(img_cv2, 1)
87
+ right[0] = 1.
88
+ crop_img_cv2, _ = generate_patch_image(img_cv2, ho_bbox, (224, 224), 0, 1.0, 0)
89
+
90
+ dataset = ViTDetDataset(wilor_model_cfg, img_cv2, boxes, right, rescale_factor=2.0)
91
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
92
+
93
+ for batch in dataloader:
94
+ batch = recursive_to(batch, device)
95
+
96
+ with torch.no_grad():
97
+ out = wilor_model(batch)
98
+
99
+ pred_cam = out['pred_cam']
100
+ box_center = batch["box_center"].float()
101
+ box_size = batch["box_size"].float()
102
+ img_size = batch["img_size"].float()
103
+ scaled_focal_length = wilor_model_cfg.EXTRA.FOCAL_LENGTH / wilor_model_cfg.MODEL.IMAGE_SIZE * 224
104
+ pred_cam_t_full = cam_crop_to_new(pred_cam, box_center, box_size, img_size, torch.from_numpy(np.array(ho_bbox, dtype=np.float32))[None, :].to(img_size.device), scaled_focal_length).detach().cpu().numpy()
105
+
106
+ batch_size = batch['img'].shape[0]
107
+ for n in range(batch_size):
108
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
109
+ joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
110
+
111
+ is_right = batch['right'][n].cpu().numpy()
112
+ palm = (verts[95] + verts[22]) / 2
113
+ cam_t = pred_cam_t_full[n]
114
+
115
+ img_input = image_transform(crop_img_cv2[:, :, ::-1]).unsqueeze(0).cuda()
116
+ camera = PerspectiveCamera(5000 / 256 * 224, 5000 / 256 * 224, 112, 112)
117
+ cam_intr = camera.intrinsics
118
+
119
+ metas = dict()
120
+ metas["right_hand_verts_3d"] = torch.from_numpy((verts + cam_t)[None]).cuda()
121
+ metas["right_hand_joints_3d"] = torch.from_numpy((joints + cam_t)[None]).cuda()
122
+ metas["right_hand_palm"] = torch.from_numpy((palm + cam_t)[None]).cuda()
123
+ metas["cam_intr"] = torch.from_numpy(cam_intr[None]).cuda()
124
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
125
+ pc_results = hort_model(img_input, metas)
126
+ objtrans = pc_results["objtrans"][0].detach().cpu().numpy()
127
+ pointclouds_up = pc_results["pointclouds_up"][0].detach().cpu().numpy() * 0.3
128
+
129
+ reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length}
130
+
131
+ return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions
132
+ else:
133
+ return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), None
134
+
135
+
136
+ def render_reconstruction(image, conf, IoU_threshold=0.3):
137
+ input_img, num_dets, reconstructions = run_model(image, conf, IoU_threshold=0.5)
138
+ if num_dets == 1:
139
+ # Render front view
140
+ misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal'])
141
+ cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args)
142
+
143
+ # Overlay image
144
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
145
+ input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
146
+
147
+ return input_img_overlay, f'{num_dets} hands detected'
148
+ else:
149
+ return input_img, f'{num_dets} hands detected'
150
+
151
+
152
+ header = ('''
153
+ <div class="embed_hidden" style="text-align: center;">
154
+ <h1> <b>HORT</b>: Monocular Hand-held Objects Reconstruction with Transformers</h1>
155
+ <h3>
156
+ <a href="https://zerchen.github.io/" target="_blank" rel="noopener noreferrer">Zerui Chen</a><sup>1</sup>,
157
+ <a href="https://rolpotamias.github.io" target="_blank" rel="noopener noreferrer">Rolandos Alexandros Potamias</a><sup>2</sup>,
158
+ <br>
159
+ <a href="https://cshizhe.github.io/" target="_blank" rel="noopener noreferrer">Shizhe Chen</a><sup>1</sup>,
160
+ <a href="https://cordeliaschmid.github.io/" target="_blank" rel="noopener noreferrer">Cordelia Schmid</a><sup>1</sup>
161
+ </h3>
162
+ <h3>
163
+ <sup>1</sup>Inria, Ecole normale supérieure, CNRS, PSL Research University;
164
+ <sup>2</sup>Imperial College London
165
+ </h3>
166
+ </div>
167
+ <div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
168
+ <a href='https://arxiv.org/abs/2503.21313'><img src='https://img.shields.io/badge/Arxiv-2503.21313-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
169
+ <a href='https://arxiv.org/pdf/2503.21313'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
170
+ <a href='https://zerchen.github.io/projects/hort.html'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
171
+ <a href='https://github.com/zerchen/hort'><img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a>
172
+ ''')
173
+
174
+
175
+ with gr.Blocks(title="HORT: Monocular Hand-held Objects Reconstruction with Transformers", css=".gradio-container") as demo:
176
+
177
+ gr.Markdown(header)
178
+
179
+ with gr.Row():
180
+ with gr.Column():
181
+ input_image = gr.Image(label="Input image", type="numpy")
182
+ threshold = gr.Slider(value=0.3, minimum=0.05, maximum=0.95, step=0.05, label='Detection Confidence Threshold')
183
+ submit = gr.Button("Submit", variant="primary")
184
+
185
+
186
+ with gr.Column():
187
+ reconstruction = gr.Image(label="Reconstructions", type="numpy")
188
+ hands_detected = gr.Textbox(label="Hands Detected")
189
+
190
+ submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected])
191
+
192
+ with gr.Row():
193
+ example_images = gr.Examples([
194
+ ['/home/user/app/assets/test1.png'],
195
+ ['./demo_img/app/assets/test2.png'],
196
+ ['./demo_img/app/assets/test3.jpg'],
197
+ ['./demo_img/app/assets/test4.jpeg'],
198
+ ['./demo_img/app/assets/test5.jpeg']
199
+ ],
200
+ inputs=input_image)
201
+
202
+ demo.launch(debug=True)
assets/test1.png ADDED

Git LFS Details

  • SHA256: 220310a89f9777975b10d933eb6aef34c3fe036ae2f453c2e31d537f8827111b
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
assets/test2.png ADDED

Git LFS Details

  • SHA256: 29e4602efe21a483442c42a50ebf1c666c9e525dc630ec801a5af1d3acee18b1
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
assets/test3.jpg ADDED
assets/test4.jpeg ADDED
assets/test5.jpeg ADDED
hort/models/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import sys
4
+ import os.path as osp
5
+ import numpy as np
6
+ from yacs.config import CfgNode as CN
7
+ this_dir = osp.dirname(__file__)
8
+ sys.path.insert(0, this_dir)
9
+ import tgs
10
+ from network.pointnet import PointNetEncoder
11
+
12
+ hort_cfg = CN()
13
+ hort_cfg.image_tokenizer_cls = "tgs.models.tokenizers.image.DINOV2SingleImageTokenizer"
14
+ hort_cfg.image_tokenizer = CN()
15
+ hort_cfg.image_tokenizer.pretrained_model_name_or_path = "facebook/dinov2-large"
16
+ hort_cfg.image_tokenizer.width = 224
17
+ hort_cfg.image_tokenizer.height = 224
18
+ hort_cfg.image_tokenizer.modulation = False
19
+ hort_cfg.image_tokenizer.modulation_zero_init = True
20
+ hort_cfg.image_tokenizer.modulation_cond_dim = 1024
21
+ hort_cfg.image_tokenizer.freeze_backbone_params = False
22
+ hort_cfg.image_tokenizer.enable_memory_efficient_attention = False
23
+ hort_cfg.image_tokenizer.enable_gradient_checkpointing = False
24
+
25
+ hort_cfg.tokenizer_cls = "tgs.models.tokenizers.point.PointLearnablePositionalEmbedding"
26
+ hort_cfg.tokenizer = CN()
27
+ hort_cfg.tokenizer.num_pcl = 2049
28
+ hort_cfg.tokenizer.num_channels = 512
29
+
30
+ hort_cfg.backbone_cls = "tgs.models.transformers.Transformer1D"
31
+ hort_cfg.backbone = CN()
32
+ hort_cfg.backbone.in_channels = 512
33
+ hort_cfg.backbone.num_attention_heads = 8
34
+ hort_cfg.backbone.attention_head_dim = 64
35
+ hort_cfg.backbone.num_layers = 10
36
+ hort_cfg.backbone.cross_attention_dim = 1024
37
+ hort_cfg.backbone.norm_type = "layer_norm"
38
+ hort_cfg.backbone.enable_memory_efficient_attention = False
39
+ hort_cfg.backbone.gradient_checkpointing = False
40
+
41
+ hort_cfg.post_processor_cls = "tgs.models.networks.PointOutLayer"
42
+ hort_cfg.post_processor = CN()
43
+ hort_cfg.post_processor.in_channels = 512
44
+ hort_cfg.post_processor.out_channels = 3
45
+
46
+ hort_cfg.pointcloud_upsampler_cls = "tgs.models.snowflake.model_spdpp.SnowflakeModelSPDPP"
47
+ hort_cfg.pointcloud_upsampler = CN()
48
+ hort_cfg.pointcloud_upsampler.input_channels = 1024
49
+ hort_cfg.pointcloud_upsampler.dim_feat = 128
50
+ hort_cfg.pointcloud_upsampler.num_p0 = 2048
51
+ hort_cfg.pointcloud_upsampler.radius = 1
52
+ hort_cfg.pointcloud_upsampler.bounding = True
53
+ hort_cfg.pointcloud_upsampler.use_fps = True
54
+ hort_cfg.pointcloud_upsampler.up_factors = [2, 4]
55
+ hort_cfg.pointcloud_upsampler.token_type = "image_token"
56
+
57
+
58
+ class model(nn.Module):
59
+ def __init__(self):
60
+ super(model, self).__init__()
61
+ self.image_tokenizer = tgs.find(hort_cfg.image_tokenizer_cls)(hort_cfg.image_tokenizer)
62
+ self.pointnet = PointNetEncoder(67, 1024)
63
+ self.tokenizer = tgs.find(hort_cfg.tokenizer_cls)(hort_cfg.tokenizer)
64
+ self.backbone = tgs.find(hort_cfg.backbone_cls)(hort_cfg.backbone)
65
+ self.post_processor = tgs.find(hort_cfg.post_processor_cls)(hort_cfg.post_processor)
66
+ self.post_processor_trans = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 3))
67
+ self.pointcloud_upsampler = tgs.find(hort_cfg.pointcloud_upsampler_cls)(hort_cfg.pointcloud_upsampler)
68
+
69
+ def forward(self, input_img, metas):
70
+ with torch.no_grad():
71
+ batch_size = input_img.shape[0]
72
+
73
+ encoder_hidden_states = self.image_tokenizer(input_img, None) # B * C * Nt
74
+ encoder_hidden_states = encoder_hidden_states.transpose(2, 1) # B * Nt * C
75
+
76
+ palm_norm_hand_verts_3d = metas['right_hand_verts_3d'] - metas['right_hand_palm'].unsqueeze(1)
77
+ point_idx = torch.arange(778).view(1, 778, 1).expand(batch_size, -1, -1).to(input_img.device) / 778.
78
+ palm_norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, point_idx], -1)
79
+ tip_norm_hand_verts_3d = (metas['right_hand_verts_3d'].unsqueeze(2) - metas['right_hand_joints_3d'].unsqueeze(1)).reshape((batch_size, 778, -1))
80
+ norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, tip_norm_hand_verts_3d], -1)
81
+ hand_feats = self.pointnet(norm_hand_verts_3d)
82
+
83
+ tokens = self.tokenizer(batch_size)
84
+ tokens = self.backbone(tokens, torch.cat([encoder_hidden_states, hand_feats.unsqueeze(1)], 1), modulation_cond=None)
85
+ tokens = self.tokenizer.detokenize(tokens)
86
+
87
+ pointclouds = self.post_processor(tokens[:, :2048, :])
88
+ pred_obj_trans = self.post_processor_trans(tokens[:, -1, :])
89
+
90
+ upsampling_input = {
91
+ "input_image_tokens": encoder_hidden_states.permute(0, 2, 1),
92
+ "intrinsic_cond": metas['cam_intr'],
93
+ "points": pointclouds,
94
+ "hand_points": metas["right_hand_verts_3d"],
95
+ "trans": pred_obj_trans + metas['right_hand_palm'],
96
+ "scale": 0.3
97
+ }
98
+ up_results = self.pointcloud_upsampler(upsampling_input)
99
+ pointclouds_up = up_results[-1]
100
+
101
+ pc_results = {}
102
+ pc_results['pointclouds'] = pointclouds
103
+ pc_results['objtrans'] = pred_obj_trans
104
+ pc_results['handpalm'] = metas['right_hand_palm']
105
+ pc_results['pointclouds_up'] = pointclouds_up
106
+
107
+ return pc_results
108
+
109
+ def load_hort(ckpt_path):
110
+ hort_model = model()
111
+ ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["network"]
112
+ ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
113
+ hort_model.load_state_dict(ckpt)
114
+ return hort_model
hort/models/network/pointnet.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PointNetEncoder(nn.Module):
6
+ """Encoder for Pointcloud
7
+ """
8
+ def __init__(self, in_channels: int=3, output_channels: int=768):
9
+ super().__init__()
10
+
11
+ block_channel = [64, 128, 256, 512]
12
+ self.mlp = nn.Sequential(
13
+ nn.Linear(in_channels, block_channel[0]),
14
+ nn.LayerNorm(block_channel[0]),
15
+ nn.ReLU(),
16
+ nn.Linear(block_channel[0], block_channel[1]),
17
+ nn.LayerNorm(block_channel[1]),
18
+ nn.ReLU(),
19
+ nn.Linear(block_channel[1], block_channel[2]),
20
+ nn.LayerNorm(block_channel[2]),
21
+ nn.ReLU(),
22
+ nn.Linear(block_channel[2], block_channel[3]),
23
+ nn.LayerNorm(block_channel[3]),
24
+ nn.ReLU(),
25
+ )
26
+
27
+ self.final_projection = nn.Sequential(
28
+ nn.Linear(block_channel[-1], output_channels),
29
+ nn.LayerNorm(output_channels)
30
+ )
31
+
32
+ def forward(self, x):
33
+ x = self.mlp(x)
34
+ x = torch.max(x, 1)[0]
35
+ x = self.final_projection(x)
36
+ return x
hort/models/tgs/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from tgs.utils.typing import *
3
+
4
+ def find(cls_string) -> Type:
5
+ module_string = ".".join(cls_string.split(".")[:-1])
6
+ cls_name = cls_string.split(".")[-1]
7
+ module = importlib.import_module(module_string, package=None)
8
+ cls = getattr(module, cls_name)
9
+ return cls
hort/models/tgs/data.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ from dataclasses import dataclass, field
4
+
5
+ import os
6
+ import imageio
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+
13
+ from tgs.utils.config import parse_structured
14
+ from tgs.utils.ops import get_intrinsic_from_fov, get_ray_directions, get_rays
15
+ from tgs.utils.typing import *
16
+
17
+
18
+ def _parse_scene_list_single(scene_list_path: str):
19
+ if scene_list_path.endswith(".json"):
20
+ with open(scene_list_path) as f:
21
+ all_scenes = json.loads(f.read())
22
+ elif scene_list_path.endswith(".txt"):
23
+ with open(scene_list_path) as f:
24
+ all_scenes = [p.strip() for p in f.readlines()]
25
+ else:
26
+ all_scenes = [scene_list_path]
27
+
28
+ return all_scenes
29
+
30
+
31
+ def _parse_scene_list(scene_list_path: Union[str, List[str]]):
32
+ all_scenes = []
33
+ if isinstance(scene_list_path, str):
34
+ scene_list_path = [scene_list_path]
35
+ for scene_list_path_ in scene_list_path:
36
+ all_scenes += _parse_scene_list_single(scene_list_path_)
37
+ return all_scenes
38
+
39
+ @dataclass
40
+ class CustomImageDataModuleConfig:
41
+ image_list: Any = ""
42
+ background_color: Tuple[float, float, float] = field(
43
+ default_factory=lambda: (1.0, 1.0, 1.0)
44
+ )
45
+
46
+ relative_pose: bool = False
47
+ cond_height: int = 512
48
+ cond_width: int = 512
49
+ cond_camera_distance: float = 1.6
50
+ cond_fovy_deg: float = 40.0
51
+ cond_elevation_deg: float = 0.0
52
+ cond_azimuth_deg: float = 0.0
53
+ num_workers: int = 16
54
+
55
+ eval_height: int = 512
56
+ eval_width: int = 512
57
+ eval_batch_size: int = 1
58
+ eval_elevation_deg: float = 0.0
59
+ eval_camera_distance: float = 1.6
60
+ eval_fovy_deg: float = 40.0
61
+ n_test_views: int = 120
62
+ num_views_output: int = 120
63
+ only_3dgs: bool = False
64
+
65
+ class CustomImageOrbitDataset(Dataset):
66
+ def __init__(self, cfg: Any) -> None:
67
+ super().__init__()
68
+ self.cfg: CustomImageDataModuleConfig = parse_structured(CustomImageDataModuleConfig, cfg)
69
+
70
+ self.n_views = self.cfg.n_test_views
71
+ assert self.n_views % self.cfg.num_views_output == 0
72
+
73
+ self.all_scenes = _parse_scene_list(self.cfg.image_list)
74
+
75
+ azimuth_deg: Float[Tensor, "B"] = torch.linspace(0, 360.0, self.n_views + 1)[
76
+ : self.n_views
77
+ ]
78
+ elevation_deg: Float[Tensor, "B"] = torch.full_like(
79
+ azimuth_deg, self.cfg.eval_elevation_deg
80
+ )
81
+ camera_distances: Float[Tensor, "B"] = torch.full_like(
82
+ elevation_deg, self.cfg.eval_camera_distance
83
+ )
84
+
85
+ elevation = elevation_deg * math.pi / 180
86
+ azimuth = azimuth_deg * math.pi / 180
87
+
88
+ # convert spherical coordinates to cartesian coordinates
89
+ # right hand coordinate system, x back, y right, z up
90
+ # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
91
+ camera_positions: Float[Tensor, "B 3"] = torch.stack(
92
+ [
93
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
94
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
95
+ camera_distances * torch.sin(elevation),
96
+ ],
97
+ dim=-1,
98
+ )
99
+
100
+ # default scene center at origin
101
+ center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
102
+ # default camera up direction as +z
103
+ up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
104
+ None, :
105
+ ].repeat(self.n_views, 1)
106
+
107
+ fovy_deg: Float[Tensor, "B"] = torch.full_like(
108
+ elevation_deg, self.cfg.eval_fovy_deg
109
+ )
110
+ fovy = fovy_deg * math.pi / 180
111
+
112
+ lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
113
+ right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
114
+ up = F.normalize(torch.cross(right, lookat), dim=-1)
115
+ c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
116
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
117
+ dim=-1,
118
+ )
119
+ c2w: Float[Tensor, "B 4 4"] = torch.cat(
120
+ [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
121
+ )
122
+ c2w[:, 3, 3] = 1.0
123
+
124
+ # get directions by dividing directions_unit_focal by focal length
125
+ focal_length: Float[Tensor, "B"] = (
126
+ 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy)
127
+ )
128
+ directions_unit_focal = get_ray_directions(
129
+ H=self.cfg.eval_height,
130
+ W=self.cfg.eval_width,
131
+ focal=1.0,
132
+ )
133
+ directions: Float[Tensor, "B H W 3"] = directions_unit_focal[
134
+ None, :, :, :
135
+ ].repeat(self.n_views, 1, 1, 1)
136
+ directions[:, :, :, :2] = (
137
+ directions[:, :, :, :2] / focal_length[:, None, None, None]
138
+ )
139
+ # must use normalize=True to normalize directions here
140
+ rays_o, rays_d = get_rays(directions, c2w, keepdim=True)
141
+
142
+ intrinsic: Float[Tensor, "B 3 3"] = get_intrinsic_from_fov(
143
+ self.cfg.eval_fovy_deg * math.pi / 180,
144
+ H=self.cfg.eval_height,
145
+ W=self.cfg.eval_width,
146
+ bs=self.n_views,
147
+ )
148
+ intrinsic_normed: Float[Tensor, "B 3 3"] = intrinsic.clone()
149
+ intrinsic_normed[..., 0, 2] /= self.cfg.eval_width
150
+ intrinsic_normed[..., 1, 2] /= self.cfg.eval_height
151
+ intrinsic_normed[..., 0, 0] /= self.cfg.eval_width
152
+ intrinsic_normed[..., 1, 1] /= self.cfg.eval_height
153
+
154
+ self.rays_o, self.rays_d = rays_o, rays_d
155
+ self.intrinsic = intrinsic
156
+ self.intrinsic_normed = intrinsic_normed
157
+ self.c2w = c2w
158
+ self.camera_positions = camera_positions
159
+
160
+ self.background_color = torch.as_tensor(self.cfg.background_color)
161
+
162
+ # condition
163
+ self.intrinsic_cond = get_intrinsic_from_fov(
164
+ np.deg2rad(self.cfg.cond_fovy_deg),
165
+ H=self.cfg.cond_height,
166
+ W=self.cfg.cond_width,
167
+ )
168
+ self.intrinsic_normed_cond = self.intrinsic_cond.clone()
169
+ self.intrinsic_normed_cond[..., 0, 2] /= self.cfg.cond_width
170
+ self.intrinsic_normed_cond[..., 1, 2] /= self.cfg.cond_height
171
+ self.intrinsic_normed_cond[..., 0, 0] /= self.cfg.cond_width
172
+ self.intrinsic_normed_cond[..., 1, 1] /= self.cfg.cond_height
173
+
174
+
175
+ if self.cfg.relative_pose:
176
+ self.c2w_cond = torch.as_tensor(
177
+ [
178
+ [0, 0, 1, self.cfg.cond_camera_distance],
179
+ [1, 0, 0, 0],
180
+ [0, 1, 0, 0],
181
+ [0, 0, 0, 1],
182
+ ]
183
+ ).float()
184
+ else:
185
+ cond_elevation = self.cfg.cond_elevation_deg * math.pi / 180
186
+ cond_azimuth = self.cfg.cond_azimuth_deg * math.pi / 180
187
+ cond_camera_position: Float[Tensor, "3"] = torch.as_tensor(
188
+ [
189
+ self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.cos(cond_azimuth),
190
+ self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.sin(cond_azimuth),
191
+ self.cfg.cond_camera_distance * np.sin(cond_elevation),
192
+ ], dtype=torch.float32
193
+ )
194
+
195
+ cond_center: Float[Tensor, "3"] = torch.zeros_like(cond_camera_position)
196
+ cond_up: Float[Tensor, "3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)
197
+ cond_lookat: Float[Tensor, "3"] = F.normalize(cond_center - cond_camera_position, dim=-1)
198
+ cond_right: Float[Tensor, "3"] = F.normalize(torch.cross(cond_lookat, cond_up), dim=-1)
199
+ cond_up = F.normalize(torch.cross(cond_right, cond_lookat), dim=-1)
200
+ cond_c2w3x4: Float[Tensor, "3 4"] = torch.cat(
201
+ [torch.stack([cond_right, cond_up, -cond_lookat], dim=-1), cond_camera_position[:, None]],
202
+ dim=-1,
203
+ )
204
+ cond_c2w: Float[Tensor, "4 4"] = torch.cat(
205
+ [cond_c2w3x4, torch.zeros_like(cond_c2w3x4[:1])], dim=0
206
+ )
207
+ cond_c2w[3, 3] = 1.0
208
+ self.c2w_cond = cond_c2w
209
+
210
+ def __len__(self):
211
+ if self.cfg.only_3dgs:
212
+ return len(self.all_scenes)
213
+ else:
214
+ return len(self.all_scenes) * self.n_views // self.cfg.num_views_output
215
+
216
+ def __getitem__(self, index):
217
+ if self.cfg.only_3dgs:
218
+ scene_index = index
219
+ view_index = [0]
220
+ else:
221
+ scene_index = index * self.cfg.num_views_output // self.n_views
222
+ view_start = index % (self.n_views // self.cfg.num_views_output)
223
+ view_index = list(range(self.n_views))[view_start * self.cfg.num_views_output :
224
+ (view_start + 1) * self.cfg.num_views_output]
225
+
226
+ img_path = self.all_scenes[scene_index]
227
+ img_cond = torch.from_numpy(
228
+ np.asarray(
229
+ Image.fromarray(imageio.v2.imread(img_path))
230
+ .convert("RGBA")
231
+ .resize((self.cfg.cond_width, self.cfg.cond_height))
232
+ )
233
+ / 255.0
234
+ ).float()
235
+ mask_cond: Float[Tensor, "Hc Wc 1"] = img_cond[:, :, -1:]
236
+ rgb_cond: Float[Tensor, "Hc Wc 3"] = img_cond[
237
+ :, :, :3
238
+ ] * mask_cond + self.background_color[None, None, :] * (1 - mask_cond)
239
+
240
+ out = {
241
+ "rgb_cond": rgb_cond.unsqueeze(0),
242
+ "c2w_cond": self.c2w_cond.unsqueeze(0),
243
+ "mask_cond": mask_cond.unsqueeze(0),
244
+ "intrinsic_cond": self.intrinsic_cond.unsqueeze(0),
245
+ "intrinsic_normed_cond": self.intrinsic_normed_cond.unsqueeze(0),
246
+ "view_index": torch.as_tensor(view_index),
247
+ "rays_o": self.rays_o[view_index],
248
+ "rays_d": self.rays_d[view_index],
249
+ "intrinsic": self.intrinsic[view_index],
250
+ "intrinsic_normed": self.intrinsic_normed[view_index],
251
+ "c2w": self.c2w[view_index],
252
+ "camera_positions": self.camera_positions[view_index],
253
+ }
254
+ out["c2w"][..., :3, 1:3] *= -1
255
+ out["c2w_cond"][..., :3, 1:3] *= -1
256
+ instance_id = os.path.split(img_path)[-1].split('.')[0]
257
+ out["index"] = torch.as_tensor(scene_index)
258
+ out["background_color"] = self.background_color
259
+ out["instance_id"] = instance_id
260
+ return out
261
+
262
+ def collate(self, batch):
263
+ batch = torch.utils.data.default_collate(batch)
264
+ batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width})
265
+ return batch
hort/models/tgs/models/__init__.py ADDED
File without changes
hort/models/tgs/models/image_feature.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ from tgs.utils.base import BaseModule
7
+ from tgs.utils.ops import compute_distance_transform
8
+ from tgs.utils.typing import *
9
+
10
+ class ImageFeature(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ use_rgb: bool = True
14
+ use_feature: bool = True
15
+ use_mask: bool = True
16
+ feature_dim: int = 128
17
+ out_dim: int = 133
18
+ backbone: str = "default"
19
+ freeze_backbone_params: bool = True
20
+
21
+ cfg: Config
22
+
23
+ def forward(self, rgb, mask=None, feature=None):
24
+ B, Nv, H, W = rgb.shape[:4]
25
+ rgb = rearrange(rgb, "B Nv H W C -> (B Nv) C H W")
26
+ if mask is not None:
27
+ mask = rearrange(mask, "B Nv H W C -> (B Nv) C H W")
28
+
29
+ assert feature is not None
30
+ # reshape dino tokens to image-like size
31
+ feature = rearrange(feature, "B (Nv Nt) C -> (B Nv) Nt C", Nv=Nv)
32
+ feature = feature[:, 1:].reshape(B * Nv, H // 14, W // 14, -1).permute(0, 3, 1, 2).contiguous()
33
+ feature = F.interpolate(feature, size=(H, W), mode='bilinear', align_corners=False)
34
+
35
+ if mask is not None and mask.is_floating_point():
36
+ mask = mask > 0.5
37
+
38
+ image_features = []
39
+ if self.cfg.use_rgb:
40
+ image_features.append(rgb)
41
+ if self.cfg.use_feature:
42
+ image_features.append(feature)
43
+ if self.cfg.use_mask:
44
+ image_features += [mask, compute_distance_transform(mask)]
45
+
46
+ # detach features, occur error when with grad
47
+ image_features = torch.cat(image_features, dim=1)#.detach()
48
+ return rearrange(image_features, "(B Nv) C H W -> B Nv C H W", B=B, Nv=Nv).squeeze(1)
hort/models/tgs/models/networks.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ import numpy as np
7
+
8
+ from tgs.utils.base import BaseModule
9
+ from tgs.utils.ops import get_activation
10
+ from tgs.utils.typing import *
11
+
12
+ class PointOutLayer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ in_channels: int = 1024
16
+ out_channels: int = 3
17
+ cfg: Config
18
+ def configure(self) -> None:
19
+ super().configure()
20
+ self.point_layer = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
+ self.initialize_weights()
22
+
23
+ def initialize_weights(self):
24
+ nn.init.constant_(self.point_layer.weight, 0)
25
+ nn.init.constant_(self.point_layer.bias, 0)
26
+
27
+ def forward(self, x):
28
+ return self.point_layer(x)
29
+
30
+ class TriplaneUpsampleNetwork(BaseModule):
31
+ @dataclass
32
+ class Config(BaseModule.Config):
33
+ in_channels: int = 1024
34
+ out_channels: int = 80
35
+
36
+ cfg: Config
37
+
38
+ def configure(self) -> None:
39
+ super().configure()
40
+ self.upsample = nn.ConvTranspose2d(
41
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
42
+ )
43
+
44
+ def forward(
45
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
46
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
47
+ triplanes_up = rearrange(
48
+ self.upsample(
49
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
50
+ ),
51
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
52
+ Np=3,
53
+ )
54
+ return triplanes_up
55
+
56
+
57
+ class MLP(nn.Module):
58
+ def __init__(
59
+ self,
60
+ dim_in: int,
61
+ dim_out: int,
62
+ n_neurons: int,
63
+ n_hidden_layers: int,
64
+ activation: str = "relu",
65
+ output_activation: Optional[str] = None,
66
+ bias: bool = True,
67
+ ):
68
+ super().__init__()
69
+ layers = [
70
+ self.make_linear(
71
+ dim_in, n_neurons, is_first=True, is_last=False, bias=bias
72
+ ),
73
+ self.make_activation(activation),
74
+ ]
75
+ for i in range(n_hidden_layers - 1):
76
+ layers += [
77
+ self.make_linear(
78
+ n_neurons, n_neurons, is_first=False, is_last=False, bias=bias
79
+ ),
80
+ self.make_activation(activation),
81
+ ]
82
+ layers += [
83
+ self.make_linear(
84
+ n_neurons, dim_out, is_first=False, is_last=True, bias=bias
85
+ )
86
+ ]
87
+ self.layers = nn.Sequential(*layers)
88
+ self.output_activation = get_activation(output_activation)
89
+
90
+ def forward(self, x):
91
+ x = self.layers(x)
92
+ x = self.output_activation(x)
93
+ return x
94
+
95
+ def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True):
96
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
97
+ return layer
98
+
99
+ def make_activation(self, activation):
100
+ if activation == "relu":
101
+ return nn.ReLU(inplace=True)
102
+ elif activation == "silu":
103
+ return nn.SiLU(inplace=True)
104
+ else:
105
+ raise NotImplementedError
106
+
107
+ class GSProjection(nn.Module):
108
+ def __init__(self,
109
+ in_channels: int = 80,
110
+ sh_degree: int = 3,
111
+ init_scaling: float = -5.0,
112
+ init_density: float = 0.1) -> None:
113
+ super().__init__()
114
+
115
+ self.out_keys = GS_KEYS + ["shs"]
116
+ self.out_channels = GS_CHANNELS + [(sh_degree + 1) ** 2 * 3]
117
+
118
+ self.out_layers = nn.ModuleList()
119
+ for key, ch in zip(self.out_keys, self.out_channels):
120
+ layer = nn.Linear(in_channels, ch)
121
+ # initialize
122
+ nn.init.constant_(layer.weight, 0)
123
+ nn.init.constant_(layer.bias, 0)
124
+
125
+ if key == "scaling":
126
+ nn.init.constant_(layer.bias, init_scaling)
127
+ elif key == "rotation":
128
+ nn.init.constant_(layer.bias, 0)
129
+ nn.init.constant_(layer.bias[0], 1.0)
130
+ elif key == "opacity":
131
+ inverse_sigmoid = lambda x: np.log(x / (1 - x))
132
+ nn.init.constant_(layer.bias, inverse_sigmoid(init_density))
133
+
134
+ self.out_layers.append(layer)
135
+
136
+ def forward(self, x):
137
+ ret = []
138
+ for k, layer in zip(self.out_keys, self.out_layers):
139
+ v = layer(x)
140
+ if k == "rotation":
141
+ v = torch.nn.functional.normalize(v)
142
+ elif k == "scaling":
143
+ v = torch.exp(v)
144
+ # v = v.detach() # FIXME: for DEBUG
145
+ elif k == "opacity":
146
+ v = torch.sigmoid(v)
147
+ # elif k == "shs":
148
+ # v = torch.reshape(v, (v.shape[0], -1, 3))
149
+ ret.append(v)
150
+ ret = torch.cat(ret, dim=-1)
151
+ return ret
152
+
153
+ def get_encoding(n_input_dims: int, config) -> nn.Module:
154
+ raise NotImplementedError
155
+
156
+
157
+ def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module:
158
+ raise NotImplementedError
159
+
160
+
161
+ # Resnet Blocks for pointnet
162
+ class ResnetBlockFC(nn.Module):
163
+ ''' Fully connected ResNet Block class.
164
+
165
+ Args:
166
+ size_in (int): input dimension
167
+ size_out (int): output dimension
168
+ size_h (int): hidden dimension
169
+ '''
170
+
171
+ def __init__(self, size_in, size_out=None, size_h=None):
172
+ super().__init__()
173
+ # Attributes
174
+ if size_out is None:
175
+ size_out = size_in
176
+
177
+ if size_h is None:
178
+ size_h = min(size_in, size_out)
179
+
180
+ self.size_in = size_in
181
+ self.size_h = size_h
182
+ self.size_out = size_out
183
+ # Submodules
184
+ self.fc_0 = nn.Linear(size_in, size_h)
185
+ self.fc_1 = nn.Linear(size_h, size_out)
186
+ self.actvn = nn.ReLU()
187
+
188
+ if size_in == size_out:
189
+ self.shortcut = None
190
+ else:
191
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
192
+ # Initialization
193
+ nn.init.zeros_(self.fc_1.weight)
194
+
195
+ def forward(self, x):
196
+ net = self.fc_0(self.actvn(x))
197
+ dx = self.fc_1(self.actvn(net))
198
+
199
+ if self.shortcut is not None:
200
+ x_s = self.shortcut(x)
201
+ else:
202
+ x_s = x
203
+
204
+ return x_s + dx
hort/models/tgs/models/pointclouds/LICENSE_POINTNET ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
hort/models/tgs/models/pointclouds/pointnet.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/src/encoder/pointnet.py
2
+ from dataclasses import dataclass
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch_scatter import scatter_mean, scatter_max
6
+
7
+ from tgs.utils.base import BaseModule
8
+ from tgs.models.networks import ResnetBlockFC
9
+ from tgs.utils.ops import scale_tensor
10
+
11
+ class LocalPoolPointnet(BaseModule):
12
+ ''' PointNet-based encoder network with ResNet blocks for each point.
13
+ Number of input points are fixed.
14
+
15
+ Args:
16
+ c_dim (int): dimension of latent code c
17
+ dim (int): input points dimension
18
+ hidden_dim (int): hidden dimension of the network
19
+ scatter_type (str): feature aggregation when doing local pooling
20
+ plane_resolution (int): defined resolution for plane feature
21
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
22
+ n_blocks (int): number of blocks ResNetBlockFC layers
23
+ '''
24
+
25
+ @dataclass
26
+ class Config(BaseModule.Config):
27
+ input_channels: int = 3
28
+ c_dim: int = 128
29
+ hidden_dim: int = 128
30
+ scatter_type: str = "max"
31
+ plane_size: int = 32
32
+ n_blocks: int = 5
33
+ radius: float = 1.
34
+
35
+ cfg: Config
36
+
37
+ def configure(self) -> None:
38
+ super().configure()
39
+ self.fc_pos = nn.Linear(self.cfg.input_channels, 2 * self.cfg.hidden_dim)
40
+ self.blocks = nn.ModuleList([
41
+ ResnetBlockFC(2 * self.cfg.hidden_dim, self.cfg.hidden_dim) for i in range(self.cfg.n_blocks)
42
+ ])
43
+ self.fc_c = nn.Linear(self.cfg.hidden_dim, self.cfg.c_dim)
44
+
45
+ self.actvn = nn.ReLU()
46
+
47
+ if self.cfg.scatter_type == 'max':
48
+ self.scatter = scatter_max
49
+ elif self.cfg.scatter_type == 'mean':
50
+ self.scatter = scatter_mean
51
+ else:
52
+ raise ValueError('incorrect scatter type')
53
+
54
+
55
+ def generate_plane_features(self, index, c):
56
+ # acquire indices of features in plane
57
+ # xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
58
+ # index = self.coordinate2index(x, self.cfg.plane_size)
59
+
60
+ # scatter plane features from points
61
+ fea_plane = c.new_zeros(index.shape[0], self.cfg.c_dim, self.cfg.plane_size ** 2)
62
+ c = c.permute(0, 2, 1) # B x 512 x T
63
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
64
+ fea_plane = fea_plane.reshape(index.shape[0], self.cfg.c_dim, self.cfg.plane_size, self.cfg.plane_size) # sparce matrix (B x 512 x reso x reso)
65
+
66
+ return fea_plane
67
+
68
+ def pool_local(self, xy, index, c):
69
+ bs, fea_dim = c.shape[0], c.shape[2]
70
+ keys = xy.keys()
71
+
72
+ c_out = 0
73
+ for key in keys:
74
+ # scatter plane features from points
75
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.cfg.plane_size ** 2)
76
+ if self.scatter == scatter_max:
77
+ fea = fea[0]
78
+ # gather feature back to points
79
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
80
+ c_out += fea
81
+ return c_out.permute(0, 2, 1)
82
+
83
+ def coordinate2index(self, x):
84
+ x = (x * self.cfg.plane_size).long()
85
+ index = x[..., 0] + self.cfg.plane_size * x[..., 1]
86
+ assert index.max() < self.cfg.plane_size ** 2
87
+ return index[:, None, :]
88
+
89
+ def forward(self, p):
90
+ batch_size, T, D = p.shape
91
+
92
+ # acquire the index for each point
93
+ coord = {}
94
+ index = {}
95
+
96
+ position = torch.clamp(p[..., :3], -self.cfg.radius + 1e-6, self.cfg.radius - 1e-6)
97
+ position_norm = scale_tensor(position, (-self.cfg.radius, self.cfg.radius), (0, 1))
98
+ coord["xy"] = position_norm[..., [0, 1]]
99
+ coord["xz"] = position_norm[..., [0, 2]]
100
+ coord["yz"] = position_norm[..., [1, 2]]
101
+ index["xy"] = self.coordinate2index(coord["xy"])
102
+ index["xz"] = self.coordinate2index(coord["xz"])
103
+ index["yz"] = self.coordinate2index(coord["yz"])
104
+
105
+ net = self.fc_pos(p)
106
+
107
+ net = self.blocks[0](net)
108
+ for block in self.blocks[1:]:
109
+ pooled = self.pool_local(coord, index, net)
110
+ net = torch.cat([net, pooled], dim=2)
111
+ net = block(net)
112
+
113
+ c = self.fc_c(net)
114
+
115
+ features = torch.stack([
116
+ self.generate_plane_features(index["xy"], c),
117
+ self.generate_plane_features(index["xz"], c),
118
+ self.generate_plane_features(index["yz"], c)
119
+ ], dim=1)
120
+
121
+ return features
hort/models/tgs/models/pointclouds/simplepoint.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import torch
3
+ from einops import rearrange
4
+
5
+ import tgs
6
+ from tgs.utils.base import BaseModule
7
+ from tgs.utils.typing import *
8
+
9
+ class SimplePointGenerator(BaseModule):
10
+ @dataclass
11
+ class Config(BaseModule.Config):
12
+ camera_embedder_cls: str = ""
13
+ camera_embedder: dict = field(default_factory=dict)
14
+
15
+ image_tokenizer_cls: str = ""
16
+ image_tokenizer: dict = field(default_factory=dict)
17
+
18
+ tokenizer_cls: str = ""
19
+ tokenizer: dict = field(default_factory=dict)
20
+
21
+ backbone_cls: str = ""
22
+ backbone: dict = field(default_factory=dict)
23
+
24
+ post_processor_cls: str = ""
25
+ post_processor: dict = field(default_factory=dict)
26
+
27
+ pointcloud_upsampling_cls: str = ""
28
+ pointcloud_upsampling: dict = field(default_factory=dict)
29
+
30
+ flip_c2w_cond: bool = True
31
+
32
+ cfg: Config
33
+
34
+ def configure(self) -> None:
35
+ super().configure()
36
+
37
+ self.image_tokenizer = tgs.find(self.cfg.image_tokenizer_cls)(
38
+ self.cfg.image_tokenizer
39
+ )
40
+
41
+ assert self.cfg.camera_embedder_cls == 'tgs.models.networks.MLP'
42
+ weights = self.cfg.camera_embedder.pop("weights") if "weights" in self.cfg.camera_embedder else None
43
+ self.camera_embedder = tgs.find(self.cfg.camera_embedder_cls)(**self.cfg.camera_embedder)
44
+ if weights:
45
+ from tgs.utils.misc import load_module_weights
46
+ weights_path, module_name = weights.split(":")
47
+ state_dict = load_module_weights(
48
+ weights_path, module_name=module_name, map_location="cpu"
49
+ )
50
+ self.camera_embedder.load_state_dict(state_dict)
51
+
52
+ self.tokenizer = tgs.find(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
53
+
54
+ self.backbone = tgs.find(self.cfg.backbone_cls)(self.cfg.backbone)
55
+
56
+ self.post_processor = tgs.find(self.cfg.post_processor_cls)(
57
+ self.cfg.post_processor
58
+ )
59
+
60
+ self.pointcloud_upsampling = tgs.find(self.cfg.pointcloud_upsampling_cls)(self.cfg.pointcloud_upsampling)
61
+
62
+ def forward(self, batch, encoder_hidden_states=None, **kwargs):
63
+ batch_size, n_input_views = batch["rgb_cond"].shape[:2]
64
+
65
+ if encoder_hidden_states is None:
66
+ # Camera modulation
67
+ c2w_cond = batch["c2w_cond"].clone()
68
+ if self.cfg.flip_c2w_cond:
69
+ c2w_cond[..., :3, 1:3] *= -1
70
+ camera_extri = c2w_cond.view(*c2w_cond.shape[:-2], -1)
71
+ camera_intri = batch["intrinsic_normed_cond"].view(
72
+ *batch["intrinsic_normed_cond"].shape[:-2], -1)
73
+ camera_feats = torch.cat([camera_intri, camera_extri], dim=-1)
74
+ # camera_feats = rearrange(camera_feats, 'B Nv C -> (B Nv) C')
75
+
76
+ camera_feats = self.camera_embedder(camera_feats)
77
+
78
+ encoder_hidden_states: Float[Tensor, "B Cit Nit"] = self.image_tokenizer(
79
+ rearrange(batch["rgb_cond"], 'B Nv H W C -> B Nv C H W'),
80
+ modulation_cond=camera_feats,
81
+ )
82
+ encoder_hidden_states = rearrange(
83
+ encoder_hidden_states, 'B Nv C Nt -> B (Nv Nt) C', Nv=n_input_views)
84
+
85
+ tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
86
+
87
+ tokens = self.backbone(
88
+ tokens,
89
+ encoder_hidden_states=encoder_hidden_states,
90
+ modulation_cond=None,
91
+ )
92
+ pointclouds = self.post_processor(self.tokenizer.detokenize(tokens))
93
+
94
+ upsampling_input = {
95
+ "input_image_tokens": encoder_hidden_states.permute(0, 2, 1),
96
+ "input_image_tokens_global": encoder_hidden_states[:, :1],
97
+ "c2w_cond": c2w_cond,
98
+ "rgb_cond": batch["rgb_cond"],
99
+ "intrinsic_cond": batch["intrinsic_cond"],
100
+ "intrinsic_normed_cond": batch["intrinsic_normed_cond"],
101
+ "points": pointclouds.float()
102
+ }
103
+ up_results = self.pointcloud_upsampling(upsampling_input)
104
+ up_results.insert(0, pointclouds)
105
+ pointclouds = up_results[-1]
106
+ out = {
107
+ "points": pointclouds,
108
+ "up_results": up_results
109
+ }
110
+ return out
hort/models/tgs/models/renderer.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from collections import defaultdict
3
+ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
4
+ from plyfile import PlyData, PlyElement
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ import math
10
+
11
+ from tgs.utils.typing import *
12
+ from tgs.utils.base import BaseModule
13
+ from tgs.utils.ops import trunc_exp
14
+ from tgs.models.networks import MLP
15
+ from tgs.utils.ops import scale_tensor
16
+ from einops import rearrange, reduce
17
+
18
+ inverse_sigmoid = lambda x: np.log(x / (1 - x))
19
+
20
+ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
21
+ Rt = np.zeros((4, 4))
22
+ Rt[:3, :3] = R.transpose()
23
+ Rt[:3, 3] = t
24
+ Rt[3, 3] = 1.0
25
+
26
+ C2W = np.linalg.inv(Rt)
27
+ cam_center = C2W[:3, 3]
28
+ cam_center = (cam_center + translate) * scale
29
+ C2W[:3, 3] = cam_center
30
+ Rt = np.linalg.inv(C2W)
31
+ return np.float32(Rt)
32
+
33
+ def getProjectionMatrix(znear, zfar, fovX, fovY):
34
+ tanHalfFovY = math.tan((fovY / 2))
35
+ tanHalfFovX = math.tan((fovX / 2))
36
+
37
+ top = tanHalfFovY * znear
38
+ bottom = -top
39
+ right = tanHalfFovX * znear
40
+ left = -right
41
+
42
+ P = torch.zeros(4, 4)
43
+
44
+ z_sign = 1.0
45
+
46
+ P[0, 0] = 2.0 * znear / (right - left)
47
+ P[1, 1] = 2.0 * znear / (top - bottom)
48
+ P[0, 2] = (right + left) / (right - left)
49
+ P[1, 2] = (top + bottom) / (top - bottom)
50
+ P[3, 2] = z_sign
51
+ P[2, 2] = z_sign * zfar / (zfar - znear)
52
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
53
+ return P
54
+
55
+ def intrinsic_to_fov(intrinsic, w, h):
56
+ fx, fy = intrinsic[0, 0], intrinsic[1, 1]
57
+ fov_x = 2 * torch.arctan2(w, 2 * fx)
58
+ fov_y = 2 * torch.arctan2(h, 2 * fy)
59
+ return fov_x, fov_y
60
+
61
+
62
+ class Camera:
63
+ def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None:
64
+ self.FoVx = FoVx
65
+ self.FoVy = FoVy
66
+ self.height = height
67
+ self.width = width
68
+ self.world_view_transform = w2c.transpose(0, 1)
69
+
70
+ self.zfar = 100.0
71
+ self.znear = 0.01
72
+
73
+ self.trans = trans
74
+ self.scale = scale
75
+
76
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device)
77
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
78
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
79
+
80
+ @staticmethod
81
+ def from_c2w(c2w, intrinsic, height, width):
82
+ w2c = torch.inverse(c2w)
83
+ FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device))
84
+ return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width)
85
+
86
+ class GaussianModel(NamedTuple):
87
+ xyz: Tensor
88
+ opacity: Tensor
89
+ rotation: Tensor
90
+ scaling: Tensor
91
+ shs: Tensor
92
+
93
+ def construct_list_of_attributes(self):
94
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
95
+ features_dc = self.shs[:, :1]
96
+ features_rest = self.shs[:, 1:]
97
+ for i in range(features_dc.shape[1]*features_dc.shape[2]):
98
+ l.append('f_dc_{}'.format(i))
99
+ for i in range(features_rest.shape[1]*features_rest.shape[2]):
100
+ l.append('f_rest_{}'.format(i))
101
+ l.append('opacity')
102
+ for i in range(self.scaling.shape[1]):
103
+ l.append('scale_{}'.format(i))
104
+ for i in range(self.rotation.shape[1]):
105
+ l.append('rot_{}'.format(i))
106
+ return l
107
+
108
+ def save_ply(self, path):
109
+
110
+ xyz = self.xyz.detach().cpu().numpy()
111
+ normals = np.zeros_like(xyz)
112
+ features_dc = self.shs[:, :1]
113
+ features_rest = self.shs[:, 1:]
114
+ f_dc = features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
115
+ f_rest = features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy()
116
+ opacities = inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3).detach().cpu().numpy())
117
+ scale = np.log(self.scaling.detach().cpu().numpy())
118
+ rotation = self.rotation.detach().cpu().numpy()
119
+
120
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
121
+
122
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
123
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
124
+ elements[:] = list(map(tuple, attributes))
125
+ el = PlyElement.describe(elements, 'vertex')
126
+ PlyData([el]).write(path)
127
+
128
+ class GSLayer(BaseModule):
129
+ @dataclass
130
+ class Config(BaseModule.Config):
131
+ in_channels: int = 128
132
+ feature_channels: dict = field(default_factory=dict)
133
+ xyz_offset: bool = True
134
+ restrict_offset: bool = False
135
+ use_rgb: bool = False
136
+ clip_scaling: Optional[float] = None
137
+ init_scaling: float = -5.0
138
+ init_density: float = 0.1
139
+
140
+ cfg: Config
141
+
142
+ def configure(self, *args, **kwargs) -> None:
143
+ self.out_layers = nn.ModuleList()
144
+ for key, out_ch in self.cfg.feature_channels.items():
145
+ if key == "shs" and self.cfg.use_rgb:
146
+ out_ch = 3
147
+ layer = nn.Linear(self.cfg.in_channels, out_ch)
148
+
149
+ # initialize
150
+ if not (key == "shs" and self.cfg.use_rgb):
151
+ nn.init.constant_(layer.weight, 0)
152
+ nn.init.constant_(layer.bias, 0)
153
+ if key == "scaling":
154
+ nn.init.constant_(layer.bias, self.cfg.init_scaling)
155
+ elif key == "rotation":
156
+ nn.init.constant_(layer.bias, 0)
157
+ nn.init.constant_(layer.bias[0], 1.0)
158
+ elif key == "opacity":
159
+ nn.init.constant_(layer.bias, inverse_sigmoid(self.cfg.init_density))
160
+
161
+ self.out_layers.append(layer)
162
+
163
+ def forward(self, x, pts):
164
+ ret = {}
165
+ for k, layer in zip(self.cfg.feature_channels.keys(), self.out_layers):
166
+ v = layer(x)
167
+ if k == "rotation":
168
+ v = torch.nn.functional.normalize(v)
169
+ elif k == "scaling":
170
+ v = trunc_exp(v)
171
+ if self.cfg.clip_scaling is not None:
172
+ v = torch.clamp(v, min=0, max=self.cfg.clip_scaling)
173
+ elif k == "opacity":
174
+ v = torch.sigmoid(v)
175
+ elif k == "shs":
176
+ if self.cfg.use_rgb:
177
+ v = torch.sigmoid(v)
178
+ v = torch.reshape(v, (v.shape[0], -1, 3))
179
+ elif k == "xyz":
180
+ if self.cfg.restrict_offset:
181
+ max_step = 1.2 / 32
182
+ v = (torch.sigmoid(v) - 0.5) * max_step
183
+ v = v + pts if self.cfg.xyz_offset else pts
184
+ ret[k] = v
185
+
186
+ return GaussianModel(**ret)
187
+
188
+ class GS3DRenderer(BaseModule):
189
+ @dataclass
190
+ class Config(BaseModule.Config):
191
+ mlp_network_config: Optional[dict] = None
192
+ gs_out: dict = field(default_factory=dict)
193
+ sh_degree: int = 3
194
+ scaling_modifier: float = 1.0
195
+ random_background: bool = False
196
+ radius: float = 1.0
197
+ feature_reduction: str = "concat"
198
+ projection_feature_dim: int = 773
199
+ background_color: Tuple[float, float, float] = field(
200
+ default_factory=lambda: (1.0, 1.0, 1.0)
201
+ )
202
+
203
+ cfg: Config
204
+
205
+ def configure(self, *args, **kwargs) -> None:
206
+ if self.cfg.feature_reduction == "mean":
207
+ mlp_in = 80
208
+ elif self.cfg.feature_reduction == "concat":
209
+ mlp_in = 80 * 3
210
+ else:
211
+ raise NotImplementedError
212
+ mlp_in = mlp_in + self.cfg.projection_feature_dim
213
+ if self.cfg.mlp_network_config is not None:
214
+ self.mlp_net = MLP(mlp_in, self.cfg.gs_out.in_channels, **self.cfg.mlp_network_config)
215
+ else:
216
+ self.cfg.gs_out.in_channels = mlp_in
217
+ self.gs_net = GSLayer(self.cfg.gs_out)
218
+
219
+ def forward_gs(self, x, p):
220
+ if self.cfg.mlp_network_config is not None:
221
+ x = self.mlp_net(x)
222
+ return self.gs_net(x, p)
223
+
224
+ def forward_single_view(self,
225
+ gs: GaussianModel,
226
+ viewpoint_camera: Camera,
227
+ background_color: Optional[Float[Tensor, "3"]],
228
+ ret_mask: bool = True,
229
+ ):
230
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
231
+ screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0
232
+ try:
233
+ screenspace_points.retain_grad()
234
+ except:
235
+ pass
236
+
237
+ bg_color = background_color
238
+ # Set up rasterization configuration
239
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
240
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
241
+
242
+ raster_settings = GaussianRasterizationSettings(
243
+ image_height=int(viewpoint_camera.height),
244
+ image_width=int(viewpoint_camera.width),
245
+ tanfovx=tanfovx,
246
+ tanfovy=tanfovy,
247
+ bg=bg_color,
248
+ scale_modifier=self.cfg.scaling_modifier,
249
+ viewmatrix=viewpoint_camera.world_view_transform,
250
+ projmatrix=viewpoint_camera.full_proj_transform.float(),
251
+ sh_degree=self.cfg.sh_degree,
252
+ campos=viewpoint_camera.camera_center,
253
+ prefiltered=False,
254
+ debug=False
255
+ )
256
+
257
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
258
+
259
+ means3D = gs.xyz
260
+ means2D = screenspace_points
261
+ opacity = gs.opacity
262
+
263
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
264
+ # scaling / rotation by the rasterizer.
265
+ scales = None
266
+ rotations = None
267
+ cov3D_precomp = None
268
+ scales = gs.scaling
269
+ rotations = gs.rotation
270
+
271
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
272
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
273
+ shs = None
274
+ colors_precomp = None
275
+ if self.gs_net.cfg.use_rgb:
276
+ colors_precomp = gs.shs.squeeze(1)
277
+ else:
278
+ shs = gs.shs
279
+
280
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
281
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
282
+ rendered_image, radii = rasterizer(
283
+ means3D = means3D,
284
+ means2D = means2D,
285
+ shs = shs,
286
+ colors_precomp = colors_precomp,
287
+ opacities = opacity,
288
+ scales = scales,
289
+ rotations = rotations,
290
+ cov3D_precomp = cov3D_precomp)
291
+
292
+ ret = {
293
+ "comp_rgb": rendered_image.permute(1, 2, 0),
294
+ "comp_rgb_bg": bg_color
295
+ }
296
+
297
+ if ret_mask:
298
+ mask_bg_color = torch.zeros(3, dtype=torch.float32, device=self.device)
299
+ raster_settings = GaussianRasterizationSettings(
300
+ image_height=int(viewpoint_camera.height),
301
+ image_width=int(viewpoint_camera.width),
302
+ tanfovx=tanfovx,
303
+ tanfovy=tanfovy,
304
+ bg=mask_bg_color,
305
+ scale_modifier=self.cfg.scaling_modifier,
306
+ viewmatrix=viewpoint_camera.world_view_transform,
307
+ projmatrix=viewpoint_camera.full_proj_transform.float(),
308
+ sh_degree=0,
309
+ campos=viewpoint_camera.camera_center,
310
+ prefiltered=False,
311
+ debug=False
312
+ )
313
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
314
+
315
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
316
+ rendered_mask, radii = rasterizer(
317
+ means3D = means3D,
318
+ means2D = means2D,
319
+ # shs = ,
320
+ colors_precomp = torch.ones_like(means3D),
321
+ opacities = opacity,
322
+ scales = scales,
323
+ rotations = rotations,
324
+ cov3D_precomp = cov3D_precomp)
325
+ ret["comp_mask"] = rendered_mask.permute(1, 2, 0)
326
+
327
+ return ret
328
+
329
+ def query_triplane(
330
+ self,
331
+ positions: Float[Tensor, "*B N 3"],
332
+ triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
333
+ ) -> Dict[str, Tensor]:
334
+ batched = positions.ndim == 3
335
+ if not batched:
336
+ # no batch dimension
337
+ triplanes = triplanes[None, ...]
338
+ positions = positions[None, ...]
339
+
340
+ positions = scale_tensor(positions, (-self.cfg.radius, self.cfg.radius), (-1, 1))
341
+ indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
342
+ (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
343
+ dim=-3,
344
+ )
345
+ out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
346
+ rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3),
347
+ rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3),
348
+ align_corners=False,
349
+ mode="bilinear",
350
+ )
351
+ if self.cfg.feature_reduction == "concat":
352
+ out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
353
+ elif self.cfg.feature_reduction == "mean":
354
+ out = reduce(out, "(B Np) Cp () N -> B N Cp", Np=3, reduction="mean")
355
+ else:
356
+ raise NotImplementedError
357
+
358
+ if not batched:
359
+ out = out.squeeze(0)
360
+
361
+ return out
362
+
363
+ def forward_single_batch(
364
+ self,
365
+ gs_hidden_features: Float[Tensor, "Np Cp"],
366
+ query_points: Float[Tensor, "Np 3"],
367
+ c2ws: Float[Tensor, "Nv 4 4"],
368
+ intrinsics: Float[Tensor, "Nv 4 4"],
369
+ height: int,
370
+ width: int,
371
+ background_color: Optional[Float[Tensor, "3"]],
372
+ ):
373
+ gs: GaussianModel = self.forward_gs(gs_hidden_features, query_points)
374
+ out_list = []
375
+
376
+ for c2w, intrinsic in zip(c2ws, intrinsics):
377
+ out_list.append(self.forward_single_view(
378
+ gs,
379
+ Camera.from_c2w(c2w, intrinsic, height, width),
380
+ background_color
381
+ ))
382
+
383
+ out = defaultdict(list)
384
+ for out_ in out_list:
385
+ for k, v in out_.items():
386
+ out[k].append(v)
387
+ out = {k: torch.stack(v, dim=0) for k, v in out.items()}
388
+ out["3dgs"] = gs
389
+
390
+ return out
391
+
392
+ def forward(self,
393
+ gs_hidden_features: Float[Tensor, "B Np Cp"],
394
+ query_points: Float[Tensor, "B Np 3"],
395
+ c2w: Float[Tensor, "B Nv 4 4"],
396
+ intrinsic: Float[Tensor, "B Nv 4 4"],
397
+ height,
398
+ width,
399
+ additional_features: Optional[Float[Tensor, "B C H W"]] = None,
400
+ background_color: Optional[Float[Tensor, "B 3"]] = None,
401
+ **kwargs):
402
+ batch_size = gs_hidden_features.shape[0]
403
+ out_list = []
404
+ gs_hidden_features = self.query_triplane(query_points, gs_hidden_features)
405
+ if additional_features is not None:
406
+ gs_hidden_features = torch.cat([gs_hidden_features, additional_features], dim=-1)
407
+
408
+ for b in range(batch_size):
409
+ out_list.append(self.forward_single_batch(
410
+ gs_hidden_features[b],
411
+ query_points[b],
412
+ c2w[b],
413
+ intrinsic[b],
414
+ height, width,
415
+ background_color[b] if background_color is not None else None))
416
+
417
+ out = defaultdict(list)
418
+ for out_ in out_list:
419
+ for k, v in out_.items():
420
+ out[k].append(v)
421
+ for k, v in out.items():
422
+ if isinstance(v[0], torch.Tensor):
423
+ out[k] = torch.stack(v, dim=0)
424
+ else:
425
+ out[k] = v
426
+ return out
427
+
hort/models/tgs/models/snowflake/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 AllenXiang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
hort/models/tgs/models/snowflake/SPD.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author: Peng Xiang
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from .utils import MLP_Res, MLP_CONV
7
+ from .skip_transformer import SkipTransformer
8
+
9
+
10
+ class SPD(nn.Module):
11
+ def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
12
+ """Snowflake Point Deconvolution"""
13
+ super(SPD, self).__init__()
14
+ self.i = i
15
+ self.up_factor = up_factor
16
+
17
+ self.bounding = bounding
18
+ self.radius = radius
19
+
20
+ self.global_feat = global_feat
21
+ self.ps_dim = 32 if global_feat else 64
22
+
23
+ self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
24
+ self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
25
+
26
+ self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
27
+
28
+ self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
29
+ self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
30
+
31
+ self.up_sampler = nn.Upsample(scale_factor=up_factor)
32
+ self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128)
33
+
34
+ self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
35
+
36
+ def forward(self, pcd_prev, feat_global=None, K_prev=None):
37
+ """
38
+ Args:
39
+ pcd_prev: Tensor, (B, 3, N_prev)
40
+ feat_global: Tensor, (B, dim_feat, 1)
41
+ K_prev: Tensor, (B, 128, N_prev)
42
+
43
+ Returns:
44
+ pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
45
+ K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
46
+ """
47
+ b, _, n_prev = pcd_prev.shape
48
+ feat_1 = self.mlp_1(pcd_prev)
49
+ feat_1 = torch.cat([feat_1,
50
+ torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))),
51
+ feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1
52
+ Q = self.mlp_2(feat_1)
53
+
54
+ H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q)
55
+
56
+ feat_child = self.mlp_ps(H)
57
+ feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
58
+ H_up = self.up_sampler(H)
59
+ K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
60
+
61
+ delta = self.mlp_delta(torch.relu(K_curr))
62
+ if self.bounding:
63
+ delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor)
64
+
65
+ pcd_child = self.up_sampler(pcd_prev)
66
+ pcd_child = pcd_child + delta
67
+
68
+ return pcd_child, K_curr
hort/models/tgs/models/snowflake/SPD_crossattn.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author: Peng Xiang
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from .utils import MLP_Res, MLP_CONV
7
+ from .skip_transformer import SkipTransformer
8
+ from .attention import ResidualTransformerBlock
9
+
10
+ class SPD_crossattn(nn.Module):
11
+ def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
12
+ """Snowflake Point Deconvolution"""
13
+ super().__init__()
14
+ self.i = i
15
+ self.up_factor = up_factor
16
+
17
+ self.bounding = bounding
18
+ self.radius = radius
19
+
20
+ self.global_feat = global_feat
21
+ self.ps_dim = 32 if global_feat else 64
22
+
23
+ self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
24
+ self.pcd_image_attn = ResidualTransformerBlock(
25
+ device=torch.device('cuda'),
26
+ dtype=torch.float32,
27
+ n_data=128,
28
+ width=128,
29
+ heads=8,
30
+ init_scale=1.0,
31
+ )
32
+
33
+ self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
34
+
35
+ self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
36
+
37
+ self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
38
+ self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
39
+
40
+ self.up_sampler = nn.Upsample(scale_factor=up_factor)
41
+ self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128)
42
+
43
+ self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
44
+
45
+ def forward(self, pcd_prev, feat_global=None, K_prev=None):
46
+ """
47
+ Args:
48
+ pcd_prev: Tensor, (B, 3, N_prev)
49
+ feat_global: Tensor, (B, dim_feat, 1)
50
+ K_prev: Tensor, (B, 128, N_prev)
51
+
52
+ Returns:
53
+ pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
54
+ K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
55
+ """
56
+ b, _, n_prev = pcd_prev.shape
57
+ feat_1 = self.mlp_1(pcd_prev)
58
+ # feat_1 = torch.cat([feat_1,
59
+ # torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))),
60
+ # feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1
61
+ feat_1 = torch.permute(feat_1, (0, 2, 1))
62
+ feat_global = torch.permute(feat_global, (0, 2, 1))
63
+ feat_1 = self.pcd_image_attn(feat_1, feat_global)
64
+ Q = torch.permute(feat_1, (0, 2, 1))
65
+ # Q = self.mlp_2(feat_1)
66
+
67
+ H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q)
68
+
69
+ feat_child = self.mlp_ps(H)
70
+ feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
71
+ H_up = self.up_sampler(H)
72
+ K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
73
+
74
+ delta = self.mlp_delta(torch.relu(K_curr))
75
+ if self.bounding:
76
+ delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor)
77
+
78
+ pcd_child = self.up_sampler(pcd_prev)
79
+ pcd_child = pcd_child + delta
80
+
81
+ return pcd_child, K_curr
hort/models/tgs/models/snowflake/SPD_pp.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from .utils import MLP_Res, MLP_CONV
5
+ from .skip_transformer import SkipTransformer
6
+
7
+ class SPD_pp(nn.Module):
8
+ def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
9
+ """Snowflake Point Deconvolution"""
10
+ super(SPD_pp, self).__init__()
11
+ self.i = i
12
+ self.up_factor = up_factor
13
+
14
+ self.bounding = bounding
15
+ self.radius = radius
16
+
17
+ self.global_feat = global_feat
18
+ self.ps_dim = 32 if global_feat else 64
19
+
20
+ self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
21
+ self.mlp_2 = MLP_CONV(
22
+ in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
23
+
24
+ self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
25
+
26
+ self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
27
+ self.ps = nn.ConvTranspose1d(
28
+ self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
29
+
30
+ self.up_sampler = nn.Upsample(scale_factor=up_factor)
31
+ self.mlp_delta_feature = MLP_Res(
32
+ in_dim=256, hidden_dim=128, out_dim=128)
33
+
34
+ self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
35
+
36
+ def forward(self, pcd_prev, feat_cond=None, K_prev=None):
37
+ """
38
+ Args:
39
+ pcd_prev: Tensor, (B, 3, N_prev)
40
+ feat_cond: Tensor, (B, dim_feat, N_prev)
41
+ K_prev: Tensor, (B, 128, N_prev)
42
+
43
+ Returns:
44
+ pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
45
+ K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
46
+ """
47
+ b, _, n_prev = pcd_prev.shape
48
+ feat_1 = self.mlp_1(pcd_prev)
49
+ feat_1 = torch.cat([feat_1,
50
+ torch.max(feat_1, 2, keepdim=True)[
51
+ 0].repeat((1, 1, feat_1.size(2))),
52
+ feat_cond], 1) if self.global_feat else feat_1
53
+ Q = self.mlp_2(feat_1)
54
+
55
+ H = self.skip_transformer(
56
+ pcd_prev, K_prev if K_prev is not None else Q, Q)
57
+
58
+ feat_child = self.mlp_ps(H)
59
+ feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
60
+ H_up = self.up_sampler(H)
61
+ K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
62
+
63
+ delta = self.mlp_delta(torch.relu(K_curr))
64
+ if self.bounding:
65
+ # (B, 3, N_prev * up_factor)
66
+ delta = torch.tanh(delta) / self.radius**self.i
67
+
68
+ pcd_child = self.up_sampler(pcd_prev)
69
+ pcd_child = pcd_child + delta
70
+
71
+ return pcd_child, K_curr
hort/models/tgs/models/snowflake/attention.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import math
5
+ from typing import Optional
6
+ from typing import Callable, Iterable, Sequence, Union
7
+
8
+ import torch
9
+
10
+ def checkpoint(
11
+ func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
12
+ inputs: Sequence[torch.Tensor],
13
+ params: Iterable[torch.Tensor],
14
+ flag: bool,
15
+ ):
16
+ """
17
+ Evaluate a function without caching intermediate activations, allowing for
18
+ reduced memory at the expense of extra compute in the backward pass.
19
+ :param func: the function to evaluate.
20
+ :param inputs: the argument sequence to pass to `func`.
21
+ :param params: a sequence of parameters `func` depends on but does not
22
+ explicitly take as arguments.
23
+ :param flag: if False, disable gradient checkpointing.
24
+ """
25
+ if flag:
26
+ args = tuple(inputs) + tuple(params)
27
+ return CheckpointFunction.apply(func, len(inputs), *args)
28
+ else:
29
+ return func(*inputs)
30
+
31
+
32
+ class CheckpointFunction(torch.autograd.Function):
33
+ @staticmethod
34
+ def forward(ctx, run_function, length, *args):
35
+ ctx.run_function = run_function
36
+ ctx.input_tensors = list(args[:length])
37
+ ctx.input_params = list(args[length:])
38
+ with torch.no_grad():
39
+ output_tensors = ctx.run_function(*ctx.input_tensors)
40
+ return output_tensors
41
+
42
+ @staticmethod
43
+ def backward(ctx, *output_grads):
44
+ ctx.input_tensors = [x.detach().requires_grad_(True)
45
+ for x in ctx.input_tensors]
46
+ with torch.enable_grad():
47
+ # Fixes a bug where the first op in run_function modifies the
48
+ # Tensor storage in place, which is not allowed for detach()'d
49
+ # Tensors.
50
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
51
+ output_tensors = ctx.run_function(*shallow_copies)
52
+ input_grads = torch.autograd.grad(
53
+ output_tensors,
54
+ ctx.input_tensors + ctx.input_params,
55
+ output_grads,
56
+ allow_unused=True,
57
+ )
58
+ del ctx.input_tensors
59
+ del ctx.input_params
60
+ del output_tensors
61
+ return (None, None) + input_grads
62
+
63
+
64
+ def init_linear(l, stddev):
65
+ nn.init.normal_(l.weight, std=stddev)
66
+ if l.bias is not None:
67
+ nn.init.constant_(l.bias, 0.0)
68
+
69
+ class MLP(nn.Module):
70
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
71
+ super().__init__()
72
+ self.width = width
73
+ self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
74
+ self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
75
+ self.gelu = nn.GELU()
76
+ init_linear(self.c_fc, init_scale)
77
+ init_linear(self.c_proj, init_scale)
78
+
79
+ def forward(self, x):
80
+ return self.c_proj(self.gelu(self.c_fc(x)))
81
+
82
+ class QKVMultiheadCrossAttention(nn.Module):
83
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: int):
84
+ super().__init__()
85
+ self.device = device
86
+ self.dtype = dtype
87
+ self.heads = heads
88
+ self.n_data = n_data
89
+
90
+ def forward(self, q, kv):
91
+ _, n_ctx, _ = q.shape
92
+ bs, n_data, width = kv.shape
93
+ attn_ch = width // self.heads // 2
94
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
95
+ q = q.view(bs, n_ctx, self.heads, -1)
96
+ kv = kv.view(bs, n_data, self.heads, -1)
97
+ k, v = torch.split(kv, attn_ch, dim=-1)
98
+ weight = torch.einsum(
99
+ "bthc,bshc->bhts", q * scale, k * scale
100
+ ) # More stable with f16 than dividing afterwards
101
+ wdtype = weight.dtype
102
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
103
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
104
+
105
+
106
+
107
+ class QKVMultiheadAttention(nn.Module):
108
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
109
+ super().__init__()
110
+ self.device = device
111
+ self.dtype = dtype
112
+ self.heads = heads
113
+ self.n_ctx = n_ctx
114
+
115
+ def forward(self, qkv):
116
+ bs, n_ctx, width = qkv.shape
117
+ attn_ch = width // self.heads // 3
118
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
119
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
120
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
121
+ weight = torch.einsum(
122
+ "bthc,bshc->bhts", q * scale, k * scale
123
+ ) # More stable with f16 than dividing afterwards
124
+ wdtype = weight.dtype
125
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
126
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
127
+
128
+
129
+
130
+ class MultiheadCrossAttention(nn.Module):
131
+ def __init__(
132
+ self,
133
+ *,
134
+ device: torch.device,
135
+ dtype: torch.dtype,
136
+ n_data: int,
137
+ width: int,
138
+ heads: int,
139
+ init_scale: float,
140
+ data_width: Optional[int] = None,
141
+ ):
142
+ super().__init__()
143
+ self.n_data = n_data
144
+ self.width = width
145
+ self.heads = heads
146
+ self.data_width = width if data_width is None else data_width
147
+ self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
148
+ self.c_kv = nn.Linear(self.data_width, width * 2,
149
+ device=device, dtype=dtype)
150
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
151
+ self.attention = QKVMultiheadCrossAttention(
152
+ device=device, dtype=dtype, heads=heads, n_data=n_data
153
+ )
154
+ init_linear(self.c_q, init_scale)
155
+ init_linear(self.c_kv, init_scale)
156
+ init_linear(self.c_proj, init_scale)
157
+
158
+ def forward(self, x, data):
159
+ x = self.c_q(x)
160
+ data = self.c_kv(data)
161
+ x = checkpoint(self.attention, (x, data), (), True)
162
+ x = self.c_proj(x)
163
+ return x
164
+
165
+
166
+ class MultiheadAttention(nn.Module):
167
+ def __init__(
168
+ self,
169
+ *,
170
+ device: torch.device,
171
+ dtype: torch.dtype,
172
+ n_ctx: int,
173
+ width: int,
174
+ heads: int,
175
+ init_scale: float,
176
+ ):
177
+ super().__init__()
178
+ self.n_ctx = n_ctx
179
+ self.width = width
180
+ self.heads = heads
181
+ self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
182
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
183
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
184
+ init_linear(self.c_qkv, init_scale)
185
+ init_linear(self.c_proj, init_scale)
186
+
187
+ def forward(self, x):
188
+ x = self.c_qkv(x)
189
+ x = checkpoint(self.attention, (x,), (), True)
190
+ x = self.c_proj(x)
191
+ return x
192
+
193
+
194
+ class ResidualTransformerBlock(nn.Module):
195
+ def __init__(
196
+ self,
197
+ *,
198
+ device: torch.device,
199
+ dtype: torch.dtype,
200
+ n_data: int,
201
+ width: int,
202
+ heads: int,
203
+ data_width: Optional[int] = None,
204
+ init_scale: float = 1.0,
205
+ ):
206
+ super().__init__()
207
+
208
+ if data_width is None:
209
+ data_width = width
210
+
211
+ self.attn_cross = MultiheadCrossAttention(
212
+ device=device,
213
+ dtype=dtype,
214
+ n_data=n_data,
215
+ width=width,
216
+ heads=heads,
217
+ data_width=data_width,
218
+ init_scale=init_scale,
219
+ )
220
+ self.attn_self = MultiheadAttention(
221
+ device=device,
222
+ dtype=dtype,
223
+ n_ctx=n_data,
224
+ width=width,
225
+ heads=heads,
226
+ init_scale=init_scale,
227
+ )
228
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
229
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
230
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
231
+ self.mlp = MLP(device=device, dtype=dtype,
232
+ width=width, init_scale=init_scale)
233
+ self.ln_4 = nn.LayerNorm(width, device=device, dtype=dtype)
234
+
235
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
236
+ x = x + self.attn_cross(self.ln_1(x), self.ln_2(data))
237
+ x = x + self.attn_self(self.ln_3(x))
238
+ x = x + self.mlp(self.ln_4(x))
239
+ return x
hort/models/tgs/models/snowflake/model_spdpp.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from tgs.utils.base import BaseModule
6
+ from tgs.utils.typing import *
7
+ from dataclasses import dataclass, field
8
+
9
+ from pytorch3d.renderer import (
10
+ AlphaCompositor,
11
+ NormWeightedCompositor,
12
+ PointsRasterizationSettings,
13
+ PointsRasterizer,
14
+ PointsRenderer)
15
+ from pytorch3d.renderer.cameras import CamerasBase
16
+ from pytorch3d.structures import Pointclouds
17
+ from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection
18
+
19
+ from .utils import fps_subsample
20
+ from einops import rearrange
21
+
22
+ from .utils import MLP_CONV
23
+ from .SPD import SPD
24
+ from .SPD_crossattn import SPD_crossattn
25
+ from .SPD_pp import SPD_pp
26
+
27
+ SPD_BLOCK = {
28
+ 'SPD': SPD,
29
+ 'SPD_crossattn': SPD_crossattn,
30
+ 'SPD_PP': SPD_pp,
31
+ }
32
+
33
+
34
+ def homoify(points):
35
+ """
36
+ Convert a batch of points to homogeneous coordinates.
37
+ Args:
38
+ points: e.g. (B, N, 3) or (N, 3)
39
+ Returns:
40
+ homoified points: e.g., (B, N, 4)
41
+ """
42
+ points_dim = points.shape[:-1] + (1,)
43
+ ones = points.new_ones(points_dim)
44
+
45
+ return torch.cat([points, ones], dim=-1)
46
+
47
+
48
+ def dehomoify(points):
49
+ """
50
+ Convert a batch of homogeneous points to cartesian coordinates.
51
+ Args:
52
+ homogeneous points: (B, N, 4/3) or (N, 4/3)
53
+ Returns:
54
+ cartesian points: (B, N, 3/2)
55
+ """
56
+ return points[..., :-1] / points[..., -1:]
57
+
58
+
59
+ def mask_generation(points: Float[Tensor, "B Np 3"],
60
+ intrinsics: Float[Tensor, "B 3 3"],
61
+ input_img: Float[Tensor, "B C H W"],
62
+ raster_point_radius: float = 0.01, # point size
63
+ raster_points_per_pixel: int = 1, # a single point per pixel, for now
64
+ bin_size: int = 0):
65
+ """
66
+ points: (B, Np, 3)
67
+ """
68
+ B, C, H, W = input_img.shape
69
+ device = intrinsics.device
70
+
71
+ cam_R = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1)
72
+ cam_t = torch.zeros(3).to(device).unsqueeze(0).repeat(B, 1)
73
+
74
+ raster_settings = PointsRasterizationSettings(image_size=(H, W), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size)
75
+
76
+ image_size = torch.as_tensor([H, W]).view(1, 2).expand(B, -1).to(device)
77
+ cameras = cameras_from_opencv_projection(cam_R, cam_t, intrinsics, image_size)
78
+
79
+ rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
80
+ fragments = rasterize(Pointclouds(points))
81
+
82
+ fragments_idx: Tensor = fragments.idx.long()
83
+ mask = (fragments_idx[..., 0] > -1)
84
+
85
+ return mask.float()
86
+
87
+
88
+ def points_projection(points: Float[Tensor, "B Np 3"],
89
+ intrinsics: Float[Tensor, "B 3 3"],
90
+ local_features: Float[Tensor, "B C H W"],
91
+ raster_point_radius: float = 0.0075, # point size
92
+ raster_points_per_pixel: int = 1, # a single point per pixel, for now
93
+ bin_size: int = 0):
94
+ """
95
+ points: (B, Np, 3)
96
+ """
97
+ B, C, H, W = local_features.shape
98
+ device = local_features.device
99
+ cam_R = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1)
100
+ cam_t = torch.zeros(3).to(device).unsqueeze(0).repeat(B, 1)
101
+
102
+ raster_settings = PointsRasterizationSettings(image_size=(H, W), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size)
103
+ Np = points.shape[1]
104
+ R = raster_settings.points_per_pixel
105
+ image_size = torch.as_tensor([H, W]).view(1, 2).expand(B, -1).to(device)
106
+ cameras = cameras_from_opencv_projection(cam_R, cam_t, intrinsics, image_size)
107
+ rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
108
+ fragments = rasterize(Pointclouds(points))
109
+ fragments_idx: Tensor = fragments.idx.long()
110
+ visible_pixels = (fragments_idx > -1) # (B, H, W, R)
111
+ points_to_visible_pixels = fragments_idx[visible_pixels]
112
+ # Reshape local features to (B, H, W, R, C)
113
+ local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C)
114
+ # Get local features corresponding to visible points
115
+ local_features_proj = torch.zeros(B * Np, C, device=device)
116
+ local_features_proj[points_to_visible_pixels] = local_features[visible_pixels]
117
+ local_features_proj = local_features_proj.reshape(B, Np, C)
118
+ return local_features_proj
119
+
120
+
121
+ def points_projection_v2(input_xyz_points, cam_intr, feature_maps):
122
+ input_points = input_xyz_points.clone()
123
+ batch_size = input_points.shape[0]
124
+ xyz = input_points[:, :, :3]
125
+ homo_xyz = homoify(xyz)
126
+ homo_xyz_2d = torch.matmul(cam_intr, homo_xyz.transpose(1, 2)).transpose(1, 2)
127
+ xyz_2d = (homo_xyz_2d[:, :, :2] / homo_xyz_2d[:, :, [2]]).unsqueeze(2)
128
+ uv_2d = xyz_2d / 224 * 2 - 1
129
+ sample_feat = torch.nn.functional.grid_sample(feature_maps, uv_2d, align_corners=True)[:, :, :, 0].transpose(1, 2)
130
+ uv_2d = uv_2d.squeeze(2).reshape((-1, 2))
131
+ validity = (uv_2d[:, 0] >= -1.0) & (uv_2d[:, 0] <= 1.0) & (uv_2d[:, 1] >= -1.0) & (uv_2d[:, 1] <= 1.0)
132
+ validity = validity.unsqueeze(1)
133
+
134
+ return sample_feat
135
+
136
+
137
+ class Decoder(nn.Module):
138
+ def __init__(self, input_channels=1152, dim_feat=512, num_p0=512,
139
+ radius=1, bounding=True, up_factors=None,
140
+ SPD_type='SPD',
141
+ token_type='image_token'
142
+ ):
143
+ super(Decoder, self).__init__()
144
+ # self.decoder_coarse = SeedGenerator(dim_feat=dim_feat, num_pc=num_p0)
145
+ if up_factors is None:
146
+ up_factors = [1]
147
+ else:
148
+ up_factors = up_factors
149
+ uppers = []
150
+ self.num_p0 = num_p0
151
+ self.mlp_feat_cond = MLP_CONV(in_channel=input_channels,
152
+ layer_dims=[dim_feat*2, dim_feat])
153
+
154
+ for i, factor in enumerate(up_factors):
155
+ uppers.append(
156
+ SPD_BLOCK[SPD_type](dim_feat=dim_feat, up_factor=factor,
157
+ i=i, bounding=bounding, radius=radius))
158
+ self.uppers = nn.ModuleList(uppers)
159
+ self.token_type = token_type
160
+
161
+ def calculate_pcl_token(self, pcl_token, up_factor):
162
+ up_token = F.interpolate(pcl_token, scale_factor=up_factor, mode='nearest')
163
+ return up_token
164
+
165
+ def calculate_image_token(self, pcd, input_image_tokens, batch):
166
+ """
167
+ Args:
168
+ """
169
+ batch_size = input_image_tokens.shape[0]
170
+ h_cond, w_cond = 224, 224
171
+ input_image_tokens = input_image_tokens.permute(0, 2, 1)
172
+ local_features = input_image_tokens[:, 1:].reshape(batch_size, h_cond // 14, w_cond // 14, -1).permute(0, 3, 1, 2)
173
+ # local_features = F.interpolate(local_features, size=(h_cond, w_cond), mode='bilinear', align_corners=False)
174
+ local_features_proj = points_projection_v2(pcd * batch['scale'] + batch['trans'].unsqueeze(1), batch['intrinsic_cond'], local_features)
175
+ local_features_proj = local_features_proj.permute(0, 2, 1).contiguous()
176
+
177
+ return local_features_proj
178
+
179
+ def forward(self, x):
180
+ """
181
+ Args:
182
+ points: Tensor, (b, num_p0, 3)
183
+ feat_cond: Tensor, (b, dim_feat) dinov2: 325x768
184
+ # partial_coarse: Tensor, (b, n_coarse, 3)
185
+ """
186
+ points = x['points']
187
+ if self.token_type == 'pcl_token':
188
+ feat_cond = x['pcl_token']
189
+ elif self.token_type == 'image_token':
190
+ feat_cond = x['input_image_tokens']
191
+ feat_cond = self.mlp_feat_cond(feat_cond)
192
+ arr_pcd = []
193
+ feat_prev = None
194
+
195
+ pcd = torch.permute(points, (0, 2, 1)).contiguous()
196
+ pcl_up_scale = 1
197
+ for upper in self.uppers:
198
+ if self.token_type == 'pcl_token':
199
+ up_cond = self.calculate_pcl_token(
200
+ feat_cond, pcl_up_scale)
201
+ pcl_up_scale *= upper.up_factor
202
+ elif self.token_type == 'image_token':
203
+ up_cond = self.calculate_image_token(points, feat_cond, x)
204
+ pcd, feat_prev = upper(pcd, up_cond, feat_prev)
205
+ points = torch.permute(pcd, (0, 2, 1)).contiguous()
206
+ arr_pcd.append(points)
207
+ return arr_pcd
208
+
209
+
210
+ class SnowflakeModelSPDPP(BaseModule):
211
+ """
212
+ apply PC^2 / PCL token to decoder
213
+ """
214
+ @dataclass
215
+ class Config(BaseModule.Config):
216
+ input_channels: int = 1152
217
+ dim_feat: int = 128
218
+ num_p0: int = 512
219
+ radius: float = 1
220
+ bounding: bool = True
221
+ use_fps: bool = True
222
+ up_factors: List[int] = field(default_factory=lambda: [2, 2])
223
+ image_full_token_cond: bool = False
224
+ SPD_type: str = 'SPD_PP'
225
+ token_type: str = 'pcl_token'
226
+ cfg: Config
227
+
228
+ def configure(self) -> None:
229
+ super().configure()
230
+ self.decoder = Decoder(input_channels=self.cfg.input_channels,
231
+ dim_feat=self.cfg.dim_feat, num_p0=self.cfg.num_p0,
232
+ radius=self.cfg.radius, up_factors=self.cfg.up_factors, bounding=self.cfg.bounding,
233
+ SPD_type=self.cfg.SPD_type,
234
+ token_type=self.cfg.token_type
235
+ )
236
+
237
+ def forward(self, x):
238
+ results = self.decoder(x)
239
+ return results
hort/models/tgs/models/snowflake/pointnet2.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule
5
+
6
+
7
+ class PointNet2ClassificationSSG(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self._build_model()
11
+
12
+ def _build_model(self):
13
+ self.SA_modules = nn.ModuleList()
14
+ self.SA_modules.append(
15
+ PointnetSAModule(
16
+ npoint=512,
17
+ radius=0.2,
18
+ nsample=64,
19
+ mlp=[3, 64, 64, 128],
20
+ use_xyz=True,
21
+ )
22
+ )
23
+ self.SA_modules.append(
24
+ PointnetSAModule(
25
+ npoint=128,
26
+ radius=0.4,
27
+ nsample=64,
28
+ mlp=[128, 128, 128, 256],
29
+ use_xyz=True,
30
+ )
31
+ )
32
+ self.SA_modules.append(
33
+ PointnetSAModule(
34
+ mlp=[256, 256, 512, 1024], use_xyz=True,
35
+ )
36
+ )
37
+
38
+ self.fc_layer = nn.Sequential(
39
+ nn.Linear(1024, 512, bias=False),
40
+ nn.BatchNorm1d(512),
41
+ nn.ReLU(True),
42
+ nn.Linear(512, 256, bias=False),
43
+ nn.BatchNorm1d(256),
44
+ nn.ReLU(True),
45
+ nn.Dropout(0.5),
46
+ nn.Linear(256, 40),
47
+ )
48
+
49
+ def _break_up_pc(self, pc):
50
+ xyz = pc[..., 0:3].contiguous()
51
+ features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None
52
+
53
+ return xyz, features
54
+
55
+ def forward(self, pointcloud):
56
+ r"""
57
+ Forward pass of the network
58
+
59
+ Parameters
60
+ ----------
61
+ pointcloud: Variable(torch.cuda.FloatTensor)
62
+ (B, N, 3 + input_channels) tensor
63
+ Point cloud to run predicts on
64
+ Each point in the point-cloud MUST
65
+ be formated as (x, y, z, features...)
66
+ """
67
+ xyz, features = self._break_up_pc(pointcloud)
68
+
69
+ for module in self.SA_modules:
70
+ xyz, features = module(xyz, features)
71
+
72
+ return self.fc_layer(features.squeeze(-1))
73
+
74
+
75
+ class PointNet2SemSegSSG(PointNet2ClassificationSSG):
76
+ def _build_model(self):
77
+ self.SA_modules = nn.ModuleList()
78
+ self.SA_modules.append(
79
+ PointnetSAModule(
80
+ npoint=256,
81
+ radius=0.05,
82
+ nsample=32,
83
+ mlp=[1, 32, 64],
84
+ use_xyz=True,
85
+ )
86
+ )
87
+ self.SA_modules.append(
88
+ PointnetSAModule(
89
+ npoint=64,
90
+ radius=0.10,
91
+ nsample=32,
92
+ mlp=[64, 128, 256],
93
+ use_xyz=True,
94
+ )
95
+ )
96
+ self.SA_modules.append(
97
+ PointnetSAModule(
98
+ npoint=16,
99
+ radius=0.20,
100
+ nsample=32,
101
+ mlp=[256, 512, 768],
102
+ use_xyz=True,
103
+ )
104
+ )
105
+
106
+ def forward(self, pointcloud):
107
+ r"""
108
+ Forward pass of the network
109
+
110
+ Parameters
111
+ ----------
112
+ pointcloud: Variable(torch.cuda.FloatTensor)
113
+ (B, N, 3 + input_channels) tensor
114
+ Point cloud to run predicts on
115
+ Each point in the point-cloud MUST
116
+ be formated as (x, y, z, features...)
117
+ """
118
+ xyz, features = self._break_up_pc(pointcloud)
119
+
120
+ l_xyz, l_features = [xyz], [features]
121
+ for i in range(len(self.SA_modules)):
122
+ li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
123
+ l_xyz.append(li_xyz)
124
+ l_features.append(li_features)
125
+
126
+ return l_features[-1].transpose(2, 1)
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import pointnet2_ops.pointnet2_modules
2
+ import pointnet2_ops.pointnet2_utils
3
+ from pointnet2_ops._version import __version__
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <torch/extension.h>
3
+
4
+ at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5
+ const int nsample);
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _CUDA_UTILS_H
2
+ #define _CUDA_UTILS_H
3
+
4
+ #include <ATen/ATen.h>
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <cmath>
7
+
8
+ #include <cuda.h>
9
+ #include <cuda_runtime.h>
10
+
11
+ #include <vector>
12
+
13
+ #define TOTAL_THREADS 512
14
+
15
+ inline int opt_n_threads(int work_size) {
16
+ const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
17
+
18
+ return max(min(1 << pow_2, TOTAL_THREADS), 1);
19
+ }
20
+
21
+ inline dim3 opt_block_config(int x, int y) {
22
+ const int x_threads = opt_n_threads(x);
23
+ const int y_threads =
24
+ max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
25
+ dim3 block_config(x_threads, y_threads, 1);
26
+
27
+ return block_config;
28
+ }
29
+
30
+ #define CUDA_CHECK_ERRORS() \
31
+ do { \
32
+ cudaError_t err = cudaGetLastError(); \
33
+ if (cudaSuccess != err) { \
34
+ fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
35
+ cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
36
+ __FILE__); \
37
+ exit(-1); \
38
+ } \
39
+ } while (0)
40
+
41
+ #endif
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <torch/extension.h>
3
+
4
+ at::Tensor group_points(at::Tensor points, at::Tensor idx);
5
+ at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/extension.h>
4
+ #include <vector>
5
+
6
+ std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows);
7
+ at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
8
+ at::Tensor weight);
9
+ at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
10
+ at::Tensor weight, const int m);
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <torch/extension.h>
3
+
4
+ at::Tensor gather_points(at::Tensor points, at::Tensor idx);
5
+ at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6
+ at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <torch/extension.h>
4
+
5
+ #define CHECK_CUDA(x) \
6
+ do { \
7
+ AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \
8
+ } while (0)
9
+
10
+ #define CHECK_CONTIGUOUS(x) \
11
+ do { \
12
+ AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
13
+ } while (0)
14
+
15
+ #define CHECK_IS_INT(x) \
16
+ do { \
17
+ AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \
18
+ #x " must be an int tensor"); \
19
+ } while (0)
20
+
21
+ #define CHECK_IS_FLOAT(x) \
22
+ do { \
23
+ AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \
24
+ #x " must be a float tensor"); \
25
+ } while (0)
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ball_query.h"
2
+ #include "utils.h"
3
+
4
+ void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
5
+ int nsample, const float *new_xyz,
6
+ const float *xyz, int *idx);
7
+
8
+ at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
9
+ const int nsample) {
10
+ CHECK_CONTIGUOUS(new_xyz);
11
+ CHECK_CONTIGUOUS(xyz);
12
+ CHECK_IS_FLOAT(new_xyz);
13
+ CHECK_IS_FLOAT(xyz);
14
+
15
+ if (new_xyz.is_cuda()) {
16
+ CHECK_CUDA(xyz);
17
+ }
18
+
19
+ at::Tensor idx =
20
+ torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
21
+ at::device(new_xyz.device()).dtype(at::ScalarType::Int));
22
+
23
+ if (new_xyz.is_cuda()) {
24
+ query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
25
+ radius, nsample, new_xyz.data_ptr<float>(),
26
+ xyz.data_ptr<float>(), idx.data_ptr<int>());
27
+ } else {
28
+ AT_ASSERT(false, "CPU not supported");
29
+ }
30
+
31
+ return idx;
32
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <stdlib.h>
4
+
5
+ #include "cuda_utils.h"
6
+
7
+ // input: new_xyz(b, m, 3) xyz(b, n, 3)
8
+ // output: idx(b, m, nsample)
9
+ __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
10
+ int nsample,
11
+ const float *__restrict__ new_xyz,
12
+ const float *__restrict__ xyz,
13
+ int *__restrict__ idx) {
14
+ int batch_index = blockIdx.x;
15
+ xyz += batch_index * n * 3;
16
+ new_xyz += batch_index * m * 3;
17
+ idx += m * nsample * batch_index;
18
+
19
+ int index = threadIdx.x;
20
+ int stride = blockDim.x;
21
+
22
+ float radius2 = radius * radius;
23
+ for (int j = index; j < m; j += stride) {
24
+ float new_x = new_xyz[j * 3 + 0];
25
+ float new_y = new_xyz[j * 3 + 1];
26
+ float new_z = new_xyz[j * 3 + 2];
27
+ for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
28
+ float x = xyz[k * 3 + 0];
29
+ float y = xyz[k * 3 + 1];
30
+ float z = xyz[k * 3 + 2];
31
+ float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
32
+ (new_z - z) * (new_z - z);
33
+ if (d2 < radius2) {
34
+ if (cnt == 0) {
35
+ for (int l = 0; l < nsample; ++l) {
36
+ idx[j * nsample + l] = k;
37
+ }
38
+ }
39
+ idx[j * nsample + cnt] = k;
40
+ ++cnt;
41
+ }
42
+ }
43
+ }
44
+ }
45
+
46
+ void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
47
+ int nsample, const float *new_xyz,
48
+ const float *xyz, int *idx) {
49
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
50
+ query_ball_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
51
+ b, n, m, radius, nsample, new_xyz, xyz, idx);
52
+
53
+ //CUDA_CHECK_ERRORS();
54
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ball_query.h"
2
+ #include "group_points.h"
3
+ #include "interpolate.h"
4
+ #include "sampling.h"
5
+
6
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7
+ m.def("gather_points", &gather_points);
8
+ m.def("gather_points_grad", &gather_points_grad);
9
+ m.def("furthest_point_sampling", &furthest_point_sampling);
10
+
11
+ m.def("three_nn", &three_nn);
12
+ m.def("three_interpolate", &three_interpolate);
13
+ m.def("three_interpolate_grad", &three_interpolate_grad);
14
+
15
+ m.def("ball_query", &ball_query);
16
+
17
+ m.def("group_points", &group_points);
18
+ m.def("group_points_grad", &group_points_grad);
19
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "group_points.h"
2
+ #include "utils.h"
3
+
4
+ void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
5
+ const float *points, const int *idx,
6
+ float *out);
7
+
8
+ void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
9
+ int nsample, const float *grad_out,
10
+ const int *idx, float *grad_points);
11
+
12
+ at::Tensor group_points(at::Tensor points, at::Tensor idx) {
13
+ CHECK_CONTIGUOUS(points);
14
+ CHECK_CONTIGUOUS(idx);
15
+ CHECK_IS_FLOAT(points);
16
+ CHECK_IS_INT(idx);
17
+
18
+ if (points.is_cuda()) {
19
+ CHECK_CUDA(idx);
20
+ }
21
+
22
+ at::Tensor output =
23
+ torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
24
+ at::device(points.device()).dtype(at::ScalarType::Float));
25
+
26
+ if (points.is_cuda()) {
27
+ group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
28
+ idx.size(1), idx.size(2),
29
+ points.data_ptr<float>(), idx.data_ptr<int>(),
30
+ output.data_ptr<float>());
31
+ } else {
32
+ AT_ASSERT(false, "CPU not supported");
33
+ }
34
+
35
+ return output;
36
+ }
37
+
38
+ at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
39
+ CHECK_CONTIGUOUS(grad_out);
40
+ CHECK_CONTIGUOUS(idx);
41
+ CHECK_IS_FLOAT(grad_out);
42
+ CHECK_IS_INT(idx);
43
+
44
+ if (grad_out.is_cuda()) {
45
+ CHECK_CUDA(idx);
46
+ }
47
+
48
+ at::Tensor output =
49
+ torch::zeros({grad_out.size(0), grad_out.size(1), n},
50
+ at::device(grad_out.device()).dtype(at::ScalarType::Float));
51
+
52
+ if (grad_out.is_cuda()) {
53
+ group_points_grad_kernel_wrapper(
54
+ grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
55
+ grad_out.data_ptr<float>(), idx.data_ptr<int>(),
56
+ output.data_ptr<float>());
57
+ } else {
58
+ AT_ASSERT(false, "CPU not supported");
59
+ }
60
+
61
+ return output;
62
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+
4
+ #include "cuda_utils.h"
5
+
6
+ // input: points(b, c, n) idx(b, npoints, nsample)
7
+ // output: out(b, c, npoints, nsample)
8
+ __global__ void group_points_kernel(int b, int c, int n, int npoints,
9
+ int nsample,
10
+ const float *__restrict__ points,
11
+ const int *__restrict__ idx,
12
+ float *__restrict__ out) {
13
+ int batch_index = blockIdx.x;
14
+ points += batch_index * n * c;
15
+ idx += batch_index * npoints * nsample;
16
+ out += batch_index * npoints * nsample * c;
17
+
18
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
19
+ const int stride = blockDim.y * blockDim.x;
20
+ for (int i = index; i < c * npoints; i += stride) {
21
+ const int l = i / npoints;
22
+ const int j = i % npoints;
23
+ for (int k = 0; k < nsample; ++k) {
24
+ int ii = idx[j * nsample + k];
25
+ out[(l * npoints + j) * nsample + k] = points[l * n + ii];
26
+ }
27
+ }
28
+ }
29
+
30
+ void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
31
+ const float *points, const int *idx,
32
+ float *out) {
33
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
34
+
35
+ group_points_kernel<<<b, opt_block_config(npoints, c), 0, stream>>>(
36
+ b, c, n, npoints, nsample, points, idx, out);
37
+
38
+ //CUDA_CHECK_ERRORS();
39
+ }
40
+
41
+ // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
42
+ // output: grad_points(b, c, n)
43
+ __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
44
+ int nsample,
45
+ const float *__restrict__ grad_out,
46
+ const int *__restrict__ idx,
47
+ float *__restrict__ grad_points) {
48
+ int batch_index = blockIdx.x;
49
+ grad_out += batch_index * npoints * nsample * c;
50
+ idx += batch_index * npoints * nsample;
51
+ grad_points += batch_index * n * c;
52
+
53
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
54
+ const int stride = blockDim.y * blockDim.x;
55
+ for (int i = index; i < c * npoints; i += stride) {
56
+ const int l = i / npoints;
57
+ const int j = i % npoints;
58
+ for (int k = 0; k < nsample; ++k) {
59
+ int ii = idx[j * nsample + k];
60
+ atomicAdd(grad_points + l * n + ii,
61
+ grad_out[(l * npoints + j) * nsample + k]);
62
+ }
63
+ }
64
+ }
65
+
66
+ void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
67
+ int nsample, const float *grad_out,
68
+ const int *idx, float *grad_points) {
69
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
70
+
71
+ group_points_grad_kernel<<<b, opt_block_config(npoints, c), 0, stream>>>(
72
+ b, c, n, npoints, nsample, grad_out, idx, grad_points);
73
+
74
+ //CUDA_CHECK_ERRORS();
75
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "interpolate.h"
2
+ #include "utils.h"
3
+
4
+ void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
5
+ const float *known, float *dist2, int *idx);
6
+ void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
7
+ const float *points, const int *idx,
8
+ const float *weight, float *out);
9
+ void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
10
+ const float *grad_out,
11
+ const int *idx, const float *weight,
12
+ float *grad_points);
13
+
14
+ std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows) {
15
+ CHECK_CONTIGUOUS(unknowns);
16
+ CHECK_CONTIGUOUS(knows);
17
+ CHECK_IS_FLOAT(unknowns);
18
+ CHECK_IS_FLOAT(knows);
19
+
20
+ if (unknowns.is_cuda()) {
21
+ CHECK_CUDA(knows);
22
+ }
23
+
24
+ at::Tensor idx =
25
+ torch::zeros({unknowns.size(0), unknowns.size(1), 3},
26
+ at::device(unknowns.device()).dtype(at::ScalarType::Int));
27
+ at::Tensor dist2 =
28
+ torch::zeros({unknowns.size(0), unknowns.size(1), 3},
29
+ at::device(unknowns.device()).dtype(at::ScalarType::Float));
30
+
31
+ if (unknowns.is_cuda()) {
32
+ three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
33
+ unknowns.data_ptr<float>(), knows.data_ptr<float>(),
34
+ dist2.data_ptr<float>(), idx.data_ptr<int>());
35
+ } else {
36
+ AT_ASSERT(false, "CPU not supported");
37
+ }
38
+
39
+ return {dist2, idx};
40
+ }
41
+
42
+ at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
43
+ at::Tensor weight) {
44
+ CHECK_CONTIGUOUS(points);
45
+ CHECK_CONTIGUOUS(idx);
46
+ CHECK_CONTIGUOUS(weight);
47
+ CHECK_IS_FLOAT(points);
48
+ CHECK_IS_INT(idx);
49
+ CHECK_IS_FLOAT(weight);
50
+
51
+ if (points.is_cuda()) {
52
+ CHECK_CUDA(idx);
53
+ CHECK_CUDA(weight);
54
+ }
55
+
56
+ at::Tensor output =
57
+ torch::zeros({points.size(0), points.size(1), idx.size(1)},
58
+ at::device(points.device()).dtype(at::ScalarType::Float));
59
+
60
+ if (points.is_cuda()) {
61
+ three_interpolate_kernel_wrapper(
62
+ points.size(0), points.size(1), points.size(2), idx.size(1),
63
+ points.data_ptr<float>(), idx.data_ptr<int>(), weight.data_ptr<float>(),
64
+ output.data_ptr<float>());
65
+ } else {
66
+ AT_ASSERT(false, "CPU not supported");
67
+ }
68
+
69
+ return output;
70
+ }
71
+ at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
72
+ at::Tensor weight, const int m) {
73
+ CHECK_CONTIGUOUS(grad_out);
74
+ CHECK_CONTIGUOUS(idx);
75
+ CHECK_CONTIGUOUS(weight);
76
+ CHECK_IS_FLOAT(grad_out);
77
+ CHECK_IS_INT(idx);
78
+ CHECK_IS_FLOAT(weight);
79
+
80
+ if (grad_out.is_cuda()) {
81
+ CHECK_CUDA(idx);
82
+ CHECK_CUDA(weight);
83
+ }
84
+
85
+ at::Tensor output =
86
+ torch::zeros({grad_out.size(0), grad_out.size(1), m},
87
+ at::device(grad_out.device()).dtype(at::ScalarType::Float));
88
+
89
+ if (grad_out.is_cuda()) {
90
+ three_interpolate_grad_kernel_wrapper(
91
+ grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
92
+ grad_out.data_ptr<float>(), idx.data_ptr<int>(),
93
+ weight.data_ptr<float>(), output.data_ptr<float>());
94
+ } else {
95
+ AT_ASSERT(false, "CPU not supported");
96
+ }
97
+
98
+ return output;
99
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <stdlib.h>
4
+
5
+ #include "cuda_utils.h"
6
+
7
+ // input: unknown(b, n, 3) known(b, m, 3)
8
+ // output: dist2(b, n, 3), idx(b, n, 3)
9
+ __global__ void three_nn_kernel(int b, int n, int m,
10
+ const float *__restrict__ unknown,
11
+ const float *__restrict__ known,
12
+ float *__restrict__ dist2,
13
+ int *__restrict__ idx) {
14
+ int batch_index = blockIdx.x;
15
+ unknown += batch_index * n * 3;
16
+ known += batch_index * m * 3;
17
+ dist2 += batch_index * n * 3;
18
+ idx += batch_index * n * 3;
19
+
20
+ int index = threadIdx.x;
21
+ int stride = blockDim.x;
22
+ for (int j = index; j < n; j += stride) {
23
+ float ux = unknown[j * 3 + 0];
24
+ float uy = unknown[j * 3 + 1];
25
+ float uz = unknown[j * 3 + 2];
26
+
27
+ double best1 = 1e40, best2 = 1e40, best3 = 1e40;
28
+ int besti1 = 0, besti2 = 0, besti3 = 0;
29
+ for (int k = 0; k < m; ++k) {
30
+ float x = known[k * 3 + 0];
31
+ float y = known[k * 3 + 1];
32
+ float z = known[k * 3 + 2];
33
+ float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
34
+ if (d < best1) {
35
+ best3 = best2;
36
+ besti3 = besti2;
37
+ best2 = best1;
38
+ besti2 = besti1;
39
+ best1 = d;
40
+ besti1 = k;
41
+ } else if (d < best2) {
42
+ best3 = best2;
43
+ besti3 = besti2;
44
+ best2 = d;
45
+ besti2 = k;
46
+ } else if (d < best3) {
47
+ best3 = d;
48
+ besti3 = k;
49
+ }
50
+ }
51
+ dist2[j * 3 + 0] = best1;
52
+ dist2[j * 3 + 1] = best2;
53
+ dist2[j * 3 + 2] = best3;
54
+
55
+ idx[j * 3 + 0] = besti1;
56
+ idx[j * 3 + 1] = besti2;
57
+ idx[j * 3 + 2] = besti3;
58
+ }
59
+ }
60
+
61
+ void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
62
+ const float *known, float *dist2, int *idx) {
63
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
64
+ three_nn_kernel<<<b, opt_n_threads(n), 0, stream>>>(b, n, m, unknown, known,
65
+ dist2, idx);
66
+
67
+ //CUDA_CHECK_ERRORS();
68
+ }
69
+
70
+ // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
71
+ // output: out(b, c, n)
72
+ __global__ void three_interpolate_kernel(int b, int c, int m, int n,
73
+ const float *__restrict__ points,
74
+ const int *__restrict__ idx,
75
+ const float *__restrict__ weight,
76
+ float *__restrict__ out) {
77
+ int batch_index = blockIdx.x;
78
+ points += batch_index * m * c;
79
+
80
+ idx += batch_index * n * 3;
81
+ weight += batch_index * n * 3;
82
+
83
+ out += batch_index * n * c;
84
+
85
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
86
+ const int stride = blockDim.y * blockDim.x;
87
+ for (int i = index; i < c * n; i += stride) {
88
+ const int l = i / n;
89
+ const int j = i % n;
90
+ float w1 = weight[j * 3 + 0];
91
+ float w2 = weight[j * 3 + 1];
92
+ float w3 = weight[j * 3 + 2];
93
+
94
+ int i1 = idx[j * 3 + 0];
95
+ int i2 = idx[j * 3 + 1];
96
+ int i3 = idx[j * 3 + 2];
97
+
98
+ out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
99
+ points[l * m + i3] * w3;
100
+ }
101
+ }
102
+
103
+ void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
104
+ const float *points, const int *idx,
105
+ const float *weight, float *out) {
106
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
107
+ three_interpolate_kernel<<<b, opt_block_config(n, c), 0, stream>>>(
108
+ b, c, m, n, points, idx, weight, out);
109
+
110
+ //CUDA_CHECK_ERRORS();
111
+ }
112
+
113
+ // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
114
+ // output: grad_points(b, c, m)
115
+
116
+ __global__ void three_interpolate_grad_kernel(
117
+ int b, int c, int n, int m, const float *__restrict__ grad_out,
118
+ const int *__restrict__ idx, const float *__restrict__ weight,
119
+ float *__restrict__ grad_points) {
120
+ int batch_index = blockIdx.x;
121
+ grad_out += batch_index * n * c;
122
+ idx += batch_index * n * 3;
123
+ weight += batch_index * n * 3;
124
+ grad_points += batch_index * m * c;
125
+
126
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
127
+ const int stride = blockDim.y * blockDim.x;
128
+ for (int i = index; i < c * n; i += stride) {
129
+ const int l = i / n;
130
+ const int j = i % n;
131
+ float w1 = weight[j * 3 + 0];
132
+ float w2 = weight[j * 3 + 1];
133
+ float w3 = weight[j * 3 + 2];
134
+
135
+ int i1 = idx[j * 3 + 0];
136
+ int i2 = idx[j * 3 + 1];
137
+ int i3 = idx[j * 3 + 2];
138
+
139
+ atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
140
+ atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
141
+ atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
142
+ }
143
+ }
144
+
145
+ void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
146
+ const float *grad_out,
147
+ const int *idx, const float *weight,
148
+ float *grad_points) {
149
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
150
+ three_interpolate_grad_kernel<<<b, opt_block_config(n, c), 0, stream>>>(
151
+ b, c, n, m, grad_out, idx, weight, grad_points);
152
+
153
+ CUDA_CHECK_ERRORS();
154
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "sampling.h"
2
+ #include "utils.h"
3
+
4
+ void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
5
+ const float *points, const int *idx,
6
+ float *out);
7
+ void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
8
+ const float *grad_out, const int *idx,
9
+ float *grad_points);
10
+
11
+ void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
12
+ const float *dataset, float *temp,
13
+ int *idxs);
14
+
15
+ at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
16
+ CHECK_CONTIGUOUS(points);
17
+ CHECK_CONTIGUOUS(idx);
18
+ CHECK_IS_FLOAT(points);
19
+ CHECK_IS_INT(idx);
20
+
21
+ if (points.is_cuda()) {
22
+ CHECK_CUDA(idx);
23
+ }
24
+
25
+ at::Tensor output =
26
+ torch::zeros({points.size(0), points.size(1), idx.size(1)},
27
+ at::device(points.device()).dtype(at::ScalarType::Float));
28
+
29
+ if (points.is_cuda()) {
30
+ gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
31
+ idx.size(1), points.data_ptr<float>(),
32
+ idx.data_ptr<int>(), output.data_ptr<float>());
33
+ } else {
34
+ AT_ASSERT(false, "CPU not supported");
35
+ }
36
+
37
+ return output;
38
+ }
39
+
40
+ at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
41
+ const int n) {
42
+ CHECK_CONTIGUOUS(grad_out);
43
+ CHECK_CONTIGUOUS(idx);
44
+ CHECK_IS_FLOAT(grad_out);
45
+ CHECK_IS_INT(idx);
46
+
47
+ if (grad_out.is_cuda()) {
48
+ CHECK_CUDA(idx);
49
+ }
50
+
51
+ at::Tensor output =
52
+ torch::zeros({grad_out.size(0), grad_out.size(1), n},
53
+ at::device(grad_out.device()).dtype(at::ScalarType::Float));
54
+
55
+ if (grad_out.is_cuda()) {
56
+ gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
57
+ idx.size(1), grad_out.data_ptr<float>(),
58
+ idx.data_ptr<int>(),
59
+ output.data_ptr<float>());
60
+ } else {
61
+ AT_ASSERT(false, "CPU not supported");
62
+ }
63
+
64
+ return output;
65
+ }
66
+ at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
67
+ CHECK_CONTIGUOUS(points);
68
+ CHECK_IS_FLOAT(points);
69
+
70
+ at::Tensor output =
71
+ torch::zeros({points.size(0), nsamples},
72
+ at::device(points.device()).dtype(at::ScalarType::Int));
73
+
74
+ at::Tensor tmp =
75
+ torch::full({points.size(0), points.size(1)}, 1e10,
76
+ at::device(points.device()).dtype(at::ScalarType::Float));
77
+
78
+ if (points.is_cuda()) {
79
+ furthest_point_sampling_kernel_wrapper(
80
+ points.size(0), points.size(1), nsamples, points.data_ptr<float>(),
81
+ tmp.data_ptr<float>(), output.data_ptr<int>());
82
+ } else {
83
+ AT_ASSERT(false, "CPU not supported");
84
+ }
85
+
86
+ return output;
87
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+
4
+ #include "cuda_utils.h"
5
+
6
+ // input: points(b, c, n) idx(b, m)
7
+ // output: out(b, c, m)
8
+ __global__ void gather_points_kernel(int b, int c, int n, int m,
9
+ const float *__restrict__ points,
10
+ const int *__restrict__ idx,
11
+ float *__restrict__ out) {
12
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
13
+ for (int l = blockIdx.y; l < c; l += gridDim.y) {
14
+ for (int j = threadIdx.x; j < m; j += blockDim.x) {
15
+ int a = idx[i * m + j];
16
+ out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
17
+ }
18
+ }
19
+ }
20
+ }
21
+
22
+ void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
23
+ const float *points, const int *idx,
24
+ float *out) {
25
+ gather_points_kernel<<<dim3(b, c, 1), opt_n_threads(npoints), 0,
26
+ at::cuda::getCurrentCUDAStream()>>>(b, c, n, npoints,
27
+ points, idx, out);
28
+
29
+ //CUDA_CHECK_ERRORS();
30
+ }
31
+
32
+ // input: grad_out(b, c, m) idx(b, m)
33
+ // output: grad_points(b, c, n)
34
+ __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
35
+ const float *__restrict__ grad_out,
36
+ const int *__restrict__ idx,
37
+ float *__restrict__ grad_points) {
38
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
39
+ for (int l = blockIdx.y; l < c; l += gridDim.y) {
40
+ for (int j = threadIdx.x; j < m; j += blockDim.x) {
41
+ int a = idx[i * m + j];
42
+ atomicAdd(grad_points + (i * c + l) * n + a,
43
+ grad_out[(i * c + l) * m + j]);
44
+ }
45
+ }
46
+ }
47
+ }
48
+
49
+ void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
50
+ const float *grad_out, const int *idx,
51
+ float *grad_points) {
52
+ gather_points_grad_kernel<<<dim3(b, c, 1), opt_n_threads(npoints), 0,
53
+ at::cuda::getCurrentCUDAStream()>>>(
54
+ b, c, n, npoints, grad_out, idx, grad_points);
55
+
56
+ //CUDA_CHECK_ERRORS();
57
+ }
58
+
59
+ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
60
+ int idx1, int idx2) {
61
+ const float v1 = dists[idx1], v2 = dists[idx2];
62
+ const int i1 = dists_i[idx1], i2 = dists_i[idx2];
63
+ dists[idx1] = max(v1, v2);
64
+ dists_i[idx1] = v2 > v1 ? i2 : i1;
65
+ }
66
+
67
+ // Input dataset: (b, n, 3), tmp: (b, n)
68
+ // Ouput idxs (b, m)
69
+ template <unsigned int block_size>
70
+ __global__ void furthest_point_sampling_kernel(
71
+ int b, int n, int m, const float *__restrict__ dataset,
72
+ float *__restrict__ temp, int *__restrict__ idxs) {
73
+ if (m <= 0) return;
74
+ __shared__ float dists[block_size];
75
+ __shared__ int dists_i[block_size];
76
+
77
+ int batch_index = blockIdx.x;
78
+ dataset += batch_index * n * 3;
79
+ temp += batch_index * n;
80
+ idxs += batch_index * m;
81
+
82
+ int tid = threadIdx.x;
83
+ const int stride = block_size;
84
+
85
+ int old = 0;
86
+ if (threadIdx.x == 0) idxs[0] = old;
87
+
88
+ __syncthreads();
89
+ for (int j = 1; j < m; j++) {
90
+ int besti = 0;
91
+ float best = -1;
92
+ float x1 = dataset[old * 3 + 0];
93
+ float y1 = dataset[old * 3 + 1];
94
+ float z1 = dataset[old * 3 + 2];
95
+ for (int k = tid; k < n; k += stride) {
96
+ float x2, y2, z2;
97
+ x2 = dataset[k * 3 + 0];
98
+ y2 = dataset[k * 3 + 1];
99
+ z2 = dataset[k * 3 + 2];
100
+ float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
101
+ if (mag <= 1e-3) continue;
102
+
103
+ float d =
104
+ (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
105
+
106
+ float d2 = min(d, temp[k]);
107
+ temp[k] = d2;
108
+ besti = d2 > best ? k : besti;
109
+ best = d2 > best ? d2 : best;
110
+ }
111
+ dists[tid] = best;
112
+ dists_i[tid] = besti;
113
+ __syncthreads();
114
+
115
+ if (block_size >= 512) {
116
+ if (tid < 256) {
117
+ __update(dists, dists_i, tid, tid + 256);
118
+ }
119
+ __syncthreads();
120
+ }
121
+ if (block_size >= 256) {
122
+ if (tid < 128) {
123
+ __update(dists, dists_i, tid, tid + 128);
124
+ }
125
+ __syncthreads();
126
+ }
127
+ if (block_size >= 128) {
128
+ if (tid < 64) {
129
+ __update(dists, dists_i, tid, tid + 64);
130
+ }
131
+ __syncthreads();
132
+ }
133
+ if (block_size >= 64) {
134
+ if (tid < 32) {
135
+ __update(dists, dists_i, tid, tid + 32);
136
+ }
137
+ __syncthreads();
138
+ }
139
+ if (block_size >= 32) {
140
+ if (tid < 16) {
141
+ __update(dists, dists_i, tid, tid + 16);
142
+ }
143
+ __syncthreads();
144
+ }
145
+ if (block_size >= 16) {
146
+ if (tid < 8) {
147
+ __update(dists, dists_i, tid, tid + 8);
148
+ }
149
+ __syncthreads();
150
+ }
151
+ if (block_size >= 8) {
152
+ if (tid < 4) {
153
+ __update(dists, dists_i, tid, tid + 4);
154
+ }
155
+ __syncthreads();
156
+ }
157
+ if (block_size >= 4) {
158
+ if (tid < 2) {
159
+ __update(dists, dists_i, tid, tid + 2);
160
+ }
161
+ __syncthreads();
162
+ }
163
+ if (block_size >= 2) {
164
+ if (tid < 1) {
165
+ __update(dists, dists_i, tid, tid + 1);
166
+ }
167
+ __syncthreads();
168
+ }
169
+
170
+ old = dists_i[0];
171
+ if (tid == 0) idxs[j] = old;
172
+ }
173
+ }
174
+
175
+ void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
176
+ const float *dataset, float *temp,
177
+ int *idxs) {
178
+ unsigned int n_threads = opt_n_threads(n);
179
+
180
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
181
+
182
+ switch (n_threads) {
183
+ case 512:
184
+ furthest_point_sampling_kernel<512>
185
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
186
+ break;
187
+ case 256:
188
+ furthest_point_sampling_kernel<256>
189
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
190
+ break;
191
+ case 128:
192
+ furthest_point_sampling_kernel<128>
193
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
194
+ break;
195
+ case 64:
196
+ furthest_point_sampling_kernel<64>
197
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
198
+ break;
199
+ case 32:
200
+ furthest_point_sampling_kernel<32>
201
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
202
+ break;
203
+ case 16:
204
+ furthest_point_sampling_kernel<16>
205
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
206
+ break;
207
+ case 8:
208
+ furthest_point_sampling_kernel<8>
209
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
210
+ break;
211
+ case 4:
212
+ furthest_point_sampling_kernel<4>
213
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
214
+ break;
215
+ case 2:
216
+ furthest_point_sampling_kernel<2>
217
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
218
+ break;
219
+ case 1:
220
+ furthest_point_sampling_kernel<1>
221
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
222
+ break;
223
+ default:
224
+ furthest_point_sampling_kernel<512>
225
+ <<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
226
+ }
227
+
228
+ //CUDA_CHECK_ERRORS();
229
+ }
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "3.0.0"
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pointnet2_ops import pointnet2_utils
7
+
8
+
9
+ def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
10
+ layers = []
11
+ for i in range(1, len(mlp_spec)):
12
+ layers.append(
13
+ nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
14
+ )
15
+ if bn:
16
+ layers.append(nn.BatchNorm2d(mlp_spec[i]))
17
+ layers.append(nn.ReLU(True))
18
+
19
+ return nn.Sequential(*layers)
20
+
21
+
22
+ class _PointnetSAModuleBase(nn.Module):
23
+ def __init__(self):
24
+ super(_PointnetSAModuleBase, self).__init__()
25
+ self.npoint = None
26
+ self.groupers = None
27
+ self.mlps = None
28
+
29
+ def forward(
30
+ self, xyz: torch.Tensor, features: Optional[torch.Tensor]
31
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ r"""
33
+ Parameters
34
+ ----------
35
+ xyz : torch.Tensor
36
+ (B, N, 3) tensor of the xyz coordinates of the features
37
+ features : torch.Tensor
38
+ (B, C, N) tensor of the descriptors of the the features
39
+
40
+ Returns
41
+ -------
42
+ new_xyz : torch.Tensor
43
+ (B, npoint, 3) tensor of the new features' xyz
44
+ new_features : torch.Tensor
45
+ (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
46
+ """
47
+
48
+ new_features_list = []
49
+
50
+ xyz_flipped = xyz.transpose(1, 2).contiguous()
51
+ new_xyz = (
52
+ pointnet2_utils.gather_operation(
53
+ xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
54
+ )
55
+ .transpose(1, 2)
56
+ .contiguous()
57
+ if self.npoint is not None
58
+ else None
59
+ )
60
+
61
+ for i in range(len(self.groupers)):
62
+ new_features = self.groupers[i](
63
+ xyz, new_xyz, features
64
+ ) # (B, C, npoint, nsample)
65
+
66
+ new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
67
+ new_features = F.max_pool2d(
68
+ new_features, kernel_size=[1, new_features.size(3)]
69
+ ) # (B, mlp[-1], npoint, 1)
70
+ new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
71
+
72
+ new_features_list.append(new_features)
73
+
74
+ return new_xyz, torch.cat(new_features_list, dim=1)
75
+
76
+
77
+ class PointnetSAModuleMSG(_PointnetSAModuleBase):
78
+ r"""Pointnet set abstrction layer with multiscale grouping
79
+
80
+ Parameters
81
+ ----------
82
+ npoint : int
83
+ Number of features
84
+ radii : list of float32
85
+ list of radii to group with
86
+ nsamples : list of int32
87
+ Number of samples in each ball query
88
+ mlps : list of list of int32
89
+ Spec of the pointnet before the global max_pool for each scale
90
+ bn : bool
91
+ Use batchnorm
92
+ """
93
+
94
+ def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
95
+ # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
96
+ super(PointnetSAModuleMSG, self).__init__()
97
+
98
+ assert len(radii) == len(nsamples) == len(mlps)
99
+
100
+ self.npoint = npoint
101
+ self.groupers = nn.ModuleList()
102
+ self.mlps = nn.ModuleList()
103
+ for i in range(len(radii)):
104
+ radius = radii[i]
105
+ nsample = nsamples[i]
106
+ self.groupers.append(
107
+ pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
108
+ if npoint is not None
109
+ else pointnet2_utils.GroupAll(use_xyz)
110
+ )
111
+ mlp_spec = mlps[i]
112
+ if use_xyz:
113
+ mlp_spec[0] += 3
114
+
115
+ self.mlps.append(build_shared_mlp(mlp_spec, bn))
116
+
117
+
118
+ class PointnetSAModule(PointnetSAModuleMSG):
119
+ r"""Pointnet set abstrction layer
120
+
121
+ Parameters
122
+ ----------
123
+ npoint : int
124
+ Number of features
125
+ radius : float
126
+ Radius of ball
127
+ nsample : int
128
+ Number of samples in the ball query
129
+ mlp : list
130
+ Spec of the pointnet before the global max_pool
131
+ bn : bool
132
+ Use batchnorm
133
+ """
134
+
135
+ def __init__(
136
+ self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
137
+ ):
138
+ # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
139
+ super(PointnetSAModule, self).__init__(
140
+ mlps=[mlp],
141
+ npoint=npoint,
142
+ radii=[radius],
143
+ nsamples=[nsample],
144
+ bn=bn,
145
+ use_xyz=use_xyz,
146
+ )
147
+
148
+
149
+ class PointnetFPModule(nn.Module):
150
+ r"""Propigates the features of one set to another
151
+
152
+ Parameters
153
+ ----------
154
+ mlp : list
155
+ Pointnet module parameters
156
+ bn : bool
157
+ Use batchnorm
158
+ """
159
+
160
+ def __init__(self, mlp, bn=True):
161
+ # type: (PointnetFPModule, List[int], bool) -> None
162
+ super(PointnetFPModule, self).__init__()
163
+ self.mlp = build_shared_mlp(mlp, bn=bn)
164
+
165
+ def forward(self, unknown, known, unknow_feats, known_feats):
166
+ # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
167
+ r"""
168
+ Parameters
169
+ ----------
170
+ unknown : torch.Tensor
171
+ (B, n, 3) tensor of the xyz positions of the unknown features
172
+ known : torch.Tensor
173
+ (B, m, 3) tensor of the xyz positions of the known features
174
+ unknow_feats : torch.Tensor
175
+ (B, C1, n) tensor of the features to be propigated to
176
+ known_feats : torch.Tensor
177
+ (B, C2, m) tensor of features to be propigated
178
+
179
+ Returns
180
+ -------
181
+ new_features : torch.Tensor
182
+ (B, mlp[-1], n) tensor of the features of the unknown features
183
+ """
184
+
185
+ if known is not None:
186
+ dist, idx = pointnet2_utils.three_nn(unknown, known)
187
+ dist_recip = 1.0 / (dist + 1e-8)
188
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
189
+ weight = dist_recip / norm
190
+
191
+ interpolated_feats = pointnet2_utils.three_interpolate(
192
+ known_feats, idx, weight
193
+ )
194
+ else:
195
+ interpolated_feats = known_feats.expand(
196
+ *(known_feats.size()[0:2] + [unknown.size(1)])
197
+ )
198
+
199
+ if unknow_feats is not None:
200
+ new_features = torch.cat(
201
+ [interpolated_feats, unknow_feats], dim=1
202
+ ) # (B, C2 + C1, n)
203
+ else:
204
+ new_features = interpolated_feats
205
+
206
+ new_features = new_features.unsqueeze(-1)
207
+ new_features = self.mlp(new_features)
208
+
209
+ return new_features.squeeze(-1)
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import warnings
4
+ from torch.autograd import Function
5
+ from typing import *
6
+
7
+ try:
8
+ import pointnet2_ops._ext as _ext
9
+ except ImportError:
10
+ from torch.utils.cpp_extension import load
11
+ import glob
12
+ import os.path as osp
13
+ import os
14
+
15
+ warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
16
+
17
+ _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
18
+ _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
19
+ osp.join(_ext_src_root, "src", "*.cu")
20
+ )
21
+ _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
22
+
23
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
24
+ _ext = load(
25
+ "_ext",
26
+ sources=_ext_sources,
27
+ extra_include_paths=[osp.join(_ext_src_root, "include")],
28
+ extra_cflags=["-O3"],
29
+ extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
30
+ with_cuda=True,
31
+ )
32
+
33
+
34
+ class FurthestPointSampling(Function):
35
+ @staticmethod
36
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
37
+ def forward(ctx, xyz, npoint):
38
+ # type: (Any, torch.Tensor, int) -> torch.Tensor
39
+ r"""
40
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
41
+ minimum distance
42
+
43
+ Parameters
44
+ ----------
45
+ xyz : torch.Tensor
46
+ (B, N, 3) tensor where N > npoint
47
+ npoint : int32
48
+ number of features in the sampled set
49
+
50
+ Returns
51
+ -------
52
+ torch.Tensor
53
+ (B, npoint) tensor containing the set
54
+ """
55
+ out = _ext.furthest_point_sampling(xyz, npoint)
56
+
57
+ ctx.mark_non_differentiable(out)
58
+
59
+ return out
60
+
61
+ @staticmethod
62
+ @torch.amp.custom_bwd(device_type="cuda")
63
+ def backward(ctx, grad_out):
64
+ return ()
65
+
66
+
67
+ furthest_point_sample = FurthestPointSampling.apply
68
+
69
+
70
+ class GatherOperation(Function):
71
+ @staticmethod
72
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
73
+ def forward(ctx, features, idx):
74
+ # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
75
+ r"""
76
+
77
+ Parameters
78
+ ----------
79
+ features : torch.Tensor
80
+ (B, C, N) tensor
81
+
82
+ idx : torch.Tensor
83
+ (B, npoint) tensor of the features to gather
84
+
85
+ Returns
86
+ -------
87
+ torch.Tensor
88
+ (B, C, npoint) tensor
89
+ """
90
+
91
+ ctx.save_for_backward(idx, features)
92
+
93
+ return _ext.gather_points(features, idx)
94
+
95
+ @staticmethod
96
+ @torch.amp.custom_bwd(device_type="cuda")
97
+ def backward(ctx, grad_out):
98
+ idx, features = ctx.saved_tensors
99
+ N = features.size(2)
100
+
101
+ grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
102
+ return grad_features, None
103
+
104
+
105
+ gather_operation = GatherOperation.apply
106
+
107
+
108
+ class ThreeNN(Function):
109
+ @staticmethod
110
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
111
+ def forward(ctx, unknown, known):
112
+ # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
113
+ r"""
114
+ Find the three nearest neighbors of unknown in known
115
+ Parameters
116
+ ----------
117
+ unknown : torch.Tensor
118
+ (B, n, 3) tensor of known features
119
+ known : torch.Tensor
120
+ (B, m, 3) tensor of unknown features
121
+
122
+ Returns
123
+ -------
124
+ dist : torch.Tensor
125
+ (B, n, 3) l2 distance to the three nearest neighbors
126
+ idx : torch.Tensor
127
+ (B, n, 3) index of 3 nearest neighbors
128
+ """
129
+ dist2, idx = _ext.three_nn(unknown, known)
130
+ dist = torch.sqrt(dist2)
131
+
132
+ ctx.mark_non_differentiable(dist, idx)
133
+
134
+ return dist, idx
135
+
136
+ @staticmethod
137
+ @torch.amp.custom_bwd(device_type="cuda")
138
+ def backward(ctx, grad_dist, grad_idx):
139
+ return ()
140
+
141
+
142
+ three_nn = ThreeNN.apply
143
+
144
+
145
+ class ThreeInterpolate(Function):
146
+ @staticmethod
147
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
148
+ def forward(ctx, features, idx, weight):
149
+ # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
150
+ r"""
151
+ Performs weight linear interpolation on 3 features
152
+ Parameters
153
+ ----------
154
+ features : torch.Tensor
155
+ (B, c, m) Features descriptors to be interpolated from
156
+ idx : torch.Tensor
157
+ (B, n, 3) three nearest neighbors of the target features in features
158
+ weight : torch.Tensor
159
+ (B, n, 3) weights
160
+
161
+ Returns
162
+ -------
163
+ torch.Tensor
164
+ (B, c, n) tensor of the interpolated features
165
+ """
166
+ ctx.save_for_backward(idx, weight, features)
167
+
168
+ return _ext.three_interpolate(features, idx, weight)
169
+
170
+ @staticmethod
171
+ @torch.amp.custom_bwd(device_type="cuda")
172
+ def backward(ctx, grad_out):
173
+ # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
174
+ r"""
175
+ Parameters
176
+ ----------
177
+ grad_out : torch.Tensor
178
+ (B, c, n) tensor with gradients of ouputs
179
+
180
+ Returns
181
+ -------
182
+ grad_features : torch.Tensor
183
+ (B, c, m) tensor with gradients of features
184
+
185
+ None
186
+
187
+ None
188
+ """
189
+ idx, weight, features = ctx.saved_tensors
190
+ m = features.size(2)
191
+
192
+ grad_features = _ext.three_interpolate_grad(
193
+ grad_out.contiguous(), idx, weight, m
194
+ )
195
+
196
+ return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
197
+
198
+
199
+ three_interpolate = ThreeInterpolate.apply
200
+
201
+
202
+ class GroupingOperation(Function):
203
+ @staticmethod
204
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
205
+ def forward(ctx, features, idx):
206
+ # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
207
+ r"""
208
+
209
+ Parameters
210
+ ----------
211
+ features : torch.Tensor
212
+ (B, C, N) tensor of features to group
213
+ idx : torch.Tensor
214
+ (B, npoint, nsample) tensor containing the indicies of features to group with
215
+
216
+ Returns
217
+ -------
218
+ torch.Tensor
219
+ (B, C, npoint, nsample) tensor
220
+ """
221
+ ctx.save_for_backward(idx, features)
222
+
223
+ return _ext.group_points(features, idx)
224
+
225
+ @staticmethod
226
+ @torch.amp.custom_bwd(device_type="cuda")
227
+ def backward(ctx, grad_out):
228
+ # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
229
+ r"""
230
+
231
+ Parameters
232
+ ----------
233
+ grad_out : torch.Tensor
234
+ (B, C, npoint, nsample) tensor of the gradients of the output from forward
235
+
236
+ Returns
237
+ -------
238
+ torch.Tensor
239
+ (B, C, N) gradient of the features
240
+ None
241
+ """
242
+ idx, features = ctx.saved_tensors
243
+ N = features.size(2)
244
+
245
+ grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
246
+
247
+ return grad_features, torch.zeros_like(idx)
248
+
249
+
250
+ grouping_operation = GroupingOperation.apply
251
+
252
+
253
+ class BallQuery(Function):
254
+ @staticmethod
255
+ @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
256
+ def forward(ctx, radius, nsample, xyz, new_xyz):
257
+ # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
258
+ r"""
259
+
260
+ Parameters
261
+ ----------
262
+ radius : float
263
+ radius of the balls
264
+ nsample : int
265
+ maximum number of features in the balls
266
+ xyz : torch.Tensor
267
+ (B, N, 3) xyz coordinates of the features
268
+ new_xyz : torch.Tensor
269
+ (B, npoint, 3) centers of the ball query
270
+
271
+ Returns
272
+ -------
273
+ torch.Tensor
274
+ (B, npoint, nsample) tensor with the indicies of the features that form the query balls
275
+ """
276
+ output = _ext.ball_query(new_xyz, xyz, radius, nsample)
277
+
278
+ ctx.mark_non_differentiable(output)
279
+
280
+ return output
281
+
282
+ @staticmethod
283
+ @torch.amp.custom_bwd(device_type="cuda")
284
+ def backward(ctx, grad_out):
285
+ return ()
286
+
287
+
288
+ ball_query = BallQuery.apply
289
+
290
+
291
+ class QueryAndGroup(nn.Module):
292
+ r"""
293
+ Groups with a ball query of radius
294
+
295
+ Parameters
296
+ ---------
297
+ radius : float32
298
+ Radius of ball
299
+ nsample : int32
300
+ Maximum number of features to gather in the ball
301
+ """
302
+
303
+ def __init__(self, radius, nsample, use_xyz=True):
304
+ # type: (QueryAndGroup, float, int, bool) -> None
305
+ super(QueryAndGroup, self).__init__()
306
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
307
+
308
+ def forward(self, xyz, new_xyz, features=None):
309
+ # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
310
+ r"""
311
+ Parameters
312
+ ----------
313
+ xyz : torch.Tensor
314
+ xyz coordinates of the features (B, N, 3)
315
+ new_xyz : torch.Tensor
316
+ centriods (B, npoint, 3)
317
+ features : torch.Tensor
318
+ Descriptors of the features (B, C, N)
319
+
320
+ Returns
321
+ -------
322
+ new_features : torch.Tensor
323
+ (B, 3 + C, npoint, nsample) tensor
324
+ """
325
+
326
+ idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
327
+ xyz_trans = xyz.transpose(1, 2).contiguous()
328
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
329
+ grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
330
+
331
+ if features is not None:
332
+ grouped_features = grouping_operation(features, idx)
333
+ if self.use_xyz:
334
+ new_features = torch.cat(
335
+ [grouped_xyz, grouped_features], dim=1
336
+ ) # (B, C + 3, npoint, nsample)
337
+ else:
338
+ new_features = grouped_features
339
+ else:
340
+ assert (
341
+ self.use_xyz
342
+ ), "Cannot have not features and not use xyz as a feature!"
343
+ new_features = grouped_xyz
344
+
345
+ return new_features
346
+
347
+
348
+ class GroupAll(nn.Module):
349
+ r"""
350
+ Groups all features
351
+
352
+ Parameters
353
+ ---------
354
+ """
355
+
356
+ def __init__(self, use_xyz=True):
357
+ # type: (GroupAll, bool) -> None
358
+ super(GroupAll, self).__init__()
359
+ self.use_xyz = use_xyz
360
+
361
+ def forward(self, xyz, new_xyz, features=None):
362
+ # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
363
+ r"""
364
+ Parameters
365
+ ----------
366
+ xyz : torch.Tensor
367
+ xyz coordinates of the features (B, N, 3)
368
+ new_xyz : torch.Tensor
369
+ Ignored
370
+ features : torch.Tensor
371
+ Descriptors of the features (B, C, N)
372
+
373
+ Returns
374
+ -------
375
+ new_features : torch.Tensor
376
+ (B, C + 3, 1, N) tensor
377
+ """
378
+
379
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
380
+ if features is not None:
381
+ grouped_features = features.unsqueeze(2)
382
+ if self.use_xyz:
383
+ new_features = torch.cat(
384
+ [grouped_xyz, grouped_features], dim=1
385
+ ) # (B, 3 + C, 1, N)
386
+ else:
387
+ new_features = grouped_features
388
+ else:
389
+ new_features = grouped_xyz
390
+
391
+ return new_features
hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import os.path as osp
4
+
5
+ from setuptools import find_packages, setup
6
+ import torch
7
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
8
+
9
+ this_dir = osp.dirname(osp.abspath(__file__))
10
+ _ext_src_root = osp.join("pointnet2_ops", "_ext-src")
11
+ _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
12
+ osp.join(_ext_src_root, "src", "*.cu")
13
+ )
14
+ _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
15
+
16
+ requirements = ["torch>=1.4"]
17
+
18
+ exec(open(osp.join("pointnet2_ops", "_version.py")).read())
19
+
20
+ # os.environ["TORCH_CUDA_ARCH_LIST"] = ".".join(map(str, torch.cuda.get_device_capability()))
21
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "5.0;6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
22
+ setup(
23
+ name="pointnet2_ops",
24
+ version=__version__,
25
+ author="Erik Wijmans",
26
+ packages=find_packages(),
27
+ install_requires=requirements,
28
+ ext_modules=[
29
+ CUDAExtension(
30
+ name="pointnet2_ops._ext",
31
+ sources=_ext_sources,
32
+ extra_compile_args={
33
+ "cxx": ["-O3"],
34
+ "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
35
+ },
36
+ include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
37
+ )
38
+ ],
39
+ cmdclass={"build_ext": BuildExtension},
40
+ include_package_data=True,
41
+ )
hort/models/tgs/models/snowflake/skip_transformer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author: Peng Xiang
3
+
4
+ import torch
5
+ from torch import nn, einsum
6
+ from .utils import MLP_Res, grouping_operation, query_knn
7
+
8
+
9
+ class SkipTransformer(nn.Module):
10
+ def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4):
11
+ super(SkipTransformer, self).__init__()
12
+ self.mlp_v = MLP_Res(in_dim=in_channel*2, hidden_dim=in_channel, out_dim=in_channel)
13
+ self.n_knn = n_knn
14
+ self.conv_key = nn.Conv1d(in_channel, dim, 1)
15
+ self.conv_query = nn.Conv1d(in_channel, dim, 1)
16
+ self.conv_value = nn.Conv1d(in_channel, dim, 1)
17
+
18
+ self.pos_mlp = nn.Sequential(
19
+ nn.Conv2d(3, pos_hidden_dim, 1),
20
+ nn.BatchNorm2d(pos_hidden_dim),
21
+ nn.ReLU(),
22
+ nn.Conv2d(pos_hidden_dim, dim, 1)
23
+ )
24
+
25
+ self.attn_mlp = nn.Sequential(
26
+ nn.Conv2d(dim, dim * attn_hidden_multiplier, 1),
27
+ nn.BatchNorm2d(dim * attn_hidden_multiplier),
28
+ nn.ReLU(),
29
+ nn.Conv2d(dim * attn_hidden_multiplier, dim, 1)
30
+ )
31
+
32
+ self.conv_end = nn.Conv1d(dim, in_channel, 1)
33
+
34
+ def forward(self, pos, key, query, include_self=True):
35
+ """
36
+ Args:
37
+ pos: (B, 3, N)
38
+ key: (B, in_channel, N)
39
+ query: (B, in_channel, N)
40
+ include_self: boolean
41
+
42
+ Returns:
43
+ Tensor: (B, in_channel, N), shape context feature
44
+ """
45
+ value = self.mlp_v(torch.cat([key, query], 1))
46
+ identity = value
47
+ key = self.conv_key(key)
48
+ query = self.conv_query(query)
49
+ value = self.conv_value(value)
50
+ b, dim, n = value.shape
51
+
52
+ pos_flipped = pos.permute(0, 2, 1).contiguous()
53
+ idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped, include_self=include_self)
54
+
55
+ key = grouping_operation(key, idx_knn) # b, dim, n, n_knn
56
+ qk_rel = query.reshape((b, -1, n, 1)) - key
57
+
58
+ pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn
59
+ pos_embedding = self.pos_mlp(pos_rel)
60
+
61
+ attention = self.attn_mlp(qk_rel + pos_embedding) # b, dim, n, n_knn
62
+ attention = torch.softmax(attention, -1)
63
+
64
+ value = value.reshape((b, -1, n, 1)) + pos_embedding #
65
+
66
+ agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n
67
+ y = self.conv_end(agg)
68
+
69
+ return y + identity
hort/models/tgs/models/snowflake/utils.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author: Peng Xiang
3
+
4
+ import types
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from torch import nn, einsum
9
+ from pointnet2_ops.pointnet2_utils import furthest_point_sample, \
10
+ gather_operation, ball_query, three_nn, three_interpolate, grouping_operation
11
+
12
+ class Conv1d(nn.Module):
13
+ def __init__(self, in_channel, out_channel, kernel_size=1, stride=1, if_bn=True, activation_fn=torch.relu):
14
+ super(Conv1d, self).__init__()
15
+ self.conv = nn.Conv1d(in_channel, out_channel, kernel_size, stride=stride)
16
+ self.if_bn = if_bn
17
+ self.bn = nn.BatchNorm1d(out_channel)
18
+ self.activation_fn = activation_fn
19
+
20
+ def forward(self, input):
21
+ out = self.conv(input)
22
+ if self.if_bn:
23
+ out = self.bn(out)
24
+
25
+ if self.activation_fn is not None:
26
+ out = self.activation_fn(out)
27
+
28
+ return out
29
+
30
+ class Conv2d(nn.Module):
31
+ def __init__(self, in_channel, out_channel, kernel_size=(1, 1), stride=(1, 1), if_bn=True, activation_fn=torch.relu):
32
+ super(Conv2d, self).__init__()
33
+ self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride)
34
+ self.if_bn = if_bn
35
+ self.bn = nn.BatchNorm2d(out_channel)
36
+ self.activation_fn = activation_fn
37
+
38
+ def forward(self, input):
39
+ out = self.conv(input)
40
+ if self.if_bn:
41
+ out = self.bn(out)
42
+
43
+ if self.activation_fn is not None:
44
+ out = self.activation_fn(out)
45
+
46
+ return out
47
+
48
+ class MLP(nn.Module):
49
+ def __init__(self, in_channel, layer_dims, bn=None):
50
+ super(MLP, self).__init__()
51
+ layers = []
52
+ last_channel = in_channel
53
+ for out_channel in layer_dims[:-1]:
54
+ layers.append(nn.Linear(last_channel, out_channel))
55
+ if bn:
56
+ layers.append(nn.BatchNorm1d(out_channel))
57
+ layers.append(nn.ReLU())
58
+ last_channel = out_channel
59
+ layers.append(nn.Linear(last_channel, layer_dims[-1]))
60
+ self.mlp = nn.Sequential(*layers)
61
+
62
+ def forward(self, inputs):
63
+ return self.mlp(inputs)
64
+
65
+ class MLP_CONV(nn.Module):
66
+ def __init__(self, in_channel, layer_dims, bn=None):
67
+ super(MLP_CONV, self).__init__()
68
+ layers = []
69
+ last_channel = in_channel
70
+ for out_channel in layer_dims[:-1]:
71
+ layers.append(nn.Conv1d(last_channel, out_channel, 1))
72
+ if bn:
73
+ layers.append(nn.BatchNorm1d(out_channel))
74
+ layers.append(nn.ReLU())
75
+ last_channel = out_channel
76
+ layers.append(nn.Conv1d(last_channel, layer_dims[-1], 1))
77
+ self.mlp = nn.Sequential(*layers)
78
+
79
+ def forward(self, inputs):
80
+ return self.mlp(inputs)
81
+
82
+ class MLP_Res(nn.Module):
83
+ def __init__(self, in_dim=128, hidden_dim=None, out_dim=128):
84
+ super(MLP_Res, self).__init__()
85
+ if hidden_dim is None:
86
+ hidden_dim = in_dim
87
+ self.conv_1 = nn.Conv1d(in_dim, hidden_dim, 1)
88
+ self.conv_2 = nn.Conv1d(hidden_dim, out_dim, 1)
89
+ self.conv_shortcut = nn.Conv1d(in_dim, out_dim, 1)
90
+
91
+ def forward(self, x):
92
+ """
93
+ Args:
94
+ x: (B, out_dim, n)
95
+ """
96
+ shortcut = self.conv_shortcut(x)
97
+ out = self.conv_2(torch.relu(self.conv_1(x))) + shortcut
98
+ return out
99
+
100
+
101
+ def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True):
102
+ """
103
+ Args:
104
+ xyz: Tensor, (B, 3, N)
105
+ points: Tensor, (B, f, N)
106
+ npoint: int
107
+ nsample: int
108
+ radius: float
109
+ use_xyz: boolean
110
+
111
+ Returns:
112
+ new_xyz: Tensor, (B, 3, npoint)
113
+ new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample)
114
+ idx_local: Tensor, (B, npoint, nsample)
115
+ grouped_xyz: Tensor, (B, 3, npoint, nsample)
116
+
117
+ """
118
+ xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3)
119
+ new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint)
120
+
121
+ idx = ball_query(radius, nsample, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous()) # (B, npoint, nsample)
122
+ grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample)
123
+ grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, nsample)
124
+
125
+ if points is not None:
126
+ grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample)
127
+ if use_xyz:
128
+ new_points = torch.cat([grouped_xyz, grouped_points], 1)
129
+ else:
130
+ new_points = grouped_points
131
+ else:
132
+ new_points = grouped_xyz
133
+
134
+ return new_xyz, new_points, idx, grouped_xyz
135
+
136
+
137
+ def sample_and_group_all(xyz, points, use_xyz=True):
138
+ """
139
+ Args:
140
+ xyz: Tensor, (B, 3, nsample)
141
+ points: Tensor, (B, f, nsample)
142
+ use_xyz: boolean
143
+
144
+ Returns:
145
+ new_xyz: Tensor, (B, 3, 1)
146
+ new_points: Tensor, (B, f|f+3|3, 1, nsample)
147
+ idx: Tensor, (B, 1, nsample)
148
+ grouped_xyz: Tensor, (B, 3, 1, nsample)
149
+ """
150
+ b, _, nsample = xyz.shape
151
+ device = xyz.device
152
+ new_xyz = torch.zeros((1, 3, 1), dtype=torch.float, device=device).repeat(b, 1, 1)
153
+ grouped_xyz = xyz.reshape((b, 3, 1, nsample))
154
+ idx = torch.arange(nsample, device=device).reshape(1, 1, nsample).repeat(b, 1, 1)
155
+ if points is not None:
156
+ if use_xyz:
157
+ new_points = torch.cat([xyz, points], 1)
158
+ else:
159
+ new_points = points
160
+ new_points = new_points.unsqueeze(2)
161
+ else:
162
+ new_points = grouped_xyz
163
+
164
+ return new_xyz, new_points, idx, grouped_xyz
165
+
166
+
167
+ class PointNet_SA_Module(nn.Module):
168
+ def __init__(self, npoint, nsample, radius, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True):
169
+ """
170
+ Args:
171
+ npoint: int, number of points to sample
172
+ nsample: int, number of points in each local region
173
+ radius: float
174
+ in_channel: int, input channel of features(points)
175
+ mlp: list of int,
176
+ """
177
+ super(PointNet_SA_Module, self).__init__()
178
+ self.npoint = npoint
179
+ self.nsample = nsample
180
+ self.radius = radius
181
+ self.mlp = mlp
182
+ self.group_all = group_all
183
+ self.use_xyz = use_xyz
184
+ if use_xyz:
185
+ in_channel += 3
186
+
187
+ last_channel = in_channel
188
+ self.mlp_conv = []
189
+ for out_channel in mlp:
190
+ self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn))
191
+ last_channel = out_channel
192
+
193
+ self.mlp_conv = nn.Sequential(*self.mlp_conv)
194
+
195
+ def forward(self, xyz, points):
196
+ """
197
+ Args:
198
+ xyz: Tensor, (B, 3, N)
199
+ points: Tensor, (B, f, N)
200
+
201
+ Returns:
202
+ new_xyz: Tensor, (B, 3, npoint)
203
+ new_points: Tensor, (B, mlp[-1], npoint)
204
+ """
205
+ if self.group_all:
206
+ new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz)
207
+ else:
208
+ new_xyz, new_points, idx, grouped_xyz = sample_and_group(xyz, points, self.npoint, self.nsample, self.radius, self.use_xyz)
209
+
210
+ new_points = self.mlp_conv(new_points)
211
+ new_points = torch.max(new_points, 3)[0]
212
+
213
+ return new_xyz, new_points
214
+
215
+
216
+ class PointNet_FP_Module(nn.Module):
217
+ def __init__(self, in_channel, mlp, use_points1=False, in_channel_points1=None, if_bn=True):
218
+ """
219
+ Args:
220
+ in_channel: int, input channel of points2
221
+ mlp: list of int
222
+ use_points1: boolean, if use points
223
+ in_channel_points1: int, input channel of points1
224
+ """
225
+ super(PointNet_FP_Module, self).__init__()
226
+ self.use_points1 = use_points1
227
+
228
+ if use_points1:
229
+ in_channel += in_channel_points1
230
+
231
+ last_channel = in_channel
232
+ self.mlp_conv = []
233
+ for out_channel in mlp:
234
+ self.mlp_conv.append(Conv1d(last_channel, out_channel, if_bn=if_bn))
235
+ last_channel = out_channel
236
+
237
+ self.mlp_conv = nn.Sequential(*self.mlp_conv)
238
+
239
+ def forward(self, xyz1, xyz2, points1, points2):
240
+ """
241
+ Args:
242
+ xyz1: Tensor, (B, 3, N)
243
+ xyz2: Tensor, (B, 3, M)
244
+ points1: Tensor, (B, in_channel, N)
245
+ points2: Tensor, (B, in_channel, M)
246
+
247
+ Returns:MLP_CONV
248
+ new_points: Tensor, (B, mlp[-1], N)
249
+ """
250
+ dist, idx = three_nn(xyz1.permute(0, 2, 1).contiguous(), xyz2.permute(0, 2, 1).contiguous())
251
+ dist = torch.clamp_min(dist, 1e-10) # (B, N, 3)
252
+ recip_dist = 1.0/dist
253
+ norm = torch.sum(recip_dist, 2, keepdim=True).repeat((1, 1, 3))
254
+ weight = recip_dist / norm
255
+ interpolated_points = three_interpolate(points2, idx, weight) # B, in_channel, N
256
+
257
+ if self.use_points1:
258
+ new_points = torch.cat([interpolated_points, points1], 1)
259
+ else:
260
+ new_points = interpolated_points
261
+
262
+ new_points = self.mlp_conv(new_points)
263
+ return new_points
264
+
265
+
266
+ def square_distance(src, dst):
267
+ """
268
+ Calculate Euclid distance between each two points.
269
+
270
+ src^T * dst = xn * xm + yn * ym + zn * zm;
271
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
272
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
273
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
274
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
275
+
276
+ Input:
277
+ src: source points, [B, N, C]
278
+ dst: target points, [B, M, C]
279
+ Output:
280
+ dist: per-point square distance, [B, N, M]
281
+ """
282
+ B, N, _ = src.shape
283
+ _, M, _ = dst.shape
284
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) # B, N, M
285
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
286
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
287
+ return dist
288
+
289
+
290
+ def query_knn(nsample, xyz, new_xyz, include_self=True):
291
+ """Find k-NN of new_xyz in xyz"""
292
+ pad = 0 if include_self else 1
293
+ sqrdists = square_distance(new_xyz, xyz) # B, S, N
294
+ idx = torch.argsort(sqrdists, dim=-1, descending=False)[:, :, pad: nsample+pad]
295
+ return idx.int()
296
+
297
+
298
+ def sample_and_group_knn(xyz, points, npoint, k, use_xyz=True, idx=None):
299
+ """
300
+ Args:
301
+ xyz: Tensor, (B, 3, N)
302
+ points: Tensor, (B, f, N)
303
+ npoint: int
304
+ nsample: int
305
+ radius: float
306
+ use_xyz: boolean
307
+
308
+ Returns:
309
+ new_xyz: Tensor, (B, 3, npoint)
310
+ new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample)
311
+ idx_local: Tensor, (B, npoint, nsample)
312
+ grouped_xyz: Tensor, (B, 3, npoint, nsample)
313
+
314
+ """
315
+ xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3)
316
+ new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint)
317
+ if idx is None:
318
+ idx = query_knn(k, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous())
319
+ grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample)
320
+ grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, k)
321
+
322
+ if points is not None:
323
+ grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample)
324
+ if use_xyz:
325
+ new_points = torch.cat([grouped_xyz, grouped_points], 1)
326
+ else:
327
+ new_points = grouped_points
328
+ else:
329
+ new_points = grouped_xyz
330
+
331
+ return new_xyz, new_points, idx, grouped_xyz
332
+
333
+
334
+ class PointNet_SA_Module_KNN(nn.Module):
335
+ def __init__(self, npoint, nsample, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True, if_idx=False):
336
+ """
337
+ Args:
338
+ npoint: int, number of points to sample
339
+ nsample: int, number of points in each local region
340
+ radius: float
341
+ in_channel: int, input channel of features(points)
342
+ mlp: list of int,
343
+ """
344
+ super(PointNet_SA_Module_KNN, self).__init__()
345
+ self.npoint = npoint
346
+ self.nsample = nsample
347
+ self.mlp = mlp
348
+ self.group_all = group_all
349
+ self.use_xyz = use_xyz
350
+ self.if_idx = if_idx
351
+ if use_xyz:
352
+ in_channel += 3
353
+
354
+ last_channel = in_channel
355
+ self.mlp_conv = []
356
+ for out_channel in mlp[:-1]:
357
+ self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn))
358
+ last_channel = out_channel
359
+ self.mlp_conv.append(Conv2d(last_channel, mlp[-1], if_bn=False, activation_fn=None))
360
+ self.mlp_conv = nn.Sequential(*self.mlp_conv)
361
+
362
+ def forward(self, xyz, points, idx=None):
363
+ """
364
+ Args:
365
+ xyz: Tensor, (B, 3, N)
366
+ points: Tensor, (B, f, N)
367
+
368
+ Returns:
369
+ new_xyz: Tensor, (B, 3, npoint)
370
+ new_points: Tensor, (B, mlp[-1], npoint)
371
+ """
372
+ if self.group_all:
373
+ new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz)
374
+ else:
375
+ new_xyz, new_points, idx, grouped_xyz = sample_and_group_knn(xyz, points, self.npoint, self.nsample, self.use_xyz, idx=idx)
376
+
377
+ new_points = self.mlp_conv(new_points)
378
+ new_points = torch.max(new_points, 3)[0]
379
+
380
+ if self.if_idx:
381
+ return new_xyz, new_points, idx
382
+ else:
383
+ return new_xyz, new_points
384
+
385
+
386
+ def fps_subsample(pcd, n_points=2048):
387
+ """
388
+ Args
389
+ pcd: (b, 16384, 3)
390
+
391
+ returns
392
+ new_pcd: (b, n_points, 3)
393
+ """
394
+ new_pcd = gather_operation(pcd.permute(0, 2, 1).contiguous(), furthest_point_sample(pcd, n_points))
395
+ new_pcd = new_pcd.permute(0, 2, 1).contiguous()
396
+ return new_pcd
397
+
398
+
399
+ class Transformer(nn.Module):
400
+ def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4):
401
+ super(Transformer, self).__init__()
402
+ self.n_knn = n_knn
403
+ self.conv_key = nn.Conv1d(dim, dim, 1)
404
+ self.conv_query = nn.Conv1d(dim, dim, 1)
405
+ self.conv_value = nn.Conv1d(dim, dim, 1)
406
+
407
+ self.pos_mlp = nn.Sequential(
408
+ nn.Conv2d(3, pos_hidden_dim, 1),
409
+ nn.BatchNorm2d(pos_hidden_dim),
410
+ nn.ReLU(),
411
+ nn.Conv2d(pos_hidden_dim, dim, 1)
412
+ )
413
+
414
+ self.attn_mlp = nn.Sequential(
415
+ nn.Conv2d(dim, dim * attn_hidden_multiplier, 1),
416
+ nn.BatchNorm2d(dim * attn_hidden_multiplier),
417
+ nn.ReLU(),
418
+ nn.Conv2d(dim * attn_hidden_multiplier, dim, 1)
419
+ )
420
+
421
+ self.linear_start = nn.Conv1d(in_channel, dim, 1)
422
+ self.linear_end = nn.Conv1d(dim, in_channel, 1)
423
+
424
+ def forward(self, x, pos):
425
+ """feed forward of transformer
426
+ Args:
427
+ x: Tensor of features, (B, in_channel, n)
428
+ pos: Tensor of positions, (B, 3, n)
429
+
430
+ Returns:
431
+ y: Tensor of features with attention, (B, in_channel, n)
432
+ """
433
+
434
+ identity = x
435
+
436
+ x = self.linear_start(x)
437
+ b, dim, n = x.shape
438
+
439
+ pos_flipped = pos.permute(0, 2, 1).contiguous()
440
+ idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped)
441
+ key = self.conv_key(x)
442
+ value = self.conv_value(x)
443
+ query = self.conv_query(x)
444
+
445
+ key = grouping_operation(key, idx_knn) # b, dim, n, n_knn
446
+ qk_rel = query.reshape((b, -1, n, 1)) - key
447
+
448
+ pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn
449
+ pos_embedding = self.pos_mlp(pos_rel) # b, dim, n, n_knn
450
+
451
+ attention = self.attn_mlp(qk_rel + pos_embedding)
452
+ attention = torch.softmax(attention, -1)
453
+
454
+ value = value.reshape((b, -1, n, 1)) + pos_embedding
455
+
456
+ agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n
457
+ y = self.linear_end(agg)
458
+
459
+ return y+identity
460
+
461
+
462
+ class CouplingLayer(nn.Module):
463
+
464
+ def __init__(self, d, intermediate_dim, swap=False):
465
+ nn.Module.__init__(self)
466
+ self.d = d - (d // 2)
467
+ self.swap = swap
468
+ self.net_s_t = nn.Sequential(
469
+ nn.Linear(self.d, intermediate_dim),
470
+ nn.ReLU(inplace=True),
471
+ nn.Linear(intermediate_dim, intermediate_dim),
472
+ nn.ReLU(inplace=True),
473
+ nn.Linear(intermediate_dim, (d - self.d) * 2),
474
+ )
475
+
476
+ def forward(self, x, logpx=None, reverse=False):
477
+
478
+ if self.swap:
479
+ x = torch.cat([x[:, self.d:], x[:, :self.d]], 1)
480
+
481
+ in_dim = self.d
482
+ out_dim = x.shape[1] - self.d
483
+
484
+ s_t = self.net_s_t(x[:, :in_dim])
485
+ scale = torch.sigmoid(s_t[:, :out_dim] + 2.)
486
+ shift = s_t[:, out_dim:]
487
+
488
+ logdetjac = torch.sum(torch.log(scale).view(scale.shape[0], -1), 1, keepdim=True)
489
+
490
+ if not reverse:
491
+ y1 = x[:, self.d:] * scale + shift
492
+ delta_logp = -logdetjac
493
+ else:
494
+ y1 = (x[:, self.d:] - shift) / scale
495
+ delta_logp = logdetjac
496
+
497
+ y = torch.cat([x[:, :self.d], y1], 1) if not self.swap else torch.cat([y1, x[:, :self.d]], 1)
498
+
499
+ if logpx is None:
500
+ return y
501
+ else:
502
+ return y, logpx + delta_logp
503
+
504
+
505
+ class SequentialFlow(nn.Module):
506
+ """A generalized nn.Sequential container for normalizing flows.
507
+ """
508
+
509
+ def __init__(self, layersList):
510
+ super(SequentialFlow, self).__init__()
511
+ self.chain = nn.ModuleList(layersList)
512
+
513
+ def forward(self, x, logpx=None, reverse=False, inds=None):
514
+ if inds is None:
515
+ if reverse:
516
+ inds = range(len(self.chain) - 1, -1, -1)
517
+ else:
518
+ inds = range(len(self.chain))
519
+
520
+ if logpx is None:
521
+ for i in inds:
522
+ x = self.chain[i](x, reverse=reverse)
523
+ return x
524
+ else:
525
+ for i in inds:
526
+ x, logpx = self.chain[i](x, logpx, reverse=reverse)
527
+ return x, logpx
528
+
529
+
530
+ def build_latent_flow(args):
531
+ chain = []
532
+ for i in range(args.latent_flow_depth):
533
+ chain.append(CouplingLayer(args.latent_dim, args.latent_flow_hidden_dim, swap=(i % 2 == 0)))
534
+ return SequentialFlow(chain)
535
+
536
+
537
+ ##################
538
+ ## SpectralNorm ##
539
+ ##################
540
+
541
+ POWER_ITERATION_FN = "spectral_norm_power_iteration"
542
+
543
+
544
+ class SpectralNorm(object):
545
+ def __init__(self, name='weight', dim=0, eps=1e-12):
546
+ self.name = name
547
+ self.dim = dim
548
+ self.eps = eps
549
+
550
+ def compute_weight(self, module, n_power_iterations):
551
+ if n_power_iterations < 0:
552
+ raise ValueError(
553
+ 'Expected n_power_iterations to be non-negative, but '
554
+ 'got n_power_iterations={}'.format(n_power_iterations)
555
+ )
556
+
557
+ weight = getattr(module, self.name + '_orig')
558
+ u = getattr(module, self.name + '_u')
559
+ v = getattr(module, self.name + '_v')
560
+ weight_mat = weight
561
+ if self.dim != 0:
562
+ # permute dim to front
563
+ weight_mat = weight_mat.permute(self.dim, * [d for d in range(weight_mat.dim()) if d != self.dim])
564
+ height = weight_mat.size(0)
565
+ weight_mat = weight_mat.reshape(height, -1)
566
+ with torch.no_grad():
567
+ for _ in range(n_power_iterations):
568
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
569
+ # are the first left and right singular vectors.
570
+ # This power iteration produces approximations of `u` and `v`.
571
+ v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
572
+ u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
573
+ setattr(module, self.name + '_u', u)
574
+ setattr(module, self.name + '_v', v)
575
+
576
+ sigma = torch.dot(u, torch.matmul(weight_mat, v))
577
+ weight = weight / sigma
578
+ setattr(module, self.name, weight)
579
+
580
+ def remove(self, module):
581
+ weight = getattr(module, self.name)
582
+ delattr(module, self.name)
583
+ delattr(module, self.name + '_u')
584
+ delattr(module, self.name + '_orig')
585
+ module.register_parameter(self.name, torch.nn.Parameter(weight))
586
+
587
+ def get_update_method(self, module):
588
+ def update_fn(module, n_power_iterations):
589
+ self.compute_weight(module, n_power_iterations)
590
+
591
+ return update_fn
592
+
593
+ def __call__(self, module, unused_inputs):
594
+ del unused_inputs
595
+ self.compute_weight(module, n_power_iterations=0)
596
+
597
+ # requires_grad might be either True or False during inference.
598
+ if not module.training:
599
+ r_g = getattr(module, self.name + '_orig').requires_grad
600
+ setattr(module, self.name, getattr(module, self.name).detach().requires_grad_(r_g))
601
+
602
+ @staticmethod
603
+ def apply(module, name, dim, eps):
604
+ fn = SpectralNorm(name, dim, eps)
605
+ weight = module._parameters[name]
606
+ height = weight.size(dim)
607
+
608
+ u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
609
+ v = F.normalize(weight.new_empty(int(weight.numel() / height)).normal_(0, 1), dim=0, eps=fn.eps)
610
+ delattr(module, fn.name)
611
+ module.register_parameter(fn.name + "_orig", weight)
612
+ # We still need to assign weight back as fn.name because all sorts of
613
+ # things may assume that it exists, e.g., when initializing weights.
614
+ # However, we can't directly assign as it could be an nn.Parameter and
615
+ # gets added as a parameter. Instead, we register weight.data as a
616
+ # buffer, which will cause weight to be included in the state dict
617
+ # and also supports nn.init due to shared storage.
618
+ module.register_buffer(fn.name, weight.data)
619
+ module.register_buffer(fn.name + "_u", u)
620
+ module.register_buffer(fn.name + "_v", v)
621
+
622
+ setattr(module, POWER_ITERATION_FN, types.MethodType(fn.get_update_method(module), module))
623
+
624
+ module.register_forward_pre_hook(fn)
625
+ return fn
626
+
627
+
628
+ def inplace_spectral_norm(module, name='weight', dim=None, eps=1e-12):
629
+ r"""Applies spectral normalization to a parameter in the given module.
630
+ .. math::
631
+ \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
632
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
633
+ Spectral normalization stabilizes the training of discriminators (critics)
634
+ in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
635
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
636
+ power iteration method. If the dimension of the weight tensor is greater
637
+ than 2, it is reshaped to 2D in power iteration method to get spectral
638
+ norm. This is implemented via a hook that calculates spectral norm and
639
+ rescales weight before every :meth:`~Module.forward` call.
640
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
641
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
642
+ Args:
643
+ module (nn.Module): containing module
644
+ name (str, optional): name of weight parameter
645
+ n_power_iterations (int, optional): number of power iterations to
646
+ calculate spectal norm
647
+ dim (int, optional): dimension corresponding to number of outputs,
648
+ the default is 0, except for modules that are instances of
649
+ ConvTranspose1/2/3d, when it is 1
650
+ eps (float, optional): epsilon for numerical stability in
651
+ calculating norms
652
+ Returns:
653
+ The original module with the spectal norm hook
654
+ Example::
655
+ >>> m = spectral_norm(nn.Linear(20, 40))
656
+ Linear (20 -> 40)
657
+ >>> m.weight_u.size()
658
+ torch.Size([20])
659
+ """
660
+ if dim is None:
661
+ if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
662
+ dim = 1
663
+ else:
664
+ dim = 0
665
+ SpectralNorm.apply(module, name, dim=dim, eps=eps)
666
+ return module
667
+
668
+
669
+ def remove_spectral_norm(module, name='weight'):
670
+ r"""Removes the spectral normalization reparameterization from a module.
671
+ Args:
672
+ module (nn.Module): containing module
673
+ name (str, optional): name of weight parameter
674
+ Example:
675
+ >>> m = spectral_norm(nn.Linear(40, 10))
676
+ >>> remove_spectral_norm(m)
677
+ """
678
+ for k, hook in module._forward_pre_hooks.items():
679
+ if isinstance(hook, SpectralNorm) and hook.name == name:
680
+ hook.remove(module)
681
+ del module._forward_pre_hooks[k]
682
+ return module
683
+
684
+ raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
685
+
686
+
687
+ def add_spectral_norm(model, logger=None):
688
+ """Applies spectral norm to all modules within the scope of a CNF."""
689
+
690
+ def apply_spectral_norm(module):
691
+ if 'weight' in module._parameters:
692
+ if logger: logger.info("Adding spectral norm to {}".format(module))
693
+ inplace_spectral_norm(module, 'weight')
694
+
695
+ def find_coupling_layer(module):
696
+ if isinstance(module, CouplingLayer):
697
+ module.apply(apply_spectral_norm)
698
+ else:
699
+ for child in module.children():
700
+ find_coupling_layer(child)
701
+
702
+ find_coupling_layer(model)
703
+
704
+
705
+ def spectral_norm_power_iteration(model, n_power_iterations=1):
706
+
707
+ def recursive_power_iteration(module):
708
+ if hasattr(module, POWER_ITERATION_FN):
709
+ getattr(module, POWER_ITERATION_FN)(n_power_iterations)
710
+
711
+ model.apply(recursive_power_iteration)
712
+
713
+ def reparameterize_gaussian(mean, logvar):
714
+ std = torch.exp(0.5 * logvar)
715
+ eps = torch.randn(std.size()).to(mean)
716
+ return mean + std * eps
717
+
718
+
719
+
720
+ def gaussian_entropy(logvar):
721
+ const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
722
+ ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
723
+ return ent
724
+
725
+
726
+ def standard_normal_logprob(z):
727
+ dim = z.size(-1)
728
+ log_z = -0.5 * dim * np.log(2 * np.pi)
729
+ return log_z - z.pow(2) / 2
730
+
731
+ def truncated_normal_(tensor, mean=0, std=1, trunc_std=2):
732
+ """
733
+ Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
734
+ """
735
+ size = tensor.shape
736
+ tmp = tensor.new_empty(size + (4,)).normal_()
737
+ valid = (tmp < trunc_std) & (tmp > -trunc_std)
738
+ ind = valid.max(-1, keepdim=True)[1]
739
+ tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
740
+ tensor.data.mul_(std).add_(mean)
741
+ return tensor
hort/models/tgs/models/tokenizers/dinov2.py ADDED
@@ -0,0 +1,1179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch DINOv2 model."""
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+ from dataclasses import dataclass
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (
31
+ BackboneOutput,
32
+ BaseModelOutput,
33
+ BaseModelOutputWithPooling,
34
+ ImageClassifierOutput,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.pytorch_utils import (
38
+ find_pruneable_heads_and_indices,
39
+ prune_linear_layer,
40
+ )
41
+ from transformers.utils import (
42
+ add_code_sample_docstrings,
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from transformers.utils.backbone_utils import BackboneMixin
49
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
50
+
51
+ from tgs.models.transformers import MemoryEfficientAttentionMixin
52
+ from tgs.utils.typing import *
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ # General docstring
58
+ _CONFIG_FOR_DOC = "Dinov2Config"
59
+
60
+ # Base docstring
61
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
62
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
63
+
64
+ # Image classification docstring
65
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
66
+
67
+
68
+ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
69
+ "facebook/dinov2-base",
70
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
71
+ ]
72
+
73
+
74
+ class Dinov2Embeddings(nn.Module):
75
+ """
76
+ Construct the CLS token, mask token, position and patch embeddings.
77
+ """
78
+
79
+ def __init__(self, config: Dinov2Config) -> None:
80
+ super().__init__()
81
+
82
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
83
+ # register as mask token as it's not used in optimization
84
+ # to avoid the use of find_unused_parameters_true
85
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
86
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
87
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
88
+ num_patches = self.patch_embeddings.num_patches
89
+ self.position_embeddings = nn.Parameter(
90
+ torch.randn(1, num_patches + 1, config.hidden_size)
91
+ )
92
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
93
+ self.config = config
94
+
95
+ def interpolate_pos_encoding(
96
+ self, embeddings: torch.Tensor, height: int, width: int
97
+ ) -> torch.Tensor:
98
+ """
99
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
100
+ resolution images.
101
+
102
+ Source:
103
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
104
+ """
105
+
106
+ num_patches = embeddings.shape[1] - 1
107
+ num_positions = self.position_embeddings.shape[1] - 1
108
+ if num_patches == num_positions and height == width:
109
+ return self.position_embeddings
110
+ class_pos_embed = self.position_embeddings[:, 0]
111
+ patch_pos_embed = self.position_embeddings[:, 1:]
112
+ dim = embeddings.shape[-1]
113
+ height = height // self.config.patch_size
114
+ width = width // self.config.patch_size
115
+ # we add a small number to avoid floating point error in the interpolation
116
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
117
+ height, width = height + 0.1, width + 0.1
118
+ patch_pos_embed = patch_pos_embed.reshape(
119
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
120
+ )
121
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
122
+ patch_pos_embed = nn.functional.interpolate(
123
+ patch_pos_embed,
124
+ scale_factor=(
125
+ height / math.sqrt(num_positions),
126
+ width / math.sqrt(num_positions),
127
+ ),
128
+ mode="bicubic",
129
+ align_corners=False,
130
+ )
131
+ if (
132
+ int(height) != patch_pos_embed.shape[-2]
133
+ or int(width) != patch_pos_embed.shape[-1]
134
+ ):
135
+ raise ValueError(
136
+ "Width or height does not match with the interpolated position embeddings"
137
+ )
138
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
139
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
140
+
141
+ def forward(
142
+ self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None,
143
+ ) -> torch.Tensor:
144
+ batch_size, _, height, width = pixel_values.shape
145
+ patch_embeddings = self.patch_embeddings(pixel_values)
146
+ embeddings = patch_embeddings
147
+
148
+ if bool_masked_pos is not None:
149
+ embeddings = torch.where(
150
+ bool_masked_pos.unsqueeze(-1),
151
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
152
+ embeddings,
153
+ )
154
+
155
+ # add the [CLS] token to the embedded patch tokens
156
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
157
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
158
+
159
+ # add positional encoding to each token
160
+ embeddings = embeddings + self.interpolate_pos_encoding(
161
+ embeddings, height, width
162
+ )
163
+
164
+ embeddings = self.dropout(embeddings)
165
+
166
+ return embeddings
167
+
168
+
169
+ class Dinov2PatchEmbeddings(nn.Module):
170
+ """
171
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
172
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
173
+ Transformer.
174
+ """
175
+
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ image_size, patch_size = config.image_size, config.patch_size
179
+ num_channels, hidden_size = config.num_channels, config.hidden_size
180
+
181
+ image_size = (
182
+ image_size
183
+ if isinstance(image_size, collections.abc.Iterable)
184
+ else (image_size, image_size)
185
+ )
186
+ patch_size = (
187
+ patch_size
188
+ if isinstance(patch_size, collections.abc.Iterable)
189
+ else (patch_size, patch_size)
190
+ )
191
+ num_patches = (image_size[1] // patch_size[1]) * (
192
+ image_size[0] // patch_size[0]
193
+ )
194
+ self.image_size = image_size
195
+ self.patch_size = patch_size
196
+ self.num_channels = num_channels
197
+ self.num_patches = num_patches
198
+
199
+ self.projection = nn.Conv2d(
200
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
201
+ )
202
+
203
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
204
+ num_channels = pixel_values.shape[1]
205
+ if num_channels != self.num_channels:
206
+ raise ValueError(
207
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
208
+ f" Expected {self.num_channels} but got {num_channels}."
209
+ )
210
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
211
+ return embeddings
212
+
213
+
214
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
215
+ class Dinov2SelfAttention(nn.Module):
216
+ def __init__(self, config: Dinov2Config) -> None:
217
+ super().__init__()
218
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
219
+ config, "embedding_size"
220
+ ):
221
+ raise ValueError(
222
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
223
+ f"heads {config.num_attention_heads}."
224
+ )
225
+
226
+ self.num_attention_heads = config.num_attention_heads
227
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
228
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
229
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
230
+
231
+ self.query = nn.Linear(
232
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
233
+ )
234
+ self.key = nn.Linear(
235
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
236
+ )
237
+ self.value = nn.Linear(
238
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
239
+ )
240
+
241
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
242
+
243
+ self.use_memory_efficient_attention_xformers: bool = False
244
+
245
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
246
+ new_x_shape = x.size()[:-1] + (
247
+ self.num_attention_heads,
248
+ self.attention_head_size,
249
+ )
250
+ x = x.view(new_x_shape)
251
+ return x.permute(0, 2, 1, 3)
252
+
253
+ def forward(
254
+ self,
255
+ hidden_states,
256
+ head_mask: Optional[torch.Tensor] = None,
257
+ output_attentions: bool = False,
258
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
259
+ mixed_query_layer = self.query(hidden_states)
260
+
261
+ if self.use_memory_efficient_attention_xformers:
262
+ import xformers
263
+ assert head_mask is None and not output_attentions
264
+ new_size = hidden_states.size()[:-1] + (
265
+ self.num_attention_heads,
266
+ self.attention_head_size,
267
+ )
268
+ key_layer = self.key(hidden_states).view(new_size)
269
+ value_layer = self.value(hidden_states).view(new_size)
270
+ query_layer = mixed_query_layer.view(new_size)
271
+ context_layer = xformers.ops.memory_efficient_attention(
272
+ query_layer, key_layer, value_layer, p=self.attention_probs_dropout_prob
273
+ )
274
+ context_layer = context_layer.view(*hidden_states.size()[:-1], -1)
275
+ else:
276
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
277
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
278
+ query_layer = self.transpose_for_scores(mixed_query_layer)
279
+
280
+ try:
281
+ context_layer = F.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask=head_mask, dropout_p=(self.dropout.p if self.training else 0.0), scale=1/math.sqrt(self.attention_head_size))
282
+ except:
283
+ # Take the dot product between "query" and "key" to get the raw attention scores.
284
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
285
+
286
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
287
+
288
+ # Normalize the attention scores to probabilities.
289
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
290
+
291
+ # This is actually dropping out entire tokens to attend to, which might
292
+ # seem a bit unusual, but is taken from the original Transformer paper.
293
+ attention_probs = self.dropout(attention_probs)
294
+
295
+ # Mask heads if we want to
296
+ if head_mask is not None:
297
+ attention_probs = attention_probs * head_mask
298
+
299
+ context_layer = torch.matmul(attention_probs, value_layer)
300
+
301
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
302
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
303
+ context_layer = context_layer.view(new_context_layer_shape)
304
+
305
+ outputs = (
306
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
307
+ )
308
+
309
+ return outputs
310
+
311
+ def set_use_memory_efficient_attention_xformers(
312
+ self, valid: bool, attention_op: Optional[Callable] = None
313
+ ):
314
+ self.use_memory_efficient_attention_xformers = valid
315
+
316
+
317
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
318
+ class Dinov2SelfOutput(nn.Module):
319
+ """
320
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
321
+ layernorm applied before each block.
322
+ """
323
+
324
+ def __init__(self, config: Dinov2Config) -> None:
325
+ super().__init__()
326
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
327
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
328
+
329
+ def forward(
330
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
331
+ ) -> torch.Tensor:
332
+ hidden_states = self.dense(hidden_states)
333
+ hidden_states = self.dropout(hidden_states)
334
+
335
+ return hidden_states
336
+
337
+
338
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
339
+ class Dinov2Attention(nn.Module):
340
+ def __init__(self, config: Dinov2Config) -> None:
341
+ super().__init__()
342
+ self.attention = Dinov2SelfAttention(config)
343
+ self.output = Dinov2SelfOutput(config)
344
+ self.pruned_heads = set()
345
+
346
+ def prune_heads(self, heads: Set[int]) -> None:
347
+ if len(heads) == 0:
348
+ return
349
+ heads, index = find_pruneable_heads_and_indices(
350
+ heads,
351
+ self.attention.num_attention_heads,
352
+ self.attention.attention_head_size,
353
+ self.pruned_heads,
354
+ )
355
+
356
+ # Prune linear layers
357
+ self.attention.query = prune_linear_layer(self.attention.query, index)
358
+ self.attention.key = prune_linear_layer(self.attention.key, index)
359
+ self.attention.value = prune_linear_layer(self.attention.value, index)
360
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
361
+
362
+ # Update hyper params and store pruned heads
363
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
364
+ heads
365
+ )
366
+ self.attention.all_head_size = (
367
+ self.attention.attention_head_size * self.attention.num_attention_heads
368
+ )
369
+ self.pruned_heads = self.pruned_heads.union(heads)
370
+
371
+ def forward(
372
+ self,
373
+ hidden_states: torch.Tensor,
374
+ head_mask: Optional[torch.Tensor] = None,
375
+ output_attentions: bool = False,
376
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
377
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
378
+
379
+ attention_output = self.output(self_outputs[0], hidden_states)
380
+
381
+ outputs = (attention_output,) + self_outputs[
382
+ 1:
383
+ ] # add attentions if we output them
384
+ return outputs
385
+
386
+
387
+ class Dinov2LayerScale(nn.Module):
388
+ def __init__(self, config) -> None:
389
+ super().__init__()
390
+ self.lambda1 = nn.Parameter(
391
+ config.layerscale_value * torch.ones(config.hidden_size)
392
+ )
393
+
394
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
395
+ return hidden_state * self.lambda1
396
+
397
+
398
+ # Copied from transformers.models.beit.modeling_beit.drop_path
399
+ def drop_path(
400
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
401
+ ) -> torch.Tensor:
402
+ """
403
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
404
+
405
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
406
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
407
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
408
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
409
+ argument.
410
+ """
411
+ if drop_prob == 0.0 or not training:
412
+ return input
413
+ keep_prob = 1 - drop_prob
414
+ shape = (input.shape[0],) + (1,) * (
415
+ input.ndim - 1
416
+ ) # work with diff dim tensors, not just 2D ConvNets
417
+ random_tensor = keep_prob + torch.rand(
418
+ shape, dtype=input.dtype, device=input.device
419
+ )
420
+ random_tensor.floor_() # binarize
421
+ output = input.div(keep_prob) * random_tensor
422
+ return output
423
+
424
+
425
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
426
+ class Dinov2DropPath(nn.Module):
427
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
428
+
429
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
430
+ super().__init__()
431
+ self.drop_prob = drop_prob
432
+
433
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
434
+ return drop_path(hidden_states, self.drop_prob, self.training)
435
+
436
+ def extra_repr(self) -> str:
437
+ return "p={}".format(self.drop_prob)
438
+
439
+
440
+ class Dinov2MLP(nn.Module):
441
+ def __init__(self, config) -> None:
442
+ super().__init__()
443
+ in_features = out_features = config.hidden_size
444
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
445
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
446
+ if isinstance(config.hidden_act, str):
447
+ self.activation = ACT2FN[config.hidden_act]
448
+ else:
449
+ self.activation = config.hidden_act
450
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
451
+
452
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
453
+ hidden_state = self.fc1(hidden_state)
454
+ hidden_state = self.activation(hidden_state)
455
+ hidden_state = self.fc2(hidden_state)
456
+ return hidden_state
457
+
458
+
459
+ class Dinov2SwiGLUFFN(nn.Module):
460
+ def __init__(self, config) -> None:
461
+ super().__init__()
462
+ in_features = out_features = config.hidden_size
463
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
464
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
465
+
466
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
467
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
468
+
469
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
470
+ hidden_state = self.weights_in(hidden_state)
471
+ x1, x2 = hidden_state.chunk(2, dim=-1)
472
+ hidden = nn.functional.silu(x1) * x2
473
+ return self.weights_out(hidden)
474
+
475
+
476
+ class Dinov2Layer(nn.Module, MemoryEfficientAttentionMixin):
477
+ """This corresponds to the Block class in the original implementation."""
478
+
479
+ def __init__(self, config: Dinov2Config) -> None:
480
+ super().__init__()
481
+
482
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
483
+ self.norm1_modulation = None
484
+ self.attention = Dinov2Attention(config)
485
+ self.layer_scale1 = Dinov2LayerScale(config)
486
+ self.drop_path1 = (
487
+ Dinov2DropPath(config.drop_path_rate)
488
+ if config.drop_path_rate > 0.0
489
+ else nn.Identity()
490
+ )
491
+
492
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
493
+ self.norm2_modulation = None
494
+
495
+ if config.use_swiglu_ffn:
496
+ self.mlp = Dinov2SwiGLUFFN(config)
497
+ else:
498
+ self.mlp = Dinov2MLP(config)
499
+ self.layer_scale2 = Dinov2LayerScale(config)
500
+ self.drop_path2 = (
501
+ Dinov2DropPath(config.drop_path_rate)
502
+ if config.drop_path_rate > 0.0
503
+ else nn.Identity()
504
+ )
505
+
506
+ def forward(
507
+ self,
508
+ hidden_states: torch.Tensor,
509
+ head_mask: Optional[torch.Tensor] = None,
510
+ modulation_cond: Optional[torch.Tensor] = None,
511
+ output_attentions: bool = False,
512
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
513
+ hidden_states_norm = self.norm1(hidden_states)
514
+ if self.norm1_modulation is not None:
515
+ assert modulation_cond is not None
516
+ hidden_states_norm = self.norm1_modulation(
517
+ hidden_states_norm, modulation_cond
518
+ )
519
+ self_attention_outputs = self.attention(
520
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
521
+ head_mask,
522
+ output_attentions=output_attentions,
523
+ )
524
+ attention_output = self_attention_outputs[0]
525
+
526
+ attention_output = self.layer_scale1(attention_output)
527
+ outputs = self_attention_outputs[
528
+ 1:
529
+ ] # add self attentions if we output attention weights
530
+
531
+ # first residual connection
532
+ hidden_states = attention_output + hidden_states
533
+
534
+ # in Dinov2, layernorm is also applied after self-attention
535
+ layer_output = self.norm2(hidden_states)
536
+ if self.norm2_modulation is not None:
537
+ assert modulation_cond is not None
538
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
539
+ layer_output = self.mlp(layer_output)
540
+ layer_output = self.layer_scale2(layer_output)
541
+
542
+ # second residual connection
543
+ layer_output = layer_output + hidden_states
544
+
545
+ outputs = (layer_output,) + outputs
546
+
547
+ return outputs
548
+
549
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
550
+ self.norm1_modulation = norm1_mod
551
+ self.norm2_modulation = norm2_mod
552
+
553
+
554
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
555
+ class Dinov2Encoder(nn.Module, MemoryEfficientAttentionMixin):
556
+ def __init__(self, config: Dinov2Config) -> None:
557
+ super().__init__()
558
+ self.config = config
559
+ self.layer = nn.ModuleList(
560
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
561
+ )
562
+ self.gradient_checkpointing = False
563
+
564
+ def forward(
565
+ self,
566
+ hidden_states: torch.Tensor,
567
+ head_mask: Optional[torch.Tensor] = None,
568
+ modulation_cond: Optional[torch.Tensor] = None,
569
+ output_attentions: bool = False,
570
+ output_hidden_states: bool = False,
571
+ return_dict: bool = True,
572
+ ) -> Union[tuple, BaseModelOutput]:
573
+ all_hidden_states = () if output_hidden_states else None
574
+ all_self_attentions = () if output_attentions else None
575
+
576
+ for i, layer_module in enumerate(self.layer):
577
+ if output_hidden_states:
578
+ all_hidden_states = all_hidden_states + (hidden_states,)
579
+
580
+ layer_head_mask = head_mask[i] if head_mask is not None else None
581
+
582
+ if self.gradient_checkpointing and self.training:
583
+
584
+ def create_custom_forward(module):
585
+ def custom_forward(*inputs):
586
+ return module(*inputs, output_attentions)
587
+
588
+ return custom_forward
589
+
590
+ layer_outputs = torch.utils.checkpoint.checkpoint(
591
+ create_custom_forward(layer_module),
592
+ hidden_states,
593
+ layer_head_mask,
594
+ modulation_cond,
595
+ )
596
+ else:
597
+ layer_outputs = layer_module(
598
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
599
+ )
600
+
601
+ hidden_states = layer_outputs[0]
602
+
603
+ if output_attentions:
604
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
605
+
606
+ if output_hidden_states:
607
+ all_hidden_states = all_hidden_states + (hidden_states,)
608
+
609
+ if not return_dict:
610
+ return tuple(
611
+ v
612
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
613
+ if v is not None
614
+ )
615
+ return BaseModelOutput(
616
+ last_hidden_state=hidden_states,
617
+ hidden_states=all_hidden_states,
618
+ attentions=all_self_attentions,
619
+ )
620
+
621
+
622
+ class Dinov2PreTrainedModel(PreTrainedModel):
623
+ """
624
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
625
+ models.
626
+ """
627
+
628
+ config_class = Dinov2Config
629
+ base_model_prefix = "dinov2"
630
+ main_input_name = "pixel_values"
631
+ supports_gradient_checkpointing = True
632
+
633
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
634
+ """Initialize the weights"""
635
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
636
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
637
+ # `trunc_normal_cpu` not implemented in `half` issues
638
+ module.weight.data = nn.init.trunc_normal_(
639
+ module.weight.data.to(torch.float32),
640
+ mean=0.0,
641
+ std=self.config.initializer_range,
642
+ ).to(module.weight.dtype)
643
+ if module.bias is not None:
644
+ module.bias.data.zero_()
645
+ elif isinstance(module, nn.LayerNorm):
646
+ module.bias.data.zero_()
647
+ module.weight.data.fill_(1.0)
648
+ elif isinstance(module, Dinov2Embeddings):
649
+ module.position_embeddings.data = nn.init.trunc_normal_(
650
+ module.position_embeddings.data.to(torch.float32),
651
+ mean=0.0,
652
+ std=self.config.initializer_range,
653
+ ).to(module.position_embeddings.dtype)
654
+
655
+ module.cls_token.data = nn.init.trunc_normal_(
656
+ module.cls_token.data.to(torch.float32),
657
+ mean=0.0,
658
+ std=self.config.initializer_range,
659
+ ).to(module.cls_token.dtype)
660
+
661
+ def _set_gradient_checkpointing(
662
+ self, module: Dinov2Encoder, value: bool = False
663
+ ) -> None:
664
+ if isinstance(module, Dinov2Encoder):
665
+ module.gradient_checkpointing = value
666
+
667
+
668
+ DINOV2_START_DOCSTRING = r"""
669
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
670
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
671
+ behavior.
672
+
673
+ Parameters:
674
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
675
+ Initializing with a config file does not load the weights associated with the model, only the
676
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
677
+ """
678
+
679
+ DINOV2_BASE_INPUTS_DOCSTRING = r"""
680
+ Args:
681
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
682
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
683
+ [`BitImageProcessor.preprocess`] for details.
684
+
685
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
686
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
687
+ pre-training.
688
+
689
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
690
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
691
+
692
+ - 1 indicates the head is **not masked**,
693
+ - 0 indicates the head is **masked**.
694
+
695
+ output_attentions (`bool`, *optional*):
696
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
697
+ tensors for more detail.
698
+ output_hidden_states (`bool`, *optional*):
699
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
700
+ more detail.
701
+ return_dict (`bool`, *optional*):
702
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
703
+ """
704
+
705
+ DINOV2_INPUTS_DOCSTRING = r"""
706
+ Args:
707
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
708
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
709
+ [`BitImageProcessor.preprocess`] for details.
710
+
711
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
712
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
713
+
714
+ - 1 indicates the head is **not masked**,
715
+ - 0 indicates the head is **masked**.
716
+
717
+ output_attentions (`bool`, *optional*):
718
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
719
+ tensors for more detail.
720
+ output_hidden_states (`bool`, *optional*):
721
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
722
+ more detail.
723
+ return_dict (`bool`, *optional*):
724
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
725
+ """
726
+
727
+ @dataclass
728
+ class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
729
+ patch_embeddings: Optional[torch.FloatTensor] = None
730
+
731
+
732
+ @add_start_docstrings(
733
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
734
+ DINOV2_START_DOCSTRING,
735
+ )
736
+ class Dinov2Model(Dinov2PreTrainedModel, MemoryEfficientAttentionMixin):
737
+ def __init__(self, config: Dinov2Config):
738
+ super().__init__(config)
739
+ self.config = config
740
+
741
+ self.embeddings = Dinov2Embeddings(config)
742
+ self.encoder = Dinov2Encoder(config)
743
+
744
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
745
+
746
+ # Initialize weights and apply final processing
747
+ self.post_init()
748
+
749
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
750
+ return self.embeddings.patch_embeddings
751
+
752
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
753
+ """
754
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
755
+ class PreTrainedModel
756
+ """
757
+ for layer, heads in heads_to_prune.items():
758
+ self.encoder.layer[layer].attention.prune_heads(heads)
759
+
760
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
761
+ @add_code_sample_docstrings(
762
+ checkpoint=_CHECKPOINT_FOR_DOC,
763
+ output_type=BaseModelOutputWithPooling,
764
+ config_class=_CONFIG_FOR_DOC,
765
+ modality="vision",
766
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
767
+ )
768
+ def forward(
769
+ self,
770
+ pixel_values: Optional[torch.Tensor] = None,
771
+ bool_masked_pos: Optional[torch.Tensor] = None,
772
+ head_mask: Optional[torch.Tensor] = None,
773
+ modulation_cond: Optional[torch.Tensor] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ output_hidden_states: Optional[bool] = None,
776
+ return_dict: Optional[bool] = None,
777
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
778
+ output_attentions = (
779
+ output_attentions
780
+ if output_attentions is not None
781
+ else self.config.output_attentions
782
+ )
783
+ output_hidden_states = (
784
+ output_hidden_states
785
+ if output_hidden_states is not None
786
+ else self.config.output_hidden_states
787
+ )
788
+ return_dict = (
789
+ return_dict if return_dict is not None else self.config.use_return_dict
790
+ )
791
+
792
+ if pixel_values is None:
793
+ raise ValueError("You have to specify pixel_values")
794
+
795
+ # Prepare head mask if needed
796
+ # 1.0 in head_mask indicate we keep the head
797
+ # attention_probs has shape bsz x n_heads x N x N
798
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
799
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
800
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
801
+
802
+ embedding_output = self.embeddings(
803
+ pixel_values, bool_masked_pos=bool_masked_pos
804
+ )
805
+
806
+ encoder_outputs = self.encoder(
807
+ embedding_output,
808
+ head_mask=head_mask,
809
+ modulation_cond=modulation_cond,
810
+ output_attentions=output_attentions,
811
+ output_hidden_states=output_hidden_states,
812
+ return_dict=return_dict,
813
+ )
814
+ sequence_output = encoder_outputs[0]
815
+ sequence_output = self.layernorm(sequence_output)
816
+ pooled_output = sequence_output[:, 0, :]
817
+
818
+ if not return_dict:
819
+ head_outputs = (sequence_output, pooled_output)
820
+ return head_outputs + encoder_outputs[1:]
821
+
822
+ return CustomBaseModelOutputWithPooling(
823
+ last_hidden_state=sequence_output,
824
+ pooler_output=pooled_output,
825
+ hidden_states=encoder_outputs.hidden_states,
826
+ attentions=encoder_outputs.attentions,
827
+ patch_embeddings=embedding_output
828
+ )
829
+
830
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
831
+ self._set_gradient_checkpointing(self.encoder, value)
832
+
833
+
834
+ @add_start_docstrings(
835
+ """
836
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
837
+ of the [CLS] token) e.g. for ImageNet.
838
+ """,
839
+ DINOV2_START_DOCSTRING,
840
+ )
841
+ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
842
+ def __init__(self, config: Dinov2Config) -> None:
843
+ super().__init__(config)
844
+
845
+ self.num_labels = config.num_labels
846
+ self.dinov2 = Dinov2Model(config)
847
+
848
+ # Classifier head
849
+ self.classifier = (
850
+ nn.Linear(config.hidden_size * 2, config.num_labels)
851
+ if config.num_labels > 0
852
+ else nn.Identity()
853
+ )
854
+
855
+ # Initialize weights and apply final processing
856
+ self.post_init()
857
+
858
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
859
+ @add_code_sample_docstrings(
860
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
861
+ output_type=ImageClassifierOutput,
862
+ config_class=_CONFIG_FOR_DOC,
863
+ )
864
+ def forward(
865
+ self,
866
+ pixel_values: Optional[torch.Tensor] = None,
867
+ head_mask: Optional[torch.Tensor] = None,
868
+ labels: Optional[torch.Tensor] = None,
869
+ output_attentions: Optional[bool] = None,
870
+ output_hidden_states: Optional[bool] = None,
871
+ return_dict: Optional[bool] = None,
872
+ ) -> Union[tuple, ImageClassifierOutput]:
873
+ r"""
874
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
875
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
876
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
877
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
878
+ """
879
+ return_dict = (
880
+ return_dict if return_dict is not None else self.config.use_return_dict
881
+ )
882
+
883
+ outputs = self.dinov2(
884
+ pixel_values,
885
+ head_mask=head_mask,
886
+ output_attentions=output_attentions,
887
+ output_hidden_states=output_hidden_states,
888
+ return_dict=return_dict,
889
+ )
890
+
891
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
892
+
893
+ cls_token = sequence_output[:, 0]
894
+ patch_tokens = sequence_output[:, 1:]
895
+
896
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
897
+
898
+ logits = self.classifier(linear_input)
899
+
900
+ loss = None
901
+ if labels is not None:
902
+ # move labels to correct device to enable model parallelism
903
+ labels = labels.to(logits.device)
904
+ if self.config.problem_type is None:
905
+ if self.num_labels == 1:
906
+ self.config.problem_type = "regression"
907
+ elif self.num_labels > 1 and (
908
+ labels.dtype == torch.long or labels.dtype == torch.int
909
+ ):
910
+ self.config.problem_type = "single_label_classification"
911
+ else:
912
+ self.config.problem_type = "multi_label_classification"
913
+
914
+ if self.config.problem_type == "regression":
915
+ loss_fct = MSELoss()
916
+ if self.num_labels == 1:
917
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
918
+ else:
919
+ loss = loss_fct(logits, labels)
920
+ elif self.config.problem_type == "single_label_classification":
921
+ loss_fct = CrossEntropyLoss()
922
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
923
+ elif self.config.problem_type == "multi_label_classification":
924
+ loss_fct = BCEWithLogitsLoss()
925
+ loss = loss_fct(logits, labels)
926
+
927
+ if not return_dict:
928
+ output = (logits,) + outputs[2:]
929
+ return ((loss,) + output) if loss is not None else output
930
+
931
+ return ImageClassifierOutput(
932
+ loss=loss,
933
+ logits=logits,
934
+ hidden_states=outputs.hidden_states,
935
+ attentions=outputs.attentions,
936
+ )
937
+
938
+
939
+ @add_start_docstrings(
940
+ """
941
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
942
+ """,
943
+ DINOV2_START_DOCSTRING,
944
+ )
945
+ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
946
+ def __init__(self, config):
947
+ super().__init__(config)
948
+ super()._init_backbone(config)
949
+
950
+ self.num_features = [
951
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
952
+ ]
953
+ self.embeddings = Dinov2Embeddings(config)
954
+ self.encoder = Dinov2Encoder(config)
955
+
956
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
957
+
958
+ # Initialize weights and apply final processing
959
+ self.post_init()
960
+
961
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
962
+ return self.embeddings.patch_embeddings
963
+
964
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
965
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
966
+ def forward(
967
+ self,
968
+ pixel_values: torch.Tensor,
969
+ output_hidden_states: Optional[bool] = None,
970
+ output_attentions: Optional[bool] = None,
971
+ return_dict: Optional[bool] = None,
972
+ ) -> BackboneOutput:
973
+ """
974
+ Returns:
975
+
976
+ Examples:
977
+
978
+ ```python
979
+ >>> from transformers import AutoImageProcessor, AutoBackbone
980
+ >>> import torch
981
+ >>> from PIL import Image
982
+ >>> import requests
983
+
984
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
985
+ >>> image = Image.open(requests.get(url, stream=True).raw)
986
+
987
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
988
+ >>> model = AutoBackbone.from_pretrained(
989
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
990
+ ... )
991
+
992
+ >>> inputs = processor(image, return_tensors="pt")
993
+
994
+ >>> outputs = model(**inputs)
995
+ >>> feature_maps = outputs.feature_maps
996
+ >>> list(feature_maps[-1].shape)
997
+ [1, 768, 16, 16]
998
+ ```"""
999
+ return_dict = (
1000
+ return_dict if return_dict is not None else self.config.use_return_dict
1001
+ )
1002
+ output_hidden_states = (
1003
+ output_hidden_states
1004
+ if output_hidden_states is not None
1005
+ else self.config.output_hidden_states
1006
+ )
1007
+ output_attentions = (
1008
+ output_attentions
1009
+ if output_attentions is not None
1010
+ else self.config.output_attentions
1011
+ )
1012
+
1013
+ embedding_output = self.embeddings(pixel_values)
1014
+
1015
+ outputs = self.encoder(
1016
+ embedding_output,
1017
+ output_hidden_states=True,
1018
+ output_attentions=output_attentions,
1019
+ return_dict=return_dict,
1020
+ )
1021
+
1022
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1023
+
1024
+ feature_maps = ()
1025
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1026
+ if stage in self.out_features:
1027
+ if self.config.apply_layernorm:
1028
+ hidden_state = self.layernorm(hidden_state)
1029
+ if self.config.reshape_hidden_states:
1030
+ batch_size, _, height, width = pixel_values.shape
1031
+ patch_size = self.config.patch_size
1032
+ hidden_state = hidden_state[:, 1:, :].reshape(
1033
+ batch_size, width // patch_size, height // patch_size, -1
1034
+ )
1035
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1036
+ feature_maps += (hidden_state,)
1037
+
1038
+ if not return_dict:
1039
+ if output_hidden_states:
1040
+ output = (feature_maps,) + outputs[1:]
1041
+ else:
1042
+ output = (feature_maps,) + outputs[2:]
1043
+ return output
1044
+
1045
+ return BackboneOutput(
1046
+ feature_maps=feature_maps,
1047
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1048
+ attentions=outputs.attentions if output_attentions else None,
1049
+ )
1050
+
1051
+
1052
+
1053
+ class CustomPatchEmbeddings(nn.Module):
1054
+ """
1055
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
1056
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
1057
+ Transformer.
1058
+ """
1059
+
1060
+ def __init__(self, image_size: int, patch_size: int, num_channels: int, hidden_size: int):
1061
+ super().__init__()
1062
+
1063
+ image_size = (
1064
+ image_size
1065
+ if isinstance(image_size, collections.abc.Iterable)
1066
+ else (image_size, image_size)
1067
+ )
1068
+ patch_size = (
1069
+ patch_size
1070
+ if isinstance(patch_size, collections.abc.Iterable)
1071
+ else (patch_size, patch_size)
1072
+ )
1073
+ num_patches = (image_size[1] // patch_size[1]) * (
1074
+ image_size[0] // patch_size[0]
1075
+ )
1076
+ self.image_size = image_size
1077
+ self.patch_size = patch_size
1078
+ self.num_channels = num_channels
1079
+ self.num_patches = num_patches
1080
+
1081
+ self.projection = nn.Conv2d(
1082
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
1083
+ )
1084
+
1085
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1086
+ num_channels = pixel_values.shape[1]
1087
+ if num_channels != self.num_channels:
1088
+ raise ValueError(
1089
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
1090
+ f" Expected {self.num_channels} but got {num_channels}."
1091
+ )
1092
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
1093
+ return embeddings
1094
+
1095
+
1096
+ class CustomEmbeddings(nn.Module):
1097
+ """
1098
+ Construct the CLS token, mask token, position and patch embeddings.
1099
+ """
1100
+
1101
+ def __init__(self, image_size: int, patch_size: int, num_channels: int, hidden_size: int) -> None:
1102
+ super().__init__()
1103
+
1104
+ self.image_size = image_size
1105
+ self.patch_size = patch_size
1106
+ self.num_channels = num_channels
1107
+ self.hidden_size = hidden_size
1108
+
1109
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
1110
+
1111
+ self.patch_embeddings = CustomPatchEmbeddings(image_size, patch_size, num_channels, hidden_size)
1112
+ num_patches = self.patch_embeddings.num_patches
1113
+ self.position_embeddings = nn.Parameter(
1114
+ torch.randn(1, num_patches + 1, self.hidden_size)
1115
+ )
1116
+
1117
+ def interpolate_pos_encoding(
1118
+ self, embeddings: torch.Tensor, height: int, width: int
1119
+ ) -> torch.Tensor:
1120
+ """
1121
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
1122
+ resolution images.
1123
+
1124
+ Source:
1125
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1126
+ """
1127
+
1128
+ num_patches = embeddings.shape[1] - 1
1129
+ num_positions = self.position_embeddings.shape[1] - 1
1130
+ if num_patches == num_positions and height == width:
1131
+ return self.position_embeddings
1132
+ class_pos_embed = self.position_embeddings[:, 0]
1133
+ patch_pos_embed = self.position_embeddings[:, 1:]
1134
+ dim = embeddings.shape[-1]
1135
+ height = height // self.patch_size
1136
+ width = width // self.patch_size
1137
+ # we add a small number to avoid floating point error in the interpolation
1138
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
1139
+ height, width = height + 0.1, width + 0.1
1140
+ patch_pos_embed = patch_pos_embed.reshape(
1141
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
1142
+ )
1143
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
1144
+ patch_pos_embed = nn.functional.interpolate(
1145
+ patch_pos_embed,
1146
+ scale_factor=(
1147
+ height / math.sqrt(num_positions),
1148
+ width / math.sqrt(num_positions),
1149
+ ),
1150
+ mode="bicubic",
1151
+ align_corners=False,
1152
+ )
1153
+ if (
1154
+ int(height) != patch_pos_embed.shape[-2]
1155
+ or int(width) != patch_pos_embed.shape[-1]
1156
+ ):
1157
+ raise ValueError(
1158
+ "Width or height does not match with the interpolated position embeddings"
1159
+ )
1160
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1161
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
1162
+
1163
+ def forward(
1164
+ self, pixel_values: torch.Tensor,
1165
+ ) -> torch.Tensor:
1166
+ batch_size, _, height, width = pixel_values.shape
1167
+ patch_embeddings = self.patch_embeddings(pixel_values)
1168
+ embeddings = patch_embeddings
1169
+
1170
+ # add the [CLS] token to the embedded patch tokens
1171
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1172
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
1173
+
1174
+ # add positional encoding to each token
1175
+ embeddings = embeddings + self.interpolate_pos_encoding(
1176
+ embeddings, height, width
1177
+ )
1178
+
1179
+ return embeddings
hort/models/tgs/models/tokenizers/image.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+
7
+ from tgs.utils.base import BaseModule
8
+ from tgs.models.tokenizers.dinov2 import Dinov2Model
9
+ from tgs.models.transformers import Modulation
10
+ from tgs.utils.typing import *
11
+
12
+ class DINOV2SingleImageTokenizer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ pretrained_model_name_or_path: str = "facebook/dinov2-base"
16
+ width: int = 224
17
+ height: int = 224
18
+ modulation: bool = False
19
+ modulation_zero_init: bool = False
20
+ modulation_single_layer: bool = False
21
+ modulation_cond_dim: int = 16
22
+ freeze_backbone_params: bool = True
23
+ enable_memory_efficient_attention: bool = False
24
+ enable_gradient_checkpointing: bool = False
25
+ use_patch_embeddings: bool = False
26
+ patch_embeddings_aggr_method: str = 'concat'
27
+
28
+ cfg: Config
29
+
30
+ def configure(self) -> None:
31
+ super().configure()
32
+ model: Dinov2Model
33
+
34
+ if self.cfg.freeze_backbone_params:
35
+ # freeze dino backbone parameters
36
+ self.register_non_module(
37
+ "model",
38
+ Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path).to(
39
+ self.device
40
+ ),
41
+ )
42
+
43
+ model = self.non_module("model")
44
+ for p in model.parameters():
45
+ p.requires_grad_(False)
46
+ model.eval()
47
+ else:
48
+ self.model = Dinov2Model.from_pretrained(
49
+ self.cfg.pretrained_model_name_or_path
50
+ ).to(self.device)
51
+ model = self.model
52
+
53
+ model.set_use_memory_efficient_attention_xformers(
54
+ self.cfg.enable_memory_efficient_attention
55
+ )
56
+ model.set_gradient_checkpointing(self.cfg.enable_gradient_checkpointing)
57
+
58
+ # add modulation
59
+ if self.cfg.modulation:
60
+ modulations = []
61
+ for layer in model.encoder.layer:
62
+ norm1_modulation = Modulation(
63
+ model.config.hidden_size,
64
+ self.cfg.modulation_cond_dim,
65
+ zero_init=self.cfg.modulation_zero_init,
66
+ single_layer=self.cfg.modulation_single_layer,
67
+ )
68
+ norm2_modulation = Modulation(
69
+ model.config.hidden_size,
70
+ self.cfg.modulation_cond_dim,
71
+ zero_init=self.cfg.modulation_zero_init,
72
+ single_layer=self.cfg.modulation_single_layer,
73
+ )
74
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
75
+ modulations += [norm1_modulation, norm2_modulation]
76
+ self.modulations = nn.ModuleList(modulations)
77
+
78
+ def forward(
79
+ self,
80
+ images: Float[Tensor, "B *N C H W"],
81
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
82
+ ) -> Float[Tensor, "B *N Ct Nt"]:
83
+ model: Dinov2Model
84
+ if self.cfg.freeze_backbone_params:
85
+ model = self.non_module("model")
86
+ else:
87
+ model = self.model
88
+
89
+ packed = False
90
+ if images.ndim == 4:
91
+ packed = True
92
+ images = images.unsqueeze(1)
93
+ if modulation_cond is not None:
94
+ assert modulation_cond.ndim == 2
95
+ modulation_cond = modulation_cond.unsqueeze(1)
96
+
97
+ batch_size, n_input_views = images.shape[:2]
98
+ out = model(
99
+ rearrange(images, "B N C H W -> (B N) C H W"),
100
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
101
+ if modulation_cond is not None
102
+ else None,
103
+ )
104
+ local_features, global_features = out.last_hidden_state, out.pooler_output
105
+ if self.cfg.use_patch_embeddings:
106
+ patch_embeddings = out.patch_embeddings
107
+ if self.cfg.patch_embeddings_aggr_method == 'concat':
108
+ local_features = torch.cat([local_features, patch_embeddings], dim=1)
109
+ elif self.cfg.patch_embeddings_aggr_method == 'add':
110
+ local_features = local_features + patch_embeddings
111
+ else:
112
+ raise NotImplementedError
113
+ local_features = local_features.permute(0, 2, 1)
114
+ local_features = rearrange(
115
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
116
+ )
117
+ if packed:
118
+ local_features = local_features.squeeze(1)
119
+
120
+ return local_features
121
+
122
+ def detokenize(self, *args, **kwargs):
123
+ raise NotImplementedError