Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import trimesh | |
from PIL import Image | |
from skimage import measure | |
from detailgen3d.inference_utils import generate_dense_grid_points | |
from detailgen3d.pipelines.pipeline_detailgen3d import ( | |
DetailGen3DPipeline, | |
) | |
def load_mesh(mesh_path, num_pc=20480): | |
mesh = trimesh.load(mesh_path,force="mesh") | |
center = mesh.bounding_box.centroid | |
mesh.apply_translation(-center) | |
scale = max(mesh.bounding_box.extents) | |
mesh.apply_scale(1.9 / scale) | |
surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000,) | |
normal = mesh.face_normals[face_indices] | |
rng = np.random.default_rng() | |
ind = rng.choice(surface.shape[0], num_pc, replace=False) | |
surface = torch.FloatTensor(surface[ind]) | |
normal = torch.FloatTensor(normal[ind]) | |
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() | |
return surface | |
if __name__ == "__main__": | |
device = "cuda" | |
dtype = torch.float16 | |
# prepare pipeline | |
pipeline = DetailGen3DPipeline.from_pretrained( | |
"VAST-AI/DetailGen3D", | |
low_cpu_mem_usage=False | |
).to(device, dtype=dtype) | |
# prepare data | |
image_path = "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png" | |
image = Image.open(image_path).convert("RGB") | |
mesh_path = "assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb" | |
surface = load_mesh(mesh_path).to(device, dtype=dtype) | |
batch_size = 1 | |
# sample query points for decoding | |
box_min = np.array([-1.005, -1.005, -1.005]) | |
box_max = np.array([1.005, 1.005, 1.005]) | |
sampled_points, grid_size, bbox_size = generate_dense_grid_points( | |
bbox_min=box_min, bbox_max=box_max, octree_depth=9, indexing="ij" | |
) | |
sampled_points = torch.FloatTensor(sampled_points).to(device, dtype=dtype) | |
sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1) | |
# inference pipeline | |
sample = pipeline.vae.encode(surface).latent_dist.sample() | |
sdf = pipeline(image, latents=sample, sampled_points=sampled_points, noise_aug_level=0).samples[0] | |
# marching cubes | |
grid_logits = sdf.view(grid_size).cpu().numpy() | |
vertices, faces, normals, _ = measure.marching_cubes( | |
grid_logits, 0, method="lewiner" | |
) | |
vertices = vertices / grid_size * bbox_size + box_min | |
mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces)) | |
mesh.export("output.glb", file_type="glb") | |