File size: 2,428 Bytes
eebae35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import torch
import trimesh
from PIL import Image
from skimage import measure

from detailgen3d.inference_utils import generate_dense_grid_points
from detailgen3d.pipelines.pipeline_detailgen3d import (
    DetailGen3DPipeline,
)

def load_mesh(mesh_path, num_pc=20480):
    mesh = trimesh.load(mesh_path,force="mesh")

    center = mesh.bounding_box.centroid
    mesh.apply_translation(-center)
    scale = max(mesh.bounding_box.extents)
    mesh.apply_scale(1.9 / scale)

    surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000,)
    normal = mesh.face_normals[face_indices]

    rng = np.random.default_rng()
    ind = rng.choice(surface.shape[0], num_pc, replace=False)
    surface = torch.FloatTensor(surface[ind])
    normal = torch.FloatTensor(normal[ind])
    surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()

    return surface

if __name__ == "__main__":
    device = "cuda"
    dtype = torch.float16

    # prepare pipeline
    pipeline = DetailGen3DPipeline.from_pretrained(
        "VAST-AI/DetailGen3D",
        low_cpu_mem_usage=False
    ).to(device, dtype=dtype)

    # prepare data
    image_path = "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png"
    image = Image.open(image_path).convert("RGB")

    mesh_path = "assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb"
    surface = load_mesh(mesh_path).to(device, dtype=dtype)

    batch_size = 1

    # sample query points for decoding
    box_min = np.array([-1.005, -1.005, -1.005])
    box_max = np.array([1.005, 1.005, 1.005])
    sampled_points, grid_size, bbox_size = generate_dense_grid_points(
        bbox_min=box_min, bbox_max=box_max, octree_depth=9, indexing="ij"
    )
    sampled_points = torch.FloatTensor(sampled_points).to(device, dtype=dtype)
    sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)

    # inference pipeline
    sample = pipeline.vae.encode(surface).latent_dist.sample()
    sdf = pipeline(image, latents=sample, sampled_points=sampled_points, noise_aug_level=0).samples[0]

    # marching cubes
    grid_logits = sdf.view(grid_size).cpu().numpy()
    vertices, faces, normals, _ = measure.marching_cubes(
        grid_logits, 0, method="lewiner"
    )
    vertices = vertices / grid_size * bbox_size + box_min
    mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
    mesh.export("output.glb", file_type="glb")