Spaces:
Running
on
Zero
Running
on
Zero
init test without models
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +139 -0
- app.py +202 -0
- assets/test1.png +3 -0
- assets/test2.png +3 -0
- assets/test3.jpg +0 -0
- assets/test4.jpeg +0 -0
- assets/test5.jpeg +0 -0
- hort/models/__init__.py +114 -0
- hort/models/network/pointnet.py +36 -0
- hort/models/tgs/__init__.py +9 -0
- hort/models/tgs/data.py +265 -0
- hort/models/tgs/models/__init__.py +0 -0
- hort/models/tgs/models/image_feature.py +48 -0
- hort/models/tgs/models/networks.py +204 -0
- hort/models/tgs/models/pointclouds/LICENSE_POINTNET +21 -0
- hort/models/tgs/models/pointclouds/pointnet.py +121 -0
- hort/models/tgs/models/pointclouds/simplepoint.py +110 -0
- hort/models/tgs/models/renderer.py +427 -0
- hort/models/tgs/models/snowflake/LICENSE +21 -0
- hort/models/tgs/models/snowflake/SPD.py +68 -0
- hort/models/tgs/models/snowflake/SPD_crossattn.py +81 -0
- hort/models/tgs/models/snowflake/SPD_pp.py +71 -0
- hort/models/tgs/models/snowflake/attention.py +239 -0
- hort/models/tgs/models/snowflake/model_spdpp.py +239 -0
- hort/models/tgs/models/snowflake/pointnet2.py +126 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py +3 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h +5 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h +41 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h +5 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h +10 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h +6 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h +25 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp +32 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu +54 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp +19 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp +62 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu +75 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp +99 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu +154 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp +87 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu +229 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py +1 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py +209 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py +391 -0
- hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py +41 -0
- hort/models/tgs/models/snowflake/skip_transformer.py +69 -0
- hort/models/tgs/models/snowflake/utils.py +741 -0
- hort/models/tgs/models/tokenizers/dinov2.py +1179 -0
- hort/models/tgs/models/tokenizers/image.py +123 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/test1.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/test2.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# pyc
|
7 |
+
*.pyc
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
pip-wheel-metadata/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
.python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
98 |
+
__pypackages__/
|
99 |
+
|
100 |
+
# Celery stuff
|
101 |
+
celerybeat-schedule
|
102 |
+
celerybeat.pid
|
103 |
+
|
104 |
+
# SageMath parsed files
|
105 |
+
*.sage.py
|
106 |
+
|
107 |
+
# Environments
|
108 |
+
.env
|
109 |
+
.venv
|
110 |
+
env/
|
111 |
+
venv/
|
112 |
+
ENV/
|
113 |
+
env.bak/
|
114 |
+
venv.bak/
|
115 |
+
|
116 |
+
# Spyder project settings
|
117 |
+
.spyderproject
|
118 |
+
.spyproject
|
119 |
+
|
120 |
+
# Rope project settings
|
121 |
+
.ropeproject
|
122 |
+
|
123 |
+
# mkdocs documentation
|
124 |
+
/site
|
125 |
+
|
126 |
+
# mypy
|
127 |
+
.mypy_cache/
|
128 |
+
.dmypy.json
|
129 |
+
dmypy.json
|
130 |
+
|
131 |
+
# Pyre type checker
|
132 |
+
.pyre/
|
133 |
+
|
134 |
+
# VSCode
|
135 |
+
.vscode
|
136 |
+
|
137 |
+
*.swp
|
138 |
+
*.h5
|
139 |
+
*.mp4
|
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
4 |
+
os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
#import spaces
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from ultralytics import YOLO
|
12 |
+
from pathlib import Path
|
13 |
+
import argparse
|
14 |
+
import json
|
15 |
+
from torchvision import transforms
|
16 |
+
from typing import Dict, Optional
|
17 |
+
from PIL import Image, ImageDraw
|
18 |
+
from lang_sam import LangSAM
|
19 |
+
|
20 |
+
from wilor.models import load_wilor
|
21 |
+
from wilor.utils import recursive_to
|
22 |
+
from wilor.datasets.vitdet_dataset import ViTDetDataset
|
23 |
+
from hort.models import load_hort
|
24 |
+
from hort.utils.renderer import Renderer, cam_crop_to_new
|
25 |
+
from hort.utils.img_utils import process_bbox, generate_patch_image, PerspectiveCamera
|
26 |
+
from ultralytics import YOLO
|
27 |
+
LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
|
28 |
+
STEEL_BLUE=(0.2745098, 0.5098039, 0.7058824)
|
29 |
+
|
30 |
+
# Download and load checkpoints
|
31 |
+
wilor_model, wilor_model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
|
32 |
+
hand_detector = YOLO('./pretrained_models/detector.pt')
|
33 |
+
# Setup the renderer
|
34 |
+
renderer = Renderer(wilor_model_cfg, faces=wilor_model.mano.faces)
|
35 |
+
# Setup the SAM model
|
36 |
+
sam_model = LangSAM(sam_type="sam2.1_hiera_large")
|
37 |
+
# Setup the HORT model
|
38 |
+
hort_model = load_hort("./pretrained_models/hort_final.pth.tar")
|
39 |
+
|
40 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
41 |
+
wilor_model = wilor_model.to(device)
|
42 |
+
hand_detector = hand_detector.to(device)
|
43 |
+
hort_model = hort_model.to(device)
|
44 |
+
wilor_model.eval()
|
45 |
+
hort_model.eval()
|
46 |
+
|
47 |
+
image_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
48 |
+
|
49 |
+
@spaces.GPU()
|
50 |
+
def run_model(image, conf, IoU_threshold=0.5):
|
51 |
+
img_cv2 = image[..., ::-1]
|
52 |
+
img_pil = Image.fromarray(image)
|
53 |
+
|
54 |
+
pred_obj = sam_model.predict([img_pil], ["manipulated object"])
|
55 |
+
pred_hand = sam_model.predict([img_pil], ["hand"])
|
56 |
+
|
57 |
+
bbox_obj = pred_obj[0]["boxes"][0].reshape((-1, 2))
|
58 |
+
mask_obj = pred_obj[0]["masks"][0]
|
59 |
+
bbox_hand = pred_hand[0]["boxes"][0].reshape((-1, 2))
|
60 |
+
mask_hand = pred_hand[0]["masks"][0]
|
61 |
+
|
62 |
+
tl = np.min(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
|
63 |
+
br = np.max(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
|
64 |
+
box_size = br - tl
|
65 |
+
bbox = np.concatenate([tl - 10, box_size + 20], axis=0)
|
66 |
+
ho_bbox = process_bbox(bbox)
|
67 |
+
|
68 |
+
detections = hand_detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0]
|
69 |
+
|
70 |
+
bboxes = []
|
71 |
+
is_right = []
|
72 |
+
for det in detections:
|
73 |
+
Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
|
74 |
+
is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
|
75 |
+
bboxes.append(Bbox[:4].tolist())
|
76 |
+
|
77 |
+
if len(bboxes) == 1:
|
78 |
+
boxes = np.stack(bboxes)
|
79 |
+
right = np.stack(is_right)
|
80 |
+
if not right:
|
81 |
+
new_x1 = img_cv2.shape[1] - boxes[0][2]
|
82 |
+
new_x2 = img_cv2.shape[1] - boxes[0][0]
|
83 |
+
boxes[0][0] = new_x1
|
84 |
+
boxes[0][2] = new_x2
|
85 |
+
ho_bbox[0] = img_cv2.shape[1] - (ho_bbox[0] + ho_bbox[2])
|
86 |
+
img_cv2 = cv2.flip(img_cv2, 1)
|
87 |
+
right[0] = 1.
|
88 |
+
crop_img_cv2, _ = generate_patch_image(img_cv2, ho_bbox, (224, 224), 0, 1.0, 0)
|
89 |
+
|
90 |
+
dataset = ViTDetDataset(wilor_model_cfg, img_cv2, boxes, right, rescale_factor=2.0)
|
91 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
|
92 |
+
|
93 |
+
for batch in dataloader:
|
94 |
+
batch = recursive_to(batch, device)
|
95 |
+
|
96 |
+
with torch.no_grad():
|
97 |
+
out = wilor_model(batch)
|
98 |
+
|
99 |
+
pred_cam = out['pred_cam']
|
100 |
+
box_center = batch["box_center"].float()
|
101 |
+
box_size = batch["box_size"].float()
|
102 |
+
img_size = batch["img_size"].float()
|
103 |
+
scaled_focal_length = wilor_model_cfg.EXTRA.FOCAL_LENGTH / wilor_model_cfg.MODEL.IMAGE_SIZE * 224
|
104 |
+
pred_cam_t_full = cam_crop_to_new(pred_cam, box_center, box_size, img_size, torch.from_numpy(np.array(ho_bbox, dtype=np.float32))[None, :].to(img_size.device), scaled_focal_length).detach().cpu().numpy()
|
105 |
+
|
106 |
+
batch_size = batch['img'].shape[0]
|
107 |
+
for n in range(batch_size):
|
108 |
+
verts = out['pred_vertices'][n].detach().cpu().numpy()
|
109 |
+
joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
|
110 |
+
|
111 |
+
is_right = batch['right'][n].cpu().numpy()
|
112 |
+
palm = (verts[95] + verts[22]) / 2
|
113 |
+
cam_t = pred_cam_t_full[n]
|
114 |
+
|
115 |
+
img_input = image_transform(crop_img_cv2[:, :, ::-1]).unsqueeze(0).cuda()
|
116 |
+
camera = PerspectiveCamera(5000 / 256 * 224, 5000 / 256 * 224, 112, 112)
|
117 |
+
cam_intr = camera.intrinsics
|
118 |
+
|
119 |
+
metas = dict()
|
120 |
+
metas["right_hand_verts_3d"] = torch.from_numpy((verts + cam_t)[None]).cuda()
|
121 |
+
metas["right_hand_joints_3d"] = torch.from_numpy((joints + cam_t)[None]).cuda()
|
122 |
+
metas["right_hand_palm"] = torch.from_numpy((palm + cam_t)[None]).cuda()
|
123 |
+
metas["cam_intr"] = torch.from_numpy(cam_intr[None]).cuda()
|
124 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
125 |
+
pc_results = hort_model(img_input, metas)
|
126 |
+
objtrans = pc_results["objtrans"][0].detach().cpu().numpy()
|
127 |
+
pointclouds_up = pc_results["pointclouds_up"][0].detach().cpu().numpy() * 0.3
|
128 |
+
|
129 |
+
reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length}
|
130 |
+
|
131 |
+
return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions
|
132 |
+
else:
|
133 |
+
return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), None
|
134 |
+
|
135 |
+
|
136 |
+
def render_reconstruction(image, conf, IoU_threshold=0.3):
|
137 |
+
input_img, num_dets, reconstructions = run_model(image, conf, IoU_threshold=0.5)
|
138 |
+
if num_dets == 1:
|
139 |
+
# Render front view
|
140 |
+
misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal'])
|
141 |
+
cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args)
|
142 |
+
|
143 |
+
# Overlay image
|
144 |
+
input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
|
145 |
+
input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
|
146 |
+
|
147 |
+
return input_img_overlay, f'{num_dets} hands detected'
|
148 |
+
else:
|
149 |
+
return input_img, f'{num_dets} hands detected'
|
150 |
+
|
151 |
+
|
152 |
+
header = ('''
|
153 |
+
<div class="embed_hidden" style="text-align: center;">
|
154 |
+
<h1> <b>HORT</b>: Monocular Hand-held Objects Reconstruction with Transformers</h1>
|
155 |
+
<h3>
|
156 |
+
<a href="https://zerchen.github.io/" target="_blank" rel="noopener noreferrer">Zerui Chen</a><sup>1</sup>,
|
157 |
+
<a href="https://rolpotamias.github.io" target="_blank" rel="noopener noreferrer">Rolandos Alexandros Potamias</a><sup>2</sup>,
|
158 |
+
<br>
|
159 |
+
<a href="https://cshizhe.github.io/" target="_blank" rel="noopener noreferrer">Shizhe Chen</a><sup>1</sup>,
|
160 |
+
<a href="https://cordeliaschmid.github.io/" target="_blank" rel="noopener noreferrer">Cordelia Schmid</a><sup>1</sup>
|
161 |
+
</h3>
|
162 |
+
<h3>
|
163 |
+
<sup>1</sup>Inria, Ecole normale supérieure, CNRS, PSL Research University;
|
164 |
+
<sup>2</sup>Imperial College London
|
165 |
+
</h3>
|
166 |
+
</div>
|
167 |
+
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
|
168 |
+
<a href='https://arxiv.org/abs/2503.21313'><img src='https://img.shields.io/badge/Arxiv-2503.21313-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
|
169 |
+
<a href='https://arxiv.org/pdf/2503.21313'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
|
170 |
+
<a href='https://zerchen.github.io/projects/hort.html'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
|
171 |
+
<a href='https://github.com/zerchen/hort'><img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a>
|
172 |
+
''')
|
173 |
+
|
174 |
+
|
175 |
+
with gr.Blocks(title="HORT: Monocular Hand-held Objects Reconstruction with Transformers", css=".gradio-container") as demo:
|
176 |
+
|
177 |
+
gr.Markdown(header)
|
178 |
+
|
179 |
+
with gr.Row():
|
180 |
+
with gr.Column():
|
181 |
+
input_image = gr.Image(label="Input image", type="numpy")
|
182 |
+
threshold = gr.Slider(value=0.3, minimum=0.05, maximum=0.95, step=0.05, label='Detection Confidence Threshold')
|
183 |
+
submit = gr.Button("Submit", variant="primary")
|
184 |
+
|
185 |
+
|
186 |
+
with gr.Column():
|
187 |
+
reconstruction = gr.Image(label="Reconstructions", type="numpy")
|
188 |
+
hands_detected = gr.Textbox(label="Hands Detected")
|
189 |
+
|
190 |
+
submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected])
|
191 |
+
|
192 |
+
with gr.Row():
|
193 |
+
example_images = gr.Examples([
|
194 |
+
['/home/user/app/assets/test1.png'],
|
195 |
+
['./demo_img/app/assets/test2.png'],
|
196 |
+
['./demo_img/app/assets/test3.jpg'],
|
197 |
+
['./demo_img/app/assets/test4.jpeg'],
|
198 |
+
['./demo_img/app/assets/test5.jpeg']
|
199 |
+
],
|
200 |
+
inputs=input_image)
|
201 |
+
|
202 |
+
demo.launch(debug=True)
|
assets/test1.png
ADDED
![]() |
Git LFS Details
|
assets/test2.png
ADDED
![]() |
Git LFS Details
|
assets/test3.jpg
ADDED
![]() |
assets/test4.jpeg
ADDED
![]() |
assets/test5.jpeg
ADDED
![]() |
hort/models/__init__.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import sys
|
4 |
+
import os.path as osp
|
5 |
+
import numpy as np
|
6 |
+
from yacs.config import CfgNode as CN
|
7 |
+
this_dir = osp.dirname(__file__)
|
8 |
+
sys.path.insert(0, this_dir)
|
9 |
+
import tgs
|
10 |
+
from network.pointnet import PointNetEncoder
|
11 |
+
|
12 |
+
hort_cfg = CN()
|
13 |
+
hort_cfg.image_tokenizer_cls = "tgs.models.tokenizers.image.DINOV2SingleImageTokenizer"
|
14 |
+
hort_cfg.image_tokenizer = CN()
|
15 |
+
hort_cfg.image_tokenizer.pretrained_model_name_or_path = "facebook/dinov2-large"
|
16 |
+
hort_cfg.image_tokenizer.width = 224
|
17 |
+
hort_cfg.image_tokenizer.height = 224
|
18 |
+
hort_cfg.image_tokenizer.modulation = False
|
19 |
+
hort_cfg.image_tokenizer.modulation_zero_init = True
|
20 |
+
hort_cfg.image_tokenizer.modulation_cond_dim = 1024
|
21 |
+
hort_cfg.image_tokenizer.freeze_backbone_params = False
|
22 |
+
hort_cfg.image_tokenizer.enable_memory_efficient_attention = False
|
23 |
+
hort_cfg.image_tokenizer.enable_gradient_checkpointing = False
|
24 |
+
|
25 |
+
hort_cfg.tokenizer_cls = "tgs.models.tokenizers.point.PointLearnablePositionalEmbedding"
|
26 |
+
hort_cfg.tokenizer = CN()
|
27 |
+
hort_cfg.tokenizer.num_pcl = 2049
|
28 |
+
hort_cfg.tokenizer.num_channels = 512
|
29 |
+
|
30 |
+
hort_cfg.backbone_cls = "tgs.models.transformers.Transformer1D"
|
31 |
+
hort_cfg.backbone = CN()
|
32 |
+
hort_cfg.backbone.in_channels = 512
|
33 |
+
hort_cfg.backbone.num_attention_heads = 8
|
34 |
+
hort_cfg.backbone.attention_head_dim = 64
|
35 |
+
hort_cfg.backbone.num_layers = 10
|
36 |
+
hort_cfg.backbone.cross_attention_dim = 1024
|
37 |
+
hort_cfg.backbone.norm_type = "layer_norm"
|
38 |
+
hort_cfg.backbone.enable_memory_efficient_attention = False
|
39 |
+
hort_cfg.backbone.gradient_checkpointing = False
|
40 |
+
|
41 |
+
hort_cfg.post_processor_cls = "tgs.models.networks.PointOutLayer"
|
42 |
+
hort_cfg.post_processor = CN()
|
43 |
+
hort_cfg.post_processor.in_channels = 512
|
44 |
+
hort_cfg.post_processor.out_channels = 3
|
45 |
+
|
46 |
+
hort_cfg.pointcloud_upsampler_cls = "tgs.models.snowflake.model_spdpp.SnowflakeModelSPDPP"
|
47 |
+
hort_cfg.pointcloud_upsampler = CN()
|
48 |
+
hort_cfg.pointcloud_upsampler.input_channels = 1024
|
49 |
+
hort_cfg.pointcloud_upsampler.dim_feat = 128
|
50 |
+
hort_cfg.pointcloud_upsampler.num_p0 = 2048
|
51 |
+
hort_cfg.pointcloud_upsampler.radius = 1
|
52 |
+
hort_cfg.pointcloud_upsampler.bounding = True
|
53 |
+
hort_cfg.pointcloud_upsampler.use_fps = True
|
54 |
+
hort_cfg.pointcloud_upsampler.up_factors = [2, 4]
|
55 |
+
hort_cfg.pointcloud_upsampler.token_type = "image_token"
|
56 |
+
|
57 |
+
|
58 |
+
class model(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(model, self).__init__()
|
61 |
+
self.image_tokenizer = tgs.find(hort_cfg.image_tokenizer_cls)(hort_cfg.image_tokenizer)
|
62 |
+
self.pointnet = PointNetEncoder(67, 1024)
|
63 |
+
self.tokenizer = tgs.find(hort_cfg.tokenizer_cls)(hort_cfg.tokenizer)
|
64 |
+
self.backbone = tgs.find(hort_cfg.backbone_cls)(hort_cfg.backbone)
|
65 |
+
self.post_processor = tgs.find(hort_cfg.post_processor_cls)(hort_cfg.post_processor)
|
66 |
+
self.post_processor_trans = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 3))
|
67 |
+
self.pointcloud_upsampler = tgs.find(hort_cfg.pointcloud_upsampler_cls)(hort_cfg.pointcloud_upsampler)
|
68 |
+
|
69 |
+
def forward(self, input_img, metas):
|
70 |
+
with torch.no_grad():
|
71 |
+
batch_size = input_img.shape[0]
|
72 |
+
|
73 |
+
encoder_hidden_states = self.image_tokenizer(input_img, None) # B * C * Nt
|
74 |
+
encoder_hidden_states = encoder_hidden_states.transpose(2, 1) # B * Nt * C
|
75 |
+
|
76 |
+
palm_norm_hand_verts_3d = metas['right_hand_verts_3d'] - metas['right_hand_palm'].unsqueeze(1)
|
77 |
+
point_idx = torch.arange(778).view(1, 778, 1).expand(batch_size, -1, -1).to(input_img.device) / 778.
|
78 |
+
palm_norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, point_idx], -1)
|
79 |
+
tip_norm_hand_verts_3d = (metas['right_hand_verts_3d'].unsqueeze(2) - metas['right_hand_joints_3d'].unsqueeze(1)).reshape((batch_size, 778, -1))
|
80 |
+
norm_hand_verts_3d = torch.cat([palm_norm_hand_verts_3d, tip_norm_hand_verts_3d], -1)
|
81 |
+
hand_feats = self.pointnet(norm_hand_verts_3d)
|
82 |
+
|
83 |
+
tokens = self.tokenizer(batch_size)
|
84 |
+
tokens = self.backbone(tokens, torch.cat([encoder_hidden_states, hand_feats.unsqueeze(1)], 1), modulation_cond=None)
|
85 |
+
tokens = self.tokenizer.detokenize(tokens)
|
86 |
+
|
87 |
+
pointclouds = self.post_processor(tokens[:, :2048, :])
|
88 |
+
pred_obj_trans = self.post_processor_trans(tokens[:, -1, :])
|
89 |
+
|
90 |
+
upsampling_input = {
|
91 |
+
"input_image_tokens": encoder_hidden_states.permute(0, 2, 1),
|
92 |
+
"intrinsic_cond": metas['cam_intr'],
|
93 |
+
"points": pointclouds,
|
94 |
+
"hand_points": metas["right_hand_verts_3d"],
|
95 |
+
"trans": pred_obj_trans + metas['right_hand_palm'],
|
96 |
+
"scale": 0.3
|
97 |
+
}
|
98 |
+
up_results = self.pointcloud_upsampler(upsampling_input)
|
99 |
+
pointclouds_up = up_results[-1]
|
100 |
+
|
101 |
+
pc_results = {}
|
102 |
+
pc_results['pointclouds'] = pointclouds
|
103 |
+
pc_results['objtrans'] = pred_obj_trans
|
104 |
+
pc_results['handpalm'] = metas['right_hand_palm']
|
105 |
+
pc_results['pointclouds_up'] = pointclouds_up
|
106 |
+
|
107 |
+
return pc_results
|
108 |
+
|
109 |
+
def load_hort(ckpt_path):
|
110 |
+
hort_model = model()
|
111 |
+
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))["network"]
|
112 |
+
ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
|
113 |
+
hort_model.load_state_dict(ckpt)
|
114 |
+
return hort_model
|
hort/models/network/pointnet.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class PointNetEncoder(nn.Module):
|
6 |
+
"""Encoder for Pointcloud
|
7 |
+
"""
|
8 |
+
def __init__(self, in_channels: int=3, output_channels: int=768):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
block_channel = [64, 128, 256, 512]
|
12 |
+
self.mlp = nn.Sequential(
|
13 |
+
nn.Linear(in_channels, block_channel[0]),
|
14 |
+
nn.LayerNorm(block_channel[0]),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.Linear(block_channel[0], block_channel[1]),
|
17 |
+
nn.LayerNorm(block_channel[1]),
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.Linear(block_channel[1], block_channel[2]),
|
20 |
+
nn.LayerNorm(block_channel[2]),
|
21 |
+
nn.ReLU(),
|
22 |
+
nn.Linear(block_channel[2], block_channel[3]),
|
23 |
+
nn.LayerNorm(block_channel[3]),
|
24 |
+
nn.ReLU(),
|
25 |
+
)
|
26 |
+
|
27 |
+
self.final_projection = nn.Sequential(
|
28 |
+
nn.Linear(block_channel[-1], output_channels),
|
29 |
+
nn.LayerNorm(output_channels)
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.mlp(x)
|
34 |
+
x = torch.max(x, 1)[0]
|
35 |
+
x = self.final_projection(x)
|
36 |
+
return x
|
hort/models/tgs/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from tgs.utils.typing import *
|
3 |
+
|
4 |
+
def find(cls_string) -> Type:
|
5 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
6 |
+
cls_name = cls_string.split(".")[-1]
|
7 |
+
module = importlib.import_module(module_string, package=None)
|
8 |
+
cls = getattr(module, cls_name)
|
9 |
+
return cls
|
hort/models/tgs/data.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
|
5 |
+
import os
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from PIL import Image
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
|
13 |
+
from tgs.utils.config import parse_structured
|
14 |
+
from tgs.utils.ops import get_intrinsic_from_fov, get_ray_directions, get_rays
|
15 |
+
from tgs.utils.typing import *
|
16 |
+
|
17 |
+
|
18 |
+
def _parse_scene_list_single(scene_list_path: str):
|
19 |
+
if scene_list_path.endswith(".json"):
|
20 |
+
with open(scene_list_path) as f:
|
21 |
+
all_scenes = json.loads(f.read())
|
22 |
+
elif scene_list_path.endswith(".txt"):
|
23 |
+
with open(scene_list_path) as f:
|
24 |
+
all_scenes = [p.strip() for p in f.readlines()]
|
25 |
+
else:
|
26 |
+
all_scenes = [scene_list_path]
|
27 |
+
|
28 |
+
return all_scenes
|
29 |
+
|
30 |
+
|
31 |
+
def _parse_scene_list(scene_list_path: Union[str, List[str]]):
|
32 |
+
all_scenes = []
|
33 |
+
if isinstance(scene_list_path, str):
|
34 |
+
scene_list_path = [scene_list_path]
|
35 |
+
for scene_list_path_ in scene_list_path:
|
36 |
+
all_scenes += _parse_scene_list_single(scene_list_path_)
|
37 |
+
return all_scenes
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class CustomImageDataModuleConfig:
|
41 |
+
image_list: Any = ""
|
42 |
+
background_color: Tuple[float, float, float] = field(
|
43 |
+
default_factory=lambda: (1.0, 1.0, 1.0)
|
44 |
+
)
|
45 |
+
|
46 |
+
relative_pose: bool = False
|
47 |
+
cond_height: int = 512
|
48 |
+
cond_width: int = 512
|
49 |
+
cond_camera_distance: float = 1.6
|
50 |
+
cond_fovy_deg: float = 40.0
|
51 |
+
cond_elevation_deg: float = 0.0
|
52 |
+
cond_azimuth_deg: float = 0.0
|
53 |
+
num_workers: int = 16
|
54 |
+
|
55 |
+
eval_height: int = 512
|
56 |
+
eval_width: int = 512
|
57 |
+
eval_batch_size: int = 1
|
58 |
+
eval_elevation_deg: float = 0.0
|
59 |
+
eval_camera_distance: float = 1.6
|
60 |
+
eval_fovy_deg: float = 40.0
|
61 |
+
n_test_views: int = 120
|
62 |
+
num_views_output: int = 120
|
63 |
+
only_3dgs: bool = False
|
64 |
+
|
65 |
+
class CustomImageOrbitDataset(Dataset):
|
66 |
+
def __init__(self, cfg: Any) -> None:
|
67 |
+
super().__init__()
|
68 |
+
self.cfg: CustomImageDataModuleConfig = parse_structured(CustomImageDataModuleConfig, cfg)
|
69 |
+
|
70 |
+
self.n_views = self.cfg.n_test_views
|
71 |
+
assert self.n_views % self.cfg.num_views_output == 0
|
72 |
+
|
73 |
+
self.all_scenes = _parse_scene_list(self.cfg.image_list)
|
74 |
+
|
75 |
+
azimuth_deg: Float[Tensor, "B"] = torch.linspace(0, 360.0, self.n_views + 1)[
|
76 |
+
: self.n_views
|
77 |
+
]
|
78 |
+
elevation_deg: Float[Tensor, "B"] = torch.full_like(
|
79 |
+
azimuth_deg, self.cfg.eval_elevation_deg
|
80 |
+
)
|
81 |
+
camera_distances: Float[Tensor, "B"] = torch.full_like(
|
82 |
+
elevation_deg, self.cfg.eval_camera_distance
|
83 |
+
)
|
84 |
+
|
85 |
+
elevation = elevation_deg * math.pi / 180
|
86 |
+
azimuth = azimuth_deg * math.pi / 180
|
87 |
+
|
88 |
+
# convert spherical coordinates to cartesian coordinates
|
89 |
+
# right hand coordinate system, x back, y right, z up
|
90 |
+
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
|
91 |
+
camera_positions: Float[Tensor, "B 3"] = torch.stack(
|
92 |
+
[
|
93 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
94 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
95 |
+
camera_distances * torch.sin(elevation),
|
96 |
+
],
|
97 |
+
dim=-1,
|
98 |
+
)
|
99 |
+
|
100 |
+
# default scene center at origin
|
101 |
+
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
|
102 |
+
# default camera up direction as +z
|
103 |
+
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
|
104 |
+
None, :
|
105 |
+
].repeat(self.n_views, 1)
|
106 |
+
|
107 |
+
fovy_deg: Float[Tensor, "B"] = torch.full_like(
|
108 |
+
elevation_deg, self.cfg.eval_fovy_deg
|
109 |
+
)
|
110 |
+
fovy = fovy_deg * math.pi / 180
|
111 |
+
|
112 |
+
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
|
113 |
+
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
|
114 |
+
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
115 |
+
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
|
116 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
117 |
+
dim=-1,
|
118 |
+
)
|
119 |
+
c2w: Float[Tensor, "B 4 4"] = torch.cat(
|
120 |
+
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
|
121 |
+
)
|
122 |
+
c2w[:, 3, 3] = 1.0
|
123 |
+
|
124 |
+
# get directions by dividing directions_unit_focal by focal length
|
125 |
+
focal_length: Float[Tensor, "B"] = (
|
126 |
+
0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy)
|
127 |
+
)
|
128 |
+
directions_unit_focal = get_ray_directions(
|
129 |
+
H=self.cfg.eval_height,
|
130 |
+
W=self.cfg.eval_width,
|
131 |
+
focal=1.0,
|
132 |
+
)
|
133 |
+
directions: Float[Tensor, "B H W 3"] = directions_unit_focal[
|
134 |
+
None, :, :, :
|
135 |
+
].repeat(self.n_views, 1, 1, 1)
|
136 |
+
directions[:, :, :, :2] = (
|
137 |
+
directions[:, :, :, :2] / focal_length[:, None, None, None]
|
138 |
+
)
|
139 |
+
# must use normalize=True to normalize directions here
|
140 |
+
rays_o, rays_d = get_rays(directions, c2w, keepdim=True)
|
141 |
+
|
142 |
+
intrinsic: Float[Tensor, "B 3 3"] = get_intrinsic_from_fov(
|
143 |
+
self.cfg.eval_fovy_deg * math.pi / 180,
|
144 |
+
H=self.cfg.eval_height,
|
145 |
+
W=self.cfg.eval_width,
|
146 |
+
bs=self.n_views,
|
147 |
+
)
|
148 |
+
intrinsic_normed: Float[Tensor, "B 3 3"] = intrinsic.clone()
|
149 |
+
intrinsic_normed[..., 0, 2] /= self.cfg.eval_width
|
150 |
+
intrinsic_normed[..., 1, 2] /= self.cfg.eval_height
|
151 |
+
intrinsic_normed[..., 0, 0] /= self.cfg.eval_width
|
152 |
+
intrinsic_normed[..., 1, 1] /= self.cfg.eval_height
|
153 |
+
|
154 |
+
self.rays_o, self.rays_d = rays_o, rays_d
|
155 |
+
self.intrinsic = intrinsic
|
156 |
+
self.intrinsic_normed = intrinsic_normed
|
157 |
+
self.c2w = c2w
|
158 |
+
self.camera_positions = camera_positions
|
159 |
+
|
160 |
+
self.background_color = torch.as_tensor(self.cfg.background_color)
|
161 |
+
|
162 |
+
# condition
|
163 |
+
self.intrinsic_cond = get_intrinsic_from_fov(
|
164 |
+
np.deg2rad(self.cfg.cond_fovy_deg),
|
165 |
+
H=self.cfg.cond_height,
|
166 |
+
W=self.cfg.cond_width,
|
167 |
+
)
|
168 |
+
self.intrinsic_normed_cond = self.intrinsic_cond.clone()
|
169 |
+
self.intrinsic_normed_cond[..., 0, 2] /= self.cfg.cond_width
|
170 |
+
self.intrinsic_normed_cond[..., 1, 2] /= self.cfg.cond_height
|
171 |
+
self.intrinsic_normed_cond[..., 0, 0] /= self.cfg.cond_width
|
172 |
+
self.intrinsic_normed_cond[..., 1, 1] /= self.cfg.cond_height
|
173 |
+
|
174 |
+
|
175 |
+
if self.cfg.relative_pose:
|
176 |
+
self.c2w_cond = torch.as_tensor(
|
177 |
+
[
|
178 |
+
[0, 0, 1, self.cfg.cond_camera_distance],
|
179 |
+
[1, 0, 0, 0],
|
180 |
+
[0, 1, 0, 0],
|
181 |
+
[0, 0, 0, 1],
|
182 |
+
]
|
183 |
+
).float()
|
184 |
+
else:
|
185 |
+
cond_elevation = self.cfg.cond_elevation_deg * math.pi / 180
|
186 |
+
cond_azimuth = self.cfg.cond_azimuth_deg * math.pi / 180
|
187 |
+
cond_camera_position: Float[Tensor, "3"] = torch.as_tensor(
|
188 |
+
[
|
189 |
+
self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.cos(cond_azimuth),
|
190 |
+
self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.sin(cond_azimuth),
|
191 |
+
self.cfg.cond_camera_distance * np.sin(cond_elevation),
|
192 |
+
], dtype=torch.float32
|
193 |
+
)
|
194 |
+
|
195 |
+
cond_center: Float[Tensor, "3"] = torch.zeros_like(cond_camera_position)
|
196 |
+
cond_up: Float[Tensor, "3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)
|
197 |
+
cond_lookat: Float[Tensor, "3"] = F.normalize(cond_center - cond_camera_position, dim=-1)
|
198 |
+
cond_right: Float[Tensor, "3"] = F.normalize(torch.cross(cond_lookat, cond_up), dim=-1)
|
199 |
+
cond_up = F.normalize(torch.cross(cond_right, cond_lookat), dim=-1)
|
200 |
+
cond_c2w3x4: Float[Tensor, "3 4"] = torch.cat(
|
201 |
+
[torch.stack([cond_right, cond_up, -cond_lookat], dim=-1), cond_camera_position[:, None]],
|
202 |
+
dim=-1,
|
203 |
+
)
|
204 |
+
cond_c2w: Float[Tensor, "4 4"] = torch.cat(
|
205 |
+
[cond_c2w3x4, torch.zeros_like(cond_c2w3x4[:1])], dim=0
|
206 |
+
)
|
207 |
+
cond_c2w[3, 3] = 1.0
|
208 |
+
self.c2w_cond = cond_c2w
|
209 |
+
|
210 |
+
def __len__(self):
|
211 |
+
if self.cfg.only_3dgs:
|
212 |
+
return len(self.all_scenes)
|
213 |
+
else:
|
214 |
+
return len(self.all_scenes) * self.n_views // self.cfg.num_views_output
|
215 |
+
|
216 |
+
def __getitem__(self, index):
|
217 |
+
if self.cfg.only_3dgs:
|
218 |
+
scene_index = index
|
219 |
+
view_index = [0]
|
220 |
+
else:
|
221 |
+
scene_index = index * self.cfg.num_views_output // self.n_views
|
222 |
+
view_start = index % (self.n_views // self.cfg.num_views_output)
|
223 |
+
view_index = list(range(self.n_views))[view_start * self.cfg.num_views_output :
|
224 |
+
(view_start + 1) * self.cfg.num_views_output]
|
225 |
+
|
226 |
+
img_path = self.all_scenes[scene_index]
|
227 |
+
img_cond = torch.from_numpy(
|
228 |
+
np.asarray(
|
229 |
+
Image.fromarray(imageio.v2.imread(img_path))
|
230 |
+
.convert("RGBA")
|
231 |
+
.resize((self.cfg.cond_width, self.cfg.cond_height))
|
232 |
+
)
|
233 |
+
/ 255.0
|
234 |
+
).float()
|
235 |
+
mask_cond: Float[Tensor, "Hc Wc 1"] = img_cond[:, :, -1:]
|
236 |
+
rgb_cond: Float[Tensor, "Hc Wc 3"] = img_cond[
|
237 |
+
:, :, :3
|
238 |
+
] * mask_cond + self.background_color[None, None, :] * (1 - mask_cond)
|
239 |
+
|
240 |
+
out = {
|
241 |
+
"rgb_cond": rgb_cond.unsqueeze(0),
|
242 |
+
"c2w_cond": self.c2w_cond.unsqueeze(0),
|
243 |
+
"mask_cond": mask_cond.unsqueeze(0),
|
244 |
+
"intrinsic_cond": self.intrinsic_cond.unsqueeze(0),
|
245 |
+
"intrinsic_normed_cond": self.intrinsic_normed_cond.unsqueeze(0),
|
246 |
+
"view_index": torch.as_tensor(view_index),
|
247 |
+
"rays_o": self.rays_o[view_index],
|
248 |
+
"rays_d": self.rays_d[view_index],
|
249 |
+
"intrinsic": self.intrinsic[view_index],
|
250 |
+
"intrinsic_normed": self.intrinsic_normed[view_index],
|
251 |
+
"c2w": self.c2w[view_index],
|
252 |
+
"camera_positions": self.camera_positions[view_index],
|
253 |
+
}
|
254 |
+
out["c2w"][..., :3, 1:3] *= -1
|
255 |
+
out["c2w_cond"][..., :3, 1:3] *= -1
|
256 |
+
instance_id = os.path.split(img_path)[-1].split('.')[0]
|
257 |
+
out["index"] = torch.as_tensor(scene_index)
|
258 |
+
out["background_color"] = self.background_color
|
259 |
+
out["instance_id"] = instance_id
|
260 |
+
return out
|
261 |
+
|
262 |
+
def collate(self, batch):
|
263 |
+
batch = torch.utils.data.default_collate(batch)
|
264 |
+
batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width})
|
265 |
+
return batch
|
hort/models/tgs/models/__init__.py
ADDED
File without changes
|
hort/models/tgs/models/image_feature.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
from tgs.utils.base import BaseModule
|
7 |
+
from tgs.utils.ops import compute_distance_transform
|
8 |
+
from tgs.utils.typing import *
|
9 |
+
|
10 |
+
class ImageFeature(BaseModule):
|
11 |
+
@dataclass
|
12 |
+
class Config(BaseModule.Config):
|
13 |
+
use_rgb: bool = True
|
14 |
+
use_feature: bool = True
|
15 |
+
use_mask: bool = True
|
16 |
+
feature_dim: int = 128
|
17 |
+
out_dim: int = 133
|
18 |
+
backbone: str = "default"
|
19 |
+
freeze_backbone_params: bool = True
|
20 |
+
|
21 |
+
cfg: Config
|
22 |
+
|
23 |
+
def forward(self, rgb, mask=None, feature=None):
|
24 |
+
B, Nv, H, W = rgb.shape[:4]
|
25 |
+
rgb = rearrange(rgb, "B Nv H W C -> (B Nv) C H W")
|
26 |
+
if mask is not None:
|
27 |
+
mask = rearrange(mask, "B Nv H W C -> (B Nv) C H W")
|
28 |
+
|
29 |
+
assert feature is not None
|
30 |
+
# reshape dino tokens to image-like size
|
31 |
+
feature = rearrange(feature, "B (Nv Nt) C -> (B Nv) Nt C", Nv=Nv)
|
32 |
+
feature = feature[:, 1:].reshape(B * Nv, H // 14, W // 14, -1).permute(0, 3, 1, 2).contiguous()
|
33 |
+
feature = F.interpolate(feature, size=(H, W), mode='bilinear', align_corners=False)
|
34 |
+
|
35 |
+
if mask is not None and mask.is_floating_point():
|
36 |
+
mask = mask > 0.5
|
37 |
+
|
38 |
+
image_features = []
|
39 |
+
if self.cfg.use_rgb:
|
40 |
+
image_features.append(rgb)
|
41 |
+
if self.cfg.use_feature:
|
42 |
+
image_features.append(feature)
|
43 |
+
if self.cfg.use_mask:
|
44 |
+
image_features += [mask, compute_distance_transform(mask)]
|
45 |
+
|
46 |
+
# detach features, occur error when with grad
|
47 |
+
image_features = torch.cat(image_features, dim=1)#.detach()
|
48 |
+
return rearrange(image_features, "(B Nv) C H W -> B Nv C H W", B=B, Nv=Nv).squeeze(1)
|
hort/models/tgs/models/networks.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from tgs.utils.base import BaseModule
|
9 |
+
from tgs.utils.ops import get_activation
|
10 |
+
from tgs.utils.typing import *
|
11 |
+
|
12 |
+
class PointOutLayer(BaseModule):
|
13 |
+
@dataclass
|
14 |
+
class Config(BaseModule.Config):
|
15 |
+
in_channels: int = 1024
|
16 |
+
out_channels: int = 3
|
17 |
+
cfg: Config
|
18 |
+
def configure(self) -> None:
|
19 |
+
super().configure()
|
20 |
+
self.point_layer = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
|
21 |
+
self.initialize_weights()
|
22 |
+
|
23 |
+
def initialize_weights(self):
|
24 |
+
nn.init.constant_(self.point_layer.weight, 0)
|
25 |
+
nn.init.constant_(self.point_layer.bias, 0)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.point_layer(x)
|
29 |
+
|
30 |
+
class TriplaneUpsampleNetwork(BaseModule):
|
31 |
+
@dataclass
|
32 |
+
class Config(BaseModule.Config):
|
33 |
+
in_channels: int = 1024
|
34 |
+
out_channels: int = 80
|
35 |
+
|
36 |
+
cfg: Config
|
37 |
+
|
38 |
+
def configure(self) -> None:
|
39 |
+
super().configure()
|
40 |
+
self.upsample = nn.ConvTranspose2d(
|
41 |
+
self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(
|
45 |
+
self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
|
46 |
+
) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
|
47 |
+
triplanes_up = rearrange(
|
48 |
+
self.upsample(
|
49 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
50 |
+
),
|
51 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
52 |
+
Np=3,
|
53 |
+
)
|
54 |
+
return triplanes_up
|
55 |
+
|
56 |
+
|
57 |
+
class MLP(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
dim_in: int,
|
61 |
+
dim_out: int,
|
62 |
+
n_neurons: int,
|
63 |
+
n_hidden_layers: int,
|
64 |
+
activation: str = "relu",
|
65 |
+
output_activation: Optional[str] = None,
|
66 |
+
bias: bool = True,
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
layers = [
|
70 |
+
self.make_linear(
|
71 |
+
dim_in, n_neurons, is_first=True, is_last=False, bias=bias
|
72 |
+
),
|
73 |
+
self.make_activation(activation),
|
74 |
+
]
|
75 |
+
for i in range(n_hidden_layers - 1):
|
76 |
+
layers += [
|
77 |
+
self.make_linear(
|
78 |
+
n_neurons, n_neurons, is_first=False, is_last=False, bias=bias
|
79 |
+
),
|
80 |
+
self.make_activation(activation),
|
81 |
+
]
|
82 |
+
layers += [
|
83 |
+
self.make_linear(
|
84 |
+
n_neurons, dim_out, is_first=False, is_last=True, bias=bias
|
85 |
+
)
|
86 |
+
]
|
87 |
+
self.layers = nn.Sequential(*layers)
|
88 |
+
self.output_activation = get_activation(output_activation)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = self.layers(x)
|
92 |
+
x = self.output_activation(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True):
|
96 |
+
layer = nn.Linear(dim_in, dim_out, bias=bias)
|
97 |
+
return layer
|
98 |
+
|
99 |
+
def make_activation(self, activation):
|
100 |
+
if activation == "relu":
|
101 |
+
return nn.ReLU(inplace=True)
|
102 |
+
elif activation == "silu":
|
103 |
+
return nn.SiLU(inplace=True)
|
104 |
+
else:
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
class GSProjection(nn.Module):
|
108 |
+
def __init__(self,
|
109 |
+
in_channels: int = 80,
|
110 |
+
sh_degree: int = 3,
|
111 |
+
init_scaling: float = -5.0,
|
112 |
+
init_density: float = 0.1) -> None:
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
self.out_keys = GS_KEYS + ["shs"]
|
116 |
+
self.out_channels = GS_CHANNELS + [(sh_degree + 1) ** 2 * 3]
|
117 |
+
|
118 |
+
self.out_layers = nn.ModuleList()
|
119 |
+
for key, ch in zip(self.out_keys, self.out_channels):
|
120 |
+
layer = nn.Linear(in_channels, ch)
|
121 |
+
# initialize
|
122 |
+
nn.init.constant_(layer.weight, 0)
|
123 |
+
nn.init.constant_(layer.bias, 0)
|
124 |
+
|
125 |
+
if key == "scaling":
|
126 |
+
nn.init.constant_(layer.bias, init_scaling)
|
127 |
+
elif key == "rotation":
|
128 |
+
nn.init.constant_(layer.bias, 0)
|
129 |
+
nn.init.constant_(layer.bias[0], 1.0)
|
130 |
+
elif key == "opacity":
|
131 |
+
inverse_sigmoid = lambda x: np.log(x / (1 - x))
|
132 |
+
nn.init.constant_(layer.bias, inverse_sigmoid(init_density))
|
133 |
+
|
134 |
+
self.out_layers.append(layer)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
ret = []
|
138 |
+
for k, layer in zip(self.out_keys, self.out_layers):
|
139 |
+
v = layer(x)
|
140 |
+
if k == "rotation":
|
141 |
+
v = torch.nn.functional.normalize(v)
|
142 |
+
elif k == "scaling":
|
143 |
+
v = torch.exp(v)
|
144 |
+
# v = v.detach() # FIXME: for DEBUG
|
145 |
+
elif k == "opacity":
|
146 |
+
v = torch.sigmoid(v)
|
147 |
+
# elif k == "shs":
|
148 |
+
# v = torch.reshape(v, (v.shape[0], -1, 3))
|
149 |
+
ret.append(v)
|
150 |
+
ret = torch.cat(ret, dim=-1)
|
151 |
+
return ret
|
152 |
+
|
153 |
+
def get_encoding(n_input_dims: int, config) -> nn.Module:
|
154 |
+
raise NotImplementedError
|
155 |
+
|
156 |
+
|
157 |
+
def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module:
|
158 |
+
raise NotImplementedError
|
159 |
+
|
160 |
+
|
161 |
+
# Resnet Blocks for pointnet
|
162 |
+
class ResnetBlockFC(nn.Module):
|
163 |
+
''' Fully connected ResNet Block class.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
size_in (int): input dimension
|
167 |
+
size_out (int): output dimension
|
168 |
+
size_h (int): hidden dimension
|
169 |
+
'''
|
170 |
+
|
171 |
+
def __init__(self, size_in, size_out=None, size_h=None):
|
172 |
+
super().__init__()
|
173 |
+
# Attributes
|
174 |
+
if size_out is None:
|
175 |
+
size_out = size_in
|
176 |
+
|
177 |
+
if size_h is None:
|
178 |
+
size_h = min(size_in, size_out)
|
179 |
+
|
180 |
+
self.size_in = size_in
|
181 |
+
self.size_h = size_h
|
182 |
+
self.size_out = size_out
|
183 |
+
# Submodules
|
184 |
+
self.fc_0 = nn.Linear(size_in, size_h)
|
185 |
+
self.fc_1 = nn.Linear(size_h, size_out)
|
186 |
+
self.actvn = nn.ReLU()
|
187 |
+
|
188 |
+
if size_in == size_out:
|
189 |
+
self.shortcut = None
|
190 |
+
else:
|
191 |
+
self.shortcut = nn.Linear(size_in, size_out, bias=False)
|
192 |
+
# Initialization
|
193 |
+
nn.init.zeros_(self.fc_1.weight)
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
net = self.fc_0(self.actvn(x))
|
197 |
+
dx = self.fc_1(self.actvn(net))
|
198 |
+
|
199 |
+
if self.shortcut is not None:
|
200 |
+
x_s = self.shortcut(x)
|
201 |
+
else:
|
202 |
+
x_s = x
|
203 |
+
|
204 |
+
return x_s + dx
|
hort/models/tgs/models/pointclouds/LICENSE_POINTNET
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
hort/models/tgs/models/pointclouds/pointnet.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/src/encoder/pointnet.py
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch_scatter import scatter_mean, scatter_max
|
6 |
+
|
7 |
+
from tgs.utils.base import BaseModule
|
8 |
+
from tgs.models.networks import ResnetBlockFC
|
9 |
+
from tgs.utils.ops import scale_tensor
|
10 |
+
|
11 |
+
class LocalPoolPointnet(BaseModule):
|
12 |
+
''' PointNet-based encoder network with ResNet blocks for each point.
|
13 |
+
Number of input points are fixed.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
c_dim (int): dimension of latent code c
|
17 |
+
dim (int): input points dimension
|
18 |
+
hidden_dim (int): hidden dimension of the network
|
19 |
+
scatter_type (str): feature aggregation when doing local pooling
|
20 |
+
plane_resolution (int): defined resolution for plane feature
|
21 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
22 |
+
n_blocks (int): number of blocks ResNetBlockFC layers
|
23 |
+
'''
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class Config(BaseModule.Config):
|
27 |
+
input_channels: int = 3
|
28 |
+
c_dim: int = 128
|
29 |
+
hidden_dim: int = 128
|
30 |
+
scatter_type: str = "max"
|
31 |
+
plane_size: int = 32
|
32 |
+
n_blocks: int = 5
|
33 |
+
radius: float = 1.
|
34 |
+
|
35 |
+
cfg: Config
|
36 |
+
|
37 |
+
def configure(self) -> None:
|
38 |
+
super().configure()
|
39 |
+
self.fc_pos = nn.Linear(self.cfg.input_channels, 2 * self.cfg.hidden_dim)
|
40 |
+
self.blocks = nn.ModuleList([
|
41 |
+
ResnetBlockFC(2 * self.cfg.hidden_dim, self.cfg.hidden_dim) for i in range(self.cfg.n_blocks)
|
42 |
+
])
|
43 |
+
self.fc_c = nn.Linear(self.cfg.hidden_dim, self.cfg.c_dim)
|
44 |
+
|
45 |
+
self.actvn = nn.ReLU()
|
46 |
+
|
47 |
+
if self.cfg.scatter_type == 'max':
|
48 |
+
self.scatter = scatter_max
|
49 |
+
elif self.cfg.scatter_type == 'mean':
|
50 |
+
self.scatter = scatter_mean
|
51 |
+
else:
|
52 |
+
raise ValueError('incorrect scatter type')
|
53 |
+
|
54 |
+
|
55 |
+
def generate_plane_features(self, index, c):
|
56 |
+
# acquire indices of features in plane
|
57 |
+
# xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
|
58 |
+
# index = self.coordinate2index(x, self.cfg.plane_size)
|
59 |
+
|
60 |
+
# scatter plane features from points
|
61 |
+
fea_plane = c.new_zeros(index.shape[0], self.cfg.c_dim, self.cfg.plane_size ** 2)
|
62 |
+
c = c.permute(0, 2, 1) # B x 512 x T
|
63 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
64 |
+
fea_plane = fea_plane.reshape(index.shape[0], self.cfg.c_dim, self.cfg.plane_size, self.cfg.plane_size) # sparce matrix (B x 512 x reso x reso)
|
65 |
+
|
66 |
+
return fea_plane
|
67 |
+
|
68 |
+
def pool_local(self, xy, index, c):
|
69 |
+
bs, fea_dim = c.shape[0], c.shape[2]
|
70 |
+
keys = xy.keys()
|
71 |
+
|
72 |
+
c_out = 0
|
73 |
+
for key in keys:
|
74 |
+
# scatter plane features from points
|
75 |
+
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.cfg.plane_size ** 2)
|
76 |
+
if self.scatter == scatter_max:
|
77 |
+
fea = fea[0]
|
78 |
+
# gather feature back to points
|
79 |
+
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
|
80 |
+
c_out += fea
|
81 |
+
return c_out.permute(0, 2, 1)
|
82 |
+
|
83 |
+
def coordinate2index(self, x):
|
84 |
+
x = (x * self.cfg.plane_size).long()
|
85 |
+
index = x[..., 0] + self.cfg.plane_size * x[..., 1]
|
86 |
+
assert index.max() < self.cfg.plane_size ** 2
|
87 |
+
return index[:, None, :]
|
88 |
+
|
89 |
+
def forward(self, p):
|
90 |
+
batch_size, T, D = p.shape
|
91 |
+
|
92 |
+
# acquire the index for each point
|
93 |
+
coord = {}
|
94 |
+
index = {}
|
95 |
+
|
96 |
+
position = torch.clamp(p[..., :3], -self.cfg.radius + 1e-6, self.cfg.radius - 1e-6)
|
97 |
+
position_norm = scale_tensor(position, (-self.cfg.radius, self.cfg.radius), (0, 1))
|
98 |
+
coord["xy"] = position_norm[..., [0, 1]]
|
99 |
+
coord["xz"] = position_norm[..., [0, 2]]
|
100 |
+
coord["yz"] = position_norm[..., [1, 2]]
|
101 |
+
index["xy"] = self.coordinate2index(coord["xy"])
|
102 |
+
index["xz"] = self.coordinate2index(coord["xz"])
|
103 |
+
index["yz"] = self.coordinate2index(coord["yz"])
|
104 |
+
|
105 |
+
net = self.fc_pos(p)
|
106 |
+
|
107 |
+
net = self.blocks[0](net)
|
108 |
+
for block in self.blocks[1:]:
|
109 |
+
pooled = self.pool_local(coord, index, net)
|
110 |
+
net = torch.cat([net, pooled], dim=2)
|
111 |
+
net = block(net)
|
112 |
+
|
113 |
+
c = self.fc_c(net)
|
114 |
+
|
115 |
+
features = torch.stack([
|
116 |
+
self.generate_plane_features(index["xy"], c),
|
117 |
+
self.generate_plane_features(index["xz"], c),
|
118 |
+
self.generate_plane_features(index["yz"], c)
|
119 |
+
], dim=1)
|
120 |
+
|
121 |
+
return features
|
hort/models/tgs/models/pointclouds/simplepoint.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import torch
|
3 |
+
from einops import rearrange
|
4 |
+
|
5 |
+
import tgs
|
6 |
+
from tgs.utils.base import BaseModule
|
7 |
+
from tgs.utils.typing import *
|
8 |
+
|
9 |
+
class SimplePointGenerator(BaseModule):
|
10 |
+
@dataclass
|
11 |
+
class Config(BaseModule.Config):
|
12 |
+
camera_embedder_cls: str = ""
|
13 |
+
camera_embedder: dict = field(default_factory=dict)
|
14 |
+
|
15 |
+
image_tokenizer_cls: str = ""
|
16 |
+
image_tokenizer: dict = field(default_factory=dict)
|
17 |
+
|
18 |
+
tokenizer_cls: str = ""
|
19 |
+
tokenizer: dict = field(default_factory=dict)
|
20 |
+
|
21 |
+
backbone_cls: str = ""
|
22 |
+
backbone: dict = field(default_factory=dict)
|
23 |
+
|
24 |
+
post_processor_cls: str = ""
|
25 |
+
post_processor: dict = field(default_factory=dict)
|
26 |
+
|
27 |
+
pointcloud_upsampling_cls: str = ""
|
28 |
+
pointcloud_upsampling: dict = field(default_factory=dict)
|
29 |
+
|
30 |
+
flip_c2w_cond: bool = True
|
31 |
+
|
32 |
+
cfg: Config
|
33 |
+
|
34 |
+
def configure(self) -> None:
|
35 |
+
super().configure()
|
36 |
+
|
37 |
+
self.image_tokenizer = tgs.find(self.cfg.image_tokenizer_cls)(
|
38 |
+
self.cfg.image_tokenizer
|
39 |
+
)
|
40 |
+
|
41 |
+
assert self.cfg.camera_embedder_cls == 'tgs.models.networks.MLP'
|
42 |
+
weights = self.cfg.camera_embedder.pop("weights") if "weights" in self.cfg.camera_embedder else None
|
43 |
+
self.camera_embedder = tgs.find(self.cfg.camera_embedder_cls)(**self.cfg.camera_embedder)
|
44 |
+
if weights:
|
45 |
+
from tgs.utils.misc import load_module_weights
|
46 |
+
weights_path, module_name = weights.split(":")
|
47 |
+
state_dict = load_module_weights(
|
48 |
+
weights_path, module_name=module_name, map_location="cpu"
|
49 |
+
)
|
50 |
+
self.camera_embedder.load_state_dict(state_dict)
|
51 |
+
|
52 |
+
self.tokenizer = tgs.find(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
53 |
+
|
54 |
+
self.backbone = tgs.find(self.cfg.backbone_cls)(self.cfg.backbone)
|
55 |
+
|
56 |
+
self.post_processor = tgs.find(self.cfg.post_processor_cls)(
|
57 |
+
self.cfg.post_processor
|
58 |
+
)
|
59 |
+
|
60 |
+
self.pointcloud_upsampling = tgs.find(self.cfg.pointcloud_upsampling_cls)(self.cfg.pointcloud_upsampling)
|
61 |
+
|
62 |
+
def forward(self, batch, encoder_hidden_states=None, **kwargs):
|
63 |
+
batch_size, n_input_views = batch["rgb_cond"].shape[:2]
|
64 |
+
|
65 |
+
if encoder_hidden_states is None:
|
66 |
+
# Camera modulation
|
67 |
+
c2w_cond = batch["c2w_cond"].clone()
|
68 |
+
if self.cfg.flip_c2w_cond:
|
69 |
+
c2w_cond[..., :3, 1:3] *= -1
|
70 |
+
camera_extri = c2w_cond.view(*c2w_cond.shape[:-2], -1)
|
71 |
+
camera_intri = batch["intrinsic_normed_cond"].view(
|
72 |
+
*batch["intrinsic_normed_cond"].shape[:-2], -1)
|
73 |
+
camera_feats = torch.cat([camera_intri, camera_extri], dim=-1)
|
74 |
+
# camera_feats = rearrange(camera_feats, 'B Nv C -> (B Nv) C')
|
75 |
+
|
76 |
+
camera_feats = self.camera_embedder(camera_feats)
|
77 |
+
|
78 |
+
encoder_hidden_states: Float[Tensor, "B Cit Nit"] = self.image_tokenizer(
|
79 |
+
rearrange(batch["rgb_cond"], 'B Nv H W C -> B Nv C H W'),
|
80 |
+
modulation_cond=camera_feats,
|
81 |
+
)
|
82 |
+
encoder_hidden_states = rearrange(
|
83 |
+
encoder_hidden_states, 'B Nv C Nt -> B (Nv Nt) C', Nv=n_input_views)
|
84 |
+
|
85 |
+
tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
|
86 |
+
|
87 |
+
tokens = self.backbone(
|
88 |
+
tokens,
|
89 |
+
encoder_hidden_states=encoder_hidden_states,
|
90 |
+
modulation_cond=None,
|
91 |
+
)
|
92 |
+
pointclouds = self.post_processor(self.tokenizer.detokenize(tokens))
|
93 |
+
|
94 |
+
upsampling_input = {
|
95 |
+
"input_image_tokens": encoder_hidden_states.permute(0, 2, 1),
|
96 |
+
"input_image_tokens_global": encoder_hidden_states[:, :1],
|
97 |
+
"c2w_cond": c2w_cond,
|
98 |
+
"rgb_cond": batch["rgb_cond"],
|
99 |
+
"intrinsic_cond": batch["intrinsic_cond"],
|
100 |
+
"intrinsic_normed_cond": batch["intrinsic_normed_cond"],
|
101 |
+
"points": pointclouds.float()
|
102 |
+
}
|
103 |
+
up_results = self.pointcloud_upsampling(upsampling_input)
|
104 |
+
up_results.insert(0, pointclouds)
|
105 |
+
pointclouds = up_results[-1]
|
106 |
+
out = {
|
107 |
+
"points": pointclouds,
|
108 |
+
"up_results": up_results
|
109 |
+
}
|
110 |
+
return out
|
hort/models/tgs/models/renderer.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from collections import defaultdict
|
3 |
+
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
4 |
+
from plyfile import PlyData, PlyElement
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import numpy as np
|
9 |
+
import math
|
10 |
+
|
11 |
+
from tgs.utils.typing import *
|
12 |
+
from tgs.utils.base import BaseModule
|
13 |
+
from tgs.utils.ops import trunc_exp
|
14 |
+
from tgs.models.networks import MLP
|
15 |
+
from tgs.utils.ops import scale_tensor
|
16 |
+
from einops import rearrange, reduce
|
17 |
+
|
18 |
+
inverse_sigmoid = lambda x: np.log(x / (1 - x))
|
19 |
+
|
20 |
+
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
|
21 |
+
Rt = np.zeros((4, 4))
|
22 |
+
Rt[:3, :3] = R.transpose()
|
23 |
+
Rt[:3, 3] = t
|
24 |
+
Rt[3, 3] = 1.0
|
25 |
+
|
26 |
+
C2W = np.linalg.inv(Rt)
|
27 |
+
cam_center = C2W[:3, 3]
|
28 |
+
cam_center = (cam_center + translate) * scale
|
29 |
+
C2W[:3, 3] = cam_center
|
30 |
+
Rt = np.linalg.inv(C2W)
|
31 |
+
return np.float32(Rt)
|
32 |
+
|
33 |
+
def getProjectionMatrix(znear, zfar, fovX, fovY):
|
34 |
+
tanHalfFovY = math.tan((fovY / 2))
|
35 |
+
tanHalfFovX = math.tan((fovX / 2))
|
36 |
+
|
37 |
+
top = tanHalfFovY * znear
|
38 |
+
bottom = -top
|
39 |
+
right = tanHalfFovX * znear
|
40 |
+
left = -right
|
41 |
+
|
42 |
+
P = torch.zeros(4, 4)
|
43 |
+
|
44 |
+
z_sign = 1.0
|
45 |
+
|
46 |
+
P[0, 0] = 2.0 * znear / (right - left)
|
47 |
+
P[1, 1] = 2.0 * znear / (top - bottom)
|
48 |
+
P[0, 2] = (right + left) / (right - left)
|
49 |
+
P[1, 2] = (top + bottom) / (top - bottom)
|
50 |
+
P[3, 2] = z_sign
|
51 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
52 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
53 |
+
return P
|
54 |
+
|
55 |
+
def intrinsic_to_fov(intrinsic, w, h):
|
56 |
+
fx, fy = intrinsic[0, 0], intrinsic[1, 1]
|
57 |
+
fov_x = 2 * torch.arctan2(w, 2 * fx)
|
58 |
+
fov_y = 2 * torch.arctan2(h, 2 * fy)
|
59 |
+
return fov_x, fov_y
|
60 |
+
|
61 |
+
|
62 |
+
class Camera:
|
63 |
+
def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None:
|
64 |
+
self.FoVx = FoVx
|
65 |
+
self.FoVy = FoVy
|
66 |
+
self.height = height
|
67 |
+
self.width = width
|
68 |
+
self.world_view_transform = w2c.transpose(0, 1)
|
69 |
+
|
70 |
+
self.zfar = 100.0
|
71 |
+
self.znear = 0.01
|
72 |
+
|
73 |
+
self.trans = trans
|
74 |
+
self.scale = scale
|
75 |
+
|
76 |
+
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device)
|
77 |
+
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
78 |
+
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def from_c2w(c2w, intrinsic, height, width):
|
82 |
+
w2c = torch.inverse(c2w)
|
83 |
+
FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device))
|
84 |
+
return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width)
|
85 |
+
|
86 |
+
class GaussianModel(NamedTuple):
|
87 |
+
xyz: Tensor
|
88 |
+
opacity: Tensor
|
89 |
+
rotation: Tensor
|
90 |
+
scaling: Tensor
|
91 |
+
shs: Tensor
|
92 |
+
|
93 |
+
def construct_list_of_attributes(self):
|
94 |
+
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
95 |
+
features_dc = self.shs[:, :1]
|
96 |
+
features_rest = self.shs[:, 1:]
|
97 |
+
for i in range(features_dc.shape[1]*features_dc.shape[2]):
|
98 |
+
l.append('f_dc_{}'.format(i))
|
99 |
+
for i in range(features_rest.shape[1]*features_rest.shape[2]):
|
100 |
+
l.append('f_rest_{}'.format(i))
|
101 |
+
l.append('opacity')
|
102 |
+
for i in range(self.scaling.shape[1]):
|
103 |
+
l.append('scale_{}'.format(i))
|
104 |
+
for i in range(self.rotation.shape[1]):
|
105 |
+
l.append('rot_{}'.format(i))
|
106 |
+
return l
|
107 |
+
|
108 |
+
def save_ply(self, path):
|
109 |
+
|
110 |
+
xyz = self.xyz.detach().cpu().numpy()
|
111 |
+
normals = np.zeros_like(xyz)
|
112 |
+
features_dc = self.shs[:, :1]
|
113 |
+
features_rest = self.shs[:, 1:]
|
114 |
+
f_dc = features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
|
115 |
+
f_rest = features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy()
|
116 |
+
opacities = inverse_sigmoid(torch.clamp(self.opacity, 1e-3, 1 - 1e-3).detach().cpu().numpy())
|
117 |
+
scale = np.log(self.scaling.detach().cpu().numpy())
|
118 |
+
rotation = self.rotation.detach().cpu().numpy()
|
119 |
+
|
120 |
+
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
121 |
+
|
122 |
+
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
123 |
+
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
124 |
+
elements[:] = list(map(tuple, attributes))
|
125 |
+
el = PlyElement.describe(elements, 'vertex')
|
126 |
+
PlyData([el]).write(path)
|
127 |
+
|
128 |
+
class GSLayer(BaseModule):
|
129 |
+
@dataclass
|
130 |
+
class Config(BaseModule.Config):
|
131 |
+
in_channels: int = 128
|
132 |
+
feature_channels: dict = field(default_factory=dict)
|
133 |
+
xyz_offset: bool = True
|
134 |
+
restrict_offset: bool = False
|
135 |
+
use_rgb: bool = False
|
136 |
+
clip_scaling: Optional[float] = None
|
137 |
+
init_scaling: float = -5.0
|
138 |
+
init_density: float = 0.1
|
139 |
+
|
140 |
+
cfg: Config
|
141 |
+
|
142 |
+
def configure(self, *args, **kwargs) -> None:
|
143 |
+
self.out_layers = nn.ModuleList()
|
144 |
+
for key, out_ch in self.cfg.feature_channels.items():
|
145 |
+
if key == "shs" and self.cfg.use_rgb:
|
146 |
+
out_ch = 3
|
147 |
+
layer = nn.Linear(self.cfg.in_channels, out_ch)
|
148 |
+
|
149 |
+
# initialize
|
150 |
+
if not (key == "shs" and self.cfg.use_rgb):
|
151 |
+
nn.init.constant_(layer.weight, 0)
|
152 |
+
nn.init.constant_(layer.bias, 0)
|
153 |
+
if key == "scaling":
|
154 |
+
nn.init.constant_(layer.bias, self.cfg.init_scaling)
|
155 |
+
elif key == "rotation":
|
156 |
+
nn.init.constant_(layer.bias, 0)
|
157 |
+
nn.init.constant_(layer.bias[0], 1.0)
|
158 |
+
elif key == "opacity":
|
159 |
+
nn.init.constant_(layer.bias, inverse_sigmoid(self.cfg.init_density))
|
160 |
+
|
161 |
+
self.out_layers.append(layer)
|
162 |
+
|
163 |
+
def forward(self, x, pts):
|
164 |
+
ret = {}
|
165 |
+
for k, layer in zip(self.cfg.feature_channels.keys(), self.out_layers):
|
166 |
+
v = layer(x)
|
167 |
+
if k == "rotation":
|
168 |
+
v = torch.nn.functional.normalize(v)
|
169 |
+
elif k == "scaling":
|
170 |
+
v = trunc_exp(v)
|
171 |
+
if self.cfg.clip_scaling is not None:
|
172 |
+
v = torch.clamp(v, min=0, max=self.cfg.clip_scaling)
|
173 |
+
elif k == "opacity":
|
174 |
+
v = torch.sigmoid(v)
|
175 |
+
elif k == "shs":
|
176 |
+
if self.cfg.use_rgb:
|
177 |
+
v = torch.sigmoid(v)
|
178 |
+
v = torch.reshape(v, (v.shape[0], -1, 3))
|
179 |
+
elif k == "xyz":
|
180 |
+
if self.cfg.restrict_offset:
|
181 |
+
max_step = 1.2 / 32
|
182 |
+
v = (torch.sigmoid(v) - 0.5) * max_step
|
183 |
+
v = v + pts if self.cfg.xyz_offset else pts
|
184 |
+
ret[k] = v
|
185 |
+
|
186 |
+
return GaussianModel(**ret)
|
187 |
+
|
188 |
+
class GS3DRenderer(BaseModule):
|
189 |
+
@dataclass
|
190 |
+
class Config(BaseModule.Config):
|
191 |
+
mlp_network_config: Optional[dict] = None
|
192 |
+
gs_out: dict = field(default_factory=dict)
|
193 |
+
sh_degree: int = 3
|
194 |
+
scaling_modifier: float = 1.0
|
195 |
+
random_background: bool = False
|
196 |
+
radius: float = 1.0
|
197 |
+
feature_reduction: str = "concat"
|
198 |
+
projection_feature_dim: int = 773
|
199 |
+
background_color: Tuple[float, float, float] = field(
|
200 |
+
default_factory=lambda: (1.0, 1.0, 1.0)
|
201 |
+
)
|
202 |
+
|
203 |
+
cfg: Config
|
204 |
+
|
205 |
+
def configure(self, *args, **kwargs) -> None:
|
206 |
+
if self.cfg.feature_reduction == "mean":
|
207 |
+
mlp_in = 80
|
208 |
+
elif self.cfg.feature_reduction == "concat":
|
209 |
+
mlp_in = 80 * 3
|
210 |
+
else:
|
211 |
+
raise NotImplementedError
|
212 |
+
mlp_in = mlp_in + self.cfg.projection_feature_dim
|
213 |
+
if self.cfg.mlp_network_config is not None:
|
214 |
+
self.mlp_net = MLP(mlp_in, self.cfg.gs_out.in_channels, **self.cfg.mlp_network_config)
|
215 |
+
else:
|
216 |
+
self.cfg.gs_out.in_channels = mlp_in
|
217 |
+
self.gs_net = GSLayer(self.cfg.gs_out)
|
218 |
+
|
219 |
+
def forward_gs(self, x, p):
|
220 |
+
if self.cfg.mlp_network_config is not None:
|
221 |
+
x = self.mlp_net(x)
|
222 |
+
return self.gs_net(x, p)
|
223 |
+
|
224 |
+
def forward_single_view(self,
|
225 |
+
gs: GaussianModel,
|
226 |
+
viewpoint_camera: Camera,
|
227 |
+
background_color: Optional[Float[Tensor, "3"]],
|
228 |
+
ret_mask: bool = True,
|
229 |
+
):
|
230 |
+
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
231 |
+
screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0
|
232 |
+
try:
|
233 |
+
screenspace_points.retain_grad()
|
234 |
+
except:
|
235 |
+
pass
|
236 |
+
|
237 |
+
bg_color = background_color
|
238 |
+
# Set up rasterization configuration
|
239 |
+
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
240 |
+
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
241 |
+
|
242 |
+
raster_settings = GaussianRasterizationSettings(
|
243 |
+
image_height=int(viewpoint_camera.height),
|
244 |
+
image_width=int(viewpoint_camera.width),
|
245 |
+
tanfovx=tanfovx,
|
246 |
+
tanfovy=tanfovy,
|
247 |
+
bg=bg_color,
|
248 |
+
scale_modifier=self.cfg.scaling_modifier,
|
249 |
+
viewmatrix=viewpoint_camera.world_view_transform,
|
250 |
+
projmatrix=viewpoint_camera.full_proj_transform.float(),
|
251 |
+
sh_degree=self.cfg.sh_degree,
|
252 |
+
campos=viewpoint_camera.camera_center,
|
253 |
+
prefiltered=False,
|
254 |
+
debug=False
|
255 |
+
)
|
256 |
+
|
257 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
258 |
+
|
259 |
+
means3D = gs.xyz
|
260 |
+
means2D = screenspace_points
|
261 |
+
opacity = gs.opacity
|
262 |
+
|
263 |
+
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
264 |
+
# scaling / rotation by the rasterizer.
|
265 |
+
scales = None
|
266 |
+
rotations = None
|
267 |
+
cov3D_precomp = None
|
268 |
+
scales = gs.scaling
|
269 |
+
rotations = gs.rotation
|
270 |
+
|
271 |
+
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
272 |
+
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
273 |
+
shs = None
|
274 |
+
colors_precomp = None
|
275 |
+
if self.gs_net.cfg.use_rgb:
|
276 |
+
colors_precomp = gs.shs.squeeze(1)
|
277 |
+
else:
|
278 |
+
shs = gs.shs
|
279 |
+
|
280 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
281 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
|
282 |
+
rendered_image, radii = rasterizer(
|
283 |
+
means3D = means3D,
|
284 |
+
means2D = means2D,
|
285 |
+
shs = shs,
|
286 |
+
colors_precomp = colors_precomp,
|
287 |
+
opacities = opacity,
|
288 |
+
scales = scales,
|
289 |
+
rotations = rotations,
|
290 |
+
cov3D_precomp = cov3D_precomp)
|
291 |
+
|
292 |
+
ret = {
|
293 |
+
"comp_rgb": rendered_image.permute(1, 2, 0),
|
294 |
+
"comp_rgb_bg": bg_color
|
295 |
+
}
|
296 |
+
|
297 |
+
if ret_mask:
|
298 |
+
mask_bg_color = torch.zeros(3, dtype=torch.float32, device=self.device)
|
299 |
+
raster_settings = GaussianRasterizationSettings(
|
300 |
+
image_height=int(viewpoint_camera.height),
|
301 |
+
image_width=int(viewpoint_camera.width),
|
302 |
+
tanfovx=tanfovx,
|
303 |
+
tanfovy=tanfovy,
|
304 |
+
bg=mask_bg_color,
|
305 |
+
scale_modifier=self.cfg.scaling_modifier,
|
306 |
+
viewmatrix=viewpoint_camera.world_view_transform,
|
307 |
+
projmatrix=viewpoint_camera.full_proj_transform.float(),
|
308 |
+
sh_degree=0,
|
309 |
+
campos=viewpoint_camera.camera_center,
|
310 |
+
prefiltered=False,
|
311 |
+
debug=False
|
312 |
+
)
|
313 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
314 |
+
|
315 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
|
316 |
+
rendered_mask, radii = rasterizer(
|
317 |
+
means3D = means3D,
|
318 |
+
means2D = means2D,
|
319 |
+
# shs = ,
|
320 |
+
colors_precomp = torch.ones_like(means3D),
|
321 |
+
opacities = opacity,
|
322 |
+
scales = scales,
|
323 |
+
rotations = rotations,
|
324 |
+
cov3D_precomp = cov3D_precomp)
|
325 |
+
ret["comp_mask"] = rendered_mask.permute(1, 2, 0)
|
326 |
+
|
327 |
+
return ret
|
328 |
+
|
329 |
+
def query_triplane(
|
330 |
+
self,
|
331 |
+
positions: Float[Tensor, "*B N 3"],
|
332 |
+
triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
|
333 |
+
) -> Dict[str, Tensor]:
|
334 |
+
batched = positions.ndim == 3
|
335 |
+
if not batched:
|
336 |
+
# no batch dimension
|
337 |
+
triplanes = triplanes[None, ...]
|
338 |
+
positions = positions[None, ...]
|
339 |
+
|
340 |
+
positions = scale_tensor(positions, (-self.cfg.radius, self.cfg.radius), (-1, 1))
|
341 |
+
indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
|
342 |
+
(positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
|
343 |
+
dim=-3,
|
344 |
+
)
|
345 |
+
out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
|
346 |
+
rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3),
|
347 |
+
rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3),
|
348 |
+
align_corners=False,
|
349 |
+
mode="bilinear",
|
350 |
+
)
|
351 |
+
if self.cfg.feature_reduction == "concat":
|
352 |
+
out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
|
353 |
+
elif self.cfg.feature_reduction == "mean":
|
354 |
+
out = reduce(out, "(B Np) Cp () N -> B N Cp", Np=3, reduction="mean")
|
355 |
+
else:
|
356 |
+
raise NotImplementedError
|
357 |
+
|
358 |
+
if not batched:
|
359 |
+
out = out.squeeze(0)
|
360 |
+
|
361 |
+
return out
|
362 |
+
|
363 |
+
def forward_single_batch(
|
364 |
+
self,
|
365 |
+
gs_hidden_features: Float[Tensor, "Np Cp"],
|
366 |
+
query_points: Float[Tensor, "Np 3"],
|
367 |
+
c2ws: Float[Tensor, "Nv 4 4"],
|
368 |
+
intrinsics: Float[Tensor, "Nv 4 4"],
|
369 |
+
height: int,
|
370 |
+
width: int,
|
371 |
+
background_color: Optional[Float[Tensor, "3"]],
|
372 |
+
):
|
373 |
+
gs: GaussianModel = self.forward_gs(gs_hidden_features, query_points)
|
374 |
+
out_list = []
|
375 |
+
|
376 |
+
for c2w, intrinsic in zip(c2ws, intrinsics):
|
377 |
+
out_list.append(self.forward_single_view(
|
378 |
+
gs,
|
379 |
+
Camera.from_c2w(c2w, intrinsic, height, width),
|
380 |
+
background_color
|
381 |
+
))
|
382 |
+
|
383 |
+
out = defaultdict(list)
|
384 |
+
for out_ in out_list:
|
385 |
+
for k, v in out_.items():
|
386 |
+
out[k].append(v)
|
387 |
+
out = {k: torch.stack(v, dim=0) for k, v in out.items()}
|
388 |
+
out["3dgs"] = gs
|
389 |
+
|
390 |
+
return out
|
391 |
+
|
392 |
+
def forward(self,
|
393 |
+
gs_hidden_features: Float[Tensor, "B Np Cp"],
|
394 |
+
query_points: Float[Tensor, "B Np 3"],
|
395 |
+
c2w: Float[Tensor, "B Nv 4 4"],
|
396 |
+
intrinsic: Float[Tensor, "B Nv 4 4"],
|
397 |
+
height,
|
398 |
+
width,
|
399 |
+
additional_features: Optional[Float[Tensor, "B C H W"]] = None,
|
400 |
+
background_color: Optional[Float[Tensor, "B 3"]] = None,
|
401 |
+
**kwargs):
|
402 |
+
batch_size = gs_hidden_features.shape[0]
|
403 |
+
out_list = []
|
404 |
+
gs_hidden_features = self.query_triplane(query_points, gs_hidden_features)
|
405 |
+
if additional_features is not None:
|
406 |
+
gs_hidden_features = torch.cat([gs_hidden_features, additional_features], dim=-1)
|
407 |
+
|
408 |
+
for b in range(batch_size):
|
409 |
+
out_list.append(self.forward_single_batch(
|
410 |
+
gs_hidden_features[b],
|
411 |
+
query_points[b],
|
412 |
+
c2w[b],
|
413 |
+
intrinsic[b],
|
414 |
+
height, width,
|
415 |
+
background_color[b] if background_color is not None else None))
|
416 |
+
|
417 |
+
out = defaultdict(list)
|
418 |
+
for out_ in out_list:
|
419 |
+
for k, v in out_.items():
|
420 |
+
out[k].append(v)
|
421 |
+
for k, v in out.items():
|
422 |
+
if isinstance(v[0], torch.Tensor):
|
423 |
+
out[k] = torch.stack(v, dim=0)
|
424 |
+
else:
|
425 |
+
out[k] = v
|
426 |
+
return out
|
427 |
+
|
hort/models/tgs/models/snowflake/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 AllenXiang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
hort/models/tgs/models/snowflake/SPD.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author: Peng Xiang
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from .utils import MLP_Res, MLP_CONV
|
7 |
+
from .skip_transformer import SkipTransformer
|
8 |
+
|
9 |
+
|
10 |
+
class SPD(nn.Module):
|
11 |
+
def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
|
12 |
+
"""Snowflake Point Deconvolution"""
|
13 |
+
super(SPD, self).__init__()
|
14 |
+
self.i = i
|
15 |
+
self.up_factor = up_factor
|
16 |
+
|
17 |
+
self.bounding = bounding
|
18 |
+
self.radius = radius
|
19 |
+
|
20 |
+
self.global_feat = global_feat
|
21 |
+
self.ps_dim = 32 if global_feat else 64
|
22 |
+
|
23 |
+
self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
|
24 |
+
self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
|
25 |
+
|
26 |
+
self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
|
27 |
+
|
28 |
+
self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
|
29 |
+
self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
|
30 |
+
|
31 |
+
self.up_sampler = nn.Upsample(scale_factor=up_factor)
|
32 |
+
self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128)
|
33 |
+
|
34 |
+
self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
|
35 |
+
|
36 |
+
def forward(self, pcd_prev, feat_global=None, K_prev=None):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
pcd_prev: Tensor, (B, 3, N_prev)
|
40 |
+
feat_global: Tensor, (B, dim_feat, 1)
|
41 |
+
K_prev: Tensor, (B, 128, N_prev)
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
|
45 |
+
K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
|
46 |
+
"""
|
47 |
+
b, _, n_prev = pcd_prev.shape
|
48 |
+
feat_1 = self.mlp_1(pcd_prev)
|
49 |
+
feat_1 = torch.cat([feat_1,
|
50 |
+
torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))),
|
51 |
+
feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1
|
52 |
+
Q = self.mlp_2(feat_1)
|
53 |
+
|
54 |
+
H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q)
|
55 |
+
|
56 |
+
feat_child = self.mlp_ps(H)
|
57 |
+
feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
|
58 |
+
H_up = self.up_sampler(H)
|
59 |
+
K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
|
60 |
+
|
61 |
+
delta = self.mlp_delta(torch.relu(K_curr))
|
62 |
+
if self.bounding:
|
63 |
+
delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor)
|
64 |
+
|
65 |
+
pcd_child = self.up_sampler(pcd_prev)
|
66 |
+
pcd_child = pcd_child + delta
|
67 |
+
|
68 |
+
return pcd_child, K_curr
|
hort/models/tgs/models/snowflake/SPD_crossattn.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author: Peng Xiang
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from .utils import MLP_Res, MLP_CONV
|
7 |
+
from .skip_transformer import SkipTransformer
|
8 |
+
from .attention import ResidualTransformerBlock
|
9 |
+
|
10 |
+
class SPD_crossattn(nn.Module):
|
11 |
+
def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
|
12 |
+
"""Snowflake Point Deconvolution"""
|
13 |
+
super().__init__()
|
14 |
+
self.i = i
|
15 |
+
self.up_factor = up_factor
|
16 |
+
|
17 |
+
self.bounding = bounding
|
18 |
+
self.radius = radius
|
19 |
+
|
20 |
+
self.global_feat = global_feat
|
21 |
+
self.ps_dim = 32 if global_feat else 64
|
22 |
+
|
23 |
+
self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
|
24 |
+
self.pcd_image_attn = ResidualTransformerBlock(
|
25 |
+
device=torch.device('cuda'),
|
26 |
+
dtype=torch.float32,
|
27 |
+
n_data=128,
|
28 |
+
width=128,
|
29 |
+
heads=8,
|
30 |
+
init_scale=1.0,
|
31 |
+
)
|
32 |
+
|
33 |
+
self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
|
34 |
+
|
35 |
+
self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
|
36 |
+
|
37 |
+
self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
|
38 |
+
self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
|
39 |
+
|
40 |
+
self.up_sampler = nn.Upsample(scale_factor=up_factor)
|
41 |
+
self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128)
|
42 |
+
|
43 |
+
self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
|
44 |
+
|
45 |
+
def forward(self, pcd_prev, feat_global=None, K_prev=None):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
pcd_prev: Tensor, (B, 3, N_prev)
|
49 |
+
feat_global: Tensor, (B, dim_feat, 1)
|
50 |
+
K_prev: Tensor, (B, 128, N_prev)
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
|
54 |
+
K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
|
55 |
+
"""
|
56 |
+
b, _, n_prev = pcd_prev.shape
|
57 |
+
feat_1 = self.mlp_1(pcd_prev)
|
58 |
+
# feat_1 = torch.cat([feat_1,
|
59 |
+
# torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))),
|
60 |
+
# feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1
|
61 |
+
feat_1 = torch.permute(feat_1, (0, 2, 1))
|
62 |
+
feat_global = torch.permute(feat_global, (0, 2, 1))
|
63 |
+
feat_1 = self.pcd_image_attn(feat_1, feat_global)
|
64 |
+
Q = torch.permute(feat_1, (0, 2, 1))
|
65 |
+
# Q = self.mlp_2(feat_1)
|
66 |
+
|
67 |
+
H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q)
|
68 |
+
|
69 |
+
feat_child = self.mlp_ps(H)
|
70 |
+
feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
|
71 |
+
H_up = self.up_sampler(H)
|
72 |
+
K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
|
73 |
+
|
74 |
+
delta = self.mlp_delta(torch.relu(K_curr))
|
75 |
+
if self.bounding:
|
76 |
+
delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor)
|
77 |
+
|
78 |
+
pcd_child = self.up_sampler(pcd_prev)
|
79 |
+
pcd_child = pcd_child + delta
|
80 |
+
|
81 |
+
return pcd_child, K_curr
|
hort/models/tgs/models/snowflake/SPD_pp.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from .utils import MLP_Res, MLP_CONV
|
5 |
+
from .skip_transformer import SkipTransformer
|
6 |
+
|
7 |
+
class SPD_pp(nn.Module):
|
8 |
+
def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True):
|
9 |
+
"""Snowflake Point Deconvolution"""
|
10 |
+
super(SPD_pp, self).__init__()
|
11 |
+
self.i = i
|
12 |
+
self.up_factor = up_factor
|
13 |
+
|
14 |
+
self.bounding = bounding
|
15 |
+
self.radius = radius
|
16 |
+
|
17 |
+
self.global_feat = global_feat
|
18 |
+
self.ps_dim = 32 if global_feat else 64
|
19 |
+
|
20 |
+
self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128])
|
21 |
+
self.mlp_2 = MLP_CONV(
|
22 |
+
in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128])
|
23 |
+
|
24 |
+
self.skip_transformer = SkipTransformer(in_channel=128, dim=64)
|
25 |
+
|
26 |
+
self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim])
|
27 |
+
self.ps = nn.ConvTranspose1d(
|
28 |
+
self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting
|
29 |
+
|
30 |
+
self.up_sampler = nn.Upsample(scale_factor=up_factor)
|
31 |
+
self.mlp_delta_feature = MLP_Res(
|
32 |
+
in_dim=256, hidden_dim=128, out_dim=128)
|
33 |
+
|
34 |
+
self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3])
|
35 |
+
|
36 |
+
def forward(self, pcd_prev, feat_cond=None, K_prev=None):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
pcd_prev: Tensor, (B, 3, N_prev)
|
40 |
+
feat_cond: Tensor, (B, dim_feat, N_prev)
|
41 |
+
K_prev: Tensor, (B, 128, N_prev)
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
|
45 |
+
K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
|
46 |
+
"""
|
47 |
+
b, _, n_prev = pcd_prev.shape
|
48 |
+
feat_1 = self.mlp_1(pcd_prev)
|
49 |
+
feat_1 = torch.cat([feat_1,
|
50 |
+
torch.max(feat_1, 2, keepdim=True)[
|
51 |
+
0].repeat((1, 1, feat_1.size(2))),
|
52 |
+
feat_cond], 1) if self.global_feat else feat_1
|
53 |
+
Q = self.mlp_2(feat_1)
|
54 |
+
|
55 |
+
H = self.skip_transformer(
|
56 |
+
pcd_prev, K_prev if K_prev is not None else Q, Q)
|
57 |
+
|
58 |
+
feat_child = self.mlp_ps(H)
|
59 |
+
feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor)
|
60 |
+
H_up = self.up_sampler(H)
|
61 |
+
K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1))
|
62 |
+
|
63 |
+
delta = self.mlp_delta(torch.relu(K_curr))
|
64 |
+
if self.bounding:
|
65 |
+
# (B, 3, N_prev * up_factor)
|
66 |
+
delta = torch.tanh(delta) / self.radius**self.i
|
67 |
+
|
68 |
+
pcd_child = self.up_sampler(pcd_prev)
|
69 |
+
pcd_child = pcd_child + delta
|
70 |
+
|
71 |
+
return pcd_child, K_curr
|
hort/models/tgs/models/snowflake/attention.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import math
|
5 |
+
from typing import Optional
|
6 |
+
from typing import Callable, Iterable, Sequence, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
def checkpoint(
|
11 |
+
func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
|
12 |
+
inputs: Sequence[torch.Tensor],
|
13 |
+
params: Iterable[torch.Tensor],
|
14 |
+
flag: bool,
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
Evaluate a function without caching intermediate activations, allowing for
|
18 |
+
reduced memory at the expense of extra compute in the backward pass.
|
19 |
+
:param func: the function to evaluate.
|
20 |
+
:param inputs: the argument sequence to pass to `func`.
|
21 |
+
:param params: a sequence of parameters `func` depends on but does not
|
22 |
+
explicitly take as arguments.
|
23 |
+
:param flag: if False, disable gradient checkpointing.
|
24 |
+
"""
|
25 |
+
if flag:
|
26 |
+
args = tuple(inputs) + tuple(params)
|
27 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
28 |
+
else:
|
29 |
+
return func(*inputs)
|
30 |
+
|
31 |
+
|
32 |
+
class CheckpointFunction(torch.autograd.Function):
|
33 |
+
@staticmethod
|
34 |
+
def forward(ctx, run_function, length, *args):
|
35 |
+
ctx.run_function = run_function
|
36 |
+
ctx.input_tensors = list(args[:length])
|
37 |
+
ctx.input_params = list(args[length:])
|
38 |
+
with torch.no_grad():
|
39 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
40 |
+
return output_tensors
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def backward(ctx, *output_grads):
|
44 |
+
ctx.input_tensors = [x.detach().requires_grad_(True)
|
45 |
+
for x in ctx.input_tensors]
|
46 |
+
with torch.enable_grad():
|
47 |
+
# Fixes a bug where the first op in run_function modifies the
|
48 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
49 |
+
# Tensors.
|
50 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
51 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
52 |
+
input_grads = torch.autograd.grad(
|
53 |
+
output_tensors,
|
54 |
+
ctx.input_tensors + ctx.input_params,
|
55 |
+
output_grads,
|
56 |
+
allow_unused=True,
|
57 |
+
)
|
58 |
+
del ctx.input_tensors
|
59 |
+
del ctx.input_params
|
60 |
+
del output_tensors
|
61 |
+
return (None, None) + input_grads
|
62 |
+
|
63 |
+
|
64 |
+
def init_linear(l, stddev):
|
65 |
+
nn.init.normal_(l.weight, std=stddev)
|
66 |
+
if l.bias is not None:
|
67 |
+
nn.init.constant_(l.bias, 0.0)
|
68 |
+
|
69 |
+
class MLP(nn.Module):
|
70 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
|
71 |
+
super().__init__()
|
72 |
+
self.width = width
|
73 |
+
self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
|
74 |
+
self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
|
75 |
+
self.gelu = nn.GELU()
|
76 |
+
init_linear(self.c_fc, init_scale)
|
77 |
+
init_linear(self.c_proj, init_scale)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
return self.c_proj(self.gelu(self.c_fc(x)))
|
81 |
+
|
82 |
+
class QKVMultiheadCrossAttention(nn.Module):
|
83 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: int):
|
84 |
+
super().__init__()
|
85 |
+
self.device = device
|
86 |
+
self.dtype = dtype
|
87 |
+
self.heads = heads
|
88 |
+
self.n_data = n_data
|
89 |
+
|
90 |
+
def forward(self, q, kv):
|
91 |
+
_, n_ctx, _ = q.shape
|
92 |
+
bs, n_data, width = kv.shape
|
93 |
+
attn_ch = width // self.heads // 2
|
94 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
95 |
+
q = q.view(bs, n_ctx, self.heads, -1)
|
96 |
+
kv = kv.view(bs, n_data, self.heads, -1)
|
97 |
+
k, v = torch.split(kv, attn_ch, dim=-1)
|
98 |
+
weight = torch.einsum(
|
99 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
100 |
+
) # More stable with f16 than dividing afterwards
|
101 |
+
wdtype = weight.dtype
|
102 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
103 |
+
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
class QKVMultiheadAttention(nn.Module):
|
108 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
|
109 |
+
super().__init__()
|
110 |
+
self.device = device
|
111 |
+
self.dtype = dtype
|
112 |
+
self.heads = heads
|
113 |
+
self.n_ctx = n_ctx
|
114 |
+
|
115 |
+
def forward(self, qkv):
|
116 |
+
bs, n_ctx, width = qkv.shape
|
117 |
+
attn_ch = width // self.heads // 3
|
118 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
119 |
+
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
120 |
+
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
121 |
+
weight = torch.einsum(
|
122 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
123 |
+
) # More stable with f16 than dividing afterwards
|
124 |
+
wdtype = weight.dtype
|
125 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
126 |
+
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
class MultiheadCrossAttention(nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
*,
|
134 |
+
device: torch.device,
|
135 |
+
dtype: torch.dtype,
|
136 |
+
n_data: int,
|
137 |
+
width: int,
|
138 |
+
heads: int,
|
139 |
+
init_scale: float,
|
140 |
+
data_width: Optional[int] = None,
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
self.n_data = n_data
|
144 |
+
self.width = width
|
145 |
+
self.heads = heads
|
146 |
+
self.data_width = width if data_width is None else data_width
|
147 |
+
self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
|
148 |
+
self.c_kv = nn.Linear(self.data_width, width * 2,
|
149 |
+
device=device, dtype=dtype)
|
150 |
+
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
151 |
+
self.attention = QKVMultiheadCrossAttention(
|
152 |
+
device=device, dtype=dtype, heads=heads, n_data=n_data
|
153 |
+
)
|
154 |
+
init_linear(self.c_q, init_scale)
|
155 |
+
init_linear(self.c_kv, init_scale)
|
156 |
+
init_linear(self.c_proj, init_scale)
|
157 |
+
|
158 |
+
def forward(self, x, data):
|
159 |
+
x = self.c_q(x)
|
160 |
+
data = self.c_kv(data)
|
161 |
+
x = checkpoint(self.attention, (x, data), (), True)
|
162 |
+
x = self.c_proj(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class MultiheadAttention(nn.Module):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
*,
|
170 |
+
device: torch.device,
|
171 |
+
dtype: torch.dtype,
|
172 |
+
n_ctx: int,
|
173 |
+
width: int,
|
174 |
+
heads: int,
|
175 |
+
init_scale: float,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
self.n_ctx = n_ctx
|
179 |
+
self.width = width
|
180 |
+
self.heads = heads
|
181 |
+
self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
|
182 |
+
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
183 |
+
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
|
184 |
+
init_linear(self.c_qkv, init_scale)
|
185 |
+
init_linear(self.c_proj, init_scale)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
x = self.c_qkv(x)
|
189 |
+
x = checkpoint(self.attention, (x,), (), True)
|
190 |
+
x = self.c_proj(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class ResidualTransformerBlock(nn.Module):
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
*,
|
198 |
+
device: torch.device,
|
199 |
+
dtype: torch.dtype,
|
200 |
+
n_data: int,
|
201 |
+
width: int,
|
202 |
+
heads: int,
|
203 |
+
data_width: Optional[int] = None,
|
204 |
+
init_scale: float = 1.0,
|
205 |
+
):
|
206 |
+
super().__init__()
|
207 |
+
|
208 |
+
if data_width is None:
|
209 |
+
data_width = width
|
210 |
+
|
211 |
+
self.attn_cross = MultiheadCrossAttention(
|
212 |
+
device=device,
|
213 |
+
dtype=dtype,
|
214 |
+
n_data=n_data,
|
215 |
+
width=width,
|
216 |
+
heads=heads,
|
217 |
+
data_width=data_width,
|
218 |
+
init_scale=init_scale,
|
219 |
+
)
|
220 |
+
self.attn_self = MultiheadAttention(
|
221 |
+
device=device,
|
222 |
+
dtype=dtype,
|
223 |
+
n_ctx=n_data,
|
224 |
+
width=width,
|
225 |
+
heads=heads,
|
226 |
+
init_scale=init_scale,
|
227 |
+
)
|
228 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
229 |
+
self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
|
230 |
+
self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
|
231 |
+
self.mlp = MLP(device=device, dtype=dtype,
|
232 |
+
width=width, init_scale=init_scale)
|
233 |
+
self.ln_4 = nn.LayerNorm(width, device=device, dtype=dtype)
|
234 |
+
|
235 |
+
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
236 |
+
x = x + self.attn_cross(self.ln_1(x), self.ln_2(data))
|
237 |
+
x = x + self.attn_self(self.ln_3(x))
|
238 |
+
x = x + self.mlp(self.ln_4(x))
|
239 |
+
return x
|
hort/models/tgs/models/snowflake/model_spdpp.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from tgs.utils.base import BaseModule
|
6 |
+
from tgs.utils.typing import *
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
from pytorch3d.renderer import (
|
10 |
+
AlphaCompositor,
|
11 |
+
NormWeightedCompositor,
|
12 |
+
PointsRasterizationSettings,
|
13 |
+
PointsRasterizer,
|
14 |
+
PointsRenderer)
|
15 |
+
from pytorch3d.renderer.cameras import CamerasBase
|
16 |
+
from pytorch3d.structures import Pointclouds
|
17 |
+
from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection
|
18 |
+
|
19 |
+
from .utils import fps_subsample
|
20 |
+
from einops import rearrange
|
21 |
+
|
22 |
+
from .utils import MLP_CONV
|
23 |
+
from .SPD import SPD
|
24 |
+
from .SPD_crossattn import SPD_crossattn
|
25 |
+
from .SPD_pp import SPD_pp
|
26 |
+
|
27 |
+
SPD_BLOCK = {
|
28 |
+
'SPD': SPD,
|
29 |
+
'SPD_crossattn': SPD_crossattn,
|
30 |
+
'SPD_PP': SPD_pp,
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
def homoify(points):
|
35 |
+
"""
|
36 |
+
Convert a batch of points to homogeneous coordinates.
|
37 |
+
Args:
|
38 |
+
points: e.g. (B, N, 3) or (N, 3)
|
39 |
+
Returns:
|
40 |
+
homoified points: e.g., (B, N, 4)
|
41 |
+
"""
|
42 |
+
points_dim = points.shape[:-1] + (1,)
|
43 |
+
ones = points.new_ones(points_dim)
|
44 |
+
|
45 |
+
return torch.cat([points, ones], dim=-1)
|
46 |
+
|
47 |
+
|
48 |
+
def dehomoify(points):
|
49 |
+
"""
|
50 |
+
Convert a batch of homogeneous points to cartesian coordinates.
|
51 |
+
Args:
|
52 |
+
homogeneous points: (B, N, 4/3) or (N, 4/3)
|
53 |
+
Returns:
|
54 |
+
cartesian points: (B, N, 3/2)
|
55 |
+
"""
|
56 |
+
return points[..., :-1] / points[..., -1:]
|
57 |
+
|
58 |
+
|
59 |
+
def mask_generation(points: Float[Tensor, "B Np 3"],
|
60 |
+
intrinsics: Float[Tensor, "B 3 3"],
|
61 |
+
input_img: Float[Tensor, "B C H W"],
|
62 |
+
raster_point_radius: float = 0.01, # point size
|
63 |
+
raster_points_per_pixel: int = 1, # a single point per pixel, for now
|
64 |
+
bin_size: int = 0):
|
65 |
+
"""
|
66 |
+
points: (B, Np, 3)
|
67 |
+
"""
|
68 |
+
B, C, H, W = input_img.shape
|
69 |
+
device = intrinsics.device
|
70 |
+
|
71 |
+
cam_R = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1)
|
72 |
+
cam_t = torch.zeros(3).to(device).unsqueeze(0).repeat(B, 1)
|
73 |
+
|
74 |
+
raster_settings = PointsRasterizationSettings(image_size=(H, W), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size)
|
75 |
+
|
76 |
+
image_size = torch.as_tensor([H, W]).view(1, 2).expand(B, -1).to(device)
|
77 |
+
cameras = cameras_from_opencv_projection(cam_R, cam_t, intrinsics, image_size)
|
78 |
+
|
79 |
+
rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
80 |
+
fragments = rasterize(Pointclouds(points))
|
81 |
+
|
82 |
+
fragments_idx: Tensor = fragments.idx.long()
|
83 |
+
mask = (fragments_idx[..., 0] > -1)
|
84 |
+
|
85 |
+
return mask.float()
|
86 |
+
|
87 |
+
|
88 |
+
def points_projection(points: Float[Tensor, "B Np 3"],
|
89 |
+
intrinsics: Float[Tensor, "B 3 3"],
|
90 |
+
local_features: Float[Tensor, "B C H W"],
|
91 |
+
raster_point_radius: float = 0.0075, # point size
|
92 |
+
raster_points_per_pixel: int = 1, # a single point per pixel, for now
|
93 |
+
bin_size: int = 0):
|
94 |
+
"""
|
95 |
+
points: (B, Np, 3)
|
96 |
+
"""
|
97 |
+
B, C, H, W = local_features.shape
|
98 |
+
device = local_features.device
|
99 |
+
cam_R = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1)
|
100 |
+
cam_t = torch.zeros(3).to(device).unsqueeze(0).repeat(B, 1)
|
101 |
+
|
102 |
+
raster_settings = PointsRasterizationSettings(image_size=(H, W), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size)
|
103 |
+
Np = points.shape[1]
|
104 |
+
R = raster_settings.points_per_pixel
|
105 |
+
image_size = torch.as_tensor([H, W]).view(1, 2).expand(B, -1).to(device)
|
106 |
+
cameras = cameras_from_opencv_projection(cam_R, cam_t, intrinsics, image_size)
|
107 |
+
rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
108 |
+
fragments = rasterize(Pointclouds(points))
|
109 |
+
fragments_idx: Tensor = fragments.idx.long()
|
110 |
+
visible_pixels = (fragments_idx > -1) # (B, H, W, R)
|
111 |
+
points_to_visible_pixels = fragments_idx[visible_pixels]
|
112 |
+
# Reshape local features to (B, H, W, R, C)
|
113 |
+
local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C)
|
114 |
+
# Get local features corresponding to visible points
|
115 |
+
local_features_proj = torch.zeros(B * Np, C, device=device)
|
116 |
+
local_features_proj[points_to_visible_pixels] = local_features[visible_pixels]
|
117 |
+
local_features_proj = local_features_proj.reshape(B, Np, C)
|
118 |
+
return local_features_proj
|
119 |
+
|
120 |
+
|
121 |
+
def points_projection_v2(input_xyz_points, cam_intr, feature_maps):
|
122 |
+
input_points = input_xyz_points.clone()
|
123 |
+
batch_size = input_points.shape[0]
|
124 |
+
xyz = input_points[:, :, :3]
|
125 |
+
homo_xyz = homoify(xyz)
|
126 |
+
homo_xyz_2d = torch.matmul(cam_intr, homo_xyz.transpose(1, 2)).transpose(1, 2)
|
127 |
+
xyz_2d = (homo_xyz_2d[:, :, :2] / homo_xyz_2d[:, :, [2]]).unsqueeze(2)
|
128 |
+
uv_2d = xyz_2d / 224 * 2 - 1
|
129 |
+
sample_feat = torch.nn.functional.grid_sample(feature_maps, uv_2d, align_corners=True)[:, :, :, 0].transpose(1, 2)
|
130 |
+
uv_2d = uv_2d.squeeze(2).reshape((-1, 2))
|
131 |
+
validity = (uv_2d[:, 0] >= -1.0) & (uv_2d[:, 0] <= 1.0) & (uv_2d[:, 1] >= -1.0) & (uv_2d[:, 1] <= 1.0)
|
132 |
+
validity = validity.unsqueeze(1)
|
133 |
+
|
134 |
+
return sample_feat
|
135 |
+
|
136 |
+
|
137 |
+
class Decoder(nn.Module):
|
138 |
+
def __init__(self, input_channels=1152, dim_feat=512, num_p0=512,
|
139 |
+
radius=1, bounding=True, up_factors=None,
|
140 |
+
SPD_type='SPD',
|
141 |
+
token_type='image_token'
|
142 |
+
):
|
143 |
+
super(Decoder, self).__init__()
|
144 |
+
# self.decoder_coarse = SeedGenerator(dim_feat=dim_feat, num_pc=num_p0)
|
145 |
+
if up_factors is None:
|
146 |
+
up_factors = [1]
|
147 |
+
else:
|
148 |
+
up_factors = up_factors
|
149 |
+
uppers = []
|
150 |
+
self.num_p0 = num_p0
|
151 |
+
self.mlp_feat_cond = MLP_CONV(in_channel=input_channels,
|
152 |
+
layer_dims=[dim_feat*2, dim_feat])
|
153 |
+
|
154 |
+
for i, factor in enumerate(up_factors):
|
155 |
+
uppers.append(
|
156 |
+
SPD_BLOCK[SPD_type](dim_feat=dim_feat, up_factor=factor,
|
157 |
+
i=i, bounding=bounding, radius=radius))
|
158 |
+
self.uppers = nn.ModuleList(uppers)
|
159 |
+
self.token_type = token_type
|
160 |
+
|
161 |
+
def calculate_pcl_token(self, pcl_token, up_factor):
|
162 |
+
up_token = F.interpolate(pcl_token, scale_factor=up_factor, mode='nearest')
|
163 |
+
return up_token
|
164 |
+
|
165 |
+
def calculate_image_token(self, pcd, input_image_tokens, batch):
|
166 |
+
"""
|
167 |
+
Args:
|
168 |
+
"""
|
169 |
+
batch_size = input_image_tokens.shape[0]
|
170 |
+
h_cond, w_cond = 224, 224
|
171 |
+
input_image_tokens = input_image_tokens.permute(0, 2, 1)
|
172 |
+
local_features = input_image_tokens[:, 1:].reshape(batch_size, h_cond // 14, w_cond // 14, -1).permute(0, 3, 1, 2)
|
173 |
+
# local_features = F.interpolate(local_features, size=(h_cond, w_cond), mode='bilinear', align_corners=False)
|
174 |
+
local_features_proj = points_projection_v2(pcd * batch['scale'] + batch['trans'].unsqueeze(1), batch['intrinsic_cond'], local_features)
|
175 |
+
local_features_proj = local_features_proj.permute(0, 2, 1).contiguous()
|
176 |
+
|
177 |
+
return local_features_proj
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
"""
|
181 |
+
Args:
|
182 |
+
points: Tensor, (b, num_p0, 3)
|
183 |
+
feat_cond: Tensor, (b, dim_feat) dinov2: 325x768
|
184 |
+
# partial_coarse: Tensor, (b, n_coarse, 3)
|
185 |
+
"""
|
186 |
+
points = x['points']
|
187 |
+
if self.token_type == 'pcl_token':
|
188 |
+
feat_cond = x['pcl_token']
|
189 |
+
elif self.token_type == 'image_token':
|
190 |
+
feat_cond = x['input_image_tokens']
|
191 |
+
feat_cond = self.mlp_feat_cond(feat_cond)
|
192 |
+
arr_pcd = []
|
193 |
+
feat_prev = None
|
194 |
+
|
195 |
+
pcd = torch.permute(points, (0, 2, 1)).contiguous()
|
196 |
+
pcl_up_scale = 1
|
197 |
+
for upper in self.uppers:
|
198 |
+
if self.token_type == 'pcl_token':
|
199 |
+
up_cond = self.calculate_pcl_token(
|
200 |
+
feat_cond, pcl_up_scale)
|
201 |
+
pcl_up_scale *= upper.up_factor
|
202 |
+
elif self.token_type == 'image_token':
|
203 |
+
up_cond = self.calculate_image_token(points, feat_cond, x)
|
204 |
+
pcd, feat_prev = upper(pcd, up_cond, feat_prev)
|
205 |
+
points = torch.permute(pcd, (0, 2, 1)).contiguous()
|
206 |
+
arr_pcd.append(points)
|
207 |
+
return arr_pcd
|
208 |
+
|
209 |
+
|
210 |
+
class SnowflakeModelSPDPP(BaseModule):
|
211 |
+
"""
|
212 |
+
apply PC^2 / PCL token to decoder
|
213 |
+
"""
|
214 |
+
@dataclass
|
215 |
+
class Config(BaseModule.Config):
|
216 |
+
input_channels: int = 1152
|
217 |
+
dim_feat: int = 128
|
218 |
+
num_p0: int = 512
|
219 |
+
radius: float = 1
|
220 |
+
bounding: bool = True
|
221 |
+
use_fps: bool = True
|
222 |
+
up_factors: List[int] = field(default_factory=lambda: [2, 2])
|
223 |
+
image_full_token_cond: bool = False
|
224 |
+
SPD_type: str = 'SPD_PP'
|
225 |
+
token_type: str = 'pcl_token'
|
226 |
+
cfg: Config
|
227 |
+
|
228 |
+
def configure(self) -> None:
|
229 |
+
super().configure()
|
230 |
+
self.decoder = Decoder(input_channels=self.cfg.input_channels,
|
231 |
+
dim_feat=self.cfg.dim_feat, num_p0=self.cfg.num_p0,
|
232 |
+
radius=self.cfg.radius, up_factors=self.cfg.up_factors, bounding=self.cfg.bounding,
|
233 |
+
SPD_type=self.cfg.SPD_type,
|
234 |
+
token_type=self.cfg.token_type
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
results = self.decoder(x)
|
239 |
+
return results
|
hort/models/tgs/models/snowflake/pointnet2.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule
|
5 |
+
|
6 |
+
|
7 |
+
class PointNet2ClassificationSSG(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
self._build_model()
|
11 |
+
|
12 |
+
def _build_model(self):
|
13 |
+
self.SA_modules = nn.ModuleList()
|
14 |
+
self.SA_modules.append(
|
15 |
+
PointnetSAModule(
|
16 |
+
npoint=512,
|
17 |
+
radius=0.2,
|
18 |
+
nsample=64,
|
19 |
+
mlp=[3, 64, 64, 128],
|
20 |
+
use_xyz=True,
|
21 |
+
)
|
22 |
+
)
|
23 |
+
self.SA_modules.append(
|
24 |
+
PointnetSAModule(
|
25 |
+
npoint=128,
|
26 |
+
radius=0.4,
|
27 |
+
nsample=64,
|
28 |
+
mlp=[128, 128, 128, 256],
|
29 |
+
use_xyz=True,
|
30 |
+
)
|
31 |
+
)
|
32 |
+
self.SA_modules.append(
|
33 |
+
PointnetSAModule(
|
34 |
+
mlp=[256, 256, 512, 1024], use_xyz=True,
|
35 |
+
)
|
36 |
+
)
|
37 |
+
|
38 |
+
self.fc_layer = nn.Sequential(
|
39 |
+
nn.Linear(1024, 512, bias=False),
|
40 |
+
nn.BatchNorm1d(512),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Linear(512, 256, bias=False),
|
43 |
+
nn.BatchNorm1d(256),
|
44 |
+
nn.ReLU(True),
|
45 |
+
nn.Dropout(0.5),
|
46 |
+
nn.Linear(256, 40),
|
47 |
+
)
|
48 |
+
|
49 |
+
def _break_up_pc(self, pc):
|
50 |
+
xyz = pc[..., 0:3].contiguous()
|
51 |
+
features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None
|
52 |
+
|
53 |
+
return xyz, features
|
54 |
+
|
55 |
+
def forward(self, pointcloud):
|
56 |
+
r"""
|
57 |
+
Forward pass of the network
|
58 |
+
|
59 |
+
Parameters
|
60 |
+
----------
|
61 |
+
pointcloud: Variable(torch.cuda.FloatTensor)
|
62 |
+
(B, N, 3 + input_channels) tensor
|
63 |
+
Point cloud to run predicts on
|
64 |
+
Each point in the point-cloud MUST
|
65 |
+
be formated as (x, y, z, features...)
|
66 |
+
"""
|
67 |
+
xyz, features = self._break_up_pc(pointcloud)
|
68 |
+
|
69 |
+
for module in self.SA_modules:
|
70 |
+
xyz, features = module(xyz, features)
|
71 |
+
|
72 |
+
return self.fc_layer(features.squeeze(-1))
|
73 |
+
|
74 |
+
|
75 |
+
class PointNet2SemSegSSG(PointNet2ClassificationSSG):
|
76 |
+
def _build_model(self):
|
77 |
+
self.SA_modules = nn.ModuleList()
|
78 |
+
self.SA_modules.append(
|
79 |
+
PointnetSAModule(
|
80 |
+
npoint=256,
|
81 |
+
radius=0.05,
|
82 |
+
nsample=32,
|
83 |
+
mlp=[1, 32, 64],
|
84 |
+
use_xyz=True,
|
85 |
+
)
|
86 |
+
)
|
87 |
+
self.SA_modules.append(
|
88 |
+
PointnetSAModule(
|
89 |
+
npoint=64,
|
90 |
+
radius=0.10,
|
91 |
+
nsample=32,
|
92 |
+
mlp=[64, 128, 256],
|
93 |
+
use_xyz=True,
|
94 |
+
)
|
95 |
+
)
|
96 |
+
self.SA_modules.append(
|
97 |
+
PointnetSAModule(
|
98 |
+
npoint=16,
|
99 |
+
radius=0.20,
|
100 |
+
nsample=32,
|
101 |
+
mlp=[256, 512, 768],
|
102 |
+
use_xyz=True,
|
103 |
+
)
|
104 |
+
)
|
105 |
+
|
106 |
+
def forward(self, pointcloud):
|
107 |
+
r"""
|
108 |
+
Forward pass of the network
|
109 |
+
|
110 |
+
Parameters
|
111 |
+
----------
|
112 |
+
pointcloud: Variable(torch.cuda.FloatTensor)
|
113 |
+
(B, N, 3 + input_channels) tensor
|
114 |
+
Point cloud to run predicts on
|
115 |
+
Each point in the point-cloud MUST
|
116 |
+
be formated as (x, y, z, features...)
|
117 |
+
"""
|
118 |
+
xyz, features = self._break_up_pc(pointcloud)
|
119 |
+
|
120 |
+
l_xyz, l_features = [xyz], [features]
|
121 |
+
for i in range(len(self.SA_modules)):
|
122 |
+
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
|
123 |
+
l_xyz.append(li_xyz)
|
124 |
+
l_features.append(li_features)
|
125 |
+
|
126 |
+
return l_features[-1].transpose(2, 1)
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import pointnet2_ops.pointnet2_modules
|
2 |
+
import pointnet2_ops.pointnet2_utils
|
3 |
+
from pointnet2_ops._version import __version__
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
|
5 |
+
const int nsample);
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _CUDA_UTILS_H
|
2 |
+
#define _CUDA_UTILS_H
|
3 |
+
|
4 |
+
#include <ATen/ATen.h>
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <cmath>
|
7 |
+
|
8 |
+
#include <cuda.h>
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
#include <vector>
|
12 |
+
|
13 |
+
#define TOTAL_THREADS 512
|
14 |
+
|
15 |
+
inline int opt_n_threads(int work_size) {
|
16 |
+
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
|
17 |
+
|
18 |
+
return max(min(1 << pow_2, TOTAL_THREADS), 1);
|
19 |
+
}
|
20 |
+
|
21 |
+
inline dim3 opt_block_config(int x, int y) {
|
22 |
+
const int x_threads = opt_n_threads(x);
|
23 |
+
const int y_threads =
|
24 |
+
max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
|
25 |
+
dim3 block_config(x_threads, y_threads, 1);
|
26 |
+
|
27 |
+
return block_config;
|
28 |
+
}
|
29 |
+
|
30 |
+
#define CUDA_CHECK_ERRORS() \
|
31 |
+
do { \
|
32 |
+
cudaError_t err = cudaGetLastError(); \
|
33 |
+
if (cudaSuccess != err) { \
|
34 |
+
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
|
35 |
+
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
|
36 |
+
__FILE__); \
|
37 |
+
exit(-1); \
|
38 |
+
} \
|
39 |
+
} while (0)
|
40 |
+
|
41 |
+
#endif
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
at::Tensor group_points(at::Tensor points, at::Tensor idx);
|
5 |
+
at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
#include <vector>
|
5 |
+
|
6 |
+
std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows);
|
7 |
+
at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
|
8 |
+
at::Tensor weight);
|
9 |
+
at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
|
10 |
+
at::Tensor weight, const int m);
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
at::Tensor gather_points(at::Tensor points, at::Tensor idx);
|
5 |
+
at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
|
6 |
+
at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
#define CHECK_CUDA(x) \
|
6 |
+
do { \
|
7 |
+
AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \
|
8 |
+
} while (0)
|
9 |
+
|
10 |
+
#define CHECK_CONTIGUOUS(x) \
|
11 |
+
do { \
|
12 |
+
AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
|
13 |
+
} while (0)
|
14 |
+
|
15 |
+
#define CHECK_IS_INT(x) \
|
16 |
+
do { \
|
17 |
+
AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \
|
18 |
+
#x " must be an int tensor"); \
|
19 |
+
} while (0)
|
20 |
+
|
21 |
+
#define CHECK_IS_FLOAT(x) \
|
22 |
+
do { \
|
23 |
+
AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \
|
24 |
+
#x " must be a float tensor"); \
|
25 |
+
} while (0)
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "ball_query.h"
|
2 |
+
#include "utils.h"
|
3 |
+
|
4 |
+
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
|
5 |
+
int nsample, const float *new_xyz,
|
6 |
+
const float *xyz, int *idx);
|
7 |
+
|
8 |
+
at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
|
9 |
+
const int nsample) {
|
10 |
+
CHECK_CONTIGUOUS(new_xyz);
|
11 |
+
CHECK_CONTIGUOUS(xyz);
|
12 |
+
CHECK_IS_FLOAT(new_xyz);
|
13 |
+
CHECK_IS_FLOAT(xyz);
|
14 |
+
|
15 |
+
if (new_xyz.is_cuda()) {
|
16 |
+
CHECK_CUDA(xyz);
|
17 |
+
}
|
18 |
+
|
19 |
+
at::Tensor idx =
|
20 |
+
torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
|
21 |
+
at::device(new_xyz.device()).dtype(at::ScalarType::Int));
|
22 |
+
|
23 |
+
if (new_xyz.is_cuda()) {
|
24 |
+
query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
|
25 |
+
radius, nsample, new_xyz.data_ptr<float>(),
|
26 |
+
xyz.data_ptr<float>(), idx.data_ptr<int>());
|
27 |
+
} else {
|
28 |
+
AT_ASSERT(false, "CPU not supported");
|
29 |
+
}
|
30 |
+
|
31 |
+
return idx;
|
32 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
#include "cuda_utils.h"
|
6 |
+
|
7 |
+
// input: new_xyz(b, m, 3) xyz(b, n, 3)
|
8 |
+
// output: idx(b, m, nsample)
|
9 |
+
__global__ void query_ball_point_kernel(int b, int n, int m, float radius,
|
10 |
+
int nsample,
|
11 |
+
const float *__restrict__ new_xyz,
|
12 |
+
const float *__restrict__ xyz,
|
13 |
+
int *__restrict__ idx) {
|
14 |
+
int batch_index = blockIdx.x;
|
15 |
+
xyz += batch_index * n * 3;
|
16 |
+
new_xyz += batch_index * m * 3;
|
17 |
+
idx += m * nsample * batch_index;
|
18 |
+
|
19 |
+
int index = threadIdx.x;
|
20 |
+
int stride = blockDim.x;
|
21 |
+
|
22 |
+
float radius2 = radius * radius;
|
23 |
+
for (int j = index; j < m; j += stride) {
|
24 |
+
float new_x = new_xyz[j * 3 + 0];
|
25 |
+
float new_y = new_xyz[j * 3 + 1];
|
26 |
+
float new_z = new_xyz[j * 3 + 2];
|
27 |
+
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
|
28 |
+
float x = xyz[k * 3 + 0];
|
29 |
+
float y = xyz[k * 3 + 1];
|
30 |
+
float z = xyz[k * 3 + 2];
|
31 |
+
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
|
32 |
+
(new_z - z) * (new_z - z);
|
33 |
+
if (d2 < radius2) {
|
34 |
+
if (cnt == 0) {
|
35 |
+
for (int l = 0; l < nsample; ++l) {
|
36 |
+
idx[j * nsample + l] = k;
|
37 |
+
}
|
38 |
+
}
|
39 |
+
idx[j * nsample + cnt] = k;
|
40 |
+
++cnt;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
|
47 |
+
int nsample, const float *new_xyz,
|
48 |
+
const float *xyz, int *idx) {
|
49 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
50 |
+
query_ball_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
|
51 |
+
b, n, m, radius, nsample, new_xyz, xyz, idx);
|
52 |
+
|
53 |
+
//CUDA_CHECK_ERRORS();
|
54 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "ball_query.h"
|
2 |
+
#include "group_points.h"
|
3 |
+
#include "interpolate.h"
|
4 |
+
#include "sampling.h"
|
5 |
+
|
6 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
7 |
+
m.def("gather_points", &gather_points);
|
8 |
+
m.def("gather_points_grad", &gather_points_grad);
|
9 |
+
m.def("furthest_point_sampling", &furthest_point_sampling);
|
10 |
+
|
11 |
+
m.def("three_nn", &three_nn);
|
12 |
+
m.def("three_interpolate", &three_interpolate);
|
13 |
+
m.def("three_interpolate_grad", &three_interpolate_grad);
|
14 |
+
|
15 |
+
m.def("ball_query", &ball_query);
|
16 |
+
|
17 |
+
m.def("group_points", &group_points);
|
18 |
+
m.def("group_points_grad", &group_points_grad);
|
19 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "group_points.h"
|
2 |
+
#include "utils.h"
|
3 |
+
|
4 |
+
void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
|
5 |
+
const float *points, const int *idx,
|
6 |
+
float *out);
|
7 |
+
|
8 |
+
void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
|
9 |
+
int nsample, const float *grad_out,
|
10 |
+
const int *idx, float *grad_points);
|
11 |
+
|
12 |
+
at::Tensor group_points(at::Tensor points, at::Tensor idx) {
|
13 |
+
CHECK_CONTIGUOUS(points);
|
14 |
+
CHECK_CONTIGUOUS(idx);
|
15 |
+
CHECK_IS_FLOAT(points);
|
16 |
+
CHECK_IS_INT(idx);
|
17 |
+
|
18 |
+
if (points.is_cuda()) {
|
19 |
+
CHECK_CUDA(idx);
|
20 |
+
}
|
21 |
+
|
22 |
+
at::Tensor output =
|
23 |
+
torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
|
24 |
+
at::device(points.device()).dtype(at::ScalarType::Float));
|
25 |
+
|
26 |
+
if (points.is_cuda()) {
|
27 |
+
group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
|
28 |
+
idx.size(1), idx.size(2),
|
29 |
+
points.data_ptr<float>(), idx.data_ptr<int>(),
|
30 |
+
output.data_ptr<float>());
|
31 |
+
} else {
|
32 |
+
AT_ASSERT(false, "CPU not supported");
|
33 |
+
}
|
34 |
+
|
35 |
+
return output;
|
36 |
+
}
|
37 |
+
|
38 |
+
at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
|
39 |
+
CHECK_CONTIGUOUS(grad_out);
|
40 |
+
CHECK_CONTIGUOUS(idx);
|
41 |
+
CHECK_IS_FLOAT(grad_out);
|
42 |
+
CHECK_IS_INT(idx);
|
43 |
+
|
44 |
+
if (grad_out.is_cuda()) {
|
45 |
+
CHECK_CUDA(idx);
|
46 |
+
}
|
47 |
+
|
48 |
+
at::Tensor output =
|
49 |
+
torch::zeros({grad_out.size(0), grad_out.size(1), n},
|
50 |
+
at::device(grad_out.device()).dtype(at::ScalarType::Float));
|
51 |
+
|
52 |
+
if (grad_out.is_cuda()) {
|
53 |
+
group_points_grad_kernel_wrapper(
|
54 |
+
grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
|
55 |
+
grad_out.data_ptr<float>(), idx.data_ptr<int>(),
|
56 |
+
output.data_ptr<float>());
|
57 |
+
} else {
|
58 |
+
AT_ASSERT(false, "CPU not supported");
|
59 |
+
}
|
60 |
+
|
61 |
+
return output;
|
62 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <stdlib.h>
|
3 |
+
|
4 |
+
#include "cuda_utils.h"
|
5 |
+
|
6 |
+
// input: points(b, c, n) idx(b, npoints, nsample)
|
7 |
+
// output: out(b, c, npoints, nsample)
|
8 |
+
__global__ void group_points_kernel(int b, int c, int n, int npoints,
|
9 |
+
int nsample,
|
10 |
+
const float *__restrict__ points,
|
11 |
+
const int *__restrict__ idx,
|
12 |
+
float *__restrict__ out) {
|
13 |
+
int batch_index = blockIdx.x;
|
14 |
+
points += batch_index * n * c;
|
15 |
+
idx += batch_index * npoints * nsample;
|
16 |
+
out += batch_index * npoints * nsample * c;
|
17 |
+
|
18 |
+
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
19 |
+
const int stride = blockDim.y * blockDim.x;
|
20 |
+
for (int i = index; i < c * npoints; i += stride) {
|
21 |
+
const int l = i / npoints;
|
22 |
+
const int j = i % npoints;
|
23 |
+
for (int k = 0; k < nsample; ++k) {
|
24 |
+
int ii = idx[j * nsample + k];
|
25 |
+
out[(l * npoints + j) * nsample + k] = points[l * n + ii];
|
26 |
+
}
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
|
31 |
+
const float *points, const int *idx,
|
32 |
+
float *out) {
|
33 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
34 |
+
|
35 |
+
group_points_kernel<<<b, opt_block_config(npoints, c), 0, stream>>>(
|
36 |
+
b, c, n, npoints, nsample, points, idx, out);
|
37 |
+
|
38 |
+
//CUDA_CHECK_ERRORS();
|
39 |
+
}
|
40 |
+
|
41 |
+
// input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
|
42 |
+
// output: grad_points(b, c, n)
|
43 |
+
__global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
|
44 |
+
int nsample,
|
45 |
+
const float *__restrict__ grad_out,
|
46 |
+
const int *__restrict__ idx,
|
47 |
+
float *__restrict__ grad_points) {
|
48 |
+
int batch_index = blockIdx.x;
|
49 |
+
grad_out += batch_index * npoints * nsample * c;
|
50 |
+
idx += batch_index * npoints * nsample;
|
51 |
+
grad_points += batch_index * n * c;
|
52 |
+
|
53 |
+
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
54 |
+
const int stride = blockDim.y * blockDim.x;
|
55 |
+
for (int i = index; i < c * npoints; i += stride) {
|
56 |
+
const int l = i / npoints;
|
57 |
+
const int j = i % npoints;
|
58 |
+
for (int k = 0; k < nsample; ++k) {
|
59 |
+
int ii = idx[j * nsample + k];
|
60 |
+
atomicAdd(grad_points + l * n + ii,
|
61 |
+
grad_out[(l * npoints + j) * nsample + k]);
|
62 |
+
}
|
63 |
+
}
|
64 |
+
}
|
65 |
+
|
66 |
+
void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
|
67 |
+
int nsample, const float *grad_out,
|
68 |
+
const int *idx, float *grad_points) {
|
69 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
70 |
+
|
71 |
+
group_points_grad_kernel<<<b, opt_block_config(npoints, c), 0, stream>>>(
|
72 |
+
b, c, n, npoints, nsample, grad_out, idx, grad_points);
|
73 |
+
|
74 |
+
//CUDA_CHECK_ERRORS();
|
75 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "interpolate.h"
|
2 |
+
#include "utils.h"
|
3 |
+
|
4 |
+
void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
|
5 |
+
const float *known, float *dist2, int *idx);
|
6 |
+
void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
|
7 |
+
const float *points, const int *idx,
|
8 |
+
const float *weight, float *out);
|
9 |
+
void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
|
10 |
+
const float *grad_out,
|
11 |
+
const int *idx, const float *weight,
|
12 |
+
float *grad_points);
|
13 |
+
|
14 |
+
std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows) {
|
15 |
+
CHECK_CONTIGUOUS(unknowns);
|
16 |
+
CHECK_CONTIGUOUS(knows);
|
17 |
+
CHECK_IS_FLOAT(unknowns);
|
18 |
+
CHECK_IS_FLOAT(knows);
|
19 |
+
|
20 |
+
if (unknowns.is_cuda()) {
|
21 |
+
CHECK_CUDA(knows);
|
22 |
+
}
|
23 |
+
|
24 |
+
at::Tensor idx =
|
25 |
+
torch::zeros({unknowns.size(0), unknowns.size(1), 3},
|
26 |
+
at::device(unknowns.device()).dtype(at::ScalarType::Int));
|
27 |
+
at::Tensor dist2 =
|
28 |
+
torch::zeros({unknowns.size(0), unknowns.size(1), 3},
|
29 |
+
at::device(unknowns.device()).dtype(at::ScalarType::Float));
|
30 |
+
|
31 |
+
if (unknowns.is_cuda()) {
|
32 |
+
three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
|
33 |
+
unknowns.data_ptr<float>(), knows.data_ptr<float>(),
|
34 |
+
dist2.data_ptr<float>(), idx.data_ptr<int>());
|
35 |
+
} else {
|
36 |
+
AT_ASSERT(false, "CPU not supported");
|
37 |
+
}
|
38 |
+
|
39 |
+
return {dist2, idx};
|
40 |
+
}
|
41 |
+
|
42 |
+
at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
|
43 |
+
at::Tensor weight) {
|
44 |
+
CHECK_CONTIGUOUS(points);
|
45 |
+
CHECK_CONTIGUOUS(idx);
|
46 |
+
CHECK_CONTIGUOUS(weight);
|
47 |
+
CHECK_IS_FLOAT(points);
|
48 |
+
CHECK_IS_INT(idx);
|
49 |
+
CHECK_IS_FLOAT(weight);
|
50 |
+
|
51 |
+
if (points.is_cuda()) {
|
52 |
+
CHECK_CUDA(idx);
|
53 |
+
CHECK_CUDA(weight);
|
54 |
+
}
|
55 |
+
|
56 |
+
at::Tensor output =
|
57 |
+
torch::zeros({points.size(0), points.size(1), idx.size(1)},
|
58 |
+
at::device(points.device()).dtype(at::ScalarType::Float));
|
59 |
+
|
60 |
+
if (points.is_cuda()) {
|
61 |
+
three_interpolate_kernel_wrapper(
|
62 |
+
points.size(0), points.size(1), points.size(2), idx.size(1),
|
63 |
+
points.data_ptr<float>(), idx.data_ptr<int>(), weight.data_ptr<float>(),
|
64 |
+
output.data_ptr<float>());
|
65 |
+
} else {
|
66 |
+
AT_ASSERT(false, "CPU not supported");
|
67 |
+
}
|
68 |
+
|
69 |
+
return output;
|
70 |
+
}
|
71 |
+
at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
|
72 |
+
at::Tensor weight, const int m) {
|
73 |
+
CHECK_CONTIGUOUS(grad_out);
|
74 |
+
CHECK_CONTIGUOUS(idx);
|
75 |
+
CHECK_CONTIGUOUS(weight);
|
76 |
+
CHECK_IS_FLOAT(grad_out);
|
77 |
+
CHECK_IS_INT(idx);
|
78 |
+
CHECK_IS_FLOAT(weight);
|
79 |
+
|
80 |
+
if (grad_out.is_cuda()) {
|
81 |
+
CHECK_CUDA(idx);
|
82 |
+
CHECK_CUDA(weight);
|
83 |
+
}
|
84 |
+
|
85 |
+
at::Tensor output =
|
86 |
+
torch::zeros({grad_out.size(0), grad_out.size(1), m},
|
87 |
+
at::device(grad_out.device()).dtype(at::ScalarType::Float));
|
88 |
+
|
89 |
+
if (grad_out.is_cuda()) {
|
90 |
+
three_interpolate_grad_kernel_wrapper(
|
91 |
+
grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
|
92 |
+
grad_out.data_ptr<float>(), idx.data_ptr<int>(),
|
93 |
+
weight.data_ptr<float>(), output.data_ptr<float>());
|
94 |
+
} else {
|
95 |
+
AT_ASSERT(false, "CPU not supported");
|
96 |
+
}
|
97 |
+
|
98 |
+
return output;
|
99 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <math.h>
|
2 |
+
#include <stdio.h>
|
3 |
+
#include <stdlib.h>
|
4 |
+
|
5 |
+
#include "cuda_utils.h"
|
6 |
+
|
7 |
+
// input: unknown(b, n, 3) known(b, m, 3)
|
8 |
+
// output: dist2(b, n, 3), idx(b, n, 3)
|
9 |
+
__global__ void three_nn_kernel(int b, int n, int m,
|
10 |
+
const float *__restrict__ unknown,
|
11 |
+
const float *__restrict__ known,
|
12 |
+
float *__restrict__ dist2,
|
13 |
+
int *__restrict__ idx) {
|
14 |
+
int batch_index = blockIdx.x;
|
15 |
+
unknown += batch_index * n * 3;
|
16 |
+
known += batch_index * m * 3;
|
17 |
+
dist2 += batch_index * n * 3;
|
18 |
+
idx += batch_index * n * 3;
|
19 |
+
|
20 |
+
int index = threadIdx.x;
|
21 |
+
int stride = blockDim.x;
|
22 |
+
for (int j = index; j < n; j += stride) {
|
23 |
+
float ux = unknown[j * 3 + 0];
|
24 |
+
float uy = unknown[j * 3 + 1];
|
25 |
+
float uz = unknown[j * 3 + 2];
|
26 |
+
|
27 |
+
double best1 = 1e40, best2 = 1e40, best3 = 1e40;
|
28 |
+
int besti1 = 0, besti2 = 0, besti3 = 0;
|
29 |
+
for (int k = 0; k < m; ++k) {
|
30 |
+
float x = known[k * 3 + 0];
|
31 |
+
float y = known[k * 3 + 1];
|
32 |
+
float z = known[k * 3 + 2];
|
33 |
+
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
34 |
+
if (d < best1) {
|
35 |
+
best3 = best2;
|
36 |
+
besti3 = besti2;
|
37 |
+
best2 = best1;
|
38 |
+
besti2 = besti1;
|
39 |
+
best1 = d;
|
40 |
+
besti1 = k;
|
41 |
+
} else if (d < best2) {
|
42 |
+
best3 = best2;
|
43 |
+
besti3 = besti2;
|
44 |
+
best2 = d;
|
45 |
+
besti2 = k;
|
46 |
+
} else if (d < best3) {
|
47 |
+
best3 = d;
|
48 |
+
besti3 = k;
|
49 |
+
}
|
50 |
+
}
|
51 |
+
dist2[j * 3 + 0] = best1;
|
52 |
+
dist2[j * 3 + 1] = best2;
|
53 |
+
dist2[j * 3 + 2] = best3;
|
54 |
+
|
55 |
+
idx[j * 3 + 0] = besti1;
|
56 |
+
idx[j * 3 + 1] = besti2;
|
57 |
+
idx[j * 3 + 2] = besti3;
|
58 |
+
}
|
59 |
+
}
|
60 |
+
|
61 |
+
void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
|
62 |
+
const float *known, float *dist2, int *idx) {
|
63 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
64 |
+
three_nn_kernel<<<b, opt_n_threads(n), 0, stream>>>(b, n, m, unknown, known,
|
65 |
+
dist2, idx);
|
66 |
+
|
67 |
+
//CUDA_CHECK_ERRORS();
|
68 |
+
}
|
69 |
+
|
70 |
+
// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
|
71 |
+
// output: out(b, c, n)
|
72 |
+
__global__ void three_interpolate_kernel(int b, int c, int m, int n,
|
73 |
+
const float *__restrict__ points,
|
74 |
+
const int *__restrict__ idx,
|
75 |
+
const float *__restrict__ weight,
|
76 |
+
float *__restrict__ out) {
|
77 |
+
int batch_index = blockIdx.x;
|
78 |
+
points += batch_index * m * c;
|
79 |
+
|
80 |
+
idx += batch_index * n * 3;
|
81 |
+
weight += batch_index * n * 3;
|
82 |
+
|
83 |
+
out += batch_index * n * c;
|
84 |
+
|
85 |
+
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
86 |
+
const int stride = blockDim.y * blockDim.x;
|
87 |
+
for (int i = index; i < c * n; i += stride) {
|
88 |
+
const int l = i / n;
|
89 |
+
const int j = i % n;
|
90 |
+
float w1 = weight[j * 3 + 0];
|
91 |
+
float w2 = weight[j * 3 + 1];
|
92 |
+
float w3 = weight[j * 3 + 2];
|
93 |
+
|
94 |
+
int i1 = idx[j * 3 + 0];
|
95 |
+
int i2 = idx[j * 3 + 1];
|
96 |
+
int i3 = idx[j * 3 + 2];
|
97 |
+
|
98 |
+
out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
|
99 |
+
points[l * m + i3] * w3;
|
100 |
+
}
|
101 |
+
}
|
102 |
+
|
103 |
+
void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
|
104 |
+
const float *points, const int *idx,
|
105 |
+
const float *weight, float *out) {
|
106 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
107 |
+
three_interpolate_kernel<<<b, opt_block_config(n, c), 0, stream>>>(
|
108 |
+
b, c, m, n, points, idx, weight, out);
|
109 |
+
|
110 |
+
//CUDA_CHECK_ERRORS();
|
111 |
+
}
|
112 |
+
|
113 |
+
// input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
|
114 |
+
// output: grad_points(b, c, m)
|
115 |
+
|
116 |
+
__global__ void three_interpolate_grad_kernel(
|
117 |
+
int b, int c, int n, int m, const float *__restrict__ grad_out,
|
118 |
+
const int *__restrict__ idx, const float *__restrict__ weight,
|
119 |
+
float *__restrict__ grad_points) {
|
120 |
+
int batch_index = blockIdx.x;
|
121 |
+
grad_out += batch_index * n * c;
|
122 |
+
idx += batch_index * n * 3;
|
123 |
+
weight += batch_index * n * 3;
|
124 |
+
grad_points += batch_index * m * c;
|
125 |
+
|
126 |
+
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
127 |
+
const int stride = blockDim.y * blockDim.x;
|
128 |
+
for (int i = index; i < c * n; i += stride) {
|
129 |
+
const int l = i / n;
|
130 |
+
const int j = i % n;
|
131 |
+
float w1 = weight[j * 3 + 0];
|
132 |
+
float w2 = weight[j * 3 + 1];
|
133 |
+
float w3 = weight[j * 3 + 2];
|
134 |
+
|
135 |
+
int i1 = idx[j * 3 + 0];
|
136 |
+
int i2 = idx[j * 3 + 1];
|
137 |
+
int i3 = idx[j * 3 + 2];
|
138 |
+
|
139 |
+
atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
|
140 |
+
atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
|
141 |
+
atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
|
146 |
+
const float *grad_out,
|
147 |
+
const int *idx, const float *weight,
|
148 |
+
float *grad_points) {
|
149 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
150 |
+
three_interpolate_grad_kernel<<<b, opt_block_config(n, c), 0, stream>>>(
|
151 |
+
b, c, n, m, grad_out, idx, weight, grad_points);
|
152 |
+
|
153 |
+
CUDA_CHECK_ERRORS();
|
154 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "sampling.h"
|
2 |
+
#include "utils.h"
|
3 |
+
|
4 |
+
void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
|
5 |
+
const float *points, const int *idx,
|
6 |
+
float *out);
|
7 |
+
void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
|
8 |
+
const float *grad_out, const int *idx,
|
9 |
+
float *grad_points);
|
10 |
+
|
11 |
+
void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
|
12 |
+
const float *dataset, float *temp,
|
13 |
+
int *idxs);
|
14 |
+
|
15 |
+
at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
|
16 |
+
CHECK_CONTIGUOUS(points);
|
17 |
+
CHECK_CONTIGUOUS(idx);
|
18 |
+
CHECK_IS_FLOAT(points);
|
19 |
+
CHECK_IS_INT(idx);
|
20 |
+
|
21 |
+
if (points.is_cuda()) {
|
22 |
+
CHECK_CUDA(idx);
|
23 |
+
}
|
24 |
+
|
25 |
+
at::Tensor output =
|
26 |
+
torch::zeros({points.size(0), points.size(1), idx.size(1)},
|
27 |
+
at::device(points.device()).dtype(at::ScalarType::Float));
|
28 |
+
|
29 |
+
if (points.is_cuda()) {
|
30 |
+
gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
|
31 |
+
idx.size(1), points.data_ptr<float>(),
|
32 |
+
idx.data_ptr<int>(), output.data_ptr<float>());
|
33 |
+
} else {
|
34 |
+
AT_ASSERT(false, "CPU not supported");
|
35 |
+
}
|
36 |
+
|
37 |
+
return output;
|
38 |
+
}
|
39 |
+
|
40 |
+
at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
|
41 |
+
const int n) {
|
42 |
+
CHECK_CONTIGUOUS(grad_out);
|
43 |
+
CHECK_CONTIGUOUS(idx);
|
44 |
+
CHECK_IS_FLOAT(grad_out);
|
45 |
+
CHECK_IS_INT(idx);
|
46 |
+
|
47 |
+
if (grad_out.is_cuda()) {
|
48 |
+
CHECK_CUDA(idx);
|
49 |
+
}
|
50 |
+
|
51 |
+
at::Tensor output =
|
52 |
+
torch::zeros({grad_out.size(0), grad_out.size(1), n},
|
53 |
+
at::device(grad_out.device()).dtype(at::ScalarType::Float));
|
54 |
+
|
55 |
+
if (grad_out.is_cuda()) {
|
56 |
+
gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
|
57 |
+
idx.size(1), grad_out.data_ptr<float>(),
|
58 |
+
idx.data_ptr<int>(),
|
59 |
+
output.data_ptr<float>());
|
60 |
+
} else {
|
61 |
+
AT_ASSERT(false, "CPU not supported");
|
62 |
+
}
|
63 |
+
|
64 |
+
return output;
|
65 |
+
}
|
66 |
+
at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
|
67 |
+
CHECK_CONTIGUOUS(points);
|
68 |
+
CHECK_IS_FLOAT(points);
|
69 |
+
|
70 |
+
at::Tensor output =
|
71 |
+
torch::zeros({points.size(0), nsamples},
|
72 |
+
at::device(points.device()).dtype(at::ScalarType::Int));
|
73 |
+
|
74 |
+
at::Tensor tmp =
|
75 |
+
torch::full({points.size(0), points.size(1)}, 1e10,
|
76 |
+
at::device(points.device()).dtype(at::ScalarType::Float));
|
77 |
+
|
78 |
+
if (points.is_cuda()) {
|
79 |
+
furthest_point_sampling_kernel_wrapper(
|
80 |
+
points.size(0), points.size(1), nsamples, points.data_ptr<float>(),
|
81 |
+
tmp.data_ptr<float>(), output.data_ptr<int>());
|
82 |
+
} else {
|
83 |
+
AT_ASSERT(false, "CPU not supported");
|
84 |
+
}
|
85 |
+
|
86 |
+
return output;
|
87 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <stdlib.h>
|
3 |
+
|
4 |
+
#include "cuda_utils.h"
|
5 |
+
|
6 |
+
// input: points(b, c, n) idx(b, m)
|
7 |
+
// output: out(b, c, m)
|
8 |
+
__global__ void gather_points_kernel(int b, int c, int n, int m,
|
9 |
+
const float *__restrict__ points,
|
10 |
+
const int *__restrict__ idx,
|
11 |
+
float *__restrict__ out) {
|
12 |
+
for (int i = blockIdx.x; i < b; i += gridDim.x) {
|
13 |
+
for (int l = blockIdx.y; l < c; l += gridDim.y) {
|
14 |
+
for (int j = threadIdx.x; j < m; j += blockDim.x) {
|
15 |
+
int a = idx[i * m + j];
|
16 |
+
out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
|
17 |
+
}
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|
21 |
+
|
22 |
+
void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
|
23 |
+
const float *points, const int *idx,
|
24 |
+
float *out) {
|
25 |
+
gather_points_kernel<<<dim3(b, c, 1), opt_n_threads(npoints), 0,
|
26 |
+
at::cuda::getCurrentCUDAStream()>>>(b, c, n, npoints,
|
27 |
+
points, idx, out);
|
28 |
+
|
29 |
+
//CUDA_CHECK_ERRORS();
|
30 |
+
}
|
31 |
+
|
32 |
+
// input: grad_out(b, c, m) idx(b, m)
|
33 |
+
// output: grad_points(b, c, n)
|
34 |
+
__global__ void gather_points_grad_kernel(int b, int c, int n, int m,
|
35 |
+
const float *__restrict__ grad_out,
|
36 |
+
const int *__restrict__ idx,
|
37 |
+
float *__restrict__ grad_points) {
|
38 |
+
for (int i = blockIdx.x; i < b; i += gridDim.x) {
|
39 |
+
for (int l = blockIdx.y; l < c; l += gridDim.y) {
|
40 |
+
for (int j = threadIdx.x; j < m; j += blockDim.x) {
|
41 |
+
int a = idx[i * m + j];
|
42 |
+
atomicAdd(grad_points + (i * c + l) * n + a,
|
43 |
+
grad_out[(i * c + l) * m + j]);
|
44 |
+
}
|
45 |
+
}
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
|
50 |
+
const float *grad_out, const int *idx,
|
51 |
+
float *grad_points) {
|
52 |
+
gather_points_grad_kernel<<<dim3(b, c, 1), opt_n_threads(npoints), 0,
|
53 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
54 |
+
b, c, n, npoints, grad_out, idx, grad_points);
|
55 |
+
|
56 |
+
//CUDA_CHECK_ERRORS();
|
57 |
+
}
|
58 |
+
|
59 |
+
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
|
60 |
+
int idx1, int idx2) {
|
61 |
+
const float v1 = dists[idx1], v2 = dists[idx2];
|
62 |
+
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
|
63 |
+
dists[idx1] = max(v1, v2);
|
64 |
+
dists_i[idx1] = v2 > v1 ? i2 : i1;
|
65 |
+
}
|
66 |
+
|
67 |
+
// Input dataset: (b, n, 3), tmp: (b, n)
|
68 |
+
// Ouput idxs (b, m)
|
69 |
+
template <unsigned int block_size>
|
70 |
+
__global__ void furthest_point_sampling_kernel(
|
71 |
+
int b, int n, int m, const float *__restrict__ dataset,
|
72 |
+
float *__restrict__ temp, int *__restrict__ idxs) {
|
73 |
+
if (m <= 0) return;
|
74 |
+
__shared__ float dists[block_size];
|
75 |
+
__shared__ int dists_i[block_size];
|
76 |
+
|
77 |
+
int batch_index = blockIdx.x;
|
78 |
+
dataset += batch_index * n * 3;
|
79 |
+
temp += batch_index * n;
|
80 |
+
idxs += batch_index * m;
|
81 |
+
|
82 |
+
int tid = threadIdx.x;
|
83 |
+
const int stride = block_size;
|
84 |
+
|
85 |
+
int old = 0;
|
86 |
+
if (threadIdx.x == 0) idxs[0] = old;
|
87 |
+
|
88 |
+
__syncthreads();
|
89 |
+
for (int j = 1; j < m; j++) {
|
90 |
+
int besti = 0;
|
91 |
+
float best = -1;
|
92 |
+
float x1 = dataset[old * 3 + 0];
|
93 |
+
float y1 = dataset[old * 3 + 1];
|
94 |
+
float z1 = dataset[old * 3 + 2];
|
95 |
+
for (int k = tid; k < n; k += stride) {
|
96 |
+
float x2, y2, z2;
|
97 |
+
x2 = dataset[k * 3 + 0];
|
98 |
+
y2 = dataset[k * 3 + 1];
|
99 |
+
z2 = dataset[k * 3 + 2];
|
100 |
+
float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
|
101 |
+
if (mag <= 1e-3) continue;
|
102 |
+
|
103 |
+
float d =
|
104 |
+
(x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
|
105 |
+
|
106 |
+
float d2 = min(d, temp[k]);
|
107 |
+
temp[k] = d2;
|
108 |
+
besti = d2 > best ? k : besti;
|
109 |
+
best = d2 > best ? d2 : best;
|
110 |
+
}
|
111 |
+
dists[tid] = best;
|
112 |
+
dists_i[tid] = besti;
|
113 |
+
__syncthreads();
|
114 |
+
|
115 |
+
if (block_size >= 512) {
|
116 |
+
if (tid < 256) {
|
117 |
+
__update(dists, dists_i, tid, tid + 256);
|
118 |
+
}
|
119 |
+
__syncthreads();
|
120 |
+
}
|
121 |
+
if (block_size >= 256) {
|
122 |
+
if (tid < 128) {
|
123 |
+
__update(dists, dists_i, tid, tid + 128);
|
124 |
+
}
|
125 |
+
__syncthreads();
|
126 |
+
}
|
127 |
+
if (block_size >= 128) {
|
128 |
+
if (tid < 64) {
|
129 |
+
__update(dists, dists_i, tid, tid + 64);
|
130 |
+
}
|
131 |
+
__syncthreads();
|
132 |
+
}
|
133 |
+
if (block_size >= 64) {
|
134 |
+
if (tid < 32) {
|
135 |
+
__update(dists, dists_i, tid, tid + 32);
|
136 |
+
}
|
137 |
+
__syncthreads();
|
138 |
+
}
|
139 |
+
if (block_size >= 32) {
|
140 |
+
if (tid < 16) {
|
141 |
+
__update(dists, dists_i, tid, tid + 16);
|
142 |
+
}
|
143 |
+
__syncthreads();
|
144 |
+
}
|
145 |
+
if (block_size >= 16) {
|
146 |
+
if (tid < 8) {
|
147 |
+
__update(dists, dists_i, tid, tid + 8);
|
148 |
+
}
|
149 |
+
__syncthreads();
|
150 |
+
}
|
151 |
+
if (block_size >= 8) {
|
152 |
+
if (tid < 4) {
|
153 |
+
__update(dists, dists_i, tid, tid + 4);
|
154 |
+
}
|
155 |
+
__syncthreads();
|
156 |
+
}
|
157 |
+
if (block_size >= 4) {
|
158 |
+
if (tid < 2) {
|
159 |
+
__update(dists, dists_i, tid, tid + 2);
|
160 |
+
}
|
161 |
+
__syncthreads();
|
162 |
+
}
|
163 |
+
if (block_size >= 2) {
|
164 |
+
if (tid < 1) {
|
165 |
+
__update(dists, dists_i, tid, tid + 1);
|
166 |
+
}
|
167 |
+
__syncthreads();
|
168 |
+
}
|
169 |
+
|
170 |
+
old = dists_i[0];
|
171 |
+
if (tid == 0) idxs[j] = old;
|
172 |
+
}
|
173 |
+
}
|
174 |
+
|
175 |
+
void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
|
176 |
+
const float *dataset, float *temp,
|
177 |
+
int *idxs) {
|
178 |
+
unsigned int n_threads = opt_n_threads(n);
|
179 |
+
|
180 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
181 |
+
|
182 |
+
switch (n_threads) {
|
183 |
+
case 512:
|
184 |
+
furthest_point_sampling_kernel<512>
|
185 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
186 |
+
break;
|
187 |
+
case 256:
|
188 |
+
furthest_point_sampling_kernel<256>
|
189 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
190 |
+
break;
|
191 |
+
case 128:
|
192 |
+
furthest_point_sampling_kernel<128>
|
193 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
194 |
+
break;
|
195 |
+
case 64:
|
196 |
+
furthest_point_sampling_kernel<64>
|
197 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
198 |
+
break;
|
199 |
+
case 32:
|
200 |
+
furthest_point_sampling_kernel<32>
|
201 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
202 |
+
break;
|
203 |
+
case 16:
|
204 |
+
furthest_point_sampling_kernel<16>
|
205 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
206 |
+
break;
|
207 |
+
case 8:
|
208 |
+
furthest_point_sampling_kernel<8>
|
209 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
210 |
+
break;
|
211 |
+
case 4:
|
212 |
+
furthest_point_sampling_kernel<4>
|
213 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
214 |
+
break;
|
215 |
+
case 2:
|
216 |
+
furthest_point_sampling_kernel<2>
|
217 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
218 |
+
break;
|
219 |
+
case 1:
|
220 |
+
furthest_point_sampling_kernel<1>
|
221 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
222 |
+
break;
|
223 |
+
default:
|
224 |
+
furthest_point_sampling_kernel<512>
|
225 |
+
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
|
226 |
+
}
|
227 |
+
|
228 |
+
//CUDA_CHECK_ERRORS();
|
229 |
+
}
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "3.0.0"
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from pointnet2_ops import pointnet2_utils
|
7 |
+
|
8 |
+
|
9 |
+
def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
|
10 |
+
layers = []
|
11 |
+
for i in range(1, len(mlp_spec)):
|
12 |
+
layers.append(
|
13 |
+
nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
|
14 |
+
)
|
15 |
+
if bn:
|
16 |
+
layers.append(nn.BatchNorm2d(mlp_spec[i]))
|
17 |
+
layers.append(nn.ReLU(True))
|
18 |
+
|
19 |
+
return nn.Sequential(*layers)
|
20 |
+
|
21 |
+
|
22 |
+
class _PointnetSAModuleBase(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super(_PointnetSAModuleBase, self).__init__()
|
25 |
+
self.npoint = None
|
26 |
+
self.groupers = None
|
27 |
+
self.mlps = None
|
28 |
+
|
29 |
+
def forward(
|
30 |
+
self, xyz: torch.Tensor, features: Optional[torch.Tensor]
|
31 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
32 |
+
r"""
|
33 |
+
Parameters
|
34 |
+
----------
|
35 |
+
xyz : torch.Tensor
|
36 |
+
(B, N, 3) tensor of the xyz coordinates of the features
|
37 |
+
features : torch.Tensor
|
38 |
+
(B, C, N) tensor of the descriptors of the the features
|
39 |
+
|
40 |
+
Returns
|
41 |
+
-------
|
42 |
+
new_xyz : torch.Tensor
|
43 |
+
(B, npoint, 3) tensor of the new features' xyz
|
44 |
+
new_features : torch.Tensor
|
45 |
+
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
|
46 |
+
"""
|
47 |
+
|
48 |
+
new_features_list = []
|
49 |
+
|
50 |
+
xyz_flipped = xyz.transpose(1, 2).contiguous()
|
51 |
+
new_xyz = (
|
52 |
+
pointnet2_utils.gather_operation(
|
53 |
+
xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
|
54 |
+
)
|
55 |
+
.transpose(1, 2)
|
56 |
+
.contiguous()
|
57 |
+
if self.npoint is not None
|
58 |
+
else None
|
59 |
+
)
|
60 |
+
|
61 |
+
for i in range(len(self.groupers)):
|
62 |
+
new_features = self.groupers[i](
|
63 |
+
xyz, new_xyz, features
|
64 |
+
) # (B, C, npoint, nsample)
|
65 |
+
|
66 |
+
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
|
67 |
+
new_features = F.max_pool2d(
|
68 |
+
new_features, kernel_size=[1, new_features.size(3)]
|
69 |
+
) # (B, mlp[-1], npoint, 1)
|
70 |
+
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
|
71 |
+
|
72 |
+
new_features_list.append(new_features)
|
73 |
+
|
74 |
+
return new_xyz, torch.cat(new_features_list, dim=1)
|
75 |
+
|
76 |
+
|
77 |
+
class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
78 |
+
r"""Pointnet set abstrction layer with multiscale grouping
|
79 |
+
|
80 |
+
Parameters
|
81 |
+
----------
|
82 |
+
npoint : int
|
83 |
+
Number of features
|
84 |
+
radii : list of float32
|
85 |
+
list of radii to group with
|
86 |
+
nsamples : list of int32
|
87 |
+
Number of samples in each ball query
|
88 |
+
mlps : list of list of int32
|
89 |
+
Spec of the pointnet before the global max_pool for each scale
|
90 |
+
bn : bool
|
91 |
+
Use batchnorm
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
|
95 |
+
# type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
|
96 |
+
super(PointnetSAModuleMSG, self).__init__()
|
97 |
+
|
98 |
+
assert len(radii) == len(nsamples) == len(mlps)
|
99 |
+
|
100 |
+
self.npoint = npoint
|
101 |
+
self.groupers = nn.ModuleList()
|
102 |
+
self.mlps = nn.ModuleList()
|
103 |
+
for i in range(len(radii)):
|
104 |
+
radius = radii[i]
|
105 |
+
nsample = nsamples[i]
|
106 |
+
self.groupers.append(
|
107 |
+
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
|
108 |
+
if npoint is not None
|
109 |
+
else pointnet2_utils.GroupAll(use_xyz)
|
110 |
+
)
|
111 |
+
mlp_spec = mlps[i]
|
112 |
+
if use_xyz:
|
113 |
+
mlp_spec[0] += 3
|
114 |
+
|
115 |
+
self.mlps.append(build_shared_mlp(mlp_spec, bn))
|
116 |
+
|
117 |
+
|
118 |
+
class PointnetSAModule(PointnetSAModuleMSG):
|
119 |
+
r"""Pointnet set abstrction layer
|
120 |
+
|
121 |
+
Parameters
|
122 |
+
----------
|
123 |
+
npoint : int
|
124 |
+
Number of features
|
125 |
+
radius : float
|
126 |
+
Radius of ball
|
127 |
+
nsample : int
|
128 |
+
Number of samples in the ball query
|
129 |
+
mlp : list
|
130 |
+
Spec of the pointnet before the global max_pool
|
131 |
+
bn : bool
|
132 |
+
Use batchnorm
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(
|
136 |
+
self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
|
137 |
+
):
|
138 |
+
# type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
|
139 |
+
super(PointnetSAModule, self).__init__(
|
140 |
+
mlps=[mlp],
|
141 |
+
npoint=npoint,
|
142 |
+
radii=[radius],
|
143 |
+
nsamples=[nsample],
|
144 |
+
bn=bn,
|
145 |
+
use_xyz=use_xyz,
|
146 |
+
)
|
147 |
+
|
148 |
+
|
149 |
+
class PointnetFPModule(nn.Module):
|
150 |
+
r"""Propigates the features of one set to another
|
151 |
+
|
152 |
+
Parameters
|
153 |
+
----------
|
154 |
+
mlp : list
|
155 |
+
Pointnet module parameters
|
156 |
+
bn : bool
|
157 |
+
Use batchnorm
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(self, mlp, bn=True):
|
161 |
+
# type: (PointnetFPModule, List[int], bool) -> None
|
162 |
+
super(PointnetFPModule, self).__init__()
|
163 |
+
self.mlp = build_shared_mlp(mlp, bn=bn)
|
164 |
+
|
165 |
+
def forward(self, unknown, known, unknow_feats, known_feats):
|
166 |
+
# type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
|
167 |
+
r"""
|
168 |
+
Parameters
|
169 |
+
----------
|
170 |
+
unknown : torch.Tensor
|
171 |
+
(B, n, 3) tensor of the xyz positions of the unknown features
|
172 |
+
known : torch.Tensor
|
173 |
+
(B, m, 3) tensor of the xyz positions of the known features
|
174 |
+
unknow_feats : torch.Tensor
|
175 |
+
(B, C1, n) tensor of the features to be propigated to
|
176 |
+
known_feats : torch.Tensor
|
177 |
+
(B, C2, m) tensor of features to be propigated
|
178 |
+
|
179 |
+
Returns
|
180 |
+
-------
|
181 |
+
new_features : torch.Tensor
|
182 |
+
(B, mlp[-1], n) tensor of the features of the unknown features
|
183 |
+
"""
|
184 |
+
|
185 |
+
if known is not None:
|
186 |
+
dist, idx = pointnet2_utils.three_nn(unknown, known)
|
187 |
+
dist_recip = 1.0 / (dist + 1e-8)
|
188 |
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
189 |
+
weight = dist_recip / norm
|
190 |
+
|
191 |
+
interpolated_feats = pointnet2_utils.three_interpolate(
|
192 |
+
known_feats, idx, weight
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
interpolated_feats = known_feats.expand(
|
196 |
+
*(known_feats.size()[0:2] + [unknown.size(1)])
|
197 |
+
)
|
198 |
+
|
199 |
+
if unknow_feats is not None:
|
200 |
+
new_features = torch.cat(
|
201 |
+
[interpolated_feats, unknow_feats], dim=1
|
202 |
+
) # (B, C2 + C1, n)
|
203 |
+
else:
|
204 |
+
new_features = interpolated_feats
|
205 |
+
|
206 |
+
new_features = new_features.unsqueeze(-1)
|
207 |
+
new_features = self.mlp(new_features)
|
208 |
+
|
209 |
+
return new_features.squeeze(-1)
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import warnings
|
4 |
+
from torch.autograd import Function
|
5 |
+
from typing import *
|
6 |
+
|
7 |
+
try:
|
8 |
+
import pointnet2_ops._ext as _ext
|
9 |
+
except ImportError:
|
10 |
+
from torch.utils.cpp_extension import load
|
11 |
+
import glob
|
12 |
+
import os.path as osp
|
13 |
+
import os
|
14 |
+
|
15 |
+
warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
|
16 |
+
|
17 |
+
_ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
|
18 |
+
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
19 |
+
osp.join(_ext_src_root, "src", "*.cu")
|
20 |
+
)
|
21 |
+
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
22 |
+
|
23 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
|
24 |
+
_ext = load(
|
25 |
+
"_ext",
|
26 |
+
sources=_ext_sources,
|
27 |
+
extra_include_paths=[osp.join(_ext_src_root, "include")],
|
28 |
+
extra_cflags=["-O3"],
|
29 |
+
extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
|
30 |
+
with_cuda=True,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class FurthestPointSampling(Function):
|
35 |
+
@staticmethod
|
36 |
+
@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
|
37 |
+
def forward(ctx, xyz, npoint):
|
38 |
+
# type: (Any, torch.Tensor, int) -> torch.Tensor
|
39 |
+
r"""
|
40 |
+
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
41 |
+
minimum distance
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
xyz : torch.Tensor
|
46 |
+
(B, N, 3) tensor where N > npoint
|
47 |
+
npoint : int32
|
48 |
+
number of features in the sampled set
|
49 |
+
|
50 |
+
Returns
|
51 |
+
-------
|
52 |
+
torch.Tensor
|
53 |
+
(B, npoint) tensor containing the set
|
54 |
+
"""
|
55 |
+
out = _ext.furthest_point_sampling(xyz, npoint)
|
56 |
+
|
57 |
+
ctx.mark_non_differentiable(out)
|
58 |
+
|
59 |
+
return out
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
63 |
+
def backward(ctx, grad_out):
|
64 |
+
return ()
|
65 |
+
|
66 |
+
|
67 |
+
furthest_point_sample = FurthestPointSampling.apply
|
68 |
+
|
69 |
+
|
70 |
+
class GatherOperation(Function):
|
71 |
+
@staticmethod
|
72 |
+
@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
|
73 |
+
def forward(ctx, features, idx):
|
74 |
+
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
75 |
+
r"""
|
76 |
+
|
77 |
+
Parameters
|
78 |
+
----------
|
79 |
+
features : torch.Tensor
|
80 |
+
(B, C, N) tensor
|
81 |
+
|
82 |
+
idx : torch.Tensor
|
83 |
+
(B, npoint) tensor of the features to gather
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
torch.Tensor
|
88 |
+
(B, C, npoint) tensor
|
89 |
+
"""
|
90 |
+
|
91 |
+
ctx.save_for_backward(idx, features)
|
92 |
+
|
93 |
+
return _ext.gather_points(features, idx)
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
97 |
+
def backward(ctx, grad_out):
|
98 |
+
idx, features = ctx.saved_tensors
|
99 |
+
N = features.size(2)
|
100 |
+
|
101 |
+
grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
|
102 |
+
return grad_features, None
|
103 |
+
|
104 |
+
|
105 |
+
gather_operation = GatherOperation.apply
|
106 |
+
|
107 |
+
|
108 |
+
class ThreeNN(Function):
|
109 |
+
@staticmethod
|
110 |
+
@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
|
111 |
+
def forward(ctx, unknown, known):
|
112 |
+
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
113 |
+
r"""
|
114 |
+
Find the three nearest neighbors of unknown in known
|
115 |
+
Parameters
|
116 |
+
----------
|
117 |
+
unknown : torch.Tensor
|
118 |
+
(B, n, 3) tensor of known features
|
119 |
+
known : torch.Tensor
|
120 |
+
(B, m, 3) tensor of unknown features
|
121 |
+
|
122 |
+
Returns
|
123 |
+
-------
|
124 |
+
dist : torch.Tensor
|
125 |
+
(B, n, 3) l2 distance to the three nearest neighbors
|
126 |
+
idx : torch.Tensor
|
127 |
+
(B, n, 3) index of 3 nearest neighbors
|
128 |
+
"""
|
129 |
+
dist2, idx = _ext.three_nn(unknown, known)
|
130 |
+
dist = torch.sqrt(dist2)
|
131 |
+
|
132 |
+
ctx.mark_non_differentiable(dist, idx)
|
133 |
+
|
134 |
+
return dist, idx
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
138 |
+
def backward(ctx, grad_dist, grad_idx):
|
139 |
+
return ()
|
140 |
+
|
141 |
+
|
142 |
+
three_nn = ThreeNN.apply
|
143 |
+
|
144 |
+
|
145 |
+
class ThreeInterpolate(Function):
|
146 |
+
@staticmethod
|
147 |
+
@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
|
148 |
+
def forward(ctx, features, idx, weight):
|
149 |
+
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
|
150 |
+
r"""
|
151 |
+
Performs weight linear interpolation on 3 features
|
152 |
+
Parameters
|
153 |
+
----------
|
154 |
+
features : torch.Tensor
|
155 |
+
(B, c, m) Features descriptors to be interpolated from
|
156 |
+
idx : torch.Tensor
|
157 |
+
(B, n, 3) three nearest neighbors of the target features in features
|
158 |
+
weight : torch.Tensor
|
159 |
+
(B, n, 3) weights
|
160 |
+
|
161 |
+
Returns
|
162 |
+
-------
|
163 |
+
torch.Tensor
|
164 |
+
(B, c, n) tensor of the interpolated features
|
165 |
+
"""
|
166 |
+
ctx.save_for_backward(idx, weight, features)
|
167 |
+
|
168 |
+
return _ext.three_interpolate(features, idx, weight)
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
172 |
+
def backward(ctx, grad_out):
|
173 |
+
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
174 |
+
r"""
|
175 |
+
Parameters
|
176 |
+
----------
|
177 |
+
grad_out : torch.Tensor
|
178 |
+
(B, c, n) tensor with gradients of ouputs
|
179 |
+
|
180 |
+
Returns
|
181 |
+
-------
|
182 |
+
grad_features : torch.Tensor
|
183 |
+
(B, c, m) tensor with gradients of features
|
184 |
+
|
185 |
+
None
|
186 |
+
|
187 |
+
None
|
188 |
+
"""
|
189 |
+
idx, weight, features = ctx.saved_tensors
|
190 |
+
m = features.size(2)
|
191 |
+
|
192 |
+
grad_features = _ext.three_interpolate_grad(
|
193 |
+
grad_out.contiguous(), idx, weight, m
|
194 |
+
)
|
195 |
+
|
196 |
+
return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
|
197 |
+
|
198 |
+
|
199 |
+
three_interpolate = ThreeInterpolate.apply
|
200 |
+
|
201 |
+
|
202 |
+
class GroupingOperation(Function):
|
203 |
+
@staticmethod
|
204 |
+
@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
|
205 |
+
def forward(ctx, features, idx):
|
206 |
+
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
207 |
+
r"""
|
208 |
+
|
209 |
+
Parameters
|
210 |
+
----------
|
211 |
+
features : torch.Tensor
|
212 |
+
(B, C, N) tensor of features to group
|
213 |
+
idx : torch.Tensor
|
214 |
+
(B, npoint, nsample) tensor containing the indicies of features to group with
|
215 |
+
|
216 |
+
Returns
|
217 |
+
-------
|
218 |
+
torch.Tensor
|
219 |
+
(B, C, npoint, nsample) tensor
|
220 |
+
"""
|
221 |
+
ctx.save_for_backward(idx, features)
|
222 |
+
|
223 |
+
return _ext.group_points(features, idx)
|
224 |
+
|
225 |
+
@staticmethod
|
226 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
227 |
+
def backward(ctx, grad_out):
|
228 |
+
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
229 |
+
r"""
|
230 |
+
|
231 |
+
Parameters
|
232 |
+
----------
|
233 |
+
grad_out : torch.Tensor
|
234 |
+
(B, C, npoint, nsample) tensor of the gradients of the output from forward
|
235 |
+
|
236 |
+
Returns
|
237 |
+
-------
|
238 |
+
torch.Tensor
|
239 |
+
(B, C, N) gradient of the features
|
240 |
+
None
|
241 |
+
"""
|
242 |
+
idx, features = ctx.saved_tensors
|
243 |
+
N = features.size(2)
|
244 |
+
|
245 |
+
grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
|
246 |
+
|
247 |
+
return grad_features, torch.zeros_like(idx)
|
248 |
+
|
249 |
+
|
250 |
+
grouping_operation = GroupingOperation.apply
|
251 |
+
|
252 |
+
|
253 |
+
class BallQuery(Function):
|
254 |
+
@staticmethod
|
255 |
+
@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
|
256 |
+
def forward(ctx, radius, nsample, xyz, new_xyz):
|
257 |
+
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
|
258 |
+
r"""
|
259 |
+
|
260 |
+
Parameters
|
261 |
+
----------
|
262 |
+
radius : float
|
263 |
+
radius of the balls
|
264 |
+
nsample : int
|
265 |
+
maximum number of features in the balls
|
266 |
+
xyz : torch.Tensor
|
267 |
+
(B, N, 3) xyz coordinates of the features
|
268 |
+
new_xyz : torch.Tensor
|
269 |
+
(B, npoint, 3) centers of the ball query
|
270 |
+
|
271 |
+
Returns
|
272 |
+
-------
|
273 |
+
torch.Tensor
|
274 |
+
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
275 |
+
"""
|
276 |
+
output = _ext.ball_query(new_xyz, xyz, radius, nsample)
|
277 |
+
|
278 |
+
ctx.mark_non_differentiable(output)
|
279 |
+
|
280 |
+
return output
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
@torch.amp.custom_bwd(device_type="cuda")
|
284 |
+
def backward(ctx, grad_out):
|
285 |
+
return ()
|
286 |
+
|
287 |
+
|
288 |
+
ball_query = BallQuery.apply
|
289 |
+
|
290 |
+
|
291 |
+
class QueryAndGroup(nn.Module):
|
292 |
+
r"""
|
293 |
+
Groups with a ball query of radius
|
294 |
+
|
295 |
+
Parameters
|
296 |
+
---------
|
297 |
+
radius : float32
|
298 |
+
Radius of ball
|
299 |
+
nsample : int32
|
300 |
+
Maximum number of features to gather in the ball
|
301 |
+
"""
|
302 |
+
|
303 |
+
def __init__(self, radius, nsample, use_xyz=True):
|
304 |
+
# type: (QueryAndGroup, float, int, bool) -> None
|
305 |
+
super(QueryAndGroup, self).__init__()
|
306 |
+
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
307 |
+
|
308 |
+
def forward(self, xyz, new_xyz, features=None):
|
309 |
+
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
|
310 |
+
r"""
|
311 |
+
Parameters
|
312 |
+
----------
|
313 |
+
xyz : torch.Tensor
|
314 |
+
xyz coordinates of the features (B, N, 3)
|
315 |
+
new_xyz : torch.Tensor
|
316 |
+
centriods (B, npoint, 3)
|
317 |
+
features : torch.Tensor
|
318 |
+
Descriptors of the features (B, C, N)
|
319 |
+
|
320 |
+
Returns
|
321 |
+
-------
|
322 |
+
new_features : torch.Tensor
|
323 |
+
(B, 3 + C, npoint, nsample) tensor
|
324 |
+
"""
|
325 |
+
|
326 |
+
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
327 |
+
xyz_trans = xyz.transpose(1, 2).contiguous()
|
328 |
+
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
329 |
+
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
330 |
+
|
331 |
+
if features is not None:
|
332 |
+
grouped_features = grouping_operation(features, idx)
|
333 |
+
if self.use_xyz:
|
334 |
+
new_features = torch.cat(
|
335 |
+
[grouped_xyz, grouped_features], dim=1
|
336 |
+
) # (B, C + 3, npoint, nsample)
|
337 |
+
else:
|
338 |
+
new_features = grouped_features
|
339 |
+
else:
|
340 |
+
assert (
|
341 |
+
self.use_xyz
|
342 |
+
), "Cannot have not features and not use xyz as a feature!"
|
343 |
+
new_features = grouped_xyz
|
344 |
+
|
345 |
+
return new_features
|
346 |
+
|
347 |
+
|
348 |
+
class GroupAll(nn.Module):
|
349 |
+
r"""
|
350 |
+
Groups all features
|
351 |
+
|
352 |
+
Parameters
|
353 |
+
---------
|
354 |
+
"""
|
355 |
+
|
356 |
+
def __init__(self, use_xyz=True):
|
357 |
+
# type: (GroupAll, bool) -> None
|
358 |
+
super(GroupAll, self).__init__()
|
359 |
+
self.use_xyz = use_xyz
|
360 |
+
|
361 |
+
def forward(self, xyz, new_xyz, features=None):
|
362 |
+
# type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
|
363 |
+
r"""
|
364 |
+
Parameters
|
365 |
+
----------
|
366 |
+
xyz : torch.Tensor
|
367 |
+
xyz coordinates of the features (B, N, 3)
|
368 |
+
new_xyz : torch.Tensor
|
369 |
+
Ignored
|
370 |
+
features : torch.Tensor
|
371 |
+
Descriptors of the features (B, C, N)
|
372 |
+
|
373 |
+
Returns
|
374 |
+
-------
|
375 |
+
new_features : torch.Tensor
|
376 |
+
(B, C + 3, 1, N) tensor
|
377 |
+
"""
|
378 |
+
|
379 |
+
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
380 |
+
if features is not None:
|
381 |
+
grouped_features = features.unsqueeze(2)
|
382 |
+
if self.use_xyz:
|
383 |
+
new_features = torch.cat(
|
384 |
+
[grouped_xyz, grouped_features], dim=1
|
385 |
+
) # (B, 3 + C, 1, N)
|
386 |
+
else:
|
387 |
+
new_features = grouped_features
|
388 |
+
else:
|
389 |
+
new_features = grouped_xyz
|
390 |
+
|
391 |
+
return new_features
|
hort/models/tgs/models/snowflake/pointnet2_ops_lib/setup.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
|
5 |
+
from setuptools import find_packages, setup
|
6 |
+
import torch
|
7 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
8 |
+
|
9 |
+
this_dir = osp.dirname(osp.abspath(__file__))
|
10 |
+
_ext_src_root = osp.join("pointnet2_ops", "_ext-src")
|
11 |
+
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
12 |
+
osp.join(_ext_src_root, "src", "*.cu")
|
13 |
+
)
|
14 |
+
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
15 |
+
|
16 |
+
requirements = ["torch>=1.4"]
|
17 |
+
|
18 |
+
exec(open(osp.join("pointnet2_ops", "_version.py")).read())
|
19 |
+
|
20 |
+
# os.environ["TORCH_CUDA_ARCH_LIST"] = ".".join(map(str, torch.cuda.get_device_capability()))
|
21 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = "5.0;6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
|
22 |
+
setup(
|
23 |
+
name="pointnet2_ops",
|
24 |
+
version=__version__,
|
25 |
+
author="Erik Wijmans",
|
26 |
+
packages=find_packages(),
|
27 |
+
install_requires=requirements,
|
28 |
+
ext_modules=[
|
29 |
+
CUDAExtension(
|
30 |
+
name="pointnet2_ops._ext",
|
31 |
+
sources=_ext_sources,
|
32 |
+
extra_compile_args={
|
33 |
+
"cxx": ["-O3"],
|
34 |
+
"nvcc": ["-O3", "-Xfatbin", "-compress-all"],
|
35 |
+
},
|
36 |
+
include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
|
37 |
+
)
|
38 |
+
],
|
39 |
+
cmdclass={"build_ext": BuildExtension},
|
40 |
+
include_package_data=True,
|
41 |
+
)
|
hort/models/tgs/models/snowflake/skip_transformer.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author: Peng Xiang
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, einsum
|
6 |
+
from .utils import MLP_Res, grouping_operation, query_knn
|
7 |
+
|
8 |
+
|
9 |
+
class SkipTransformer(nn.Module):
|
10 |
+
def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4):
|
11 |
+
super(SkipTransformer, self).__init__()
|
12 |
+
self.mlp_v = MLP_Res(in_dim=in_channel*2, hidden_dim=in_channel, out_dim=in_channel)
|
13 |
+
self.n_knn = n_knn
|
14 |
+
self.conv_key = nn.Conv1d(in_channel, dim, 1)
|
15 |
+
self.conv_query = nn.Conv1d(in_channel, dim, 1)
|
16 |
+
self.conv_value = nn.Conv1d(in_channel, dim, 1)
|
17 |
+
|
18 |
+
self.pos_mlp = nn.Sequential(
|
19 |
+
nn.Conv2d(3, pos_hidden_dim, 1),
|
20 |
+
nn.BatchNorm2d(pos_hidden_dim),
|
21 |
+
nn.ReLU(),
|
22 |
+
nn.Conv2d(pos_hidden_dim, dim, 1)
|
23 |
+
)
|
24 |
+
|
25 |
+
self.attn_mlp = nn.Sequential(
|
26 |
+
nn.Conv2d(dim, dim * attn_hidden_multiplier, 1),
|
27 |
+
nn.BatchNorm2d(dim * attn_hidden_multiplier),
|
28 |
+
nn.ReLU(),
|
29 |
+
nn.Conv2d(dim * attn_hidden_multiplier, dim, 1)
|
30 |
+
)
|
31 |
+
|
32 |
+
self.conv_end = nn.Conv1d(dim, in_channel, 1)
|
33 |
+
|
34 |
+
def forward(self, pos, key, query, include_self=True):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
pos: (B, 3, N)
|
38 |
+
key: (B, in_channel, N)
|
39 |
+
query: (B, in_channel, N)
|
40 |
+
include_self: boolean
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Tensor: (B, in_channel, N), shape context feature
|
44 |
+
"""
|
45 |
+
value = self.mlp_v(torch.cat([key, query], 1))
|
46 |
+
identity = value
|
47 |
+
key = self.conv_key(key)
|
48 |
+
query = self.conv_query(query)
|
49 |
+
value = self.conv_value(value)
|
50 |
+
b, dim, n = value.shape
|
51 |
+
|
52 |
+
pos_flipped = pos.permute(0, 2, 1).contiguous()
|
53 |
+
idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped, include_self=include_self)
|
54 |
+
|
55 |
+
key = grouping_operation(key, idx_knn) # b, dim, n, n_knn
|
56 |
+
qk_rel = query.reshape((b, -1, n, 1)) - key
|
57 |
+
|
58 |
+
pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn
|
59 |
+
pos_embedding = self.pos_mlp(pos_rel)
|
60 |
+
|
61 |
+
attention = self.attn_mlp(qk_rel + pos_embedding) # b, dim, n, n_knn
|
62 |
+
attention = torch.softmax(attention, -1)
|
63 |
+
|
64 |
+
value = value.reshape((b, -1, n, 1)) + pos_embedding #
|
65 |
+
|
66 |
+
agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n
|
67 |
+
y = self.conv_end(agg)
|
68 |
+
|
69 |
+
return y + identity
|
hort/models/tgs/models/snowflake/utils.py
ADDED
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author: Peng Xiang
|
3 |
+
|
4 |
+
import types
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
from torch import nn, einsum
|
9 |
+
from pointnet2_ops.pointnet2_utils import furthest_point_sample, \
|
10 |
+
gather_operation, ball_query, three_nn, three_interpolate, grouping_operation
|
11 |
+
|
12 |
+
class Conv1d(nn.Module):
|
13 |
+
def __init__(self, in_channel, out_channel, kernel_size=1, stride=1, if_bn=True, activation_fn=torch.relu):
|
14 |
+
super(Conv1d, self).__init__()
|
15 |
+
self.conv = nn.Conv1d(in_channel, out_channel, kernel_size, stride=stride)
|
16 |
+
self.if_bn = if_bn
|
17 |
+
self.bn = nn.BatchNorm1d(out_channel)
|
18 |
+
self.activation_fn = activation_fn
|
19 |
+
|
20 |
+
def forward(self, input):
|
21 |
+
out = self.conv(input)
|
22 |
+
if self.if_bn:
|
23 |
+
out = self.bn(out)
|
24 |
+
|
25 |
+
if self.activation_fn is not None:
|
26 |
+
out = self.activation_fn(out)
|
27 |
+
|
28 |
+
return out
|
29 |
+
|
30 |
+
class Conv2d(nn.Module):
|
31 |
+
def __init__(self, in_channel, out_channel, kernel_size=(1, 1), stride=(1, 1), if_bn=True, activation_fn=torch.relu):
|
32 |
+
super(Conv2d, self).__init__()
|
33 |
+
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride)
|
34 |
+
self.if_bn = if_bn
|
35 |
+
self.bn = nn.BatchNorm2d(out_channel)
|
36 |
+
self.activation_fn = activation_fn
|
37 |
+
|
38 |
+
def forward(self, input):
|
39 |
+
out = self.conv(input)
|
40 |
+
if self.if_bn:
|
41 |
+
out = self.bn(out)
|
42 |
+
|
43 |
+
if self.activation_fn is not None:
|
44 |
+
out = self.activation_fn(out)
|
45 |
+
|
46 |
+
return out
|
47 |
+
|
48 |
+
class MLP(nn.Module):
|
49 |
+
def __init__(self, in_channel, layer_dims, bn=None):
|
50 |
+
super(MLP, self).__init__()
|
51 |
+
layers = []
|
52 |
+
last_channel = in_channel
|
53 |
+
for out_channel in layer_dims[:-1]:
|
54 |
+
layers.append(nn.Linear(last_channel, out_channel))
|
55 |
+
if bn:
|
56 |
+
layers.append(nn.BatchNorm1d(out_channel))
|
57 |
+
layers.append(nn.ReLU())
|
58 |
+
last_channel = out_channel
|
59 |
+
layers.append(nn.Linear(last_channel, layer_dims[-1]))
|
60 |
+
self.mlp = nn.Sequential(*layers)
|
61 |
+
|
62 |
+
def forward(self, inputs):
|
63 |
+
return self.mlp(inputs)
|
64 |
+
|
65 |
+
class MLP_CONV(nn.Module):
|
66 |
+
def __init__(self, in_channel, layer_dims, bn=None):
|
67 |
+
super(MLP_CONV, self).__init__()
|
68 |
+
layers = []
|
69 |
+
last_channel = in_channel
|
70 |
+
for out_channel in layer_dims[:-1]:
|
71 |
+
layers.append(nn.Conv1d(last_channel, out_channel, 1))
|
72 |
+
if bn:
|
73 |
+
layers.append(nn.BatchNorm1d(out_channel))
|
74 |
+
layers.append(nn.ReLU())
|
75 |
+
last_channel = out_channel
|
76 |
+
layers.append(nn.Conv1d(last_channel, layer_dims[-1], 1))
|
77 |
+
self.mlp = nn.Sequential(*layers)
|
78 |
+
|
79 |
+
def forward(self, inputs):
|
80 |
+
return self.mlp(inputs)
|
81 |
+
|
82 |
+
class MLP_Res(nn.Module):
|
83 |
+
def __init__(self, in_dim=128, hidden_dim=None, out_dim=128):
|
84 |
+
super(MLP_Res, self).__init__()
|
85 |
+
if hidden_dim is None:
|
86 |
+
hidden_dim = in_dim
|
87 |
+
self.conv_1 = nn.Conv1d(in_dim, hidden_dim, 1)
|
88 |
+
self.conv_2 = nn.Conv1d(hidden_dim, out_dim, 1)
|
89 |
+
self.conv_shortcut = nn.Conv1d(in_dim, out_dim, 1)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
"""
|
93 |
+
Args:
|
94 |
+
x: (B, out_dim, n)
|
95 |
+
"""
|
96 |
+
shortcut = self.conv_shortcut(x)
|
97 |
+
out = self.conv_2(torch.relu(self.conv_1(x))) + shortcut
|
98 |
+
return out
|
99 |
+
|
100 |
+
|
101 |
+
def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
xyz: Tensor, (B, 3, N)
|
105 |
+
points: Tensor, (B, f, N)
|
106 |
+
npoint: int
|
107 |
+
nsample: int
|
108 |
+
radius: float
|
109 |
+
use_xyz: boolean
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
new_xyz: Tensor, (B, 3, npoint)
|
113 |
+
new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample)
|
114 |
+
idx_local: Tensor, (B, npoint, nsample)
|
115 |
+
grouped_xyz: Tensor, (B, 3, npoint, nsample)
|
116 |
+
|
117 |
+
"""
|
118 |
+
xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3)
|
119 |
+
new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint)
|
120 |
+
|
121 |
+
idx = ball_query(radius, nsample, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous()) # (B, npoint, nsample)
|
122 |
+
grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample)
|
123 |
+
grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, nsample)
|
124 |
+
|
125 |
+
if points is not None:
|
126 |
+
grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample)
|
127 |
+
if use_xyz:
|
128 |
+
new_points = torch.cat([grouped_xyz, grouped_points], 1)
|
129 |
+
else:
|
130 |
+
new_points = grouped_points
|
131 |
+
else:
|
132 |
+
new_points = grouped_xyz
|
133 |
+
|
134 |
+
return new_xyz, new_points, idx, grouped_xyz
|
135 |
+
|
136 |
+
|
137 |
+
def sample_and_group_all(xyz, points, use_xyz=True):
|
138 |
+
"""
|
139 |
+
Args:
|
140 |
+
xyz: Tensor, (B, 3, nsample)
|
141 |
+
points: Tensor, (B, f, nsample)
|
142 |
+
use_xyz: boolean
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
new_xyz: Tensor, (B, 3, 1)
|
146 |
+
new_points: Tensor, (B, f|f+3|3, 1, nsample)
|
147 |
+
idx: Tensor, (B, 1, nsample)
|
148 |
+
grouped_xyz: Tensor, (B, 3, 1, nsample)
|
149 |
+
"""
|
150 |
+
b, _, nsample = xyz.shape
|
151 |
+
device = xyz.device
|
152 |
+
new_xyz = torch.zeros((1, 3, 1), dtype=torch.float, device=device).repeat(b, 1, 1)
|
153 |
+
grouped_xyz = xyz.reshape((b, 3, 1, nsample))
|
154 |
+
idx = torch.arange(nsample, device=device).reshape(1, 1, nsample).repeat(b, 1, 1)
|
155 |
+
if points is not None:
|
156 |
+
if use_xyz:
|
157 |
+
new_points = torch.cat([xyz, points], 1)
|
158 |
+
else:
|
159 |
+
new_points = points
|
160 |
+
new_points = new_points.unsqueeze(2)
|
161 |
+
else:
|
162 |
+
new_points = grouped_xyz
|
163 |
+
|
164 |
+
return new_xyz, new_points, idx, grouped_xyz
|
165 |
+
|
166 |
+
|
167 |
+
class PointNet_SA_Module(nn.Module):
|
168 |
+
def __init__(self, npoint, nsample, radius, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True):
|
169 |
+
"""
|
170 |
+
Args:
|
171 |
+
npoint: int, number of points to sample
|
172 |
+
nsample: int, number of points in each local region
|
173 |
+
radius: float
|
174 |
+
in_channel: int, input channel of features(points)
|
175 |
+
mlp: list of int,
|
176 |
+
"""
|
177 |
+
super(PointNet_SA_Module, self).__init__()
|
178 |
+
self.npoint = npoint
|
179 |
+
self.nsample = nsample
|
180 |
+
self.radius = radius
|
181 |
+
self.mlp = mlp
|
182 |
+
self.group_all = group_all
|
183 |
+
self.use_xyz = use_xyz
|
184 |
+
if use_xyz:
|
185 |
+
in_channel += 3
|
186 |
+
|
187 |
+
last_channel = in_channel
|
188 |
+
self.mlp_conv = []
|
189 |
+
for out_channel in mlp:
|
190 |
+
self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn))
|
191 |
+
last_channel = out_channel
|
192 |
+
|
193 |
+
self.mlp_conv = nn.Sequential(*self.mlp_conv)
|
194 |
+
|
195 |
+
def forward(self, xyz, points):
|
196 |
+
"""
|
197 |
+
Args:
|
198 |
+
xyz: Tensor, (B, 3, N)
|
199 |
+
points: Tensor, (B, f, N)
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
new_xyz: Tensor, (B, 3, npoint)
|
203 |
+
new_points: Tensor, (B, mlp[-1], npoint)
|
204 |
+
"""
|
205 |
+
if self.group_all:
|
206 |
+
new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz)
|
207 |
+
else:
|
208 |
+
new_xyz, new_points, idx, grouped_xyz = sample_and_group(xyz, points, self.npoint, self.nsample, self.radius, self.use_xyz)
|
209 |
+
|
210 |
+
new_points = self.mlp_conv(new_points)
|
211 |
+
new_points = torch.max(new_points, 3)[0]
|
212 |
+
|
213 |
+
return new_xyz, new_points
|
214 |
+
|
215 |
+
|
216 |
+
class PointNet_FP_Module(nn.Module):
|
217 |
+
def __init__(self, in_channel, mlp, use_points1=False, in_channel_points1=None, if_bn=True):
|
218 |
+
"""
|
219 |
+
Args:
|
220 |
+
in_channel: int, input channel of points2
|
221 |
+
mlp: list of int
|
222 |
+
use_points1: boolean, if use points
|
223 |
+
in_channel_points1: int, input channel of points1
|
224 |
+
"""
|
225 |
+
super(PointNet_FP_Module, self).__init__()
|
226 |
+
self.use_points1 = use_points1
|
227 |
+
|
228 |
+
if use_points1:
|
229 |
+
in_channel += in_channel_points1
|
230 |
+
|
231 |
+
last_channel = in_channel
|
232 |
+
self.mlp_conv = []
|
233 |
+
for out_channel in mlp:
|
234 |
+
self.mlp_conv.append(Conv1d(last_channel, out_channel, if_bn=if_bn))
|
235 |
+
last_channel = out_channel
|
236 |
+
|
237 |
+
self.mlp_conv = nn.Sequential(*self.mlp_conv)
|
238 |
+
|
239 |
+
def forward(self, xyz1, xyz2, points1, points2):
|
240 |
+
"""
|
241 |
+
Args:
|
242 |
+
xyz1: Tensor, (B, 3, N)
|
243 |
+
xyz2: Tensor, (B, 3, M)
|
244 |
+
points1: Tensor, (B, in_channel, N)
|
245 |
+
points2: Tensor, (B, in_channel, M)
|
246 |
+
|
247 |
+
Returns:MLP_CONV
|
248 |
+
new_points: Tensor, (B, mlp[-1], N)
|
249 |
+
"""
|
250 |
+
dist, idx = three_nn(xyz1.permute(0, 2, 1).contiguous(), xyz2.permute(0, 2, 1).contiguous())
|
251 |
+
dist = torch.clamp_min(dist, 1e-10) # (B, N, 3)
|
252 |
+
recip_dist = 1.0/dist
|
253 |
+
norm = torch.sum(recip_dist, 2, keepdim=True).repeat((1, 1, 3))
|
254 |
+
weight = recip_dist / norm
|
255 |
+
interpolated_points = three_interpolate(points2, idx, weight) # B, in_channel, N
|
256 |
+
|
257 |
+
if self.use_points1:
|
258 |
+
new_points = torch.cat([interpolated_points, points1], 1)
|
259 |
+
else:
|
260 |
+
new_points = interpolated_points
|
261 |
+
|
262 |
+
new_points = self.mlp_conv(new_points)
|
263 |
+
return new_points
|
264 |
+
|
265 |
+
|
266 |
+
def square_distance(src, dst):
|
267 |
+
"""
|
268 |
+
Calculate Euclid distance between each two points.
|
269 |
+
|
270 |
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
271 |
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
272 |
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
273 |
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
274 |
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
275 |
+
|
276 |
+
Input:
|
277 |
+
src: source points, [B, N, C]
|
278 |
+
dst: target points, [B, M, C]
|
279 |
+
Output:
|
280 |
+
dist: per-point square distance, [B, N, M]
|
281 |
+
"""
|
282 |
+
B, N, _ = src.shape
|
283 |
+
_, M, _ = dst.shape
|
284 |
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) # B, N, M
|
285 |
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
286 |
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
287 |
+
return dist
|
288 |
+
|
289 |
+
|
290 |
+
def query_knn(nsample, xyz, new_xyz, include_self=True):
|
291 |
+
"""Find k-NN of new_xyz in xyz"""
|
292 |
+
pad = 0 if include_self else 1
|
293 |
+
sqrdists = square_distance(new_xyz, xyz) # B, S, N
|
294 |
+
idx = torch.argsort(sqrdists, dim=-1, descending=False)[:, :, pad: nsample+pad]
|
295 |
+
return idx.int()
|
296 |
+
|
297 |
+
|
298 |
+
def sample_and_group_knn(xyz, points, npoint, k, use_xyz=True, idx=None):
|
299 |
+
"""
|
300 |
+
Args:
|
301 |
+
xyz: Tensor, (B, 3, N)
|
302 |
+
points: Tensor, (B, f, N)
|
303 |
+
npoint: int
|
304 |
+
nsample: int
|
305 |
+
radius: float
|
306 |
+
use_xyz: boolean
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
new_xyz: Tensor, (B, 3, npoint)
|
310 |
+
new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample)
|
311 |
+
idx_local: Tensor, (B, npoint, nsample)
|
312 |
+
grouped_xyz: Tensor, (B, 3, npoint, nsample)
|
313 |
+
|
314 |
+
"""
|
315 |
+
xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3)
|
316 |
+
new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint)
|
317 |
+
if idx is None:
|
318 |
+
idx = query_knn(k, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous())
|
319 |
+
grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample)
|
320 |
+
grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, k)
|
321 |
+
|
322 |
+
if points is not None:
|
323 |
+
grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample)
|
324 |
+
if use_xyz:
|
325 |
+
new_points = torch.cat([grouped_xyz, grouped_points], 1)
|
326 |
+
else:
|
327 |
+
new_points = grouped_points
|
328 |
+
else:
|
329 |
+
new_points = grouped_xyz
|
330 |
+
|
331 |
+
return new_xyz, new_points, idx, grouped_xyz
|
332 |
+
|
333 |
+
|
334 |
+
class PointNet_SA_Module_KNN(nn.Module):
|
335 |
+
def __init__(self, npoint, nsample, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True, if_idx=False):
|
336 |
+
"""
|
337 |
+
Args:
|
338 |
+
npoint: int, number of points to sample
|
339 |
+
nsample: int, number of points in each local region
|
340 |
+
radius: float
|
341 |
+
in_channel: int, input channel of features(points)
|
342 |
+
mlp: list of int,
|
343 |
+
"""
|
344 |
+
super(PointNet_SA_Module_KNN, self).__init__()
|
345 |
+
self.npoint = npoint
|
346 |
+
self.nsample = nsample
|
347 |
+
self.mlp = mlp
|
348 |
+
self.group_all = group_all
|
349 |
+
self.use_xyz = use_xyz
|
350 |
+
self.if_idx = if_idx
|
351 |
+
if use_xyz:
|
352 |
+
in_channel += 3
|
353 |
+
|
354 |
+
last_channel = in_channel
|
355 |
+
self.mlp_conv = []
|
356 |
+
for out_channel in mlp[:-1]:
|
357 |
+
self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn))
|
358 |
+
last_channel = out_channel
|
359 |
+
self.mlp_conv.append(Conv2d(last_channel, mlp[-1], if_bn=False, activation_fn=None))
|
360 |
+
self.mlp_conv = nn.Sequential(*self.mlp_conv)
|
361 |
+
|
362 |
+
def forward(self, xyz, points, idx=None):
|
363 |
+
"""
|
364 |
+
Args:
|
365 |
+
xyz: Tensor, (B, 3, N)
|
366 |
+
points: Tensor, (B, f, N)
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
new_xyz: Tensor, (B, 3, npoint)
|
370 |
+
new_points: Tensor, (B, mlp[-1], npoint)
|
371 |
+
"""
|
372 |
+
if self.group_all:
|
373 |
+
new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz)
|
374 |
+
else:
|
375 |
+
new_xyz, new_points, idx, grouped_xyz = sample_and_group_knn(xyz, points, self.npoint, self.nsample, self.use_xyz, idx=idx)
|
376 |
+
|
377 |
+
new_points = self.mlp_conv(new_points)
|
378 |
+
new_points = torch.max(new_points, 3)[0]
|
379 |
+
|
380 |
+
if self.if_idx:
|
381 |
+
return new_xyz, new_points, idx
|
382 |
+
else:
|
383 |
+
return new_xyz, new_points
|
384 |
+
|
385 |
+
|
386 |
+
def fps_subsample(pcd, n_points=2048):
|
387 |
+
"""
|
388 |
+
Args
|
389 |
+
pcd: (b, 16384, 3)
|
390 |
+
|
391 |
+
returns
|
392 |
+
new_pcd: (b, n_points, 3)
|
393 |
+
"""
|
394 |
+
new_pcd = gather_operation(pcd.permute(0, 2, 1).contiguous(), furthest_point_sample(pcd, n_points))
|
395 |
+
new_pcd = new_pcd.permute(0, 2, 1).contiguous()
|
396 |
+
return new_pcd
|
397 |
+
|
398 |
+
|
399 |
+
class Transformer(nn.Module):
|
400 |
+
def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4):
|
401 |
+
super(Transformer, self).__init__()
|
402 |
+
self.n_knn = n_knn
|
403 |
+
self.conv_key = nn.Conv1d(dim, dim, 1)
|
404 |
+
self.conv_query = nn.Conv1d(dim, dim, 1)
|
405 |
+
self.conv_value = nn.Conv1d(dim, dim, 1)
|
406 |
+
|
407 |
+
self.pos_mlp = nn.Sequential(
|
408 |
+
nn.Conv2d(3, pos_hidden_dim, 1),
|
409 |
+
nn.BatchNorm2d(pos_hidden_dim),
|
410 |
+
nn.ReLU(),
|
411 |
+
nn.Conv2d(pos_hidden_dim, dim, 1)
|
412 |
+
)
|
413 |
+
|
414 |
+
self.attn_mlp = nn.Sequential(
|
415 |
+
nn.Conv2d(dim, dim * attn_hidden_multiplier, 1),
|
416 |
+
nn.BatchNorm2d(dim * attn_hidden_multiplier),
|
417 |
+
nn.ReLU(),
|
418 |
+
nn.Conv2d(dim * attn_hidden_multiplier, dim, 1)
|
419 |
+
)
|
420 |
+
|
421 |
+
self.linear_start = nn.Conv1d(in_channel, dim, 1)
|
422 |
+
self.linear_end = nn.Conv1d(dim, in_channel, 1)
|
423 |
+
|
424 |
+
def forward(self, x, pos):
|
425 |
+
"""feed forward of transformer
|
426 |
+
Args:
|
427 |
+
x: Tensor of features, (B, in_channel, n)
|
428 |
+
pos: Tensor of positions, (B, 3, n)
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
y: Tensor of features with attention, (B, in_channel, n)
|
432 |
+
"""
|
433 |
+
|
434 |
+
identity = x
|
435 |
+
|
436 |
+
x = self.linear_start(x)
|
437 |
+
b, dim, n = x.shape
|
438 |
+
|
439 |
+
pos_flipped = pos.permute(0, 2, 1).contiguous()
|
440 |
+
idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped)
|
441 |
+
key = self.conv_key(x)
|
442 |
+
value = self.conv_value(x)
|
443 |
+
query = self.conv_query(x)
|
444 |
+
|
445 |
+
key = grouping_operation(key, idx_knn) # b, dim, n, n_knn
|
446 |
+
qk_rel = query.reshape((b, -1, n, 1)) - key
|
447 |
+
|
448 |
+
pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn
|
449 |
+
pos_embedding = self.pos_mlp(pos_rel) # b, dim, n, n_knn
|
450 |
+
|
451 |
+
attention = self.attn_mlp(qk_rel + pos_embedding)
|
452 |
+
attention = torch.softmax(attention, -1)
|
453 |
+
|
454 |
+
value = value.reshape((b, -1, n, 1)) + pos_embedding
|
455 |
+
|
456 |
+
agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n
|
457 |
+
y = self.linear_end(agg)
|
458 |
+
|
459 |
+
return y+identity
|
460 |
+
|
461 |
+
|
462 |
+
class CouplingLayer(nn.Module):
|
463 |
+
|
464 |
+
def __init__(self, d, intermediate_dim, swap=False):
|
465 |
+
nn.Module.__init__(self)
|
466 |
+
self.d = d - (d // 2)
|
467 |
+
self.swap = swap
|
468 |
+
self.net_s_t = nn.Sequential(
|
469 |
+
nn.Linear(self.d, intermediate_dim),
|
470 |
+
nn.ReLU(inplace=True),
|
471 |
+
nn.Linear(intermediate_dim, intermediate_dim),
|
472 |
+
nn.ReLU(inplace=True),
|
473 |
+
nn.Linear(intermediate_dim, (d - self.d) * 2),
|
474 |
+
)
|
475 |
+
|
476 |
+
def forward(self, x, logpx=None, reverse=False):
|
477 |
+
|
478 |
+
if self.swap:
|
479 |
+
x = torch.cat([x[:, self.d:], x[:, :self.d]], 1)
|
480 |
+
|
481 |
+
in_dim = self.d
|
482 |
+
out_dim = x.shape[1] - self.d
|
483 |
+
|
484 |
+
s_t = self.net_s_t(x[:, :in_dim])
|
485 |
+
scale = torch.sigmoid(s_t[:, :out_dim] + 2.)
|
486 |
+
shift = s_t[:, out_dim:]
|
487 |
+
|
488 |
+
logdetjac = torch.sum(torch.log(scale).view(scale.shape[0], -1), 1, keepdim=True)
|
489 |
+
|
490 |
+
if not reverse:
|
491 |
+
y1 = x[:, self.d:] * scale + shift
|
492 |
+
delta_logp = -logdetjac
|
493 |
+
else:
|
494 |
+
y1 = (x[:, self.d:] - shift) / scale
|
495 |
+
delta_logp = logdetjac
|
496 |
+
|
497 |
+
y = torch.cat([x[:, :self.d], y1], 1) if not self.swap else torch.cat([y1, x[:, :self.d]], 1)
|
498 |
+
|
499 |
+
if logpx is None:
|
500 |
+
return y
|
501 |
+
else:
|
502 |
+
return y, logpx + delta_logp
|
503 |
+
|
504 |
+
|
505 |
+
class SequentialFlow(nn.Module):
|
506 |
+
"""A generalized nn.Sequential container for normalizing flows.
|
507 |
+
"""
|
508 |
+
|
509 |
+
def __init__(self, layersList):
|
510 |
+
super(SequentialFlow, self).__init__()
|
511 |
+
self.chain = nn.ModuleList(layersList)
|
512 |
+
|
513 |
+
def forward(self, x, logpx=None, reverse=False, inds=None):
|
514 |
+
if inds is None:
|
515 |
+
if reverse:
|
516 |
+
inds = range(len(self.chain) - 1, -1, -1)
|
517 |
+
else:
|
518 |
+
inds = range(len(self.chain))
|
519 |
+
|
520 |
+
if logpx is None:
|
521 |
+
for i in inds:
|
522 |
+
x = self.chain[i](x, reverse=reverse)
|
523 |
+
return x
|
524 |
+
else:
|
525 |
+
for i in inds:
|
526 |
+
x, logpx = self.chain[i](x, logpx, reverse=reverse)
|
527 |
+
return x, logpx
|
528 |
+
|
529 |
+
|
530 |
+
def build_latent_flow(args):
|
531 |
+
chain = []
|
532 |
+
for i in range(args.latent_flow_depth):
|
533 |
+
chain.append(CouplingLayer(args.latent_dim, args.latent_flow_hidden_dim, swap=(i % 2 == 0)))
|
534 |
+
return SequentialFlow(chain)
|
535 |
+
|
536 |
+
|
537 |
+
##################
|
538 |
+
## SpectralNorm ##
|
539 |
+
##################
|
540 |
+
|
541 |
+
POWER_ITERATION_FN = "spectral_norm_power_iteration"
|
542 |
+
|
543 |
+
|
544 |
+
class SpectralNorm(object):
|
545 |
+
def __init__(self, name='weight', dim=0, eps=1e-12):
|
546 |
+
self.name = name
|
547 |
+
self.dim = dim
|
548 |
+
self.eps = eps
|
549 |
+
|
550 |
+
def compute_weight(self, module, n_power_iterations):
|
551 |
+
if n_power_iterations < 0:
|
552 |
+
raise ValueError(
|
553 |
+
'Expected n_power_iterations to be non-negative, but '
|
554 |
+
'got n_power_iterations={}'.format(n_power_iterations)
|
555 |
+
)
|
556 |
+
|
557 |
+
weight = getattr(module, self.name + '_orig')
|
558 |
+
u = getattr(module, self.name + '_u')
|
559 |
+
v = getattr(module, self.name + '_v')
|
560 |
+
weight_mat = weight
|
561 |
+
if self.dim != 0:
|
562 |
+
# permute dim to front
|
563 |
+
weight_mat = weight_mat.permute(self.dim, * [d for d in range(weight_mat.dim()) if d != self.dim])
|
564 |
+
height = weight_mat.size(0)
|
565 |
+
weight_mat = weight_mat.reshape(height, -1)
|
566 |
+
with torch.no_grad():
|
567 |
+
for _ in range(n_power_iterations):
|
568 |
+
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
569 |
+
# are the first left and right singular vectors.
|
570 |
+
# This power iteration produces approximations of `u` and `v`.
|
571 |
+
v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
|
572 |
+
u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
|
573 |
+
setattr(module, self.name + '_u', u)
|
574 |
+
setattr(module, self.name + '_v', v)
|
575 |
+
|
576 |
+
sigma = torch.dot(u, torch.matmul(weight_mat, v))
|
577 |
+
weight = weight / sigma
|
578 |
+
setattr(module, self.name, weight)
|
579 |
+
|
580 |
+
def remove(self, module):
|
581 |
+
weight = getattr(module, self.name)
|
582 |
+
delattr(module, self.name)
|
583 |
+
delattr(module, self.name + '_u')
|
584 |
+
delattr(module, self.name + '_orig')
|
585 |
+
module.register_parameter(self.name, torch.nn.Parameter(weight))
|
586 |
+
|
587 |
+
def get_update_method(self, module):
|
588 |
+
def update_fn(module, n_power_iterations):
|
589 |
+
self.compute_weight(module, n_power_iterations)
|
590 |
+
|
591 |
+
return update_fn
|
592 |
+
|
593 |
+
def __call__(self, module, unused_inputs):
|
594 |
+
del unused_inputs
|
595 |
+
self.compute_weight(module, n_power_iterations=0)
|
596 |
+
|
597 |
+
# requires_grad might be either True or False during inference.
|
598 |
+
if not module.training:
|
599 |
+
r_g = getattr(module, self.name + '_orig').requires_grad
|
600 |
+
setattr(module, self.name, getattr(module, self.name).detach().requires_grad_(r_g))
|
601 |
+
|
602 |
+
@staticmethod
|
603 |
+
def apply(module, name, dim, eps):
|
604 |
+
fn = SpectralNorm(name, dim, eps)
|
605 |
+
weight = module._parameters[name]
|
606 |
+
height = weight.size(dim)
|
607 |
+
|
608 |
+
u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
|
609 |
+
v = F.normalize(weight.new_empty(int(weight.numel() / height)).normal_(0, 1), dim=0, eps=fn.eps)
|
610 |
+
delattr(module, fn.name)
|
611 |
+
module.register_parameter(fn.name + "_orig", weight)
|
612 |
+
# We still need to assign weight back as fn.name because all sorts of
|
613 |
+
# things may assume that it exists, e.g., when initializing weights.
|
614 |
+
# However, we can't directly assign as it could be an nn.Parameter and
|
615 |
+
# gets added as a parameter. Instead, we register weight.data as a
|
616 |
+
# buffer, which will cause weight to be included in the state dict
|
617 |
+
# and also supports nn.init due to shared storage.
|
618 |
+
module.register_buffer(fn.name, weight.data)
|
619 |
+
module.register_buffer(fn.name + "_u", u)
|
620 |
+
module.register_buffer(fn.name + "_v", v)
|
621 |
+
|
622 |
+
setattr(module, POWER_ITERATION_FN, types.MethodType(fn.get_update_method(module), module))
|
623 |
+
|
624 |
+
module.register_forward_pre_hook(fn)
|
625 |
+
return fn
|
626 |
+
|
627 |
+
|
628 |
+
def inplace_spectral_norm(module, name='weight', dim=None, eps=1e-12):
|
629 |
+
r"""Applies spectral normalization to a parameter in the given module.
|
630 |
+
.. math::
|
631 |
+
\mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
|
632 |
+
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
633 |
+
Spectral normalization stabilizes the training of discriminators (critics)
|
634 |
+
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
|
635 |
+
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
636 |
+
power iteration method. If the dimension of the weight tensor is greater
|
637 |
+
than 2, it is reshaped to 2D in power iteration method to get spectral
|
638 |
+
norm. This is implemented via a hook that calculates spectral norm and
|
639 |
+
rescales weight before every :meth:`~Module.forward` call.
|
640 |
+
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
641 |
+
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
642 |
+
Args:
|
643 |
+
module (nn.Module): containing module
|
644 |
+
name (str, optional): name of weight parameter
|
645 |
+
n_power_iterations (int, optional): number of power iterations to
|
646 |
+
calculate spectal norm
|
647 |
+
dim (int, optional): dimension corresponding to number of outputs,
|
648 |
+
the default is 0, except for modules that are instances of
|
649 |
+
ConvTranspose1/2/3d, when it is 1
|
650 |
+
eps (float, optional): epsilon for numerical stability in
|
651 |
+
calculating norms
|
652 |
+
Returns:
|
653 |
+
The original module with the spectal norm hook
|
654 |
+
Example::
|
655 |
+
>>> m = spectral_norm(nn.Linear(20, 40))
|
656 |
+
Linear (20 -> 40)
|
657 |
+
>>> m.weight_u.size()
|
658 |
+
torch.Size([20])
|
659 |
+
"""
|
660 |
+
if dim is None:
|
661 |
+
if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
|
662 |
+
dim = 1
|
663 |
+
else:
|
664 |
+
dim = 0
|
665 |
+
SpectralNorm.apply(module, name, dim=dim, eps=eps)
|
666 |
+
return module
|
667 |
+
|
668 |
+
|
669 |
+
def remove_spectral_norm(module, name='weight'):
|
670 |
+
r"""Removes the spectral normalization reparameterization from a module.
|
671 |
+
Args:
|
672 |
+
module (nn.Module): containing module
|
673 |
+
name (str, optional): name of weight parameter
|
674 |
+
Example:
|
675 |
+
>>> m = spectral_norm(nn.Linear(40, 10))
|
676 |
+
>>> remove_spectral_norm(m)
|
677 |
+
"""
|
678 |
+
for k, hook in module._forward_pre_hooks.items():
|
679 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
680 |
+
hook.remove(module)
|
681 |
+
del module._forward_pre_hooks[k]
|
682 |
+
return module
|
683 |
+
|
684 |
+
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
|
685 |
+
|
686 |
+
|
687 |
+
def add_spectral_norm(model, logger=None):
|
688 |
+
"""Applies spectral norm to all modules within the scope of a CNF."""
|
689 |
+
|
690 |
+
def apply_spectral_norm(module):
|
691 |
+
if 'weight' in module._parameters:
|
692 |
+
if logger: logger.info("Adding spectral norm to {}".format(module))
|
693 |
+
inplace_spectral_norm(module, 'weight')
|
694 |
+
|
695 |
+
def find_coupling_layer(module):
|
696 |
+
if isinstance(module, CouplingLayer):
|
697 |
+
module.apply(apply_spectral_norm)
|
698 |
+
else:
|
699 |
+
for child in module.children():
|
700 |
+
find_coupling_layer(child)
|
701 |
+
|
702 |
+
find_coupling_layer(model)
|
703 |
+
|
704 |
+
|
705 |
+
def spectral_norm_power_iteration(model, n_power_iterations=1):
|
706 |
+
|
707 |
+
def recursive_power_iteration(module):
|
708 |
+
if hasattr(module, POWER_ITERATION_FN):
|
709 |
+
getattr(module, POWER_ITERATION_FN)(n_power_iterations)
|
710 |
+
|
711 |
+
model.apply(recursive_power_iteration)
|
712 |
+
|
713 |
+
def reparameterize_gaussian(mean, logvar):
|
714 |
+
std = torch.exp(0.5 * logvar)
|
715 |
+
eps = torch.randn(std.size()).to(mean)
|
716 |
+
return mean + std * eps
|
717 |
+
|
718 |
+
|
719 |
+
|
720 |
+
def gaussian_entropy(logvar):
|
721 |
+
const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
|
722 |
+
ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
|
723 |
+
return ent
|
724 |
+
|
725 |
+
|
726 |
+
def standard_normal_logprob(z):
|
727 |
+
dim = z.size(-1)
|
728 |
+
log_z = -0.5 * dim * np.log(2 * np.pi)
|
729 |
+
return log_z - z.pow(2) / 2
|
730 |
+
|
731 |
+
def truncated_normal_(tensor, mean=0, std=1, trunc_std=2):
|
732 |
+
"""
|
733 |
+
Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
|
734 |
+
"""
|
735 |
+
size = tensor.shape
|
736 |
+
tmp = tensor.new_empty(size + (4,)).normal_()
|
737 |
+
valid = (tmp < trunc_std) & (tmp > -trunc_std)
|
738 |
+
ind = valid.max(-1, keepdim=True)[1]
|
739 |
+
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
|
740 |
+
tensor.data.mul_(std).add_(mean)
|
741 |
+
return tensor
|
hort/models/tgs/models/tokenizers/dinov2.py
ADDED
@@ -0,0 +1,1179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch DINOv2 model."""
|
16 |
+
|
17 |
+
|
18 |
+
import collections.abc
|
19 |
+
import math
|
20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
21 |
+
from dataclasses import dataclass
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
import torch.nn.functional as F
|
26 |
+
from torch import nn
|
27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 |
+
|
29 |
+
from transformers.activations import ACT2FN
|
30 |
+
from transformers.modeling_outputs import (
|
31 |
+
BackboneOutput,
|
32 |
+
BaseModelOutput,
|
33 |
+
BaseModelOutputWithPooling,
|
34 |
+
ImageClassifierOutput,
|
35 |
+
)
|
36 |
+
from transformers.modeling_utils import PreTrainedModel
|
37 |
+
from transformers.pytorch_utils import (
|
38 |
+
find_pruneable_heads_and_indices,
|
39 |
+
prune_linear_layer,
|
40 |
+
)
|
41 |
+
from transformers.utils import (
|
42 |
+
add_code_sample_docstrings,
|
43 |
+
add_start_docstrings,
|
44 |
+
add_start_docstrings_to_model_forward,
|
45 |
+
logging,
|
46 |
+
replace_return_docstrings,
|
47 |
+
)
|
48 |
+
from transformers.utils.backbone_utils import BackboneMixin
|
49 |
+
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
|
50 |
+
|
51 |
+
from tgs.models.transformers import MemoryEfficientAttentionMixin
|
52 |
+
from tgs.utils.typing import *
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__)
|
56 |
+
|
57 |
+
# General docstring
|
58 |
+
_CONFIG_FOR_DOC = "Dinov2Config"
|
59 |
+
|
60 |
+
# Base docstring
|
61 |
+
_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
|
62 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
|
63 |
+
|
64 |
+
# Image classification docstring
|
65 |
+
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
|
66 |
+
|
67 |
+
|
68 |
+
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
69 |
+
"facebook/dinov2-base",
|
70 |
+
# See all DINOv2 models at https://huggingface.co/models?filter=dinov2
|
71 |
+
]
|
72 |
+
|
73 |
+
|
74 |
+
class Dinov2Embeddings(nn.Module):
|
75 |
+
"""
|
76 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, config: Dinov2Config) -> None:
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
83 |
+
# register as mask token as it's not used in optimization
|
84 |
+
# to avoid the use of find_unused_parameters_true
|
85 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
|
86 |
+
self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
|
87 |
+
self.patch_embeddings = Dinov2PatchEmbeddings(config)
|
88 |
+
num_patches = self.patch_embeddings.num_patches
|
89 |
+
self.position_embeddings = nn.Parameter(
|
90 |
+
torch.randn(1, num_patches + 1, config.hidden_size)
|
91 |
+
)
|
92 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
93 |
+
self.config = config
|
94 |
+
|
95 |
+
def interpolate_pos_encoding(
|
96 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
97 |
+
) -> torch.Tensor:
|
98 |
+
"""
|
99 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
100 |
+
resolution images.
|
101 |
+
|
102 |
+
Source:
|
103 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
104 |
+
"""
|
105 |
+
|
106 |
+
num_patches = embeddings.shape[1] - 1
|
107 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
108 |
+
if num_patches == num_positions and height == width:
|
109 |
+
return self.position_embeddings
|
110 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
111 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
112 |
+
dim = embeddings.shape[-1]
|
113 |
+
height = height // self.config.patch_size
|
114 |
+
width = width // self.config.patch_size
|
115 |
+
# we add a small number to avoid floating point error in the interpolation
|
116 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
117 |
+
height, width = height + 0.1, width + 0.1
|
118 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
119 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
120 |
+
)
|
121 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
122 |
+
patch_pos_embed = nn.functional.interpolate(
|
123 |
+
patch_pos_embed,
|
124 |
+
scale_factor=(
|
125 |
+
height / math.sqrt(num_positions),
|
126 |
+
width / math.sqrt(num_positions),
|
127 |
+
),
|
128 |
+
mode="bicubic",
|
129 |
+
align_corners=False,
|
130 |
+
)
|
131 |
+
if (
|
132 |
+
int(height) != patch_pos_embed.shape[-2]
|
133 |
+
or int(width) != patch_pos_embed.shape[-1]
|
134 |
+
):
|
135 |
+
raise ValueError(
|
136 |
+
"Width or height does not match with the interpolated position embeddings"
|
137 |
+
)
|
138 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
139 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
140 |
+
|
141 |
+
def forward(
|
142 |
+
self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None,
|
143 |
+
) -> torch.Tensor:
|
144 |
+
batch_size, _, height, width = pixel_values.shape
|
145 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
146 |
+
embeddings = patch_embeddings
|
147 |
+
|
148 |
+
if bool_masked_pos is not None:
|
149 |
+
embeddings = torch.where(
|
150 |
+
bool_masked_pos.unsqueeze(-1),
|
151 |
+
self.mask_token.to(embeddings.dtype).unsqueeze(0),
|
152 |
+
embeddings,
|
153 |
+
)
|
154 |
+
|
155 |
+
# add the [CLS] token to the embedded patch tokens
|
156 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
157 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
158 |
+
|
159 |
+
# add positional encoding to each token
|
160 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
161 |
+
embeddings, height, width
|
162 |
+
)
|
163 |
+
|
164 |
+
embeddings = self.dropout(embeddings)
|
165 |
+
|
166 |
+
return embeddings
|
167 |
+
|
168 |
+
|
169 |
+
class Dinov2PatchEmbeddings(nn.Module):
|
170 |
+
"""
|
171 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
172 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
173 |
+
Transformer.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, config):
|
177 |
+
super().__init__()
|
178 |
+
image_size, patch_size = config.image_size, config.patch_size
|
179 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
180 |
+
|
181 |
+
image_size = (
|
182 |
+
image_size
|
183 |
+
if isinstance(image_size, collections.abc.Iterable)
|
184 |
+
else (image_size, image_size)
|
185 |
+
)
|
186 |
+
patch_size = (
|
187 |
+
patch_size
|
188 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
189 |
+
else (patch_size, patch_size)
|
190 |
+
)
|
191 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
192 |
+
image_size[0] // patch_size[0]
|
193 |
+
)
|
194 |
+
self.image_size = image_size
|
195 |
+
self.patch_size = patch_size
|
196 |
+
self.num_channels = num_channels
|
197 |
+
self.num_patches = num_patches
|
198 |
+
|
199 |
+
self.projection = nn.Conv2d(
|
200 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
204 |
+
num_channels = pixel_values.shape[1]
|
205 |
+
if num_channels != self.num_channels:
|
206 |
+
raise ValueError(
|
207 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
208 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
209 |
+
)
|
210 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
211 |
+
return embeddings
|
212 |
+
|
213 |
+
|
214 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
|
215 |
+
class Dinov2SelfAttention(nn.Module):
|
216 |
+
def __init__(self, config: Dinov2Config) -> None:
|
217 |
+
super().__init__()
|
218 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
219 |
+
config, "embedding_size"
|
220 |
+
):
|
221 |
+
raise ValueError(
|
222 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
223 |
+
f"heads {config.num_attention_heads}."
|
224 |
+
)
|
225 |
+
|
226 |
+
self.num_attention_heads = config.num_attention_heads
|
227 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
228 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
229 |
+
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
230 |
+
|
231 |
+
self.query = nn.Linear(
|
232 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
233 |
+
)
|
234 |
+
self.key = nn.Linear(
|
235 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
236 |
+
)
|
237 |
+
self.value = nn.Linear(
|
238 |
+
config.hidden_size, self.all_head_size, bias=config.qkv_bias
|
239 |
+
)
|
240 |
+
|
241 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
242 |
+
|
243 |
+
self.use_memory_efficient_attention_xformers: bool = False
|
244 |
+
|
245 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
246 |
+
new_x_shape = x.size()[:-1] + (
|
247 |
+
self.num_attention_heads,
|
248 |
+
self.attention_head_size,
|
249 |
+
)
|
250 |
+
x = x.view(new_x_shape)
|
251 |
+
return x.permute(0, 2, 1, 3)
|
252 |
+
|
253 |
+
def forward(
|
254 |
+
self,
|
255 |
+
hidden_states,
|
256 |
+
head_mask: Optional[torch.Tensor] = None,
|
257 |
+
output_attentions: bool = False,
|
258 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
259 |
+
mixed_query_layer = self.query(hidden_states)
|
260 |
+
|
261 |
+
if self.use_memory_efficient_attention_xformers:
|
262 |
+
import xformers
|
263 |
+
assert head_mask is None and not output_attentions
|
264 |
+
new_size = hidden_states.size()[:-1] + (
|
265 |
+
self.num_attention_heads,
|
266 |
+
self.attention_head_size,
|
267 |
+
)
|
268 |
+
key_layer = self.key(hidden_states).view(new_size)
|
269 |
+
value_layer = self.value(hidden_states).view(new_size)
|
270 |
+
query_layer = mixed_query_layer.view(new_size)
|
271 |
+
context_layer = xformers.ops.memory_efficient_attention(
|
272 |
+
query_layer, key_layer, value_layer, p=self.attention_probs_dropout_prob
|
273 |
+
)
|
274 |
+
context_layer = context_layer.view(*hidden_states.size()[:-1], -1)
|
275 |
+
else:
|
276 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
277 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
278 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
279 |
+
|
280 |
+
try:
|
281 |
+
context_layer = F.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask=head_mask, dropout_p=(self.dropout.p if self.training else 0.0), scale=1/math.sqrt(self.attention_head_size))
|
282 |
+
except:
|
283 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
284 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
285 |
+
|
286 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
287 |
+
|
288 |
+
# Normalize the attention scores to probabilities.
|
289 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
290 |
+
|
291 |
+
# This is actually dropping out entire tokens to attend to, which might
|
292 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
293 |
+
attention_probs = self.dropout(attention_probs)
|
294 |
+
|
295 |
+
# Mask heads if we want to
|
296 |
+
if head_mask is not None:
|
297 |
+
attention_probs = attention_probs * head_mask
|
298 |
+
|
299 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
300 |
+
|
301 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
302 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
303 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
304 |
+
|
305 |
+
outputs = (
|
306 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
307 |
+
)
|
308 |
+
|
309 |
+
return outputs
|
310 |
+
|
311 |
+
def set_use_memory_efficient_attention_xformers(
|
312 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
313 |
+
):
|
314 |
+
self.use_memory_efficient_attention_xformers = valid
|
315 |
+
|
316 |
+
|
317 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
318 |
+
class Dinov2SelfOutput(nn.Module):
|
319 |
+
"""
|
320 |
+
The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
|
321 |
+
layernorm applied before each block.
|
322 |
+
"""
|
323 |
+
|
324 |
+
def __init__(self, config: Dinov2Config) -> None:
|
325 |
+
super().__init__()
|
326 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
327 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
328 |
+
|
329 |
+
def forward(
|
330 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
331 |
+
) -> torch.Tensor:
|
332 |
+
hidden_states = self.dense(hidden_states)
|
333 |
+
hidden_states = self.dropout(hidden_states)
|
334 |
+
|
335 |
+
return hidden_states
|
336 |
+
|
337 |
+
|
338 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
|
339 |
+
class Dinov2Attention(nn.Module):
|
340 |
+
def __init__(self, config: Dinov2Config) -> None:
|
341 |
+
super().__init__()
|
342 |
+
self.attention = Dinov2SelfAttention(config)
|
343 |
+
self.output = Dinov2SelfOutput(config)
|
344 |
+
self.pruned_heads = set()
|
345 |
+
|
346 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
347 |
+
if len(heads) == 0:
|
348 |
+
return
|
349 |
+
heads, index = find_pruneable_heads_and_indices(
|
350 |
+
heads,
|
351 |
+
self.attention.num_attention_heads,
|
352 |
+
self.attention.attention_head_size,
|
353 |
+
self.pruned_heads,
|
354 |
+
)
|
355 |
+
|
356 |
+
# Prune linear layers
|
357 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
358 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
359 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
360 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
361 |
+
|
362 |
+
# Update hyper params and store pruned heads
|
363 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(
|
364 |
+
heads
|
365 |
+
)
|
366 |
+
self.attention.all_head_size = (
|
367 |
+
self.attention.attention_head_size * self.attention.num_attention_heads
|
368 |
+
)
|
369 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
370 |
+
|
371 |
+
def forward(
|
372 |
+
self,
|
373 |
+
hidden_states: torch.Tensor,
|
374 |
+
head_mask: Optional[torch.Tensor] = None,
|
375 |
+
output_attentions: bool = False,
|
376 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
377 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
378 |
+
|
379 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
380 |
+
|
381 |
+
outputs = (attention_output,) + self_outputs[
|
382 |
+
1:
|
383 |
+
] # add attentions if we output them
|
384 |
+
return outputs
|
385 |
+
|
386 |
+
|
387 |
+
class Dinov2LayerScale(nn.Module):
|
388 |
+
def __init__(self, config) -> None:
|
389 |
+
super().__init__()
|
390 |
+
self.lambda1 = nn.Parameter(
|
391 |
+
config.layerscale_value * torch.ones(config.hidden_size)
|
392 |
+
)
|
393 |
+
|
394 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
395 |
+
return hidden_state * self.lambda1
|
396 |
+
|
397 |
+
|
398 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
399 |
+
def drop_path(
|
400 |
+
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
401 |
+
) -> torch.Tensor:
|
402 |
+
"""
|
403 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
404 |
+
|
405 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
406 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
407 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
408 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
409 |
+
argument.
|
410 |
+
"""
|
411 |
+
if drop_prob == 0.0 or not training:
|
412 |
+
return input
|
413 |
+
keep_prob = 1 - drop_prob
|
414 |
+
shape = (input.shape[0],) + (1,) * (
|
415 |
+
input.ndim - 1
|
416 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
417 |
+
random_tensor = keep_prob + torch.rand(
|
418 |
+
shape, dtype=input.dtype, device=input.device
|
419 |
+
)
|
420 |
+
random_tensor.floor_() # binarize
|
421 |
+
output = input.div(keep_prob) * random_tensor
|
422 |
+
return output
|
423 |
+
|
424 |
+
|
425 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
|
426 |
+
class Dinov2DropPath(nn.Module):
|
427 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
428 |
+
|
429 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
430 |
+
super().__init__()
|
431 |
+
self.drop_prob = drop_prob
|
432 |
+
|
433 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
434 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
435 |
+
|
436 |
+
def extra_repr(self) -> str:
|
437 |
+
return "p={}".format(self.drop_prob)
|
438 |
+
|
439 |
+
|
440 |
+
class Dinov2MLP(nn.Module):
|
441 |
+
def __init__(self, config) -> None:
|
442 |
+
super().__init__()
|
443 |
+
in_features = out_features = config.hidden_size
|
444 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
445 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
|
446 |
+
if isinstance(config.hidden_act, str):
|
447 |
+
self.activation = ACT2FN[config.hidden_act]
|
448 |
+
else:
|
449 |
+
self.activation = config.hidden_act
|
450 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
|
451 |
+
|
452 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
453 |
+
hidden_state = self.fc1(hidden_state)
|
454 |
+
hidden_state = self.activation(hidden_state)
|
455 |
+
hidden_state = self.fc2(hidden_state)
|
456 |
+
return hidden_state
|
457 |
+
|
458 |
+
|
459 |
+
class Dinov2SwiGLUFFN(nn.Module):
|
460 |
+
def __init__(self, config) -> None:
|
461 |
+
super().__init__()
|
462 |
+
in_features = out_features = config.hidden_size
|
463 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
464 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
465 |
+
|
466 |
+
self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
|
467 |
+
self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
|
468 |
+
|
469 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
470 |
+
hidden_state = self.weights_in(hidden_state)
|
471 |
+
x1, x2 = hidden_state.chunk(2, dim=-1)
|
472 |
+
hidden = nn.functional.silu(x1) * x2
|
473 |
+
return self.weights_out(hidden)
|
474 |
+
|
475 |
+
|
476 |
+
class Dinov2Layer(nn.Module, MemoryEfficientAttentionMixin):
|
477 |
+
"""This corresponds to the Block class in the original implementation."""
|
478 |
+
|
479 |
+
def __init__(self, config: Dinov2Config) -> None:
|
480 |
+
super().__init__()
|
481 |
+
|
482 |
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
483 |
+
self.norm1_modulation = None
|
484 |
+
self.attention = Dinov2Attention(config)
|
485 |
+
self.layer_scale1 = Dinov2LayerScale(config)
|
486 |
+
self.drop_path1 = (
|
487 |
+
Dinov2DropPath(config.drop_path_rate)
|
488 |
+
if config.drop_path_rate > 0.0
|
489 |
+
else nn.Identity()
|
490 |
+
)
|
491 |
+
|
492 |
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
493 |
+
self.norm2_modulation = None
|
494 |
+
|
495 |
+
if config.use_swiglu_ffn:
|
496 |
+
self.mlp = Dinov2SwiGLUFFN(config)
|
497 |
+
else:
|
498 |
+
self.mlp = Dinov2MLP(config)
|
499 |
+
self.layer_scale2 = Dinov2LayerScale(config)
|
500 |
+
self.drop_path2 = (
|
501 |
+
Dinov2DropPath(config.drop_path_rate)
|
502 |
+
if config.drop_path_rate > 0.0
|
503 |
+
else nn.Identity()
|
504 |
+
)
|
505 |
+
|
506 |
+
def forward(
|
507 |
+
self,
|
508 |
+
hidden_states: torch.Tensor,
|
509 |
+
head_mask: Optional[torch.Tensor] = None,
|
510 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
511 |
+
output_attentions: bool = False,
|
512 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
513 |
+
hidden_states_norm = self.norm1(hidden_states)
|
514 |
+
if self.norm1_modulation is not None:
|
515 |
+
assert modulation_cond is not None
|
516 |
+
hidden_states_norm = self.norm1_modulation(
|
517 |
+
hidden_states_norm, modulation_cond
|
518 |
+
)
|
519 |
+
self_attention_outputs = self.attention(
|
520 |
+
hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
|
521 |
+
head_mask,
|
522 |
+
output_attentions=output_attentions,
|
523 |
+
)
|
524 |
+
attention_output = self_attention_outputs[0]
|
525 |
+
|
526 |
+
attention_output = self.layer_scale1(attention_output)
|
527 |
+
outputs = self_attention_outputs[
|
528 |
+
1:
|
529 |
+
] # add self attentions if we output attention weights
|
530 |
+
|
531 |
+
# first residual connection
|
532 |
+
hidden_states = attention_output + hidden_states
|
533 |
+
|
534 |
+
# in Dinov2, layernorm is also applied after self-attention
|
535 |
+
layer_output = self.norm2(hidden_states)
|
536 |
+
if self.norm2_modulation is not None:
|
537 |
+
assert modulation_cond is not None
|
538 |
+
layer_output = self.norm2_modulation(layer_output, modulation_cond)
|
539 |
+
layer_output = self.mlp(layer_output)
|
540 |
+
layer_output = self.layer_scale2(layer_output)
|
541 |
+
|
542 |
+
# second residual connection
|
543 |
+
layer_output = layer_output + hidden_states
|
544 |
+
|
545 |
+
outputs = (layer_output,) + outputs
|
546 |
+
|
547 |
+
return outputs
|
548 |
+
|
549 |
+
def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
|
550 |
+
self.norm1_modulation = norm1_mod
|
551 |
+
self.norm2_modulation = norm2_mod
|
552 |
+
|
553 |
+
|
554 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
|
555 |
+
class Dinov2Encoder(nn.Module, MemoryEfficientAttentionMixin):
|
556 |
+
def __init__(self, config: Dinov2Config) -> None:
|
557 |
+
super().__init__()
|
558 |
+
self.config = config
|
559 |
+
self.layer = nn.ModuleList(
|
560 |
+
[Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
|
561 |
+
)
|
562 |
+
self.gradient_checkpointing = False
|
563 |
+
|
564 |
+
def forward(
|
565 |
+
self,
|
566 |
+
hidden_states: torch.Tensor,
|
567 |
+
head_mask: Optional[torch.Tensor] = None,
|
568 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
569 |
+
output_attentions: bool = False,
|
570 |
+
output_hidden_states: bool = False,
|
571 |
+
return_dict: bool = True,
|
572 |
+
) -> Union[tuple, BaseModelOutput]:
|
573 |
+
all_hidden_states = () if output_hidden_states else None
|
574 |
+
all_self_attentions = () if output_attentions else None
|
575 |
+
|
576 |
+
for i, layer_module in enumerate(self.layer):
|
577 |
+
if output_hidden_states:
|
578 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
579 |
+
|
580 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
581 |
+
|
582 |
+
if self.gradient_checkpointing and self.training:
|
583 |
+
|
584 |
+
def create_custom_forward(module):
|
585 |
+
def custom_forward(*inputs):
|
586 |
+
return module(*inputs, output_attentions)
|
587 |
+
|
588 |
+
return custom_forward
|
589 |
+
|
590 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
591 |
+
create_custom_forward(layer_module),
|
592 |
+
hidden_states,
|
593 |
+
layer_head_mask,
|
594 |
+
modulation_cond,
|
595 |
+
)
|
596 |
+
else:
|
597 |
+
layer_outputs = layer_module(
|
598 |
+
hidden_states, layer_head_mask, modulation_cond, output_attentions
|
599 |
+
)
|
600 |
+
|
601 |
+
hidden_states = layer_outputs[0]
|
602 |
+
|
603 |
+
if output_attentions:
|
604 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
605 |
+
|
606 |
+
if output_hidden_states:
|
607 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
608 |
+
|
609 |
+
if not return_dict:
|
610 |
+
return tuple(
|
611 |
+
v
|
612 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
613 |
+
if v is not None
|
614 |
+
)
|
615 |
+
return BaseModelOutput(
|
616 |
+
last_hidden_state=hidden_states,
|
617 |
+
hidden_states=all_hidden_states,
|
618 |
+
attentions=all_self_attentions,
|
619 |
+
)
|
620 |
+
|
621 |
+
|
622 |
+
class Dinov2PreTrainedModel(PreTrainedModel):
|
623 |
+
"""
|
624 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
625 |
+
models.
|
626 |
+
"""
|
627 |
+
|
628 |
+
config_class = Dinov2Config
|
629 |
+
base_model_prefix = "dinov2"
|
630 |
+
main_input_name = "pixel_values"
|
631 |
+
supports_gradient_checkpointing = True
|
632 |
+
|
633 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
634 |
+
"""Initialize the weights"""
|
635 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
636 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
637 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
638 |
+
module.weight.data = nn.init.trunc_normal_(
|
639 |
+
module.weight.data.to(torch.float32),
|
640 |
+
mean=0.0,
|
641 |
+
std=self.config.initializer_range,
|
642 |
+
).to(module.weight.dtype)
|
643 |
+
if module.bias is not None:
|
644 |
+
module.bias.data.zero_()
|
645 |
+
elif isinstance(module, nn.LayerNorm):
|
646 |
+
module.bias.data.zero_()
|
647 |
+
module.weight.data.fill_(1.0)
|
648 |
+
elif isinstance(module, Dinov2Embeddings):
|
649 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
650 |
+
module.position_embeddings.data.to(torch.float32),
|
651 |
+
mean=0.0,
|
652 |
+
std=self.config.initializer_range,
|
653 |
+
).to(module.position_embeddings.dtype)
|
654 |
+
|
655 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
656 |
+
module.cls_token.data.to(torch.float32),
|
657 |
+
mean=0.0,
|
658 |
+
std=self.config.initializer_range,
|
659 |
+
).to(module.cls_token.dtype)
|
660 |
+
|
661 |
+
def _set_gradient_checkpointing(
|
662 |
+
self, module: Dinov2Encoder, value: bool = False
|
663 |
+
) -> None:
|
664 |
+
if isinstance(module, Dinov2Encoder):
|
665 |
+
module.gradient_checkpointing = value
|
666 |
+
|
667 |
+
|
668 |
+
DINOV2_START_DOCSTRING = r"""
|
669 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
670 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
671 |
+
behavior.
|
672 |
+
|
673 |
+
Parameters:
|
674 |
+
config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
|
675 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
676 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
677 |
+
"""
|
678 |
+
|
679 |
+
DINOV2_BASE_INPUTS_DOCSTRING = r"""
|
680 |
+
Args:
|
681 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
682 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
683 |
+
[`BitImageProcessor.preprocess`] for details.
|
684 |
+
|
685 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
|
686 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
|
687 |
+
pre-training.
|
688 |
+
|
689 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
690 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
691 |
+
|
692 |
+
- 1 indicates the head is **not masked**,
|
693 |
+
- 0 indicates the head is **masked**.
|
694 |
+
|
695 |
+
output_attentions (`bool`, *optional*):
|
696 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
697 |
+
tensors for more detail.
|
698 |
+
output_hidden_states (`bool`, *optional*):
|
699 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
700 |
+
more detail.
|
701 |
+
return_dict (`bool`, *optional*):
|
702 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
703 |
+
"""
|
704 |
+
|
705 |
+
DINOV2_INPUTS_DOCSTRING = r"""
|
706 |
+
Args:
|
707 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
708 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
709 |
+
[`BitImageProcessor.preprocess`] for details.
|
710 |
+
|
711 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
712 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
713 |
+
|
714 |
+
- 1 indicates the head is **not masked**,
|
715 |
+
- 0 indicates the head is **masked**.
|
716 |
+
|
717 |
+
output_attentions (`bool`, *optional*):
|
718 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
719 |
+
tensors for more detail.
|
720 |
+
output_hidden_states (`bool`, *optional*):
|
721 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
722 |
+
more detail.
|
723 |
+
return_dict (`bool`, *optional*):
|
724 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
725 |
+
"""
|
726 |
+
|
727 |
+
@dataclass
|
728 |
+
class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
|
729 |
+
patch_embeddings: Optional[torch.FloatTensor] = None
|
730 |
+
|
731 |
+
|
732 |
+
@add_start_docstrings(
|
733 |
+
"The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
|
734 |
+
DINOV2_START_DOCSTRING,
|
735 |
+
)
|
736 |
+
class Dinov2Model(Dinov2PreTrainedModel, MemoryEfficientAttentionMixin):
|
737 |
+
def __init__(self, config: Dinov2Config):
|
738 |
+
super().__init__(config)
|
739 |
+
self.config = config
|
740 |
+
|
741 |
+
self.embeddings = Dinov2Embeddings(config)
|
742 |
+
self.encoder = Dinov2Encoder(config)
|
743 |
+
|
744 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
745 |
+
|
746 |
+
# Initialize weights and apply final processing
|
747 |
+
self.post_init()
|
748 |
+
|
749 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
750 |
+
return self.embeddings.patch_embeddings
|
751 |
+
|
752 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
753 |
+
"""
|
754 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
755 |
+
class PreTrainedModel
|
756 |
+
"""
|
757 |
+
for layer, heads in heads_to_prune.items():
|
758 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
759 |
+
|
760 |
+
@add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
|
761 |
+
@add_code_sample_docstrings(
|
762 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
763 |
+
output_type=BaseModelOutputWithPooling,
|
764 |
+
config_class=_CONFIG_FOR_DOC,
|
765 |
+
modality="vision",
|
766 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
767 |
+
)
|
768 |
+
def forward(
|
769 |
+
self,
|
770 |
+
pixel_values: Optional[torch.Tensor] = None,
|
771 |
+
bool_masked_pos: Optional[torch.Tensor] = None,
|
772 |
+
head_mask: Optional[torch.Tensor] = None,
|
773 |
+
modulation_cond: Optional[torch.Tensor] = None,
|
774 |
+
output_attentions: Optional[bool] = None,
|
775 |
+
output_hidden_states: Optional[bool] = None,
|
776 |
+
return_dict: Optional[bool] = None,
|
777 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
778 |
+
output_attentions = (
|
779 |
+
output_attentions
|
780 |
+
if output_attentions is not None
|
781 |
+
else self.config.output_attentions
|
782 |
+
)
|
783 |
+
output_hidden_states = (
|
784 |
+
output_hidden_states
|
785 |
+
if output_hidden_states is not None
|
786 |
+
else self.config.output_hidden_states
|
787 |
+
)
|
788 |
+
return_dict = (
|
789 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
790 |
+
)
|
791 |
+
|
792 |
+
if pixel_values is None:
|
793 |
+
raise ValueError("You have to specify pixel_values")
|
794 |
+
|
795 |
+
# Prepare head mask if needed
|
796 |
+
# 1.0 in head_mask indicate we keep the head
|
797 |
+
# attention_probs has shape bsz x n_heads x N x N
|
798 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
799 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
800 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
801 |
+
|
802 |
+
embedding_output = self.embeddings(
|
803 |
+
pixel_values, bool_masked_pos=bool_masked_pos
|
804 |
+
)
|
805 |
+
|
806 |
+
encoder_outputs = self.encoder(
|
807 |
+
embedding_output,
|
808 |
+
head_mask=head_mask,
|
809 |
+
modulation_cond=modulation_cond,
|
810 |
+
output_attentions=output_attentions,
|
811 |
+
output_hidden_states=output_hidden_states,
|
812 |
+
return_dict=return_dict,
|
813 |
+
)
|
814 |
+
sequence_output = encoder_outputs[0]
|
815 |
+
sequence_output = self.layernorm(sequence_output)
|
816 |
+
pooled_output = sequence_output[:, 0, :]
|
817 |
+
|
818 |
+
if not return_dict:
|
819 |
+
head_outputs = (sequence_output, pooled_output)
|
820 |
+
return head_outputs + encoder_outputs[1:]
|
821 |
+
|
822 |
+
return CustomBaseModelOutputWithPooling(
|
823 |
+
last_hidden_state=sequence_output,
|
824 |
+
pooler_output=pooled_output,
|
825 |
+
hidden_states=encoder_outputs.hidden_states,
|
826 |
+
attentions=encoder_outputs.attentions,
|
827 |
+
patch_embeddings=embedding_output
|
828 |
+
)
|
829 |
+
|
830 |
+
def set_gradient_checkpointing(self, value: bool = False) -> None:
|
831 |
+
self._set_gradient_checkpointing(self.encoder, value)
|
832 |
+
|
833 |
+
|
834 |
+
@add_start_docstrings(
|
835 |
+
"""
|
836 |
+
Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
|
837 |
+
of the [CLS] token) e.g. for ImageNet.
|
838 |
+
""",
|
839 |
+
DINOV2_START_DOCSTRING,
|
840 |
+
)
|
841 |
+
class Dinov2ForImageClassification(Dinov2PreTrainedModel):
|
842 |
+
def __init__(self, config: Dinov2Config) -> None:
|
843 |
+
super().__init__(config)
|
844 |
+
|
845 |
+
self.num_labels = config.num_labels
|
846 |
+
self.dinov2 = Dinov2Model(config)
|
847 |
+
|
848 |
+
# Classifier head
|
849 |
+
self.classifier = (
|
850 |
+
nn.Linear(config.hidden_size * 2, config.num_labels)
|
851 |
+
if config.num_labels > 0
|
852 |
+
else nn.Identity()
|
853 |
+
)
|
854 |
+
|
855 |
+
# Initialize weights and apply final processing
|
856 |
+
self.post_init()
|
857 |
+
|
858 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
859 |
+
@add_code_sample_docstrings(
|
860 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
861 |
+
output_type=ImageClassifierOutput,
|
862 |
+
config_class=_CONFIG_FOR_DOC,
|
863 |
+
)
|
864 |
+
def forward(
|
865 |
+
self,
|
866 |
+
pixel_values: Optional[torch.Tensor] = None,
|
867 |
+
head_mask: Optional[torch.Tensor] = None,
|
868 |
+
labels: Optional[torch.Tensor] = None,
|
869 |
+
output_attentions: Optional[bool] = None,
|
870 |
+
output_hidden_states: Optional[bool] = None,
|
871 |
+
return_dict: Optional[bool] = None,
|
872 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
873 |
+
r"""
|
874 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
875 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
876 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
877 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
878 |
+
"""
|
879 |
+
return_dict = (
|
880 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
881 |
+
)
|
882 |
+
|
883 |
+
outputs = self.dinov2(
|
884 |
+
pixel_values,
|
885 |
+
head_mask=head_mask,
|
886 |
+
output_attentions=output_attentions,
|
887 |
+
output_hidden_states=output_hidden_states,
|
888 |
+
return_dict=return_dict,
|
889 |
+
)
|
890 |
+
|
891 |
+
sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
|
892 |
+
|
893 |
+
cls_token = sequence_output[:, 0]
|
894 |
+
patch_tokens = sequence_output[:, 1:]
|
895 |
+
|
896 |
+
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
|
897 |
+
|
898 |
+
logits = self.classifier(linear_input)
|
899 |
+
|
900 |
+
loss = None
|
901 |
+
if labels is not None:
|
902 |
+
# move labels to correct device to enable model parallelism
|
903 |
+
labels = labels.to(logits.device)
|
904 |
+
if self.config.problem_type is None:
|
905 |
+
if self.num_labels == 1:
|
906 |
+
self.config.problem_type = "regression"
|
907 |
+
elif self.num_labels > 1 and (
|
908 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
909 |
+
):
|
910 |
+
self.config.problem_type = "single_label_classification"
|
911 |
+
else:
|
912 |
+
self.config.problem_type = "multi_label_classification"
|
913 |
+
|
914 |
+
if self.config.problem_type == "regression":
|
915 |
+
loss_fct = MSELoss()
|
916 |
+
if self.num_labels == 1:
|
917 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
918 |
+
else:
|
919 |
+
loss = loss_fct(logits, labels)
|
920 |
+
elif self.config.problem_type == "single_label_classification":
|
921 |
+
loss_fct = CrossEntropyLoss()
|
922 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
923 |
+
elif self.config.problem_type == "multi_label_classification":
|
924 |
+
loss_fct = BCEWithLogitsLoss()
|
925 |
+
loss = loss_fct(logits, labels)
|
926 |
+
|
927 |
+
if not return_dict:
|
928 |
+
output = (logits,) + outputs[2:]
|
929 |
+
return ((loss,) + output) if loss is not None else output
|
930 |
+
|
931 |
+
return ImageClassifierOutput(
|
932 |
+
loss=loss,
|
933 |
+
logits=logits,
|
934 |
+
hidden_states=outputs.hidden_states,
|
935 |
+
attentions=outputs.attentions,
|
936 |
+
)
|
937 |
+
|
938 |
+
|
939 |
+
@add_start_docstrings(
|
940 |
+
"""
|
941 |
+
Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
|
942 |
+
""",
|
943 |
+
DINOV2_START_DOCSTRING,
|
944 |
+
)
|
945 |
+
class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
|
946 |
+
def __init__(self, config):
|
947 |
+
super().__init__(config)
|
948 |
+
super()._init_backbone(config)
|
949 |
+
|
950 |
+
self.num_features = [
|
951 |
+
config.hidden_size for _ in range(config.num_hidden_layers + 1)
|
952 |
+
]
|
953 |
+
self.embeddings = Dinov2Embeddings(config)
|
954 |
+
self.encoder = Dinov2Encoder(config)
|
955 |
+
|
956 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
957 |
+
|
958 |
+
# Initialize weights and apply final processing
|
959 |
+
self.post_init()
|
960 |
+
|
961 |
+
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
|
962 |
+
return self.embeddings.patch_embeddings
|
963 |
+
|
964 |
+
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
|
965 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
966 |
+
def forward(
|
967 |
+
self,
|
968 |
+
pixel_values: torch.Tensor,
|
969 |
+
output_hidden_states: Optional[bool] = None,
|
970 |
+
output_attentions: Optional[bool] = None,
|
971 |
+
return_dict: Optional[bool] = None,
|
972 |
+
) -> BackboneOutput:
|
973 |
+
"""
|
974 |
+
Returns:
|
975 |
+
|
976 |
+
Examples:
|
977 |
+
|
978 |
+
```python
|
979 |
+
>>> from transformers import AutoImageProcessor, AutoBackbone
|
980 |
+
>>> import torch
|
981 |
+
>>> from PIL import Image
|
982 |
+
>>> import requests
|
983 |
+
|
984 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
985 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
986 |
+
|
987 |
+
>>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
988 |
+
>>> model = AutoBackbone.from_pretrained(
|
989 |
+
... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
|
990 |
+
... )
|
991 |
+
|
992 |
+
>>> inputs = processor(image, return_tensors="pt")
|
993 |
+
|
994 |
+
>>> outputs = model(**inputs)
|
995 |
+
>>> feature_maps = outputs.feature_maps
|
996 |
+
>>> list(feature_maps[-1].shape)
|
997 |
+
[1, 768, 16, 16]
|
998 |
+
```"""
|
999 |
+
return_dict = (
|
1000 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1001 |
+
)
|
1002 |
+
output_hidden_states = (
|
1003 |
+
output_hidden_states
|
1004 |
+
if output_hidden_states is not None
|
1005 |
+
else self.config.output_hidden_states
|
1006 |
+
)
|
1007 |
+
output_attentions = (
|
1008 |
+
output_attentions
|
1009 |
+
if output_attentions is not None
|
1010 |
+
else self.config.output_attentions
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
embedding_output = self.embeddings(pixel_values)
|
1014 |
+
|
1015 |
+
outputs = self.encoder(
|
1016 |
+
embedding_output,
|
1017 |
+
output_hidden_states=True,
|
1018 |
+
output_attentions=output_attentions,
|
1019 |
+
return_dict=return_dict,
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
1023 |
+
|
1024 |
+
feature_maps = ()
|
1025 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
1026 |
+
if stage in self.out_features:
|
1027 |
+
if self.config.apply_layernorm:
|
1028 |
+
hidden_state = self.layernorm(hidden_state)
|
1029 |
+
if self.config.reshape_hidden_states:
|
1030 |
+
batch_size, _, height, width = pixel_values.shape
|
1031 |
+
patch_size = self.config.patch_size
|
1032 |
+
hidden_state = hidden_state[:, 1:, :].reshape(
|
1033 |
+
batch_size, width // patch_size, height // patch_size, -1
|
1034 |
+
)
|
1035 |
+
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
1036 |
+
feature_maps += (hidden_state,)
|
1037 |
+
|
1038 |
+
if not return_dict:
|
1039 |
+
if output_hidden_states:
|
1040 |
+
output = (feature_maps,) + outputs[1:]
|
1041 |
+
else:
|
1042 |
+
output = (feature_maps,) + outputs[2:]
|
1043 |
+
return output
|
1044 |
+
|
1045 |
+
return BackboneOutput(
|
1046 |
+
feature_maps=feature_maps,
|
1047 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
1048 |
+
attentions=outputs.attentions if output_attentions else None,
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
|
1052 |
+
|
1053 |
+
class CustomPatchEmbeddings(nn.Module):
|
1054 |
+
"""
|
1055 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
1056 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
1057 |
+
Transformer.
|
1058 |
+
"""
|
1059 |
+
|
1060 |
+
def __init__(self, image_size: int, patch_size: int, num_channels: int, hidden_size: int):
|
1061 |
+
super().__init__()
|
1062 |
+
|
1063 |
+
image_size = (
|
1064 |
+
image_size
|
1065 |
+
if isinstance(image_size, collections.abc.Iterable)
|
1066 |
+
else (image_size, image_size)
|
1067 |
+
)
|
1068 |
+
patch_size = (
|
1069 |
+
patch_size
|
1070 |
+
if isinstance(patch_size, collections.abc.Iterable)
|
1071 |
+
else (patch_size, patch_size)
|
1072 |
+
)
|
1073 |
+
num_patches = (image_size[1] // patch_size[1]) * (
|
1074 |
+
image_size[0] // patch_size[0]
|
1075 |
+
)
|
1076 |
+
self.image_size = image_size
|
1077 |
+
self.patch_size = patch_size
|
1078 |
+
self.num_channels = num_channels
|
1079 |
+
self.num_patches = num_patches
|
1080 |
+
|
1081 |
+
self.projection = nn.Conv2d(
|
1082 |
+
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
|
1083 |
+
)
|
1084 |
+
|
1085 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
1086 |
+
num_channels = pixel_values.shape[1]
|
1087 |
+
if num_channels != self.num_channels:
|
1088 |
+
raise ValueError(
|
1089 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
1090 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
1091 |
+
)
|
1092 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
1093 |
+
return embeddings
|
1094 |
+
|
1095 |
+
|
1096 |
+
class CustomEmbeddings(nn.Module):
|
1097 |
+
"""
|
1098 |
+
Construct the CLS token, mask token, position and patch embeddings.
|
1099 |
+
"""
|
1100 |
+
|
1101 |
+
def __init__(self, image_size: int, patch_size: int, num_channels: int, hidden_size: int) -> None:
|
1102 |
+
super().__init__()
|
1103 |
+
|
1104 |
+
self.image_size = image_size
|
1105 |
+
self.patch_size = patch_size
|
1106 |
+
self.num_channels = num_channels
|
1107 |
+
self.hidden_size = hidden_size
|
1108 |
+
|
1109 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
1110 |
+
|
1111 |
+
self.patch_embeddings = CustomPatchEmbeddings(image_size, patch_size, num_channels, hidden_size)
|
1112 |
+
num_patches = self.patch_embeddings.num_patches
|
1113 |
+
self.position_embeddings = nn.Parameter(
|
1114 |
+
torch.randn(1, num_patches + 1, self.hidden_size)
|
1115 |
+
)
|
1116 |
+
|
1117 |
+
def interpolate_pos_encoding(
|
1118 |
+
self, embeddings: torch.Tensor, height: int, width: int
|
1119 |
+
) -> torch.Tensor:
|
1120 |
+
"""
|
1121 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
1122 |
+
resolution images.
|
1123 |
+
|
1124 |
+
Source:
|
1125 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
1126 |
+
"""
|
1127 |
+
|
1128 |
+
num_patches = embeddings.shape[1] - 1
|
1129 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
1130 |
+
if num_patches == num_positions and height == width:
|
1131 |
+
return self.position_embeddings
|
1132 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
1133 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
1134 |
+
dim = embeddings.shape[-1]
|
1135 |
+
height = height // self.patch_size
|
1136 |
+
width = width // self.patch_size
|
1137 |
+
# we add a small number to avoid floating point error in the interpolation
|
1138 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
1139 |
+
height, width = height + 0.1, width + 0.1
|
1140 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
1141 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
1142 |
+
)
|
1143 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
1144 |
+
patch_pos_embed = nn.functional.interpolate(
|
1145 |
+
patch_pos_embed,
|
1146 |
+
scale_factor=(
|
1147 |
+
height / math.sqrt(num_positions),
|
1148 |
+
width / math.sqrt(num_positions),
|
1149 |
+
),
|
1150 |
+
mode="bicubic",
|
1151 |
+
align_corners=False,
|
1152 |
+
)
|
1153 |
+
if (
|
1154 |
+
int(height) != patch_pos_embed.shape[-2]
|
1155 |
+
or int(width) != patch_pos_embed.shape[-1]
|
1156 |
+
):
|
1157 |
+
raise ValueError(
|
1158 |
+
"Width or height does not match with the interpolated position embeddings"
|
1159 |
+
)
|
1160 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
1161 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
1162 |
+
|
1163 |
+
def forward(
|
1164 |
+
self, pixel_values: torch.Tensor,
|
1165 |
+
) -> torch.Tensor:
|
1166 |
+
batch_size, _, height, width = pixel_values.shape
|
1167 |
+
patch_embeddings = self.patch_embeddings(pixel_values)
|
1168 |
+
embeddings = patch_embeddings
|
1169 |
+
|
1170 |
+
# add the [CLS] token to the embedded patch tokens
|
1171 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
1172 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
1173 |
+
|
1174 |
+
# add positional encoding to each token
|
1175 |
+
embeddings = embeddings + self.interpolate_pos_encoding(
|
1176 |
+
embeddings, height, width
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
return embeddings
|
hort/models/tgs/models/tokenizers/image.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from tgs.utils.base import BaseModule
|
8 |
+
from tgs.models.tokenizers.dinov2 import Dinov2Model
|
9 |
+
from tgs.models.transformers import Modulation
|
10 |
+
from tgs.utils.typing import *
|
11 |
+
|
12 |
+
class DINOV2SingleImageTokenizer(BaseModule):
|
13 |
+
@dataclass
|
14 |
+
class Config(BaseModule.Config):
|
15 |
+
pretrained_model_name_or_path: str = "facebook/dinov2-base"
|
16 |
+
width: int = 224
|
17 |
+
height: int = 224
|
18 |
+
modulation: bool = False
|
19 |
+
modulation_zero_init: bool = False
|
20 |
+
modulation_single_layer: bool = False
|
21 |
+
modulation_cond_dim: int = 16
|
22 |
+
freeze_backbone_params: bool = True
|
23 |
+
enable_memory_efficient_attention: bool = False
|
24 |
+
enable_gradient_checkpointing: bool = False
|
25 |
+
use_patch_embeddings: bool = False
|
26 |
+
patch_embeddings_aggr_method: str = 'concat'
|
27 |
+
|
28 |
+
cfg: Config
|
29 |
+
|
30 |
+
def configure(self) -> None:
|
31 |
+
super().configure()
|
32 |
+
model: Dinov2Model
|
33 |
+
|
34 |
+
if self.cfg.freeze_backbone_params:
|
35 |
+
# freeze dino backbone parameters
|
36 |
+
self.register_non_module(
|
37 |
+
"model",
|
38 |
+
Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path).to(
|
39 |
+
self.device
|
40 |
+
),
|
41 |
+
)
|
42 |
+
|
43 |
+
model = self.non_module("model")
|
44 |
+
for p in model.parameters():
|
45 |
+
p.requires_grad_(False)
|
46 |
+
model.eval()
|
47 |
+
else:
|
48 |
+
self.model = Dinov2Model.from_pretrained(
|
49 |
+
self.cfg.pretrained_model_name_or_path
|
50 |
+
).to(self.device)
|
51 |
+
model = self.model
|
52 |
+
|
53 |
+
model.set_use_memory_efficient_attention_xformers(
|
54 |
+
self.cfg.enable_memory_efficient_attention
|
55 |
+
)
|
56 |
+
model.set_gradient_checkpointing(self.cfg.enable_gradient_checkpointing)
|
57 |
+
|
58 |
+
# add modulation
|
59 |
+
if self.cfg.modulation:
|
60 |
+
modulations = []
|
61 |
+
for layer in model.encoder.layer:
|
62 |
+
norm1_modulation = Modulation(
|
63 |
+
model.config.hidden_size,
|
64 |
+
self.cfg.modulation_cond_dim,
|
65 |
+
zero_init=self.cfg.modulation_zero_init,
|
66 |
+
single_layer=self.cfg.modulation_single_layer,
|
67 |
+
)
|
68 |
+
norm2_modulation = Modulation(
|
69 |
+
model.config.hidden_size,
|
70 |
+
self.cfg.modulation_cond_dim,
|
71 |
+
zero_init=self.cfg.modulation_zero_init,
|
72 |
+
single_layer=self.cfg.modulation_single_layer,
|
73 |
+
)
|
74 |
+
layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
|
75 |
+
modulations += [norm1_modulation, norm2_modulation]
|
76 |
+
self.modulations = nn.ModuleList(modulations)
|
77 |
+
|
78 |
+
def forward(
|
79 |
+
self,
|
80 |
+
images: Float[Tensor, "B *N C H W"],
|
81 |
+
modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
|
82 |
+
) -> Float[Tensor, "B *N Ct Nt"]:
|
83 |
+
model: Dinov2Model
|
84 |
+
if self.cfg.freeze_backbone_params:
|
85 |
+
model = self.non_module("model")
|
86 |
+
else:
|
87 |
+
model = self.model
|
88 |
+
|
89 |
+
packed = False
|
90 |
+
if images.ndim == 4:
|
91 |
+
packed = True
|
92 |
+
images = images.unsqueeze(1)
|
93 |
+
if modulation_cond is not None:
|
94 |
+
assert modulation_cond.ndim == 2
|
95 |
+
modulation_cond = modulation_cond.unsqueeze(1)
|
96 |
+
|
97 |
+
batch_size, n_input_views = images.shape[:2]
|
98 |
+
out = model(
|
99 |
+
rearrange(images, "B N C H W -> (B N) C H W"),
|
100 |
+
modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
|
101 |
+
if modulation_cond is not None
|
102 |
+
else None,
|
103 |
+
)
|
104 |
+
local_features, global_features = out.last_hidden_state, out.pooler_output
|
105 |
+
if self.cfg.use_patch_embeddings:
|
106 |
+
patch_embeddings = out.patch_embeddings
|
107 |
+
if self.cfg.patch_embeddings_aggr_method == 'concat':
|
108 |
+
local_features = torch.cat([local_features, patch_embeddings], dim=1)
|
109 |
+
elif self.cfg.patch_embeddings_aggr_method == 'add':
|
110 |
+
local_features = local_features + patch_embeddings
|
111 |
+
else:
|
112 |
+
raise NotImplementedError
|
113 |
+
local_features = local_features.permute(0, 2, 1)
|
114 |
+
local_features = rearrange(
|
115 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
116 |
+
)
|
117 |
+
if packed:
|
118 |
+
local_features = local_features.squeeze(1)
|
119 |
+
|
120 |
+
return local_features
|
121 |
+
|
122 |
+
def detokenize(self, *args, **kwargs):
|
123 |
+
raise NotImplementedError
|