|
import argparse |
|
import os |
|
|
|
import torch |
|
import trimesh |
|
|
|
from cube3d.inference.engine import Engine, EngineFast |
|
from cube3d.mesh_utils.postprocessing import ( |
|
PYMESHLAB_AVAILABLE, |
|
create_pymeshset, |
|
postprocess_mesh, |
|
save_mesh, |
|
) |
|
from cube3d.renderer import renderer |
|
|
|
def generate_mesh( |
|
engine, |
|
prompt, |
|
output_dir, |
|
output_name, |
|
resolution_base=8.0, |
|
disable_postprocess=False, |
|
top_p=None, |
|
): |
|
mesh_v_f = engine.t2s( |
|
[prompt], |
|
use_kv_cache=True, |
|
resolution_base=resolution_base, |
|
top_p=top_p, |
|
) |
|
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] |
|
obj_path = os.path.join(output_dir, f"{output_name}.obj") |
|
if PYMESHLAB_AVAILABLE: |
|
ms = create_pymeshset(vertices, faces) |
|
if not disable_postprocess: |
|
target_face_num = max(10000, int(faces.shape[0] * 0.1)) |
|
print(f"Postprocessing mesh to {target_face_num} faces") |
|
postprocess_mesh(ms, target_face_num, obj_path) |
|
|
|
save_mesh(ms, obj_path) |
|
else: |
|
print( |
|
"WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing." |
|
) |
|
mesh = trimesh.Trimesh(vertices, faces) |
|
mesh.export(obj_path) |
|
|
|
return obj_path |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="cube shape generation script") |
|
parser.add_argument( |
|
"--config-path", |
|
type=str, |
|
default="cube3d/configs/open_model.yaml", |
|
help="Path to the configuration YAML file.", |
|
) |
|
parser.add_argument( |
|
"--output-dir", |
|
type=str, |
|
default="outputs/", |
|
help="Path to the output directory to store .obj and .gif files", |
|
) |
|
parser.add_argument( |
|
"--gpt-ckpt-path", |
|
type=str, |
|
required=True, |
|
help="Path to the main GPT checkpoint file.", |
|
) |
|
parser.add_argument( |
|
"--shape-ckpt-path", |
|
type=str, |
|
required=True, |
|
help="Path to the shape encoder/decoder checkpoint file.", |
|
) |
|
parser.add_argument( |
|
"--fast-inference", |
|
help="Use optimized inference", |
|
default=False, |
|
action="store_true", |
|
) |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
required=True, |
|
help="Text prompt for generating a 3D mesh", |
|
) |
|
parser.add_argument( |
|
"--top-p", |
|
type=float, |
|
default=None, |
|
help="Float < 1: Keep smallest set of tokens with cumulative probability β₯ top_p. Default None: deterministic generation.", |
|
) |
|
parser.add_argument( |
|
"--render-gif", |
|
help="Render a turntable gif of the mesh", |
|
default=False, |
|
action="store_true", |
|
) |
|
parser.add_argument( |
|
"--disable-postprocessing", |
|
help="Disable postprocessing on the mesh. This will result in a mesh with more faces.", |
|
default=False, |
|
action="store_true", |
|
) |
|
parser.add_argument( |
|
"--resolution-base", |
|
type=float, |
|
default=8.0, |
|
help="Resolution base for the shape decoder.", |
|
) |
|
args = parser.parse_args() |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
print(f"Using device: {device}") |
|
|
|
if args.fast_inference: |
|
print( |
|
"Using cuda graphs, this will take some time to warmup and capture the graph." |
|
) |
|
engine = EngineFast( |
|
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device |
|
) |
|
print("Compiled the graph.") |
|
else: |
|
engine = Engine( |
|
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device |
|
) |
|
|
|
|
|
obj_path = generate_mesh( |
|
engine, |
|
args.prompt, |
|
args.output_dir, |
|
"output", |
|
args.resolution_base, |
|
args.disable_postprocessing, |
|
args.top_p, |
|
) |
|
if args.render_gif: |
|
gif_path = renderer.render_turntable(obj_path, args.output_dir) |
|
print(f"Rendered turntable gif for {args.prompt} at `{gif_path}`") |
|
print(f"Generated mesh for {args.prompt} at `{obj_path}`") |
|
|