import numpy as np import torch import torch.nn.functional as F import trimesh def dot(x, y): return torch.sum(x * y, -1, keepdim=True) class Mesh: def __init__( self, v_pos, t_pos_idx, material=None ): self.v_pos = v_pos self.t_pos_idx = t_pos_idx self.material = material self._v_nrm = None self._v_tng = None self._v_tex = None self._t_tex_idx = None self._v_rgb = None self._edges = None self.extras = {} def add_extra(self, k, v) -> None: self.extras[k] = v def remove_outlier(self, n_face_threshold=5): """Remove outlier components with fewer faces than threshold.""" # Convert to trimesh trimesh_mesh = self.as_trimesh() # Split into connected components components = trimesh_mesh.split(only_watertight=False) # Filter components with few faces valid_components = [c for c in components if len(c.faces) > n_face_threshold] if len(valid_components) == 0: # If no valid components, return the original mesh return self # Combine valid components combined = trimesh.util.concatenate(valid_components) # Convert back to our Mesh format new_mesh = Mesh( torch.tensor(combined.vertices, dtype=self.v_pos.dtype, device=self.v_pos.device), torch.tensor(combined.faces, dtype=self.t_pos_idx.dtype, device=self.t_pos_idx.device) ) return new_mesh @property def requires_grad(self): return self.v_pos.requires_grad @property def v_nrm(self): if self._v_nrm is None: self._v_nrm = self._compute_vertex_normal() return self._v_nrm @property def v_tng(self): if self._v_tng is None: self._v_tng = self._compute_vertex_tangent() return self._v_tng @property def v_tex(self): if self._v_tex is None: self._v_tex, self._t_tex_idx = self._unwrap_uv() return self._v_tex @property def t_tex_idx(self): if self._t_tex_idx is None: self._v_tex, self._t_tex_idx = self._unwrap_uv() return self._t_tex_idx @property def v_rgb(self): return self._v_rgb @property def edges(self): if self._edges is None: self._edges = self._compute_edges() return self._edges def _compute_vertex_normal(self): i0 = self.t_pos_idx[:, 0] i1 = self.t_pos_idx[:, 1] i2 = self.t_pos_idx[:, 2] v0 = self.v_pos[i0, :] v1 = self.v_pos[i1, :] v2 = self.v_pos[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) # Splat face normals to vertices v_nrm = torch.zeros_like(self.v_pos) v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) # Normalize, replace zero (degenerated) normals with some default value v_nrm = torch.where( dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) ) v_nrm = F.normalize(v_nrm, dim=1) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(v_nrm)) return v_nrm def _compute_vertex_tangent(self): vn_idx = [None] * 3 pos = [None] * 3 tex = [None] * 3 for i in range(0, 3): pos[i] = self.v_pos[self.t_pos_idx[:, i]] tex[i] = self.v_tex[self.t_tex_idx[:, i]] # t_nrm_idx is always the same as t_pos_idx vn_idx[i] = self.t_pos_idx[:, i] tangents = torch.zeros_like(self.v_nrm) tansum = torch.zeros_like(self.v_nrm) # Compute tangent space for each triangle uve1 = tex[1] - tex[0] uve2 = tex[2] - tex[0] pe1 = pos[1] - pos[0] pe2 = pos[2] - pos[0] nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] # Avoid division by zero for degenerated texture coordinates tang = nom / torch.where( denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6) ) # Update all 3 vertices for i in range(0, 3): idx = vn_idx[i][:, None].repeat(1, 3) tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang tansum.scatter_add_( 0, idx, torch.ones_like(tang) ) # tansum[n_i] = tansum[n_i] + 1 tangents = tangents / tansum # Normalize and make sure tangent is perpendicular to normal tangents = F.normalize(tangents, dim=1) tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(tangents)) return tangents def _unwrap_uv( self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} ): import xatlas atlas = xatlas.Atlas() atlas.add_mesh( self.v_pos.detach().cpu().numpy(), self.t_pos_idx.cpu().numpy(), ) co = xatlas.ChartOptions() po = xatlas.PackOptions() for k, v in xatlas_chart_options.items(): setattr(co, k, v) for k, v in xatlas_pack_options.items(): setattr(po, k, v) atlas.generate(co, po) vmapping, indices, uvs = atlas.get_mesh(0) vmapping = ( torch.from_numpy( vmapping.astype(np.uint64, casting="same_kind").view(np.int64) ) .to(self.v_pos.device) .long() ) uvs = torch.from_numpy(uvs).to(self.v_pos.device).float() indices = ( torch.from_numpy( indices.astype(np.uint64, casting="same_kind").view(np.int64) ) .to(self.v_pos.device) .long() ) return uvs, indices def unwrap_uv( self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} ): self._v_tex, self._t_tex_idx = self._unwrap_uv( xatlas_chart_options, xatlas_pack_options ) def set_vertex_color(self, v_rgb): assert v_rgb.shape[0] == self.v_pos.shape[0] self._v_rgb = v_rgb def _compute_edges(self): # Compute edges edges = torch.cat( [ self.t_pos_idx[:, [0, 1]], self.t_pos_idx[:, [1, 2]], self.t_pos_idx[:, [2, 0]], ], dim=0, ) edges = edges.sort()[0] edges = torch.unique(edges, dim=0) return edges def normal_consistency(self): edge_nrm = self.v_nrm[self.edges] nc = ( 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1) ).mean() return nc def _laplacian_uniform(self): # from stable-dreamfusion # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224 verts, faces = self.v_pos, self.t_pos_idx V = verts.shape[0] F = faces.shape[0] # Neighbor indices ii = faces[:, [1, 2, 0]].flatten() jj = faces[:, [2, 0, 1]].flatten() adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique( dim=1 ) adj_values = torch.ones(adj.shape[1]).to(verts) # Diagonal indices diag_idx = adj[0] # Build the sparse matrix idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) values = torch.cat((-adj_values, adj_values)) # The coalesce operation sums the duplicate indices, resulting in the # correct diagonal return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() def laplacian(self): with torch.no_grad(): L = self._laplacian_uniform() loss = L.mm(self.v_pos) loss = loss.norm(dim=1) loss = loss.mean() return loss def to(self, device): v_pos = self.v_pos.to(device) t_pos_idx = self.t_pos_idx.to(device) return Mesh(v_pos, t_pos_idx) def as_trimesh(self): vertices = self.v_pos.detach().cpu().numpy() faces = self.t_pos_idx.detach().cpu().numpy() mesh = trimesh.Trimesh( vertices=vertices, faces=faces, process=False ) # Add texture if available if hasattr(self, 'albedo_map') and self.albedo_map is not None: # Create texture visuals uv = self.v_tex.detach().cpu().numpy() # Create texture visuals visual = trimesh.visual.texture.TextureVisuals( uv=uv, material=trimesh.visual.material.SimpleMaterial() ) mesh.visual = visual return mesh def scale_tensor(x, input_range, target_range): """Scale tensor from input_range to target_range.""" x_unit = (x - input_range[0]) / (input_range[1] - input_range[0]) x_scaled = x_unit * (target_range[1] - target_range[0]) + target_range[0] return x_scaled