ZhiyuanthePony's picture
remove_type_annotator
fc44d4b
from dataclasses import dataclass
import torch
import torch.nn as nn
import os
import numpy as np
from .saving import SaverMixin
from ..utils.mesh import Mesh
from ..utils.general_utils import scale_tensor
@dataclass
class ExporterOutput:
save_name: str
save_type: str
params: dict
class IsosurfaceHelper(nn.Module):
points_range = (0, 1)
@property
def grid_vertices(self):
raise NotImplementedError
class DiffMarchingCubeHelper(IsosurfaceHelper):
def __init__(
self,
resolution,
point_range = (0, 1)
):
super().__init__()
self.resolution = resolution
self.points_range = point_range
from diso import DiffMC
self.mc_func = DiffMC(dtype=torch.float32)
self._grid_vertices = None
self.register_buffer(
"_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
)
@property
def grid_vertices(self):
if self._grid_vertices is None:
# keep the vertices on CPU so that we can support very large resolution
x, y, z = (
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
)
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
verts = torch.stack([x, y, z], dim=-1).reshape(-1, 3)
verts = verts * (self.points_range[1] - self.points_range[0]) + self.points_range[0]
self._grid_vertices = verts
return self._grid_vertices
def forward(
self,
level,
deformation = None,
isovalue=0.0,
):
level = level.view(self.resolution, self.resolution, self.resolution)
if deformation is not None:
deformation = deformation.view(self.resolution, self.resolution, self.resolution, 3)
v_pos, t_pos_idx = self.mc_func(level, deformation, isovalue=isovalue)
v_pos = v_pos * (self.points_range[1] - self.points_range[0]) + self.points_range[0]
# TODO: if the mesh is good
return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx)
def isosurface(
space_cache,
forward_field,
isosurface_helper,
):
# the isosurface is dependent on the space cache
# randomly detach isosurface method if it is differentiable
# get the batchsize
if torch.is_tensor(space_cache): #space cache
batch_size = space_cache.shape[0]
elif isinstance(space_cache, dict): #hyper net
# Dict[str, List[Float[Tensor, "B ..."]]]
for key in space_cache.keys():
batch_size = space_cache[key][0].shape[0]
break
# scale the points to [-1, 1]
points = scale_tensor(
isosurface_helper.grid_vertices.to(space_cache.device),
isosurface_helper.points_range,
[-1, 1], # hard coded isosurface_bbox
)
# get the sdf values
sdf_batch, deformation_batch = forward_field(
points[None, ...].expand(batch_size, -1, -1),
space_cache
)
# get the isosurface
mesh_list = []
# check if the sdf is empty
# for sdf, deformation in zip(sdf_batch, deformation_batch):
for index in range(sdf_batch.shape[0]):
sdf = sdf_batch[index]
# the deformation may be None
if deformation_batch is None:
deformation = None
else:
deformation = deformation_batch[index]
# special case when all sdf values are positive or negative, thus no isosurface
if torch.all(sdf > 0) or torch.all(sdf < 0):
print(f"All sdf values are positive or negative, no isosurface")
sdf = torch.norm(points, dim=-1) - 1
mesh = isosurface_helper(sdf, deformation)
mesh.v_pos = scale_tensor(
mesh.v_pos,
isosurface_helper.points_range,
[-1, 1], # hard coded isosurface_bbox
)
# TODO: implement outlier removal
# if cfg.isosurface_remove_outliers:
# mesh = mesh.remove_outlier(cfg.isosurface_outlier_n_faces_threshold)
mesh_list.append(mesh)
return mesh_list
def colorize_mesh(
space_cache,
export_fn,
mesh_list,
activation,
):
"""Colorize the mesh using the geometry's export function and space cache.
Args:
space_cache: The space cache containing feature information
export_fn: The export function from geometry that generates features
mesh_list: List of meshes to colorize
Returns:
List[Mesh]: List of colorized meshes
"""
# Process each mesh in the batch
for i, mesh in enumerate(mesh_list):
# Get vertex positions
points = mesh.v_pos[None, ...] # Add batch dimension [1, N, 3]
# Get the corresponding space cache slice for this mesh
if torch.is_tensor(space_cache):
space_cache_slice = space_cache[i:i+1]
elif isinstance(space_cache, dict):
space_cache_slice = {}
for key in space_cache.keys():
space_cache_slice[key] = [
weight[i:i+1] for weight in space_cache[key]
]
# Export features for the vertices
out = export_fn(points, space_cache_slice)
# Update vertex colors if features exist
if "features" in out:
features = out["features"].squeeze(0) # Remove batch dim [N, C]
# Convert features to RGB colors
mesh._v_rgb = activation(features) # Access private attribute directly
return mesh_list
class MeshExporter(SaverMixin):
def __init__(self, save_dir="outputs"):
self.save_dir = save_dir
os.makedirs(save_dir, exist_ok=True)
def get_save_dir(self):
return self.save_dir
def get_save_path(self, filename):
return os.path.join(self.save_dir, filename)
def convert_data(self, x):
if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
return x
def export_obj(
mesh,
save_path,
save_normal = False,
):
"""
Export mesh data to OBJ file format.
Args:
mesh_data: Dictionary containing mesh data (vertices, faces, etc.)
save_path: Path to save the OBJ file
Returns:
List of saved file paths
"""
# Create exporter
exporter = MeshExporter(os.path.dirname(save_path))
# Export mesh
save_paths = exporter.save_obj(
os.path.basename(save_path),
mesh,
save_mat=None,
save_normal=save_normal and mesh.v_nrm is not None,
save_uv=False,
save_vertex_color=mesh.v_rgb is not None,
)
return save_paths