LSM / src /utils /gaussian_model.py
kairunwen's picture
Update Code
57746f1
import os
import torch
from einops import rearrange
import numpy as np
from plyfile import PlyData, PlyElement
from os import makedirs, path
from errno import EEXIST
def mkdir_p(folder_path):
# Creates a directory. equivalent to using mkdir -p on the command line
try:
makedirs(folder_path)
except OSError as exc: # Python >2.5
if exc.errno == EEXIST and path.isdir(folder_path):
pass
else:
raise
def RGB2SH(rgb):
return (rgb - 0.5) / C0
C0 = 0.28209479177387814
# https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
def quaternion_to_matrix(
quaternions,
eps=1e-8,
) :
# Order changed to match scipy format!
i, j, k, r = torch.unbind(quaternions, dim=-1)
two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return rearrange(o, "... (i j) -> ... i j", i=3, j=3)
def build_covariance(
scale,
rotation_xyzw,
):
scale = scale.diag_embed()
rotation = quaternion_to_matrix(rotation_xyzw)
return (
rotation
@ scale
@ rearrange(scale, "... i j -> ... j i")
@ rearrange(rotation, "... i j -> ... j i")
)
def inverse_sigmoid(x):
return torch.log(x/(1-x))
class GaussianModel:
def __init__(self, sh_degree : int):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
self._xyz = torch.empty(0)
self._features_dc = torch.empty(0)
self._features_rest = torch.empty(0)
self._scaling = torch.empty(0)
self._rotation = torch.empty(0)
self._opacity = torch.empty(0)
self.max_radii2D = torch.empty(0)
self.xyz_gradient_accum = torch.empty(0)
self.denom = torch.empty(0)
self.optimizer = None
self.percent_dense = 0
self.spatial_lr_scale = 0
self._semantic_feature = torch.empty(0)
@property
def get_scaling(self):
return self._scaling
@property
def get_rotation(self):
return self._rotation
@property
def get_xyz(self):
return self._xyz
@property
def get_features(self):
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
@property
def get_opacity(self):
return self._opacity
@property
def get_semantic_feature(self):
return self._semantic_feature
def construct_list_of_attributes(self):
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
# All channels except the 3 DC
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
l.append('f_dc_{}'.format(i))
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
l.append('f_rest_{}'.format(i))
l.append('opacity')
for i in range(self._scaling.shape[1]):
l.append('scale_{}'.format(i))
for i in range(self._rotation.shape[1]):
l.append('rot_{}'.format(i))
# Add semantic features
for i in range(self._semantic_feature.shape[1]*self._semantic_feature.shape[2]):
l.append('semantic_{}'.format(i))
return l
@staticmethod
def from_predictions(pred, sh_degree):
gaussians = GaussianModel(sh_degree=sh_degree)
gaussians._xyz = pred['means']
gaussians._features_dc = pred['sh_coeffs'][:, :1] # N, 1, d_sh
gaussians._features_rest = pred['sh_coeffs'][:, 1:] # N, d_sh-1, d_sh
gaussians._opacity = pred['opacities'] # N, 1
gaussians._scaling = pred['scales'] # N, 3, 3
gaussians._rotation = pred['rotations'] # N, 4
gaussians._semantic_feature = pred['gs_feats'][:, None, :] # N, 1, d_feats
return gaussians
def save_ply(self, path):
mkdir_p(os.path.dirname(path))
xyz = self._xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = inverse_sigmoid(self._opacity).detach().cpu().numpy()
scale = torch.log(self._scaling).detach().cpu().numpy()
rotation = self._rotation.detach().cpu().numpy()
semantic_feature = self._semantic_feature.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation, semantic_feature), axis=1)
# attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)