|
import argparse |
|
import gradio as gr |
|
import os |
|
import torch |
|
import trimesh |
|
import sys |
|
from pathlib import Path |
|
|
|
pathdir = Path(__file__).parent / 'cube' |
|
sys.path.append(pathdir.as_posix()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
from cube3d.inference.engine import EngineFast, Engine |
|
from pathlib import Path |
|
import uuid |
|
import shutil |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
GLOBAL_STATE = {} |
|
|
|
def gen_save_folder(max_size=200): |
|
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True) |
|
|
|
dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()] |
|
|
|
if len(dirs) >= max_size: |
|
oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime) |
|
shutil.rmtree(oldest_dir) |
|
print(f"Removed the oldest folder: {oldest_dir}") |
|
|
|
new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4())) |
|
os.makedirs(new_folder, exist_ok=True) |
|
print(f"Created new folder: {new_folder}") |
|
|
|
return new_folder |
|
|
|
def handle_text_prompt(input_prompt, variance = 0): |
|
print(f"prompt: {input_prompt}, variance: {variance}") |
|
top_p = None if variance == 0 else (100 - variance) / 100.0 |
|
mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p) |
|
|
|
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] |
|
save_folder = gen_save_folder() |
|
output_path = os.path.join(save_folder, "output.glb") |
|
trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path) |
|
return output_path |
|
|
|
def build_interface(): |
|
"""Build UI for gradio app |
|
""" |
|
title = "Cube 3D" |
|
with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface: |
|
gr.Markdown( |
|
f""" |
|
# {title} |
|
# Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine! |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Group(): |
|
input_text_box = gr.Textbox( |
|
value=None, |
|
label="Prompt", |
|
lines=2, |
|
) |
|
variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance") |
|
with gr.Row(): |
|
submit_button = gr.Button("Submit", variant="primary") |
|
with gr.Column(scale=3): |
|
model3d = gr.Model3D( |
|
label="Output", height="45em", interactive=False |
|
) |
|
|
|
submit_button.click( |
|
handle_text_prompt, |
|
inputs=[ |
|
input_text_box, |
|
variance |
|
], |
|
outputs=[ |
|
model3d |
|
] |
|
) |
|
|
|
return interface |
|
|
|
if __name__=="__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config_path", |
|
type=str, |
|
help="Path to the config file", |
|
default="cube/cube3d/configs/open_model.yaml", |
|
) |
|
parser.add_argument( |
|
"--gpt_ckpt_path", |
|
type=str, |
|
help="Path to the gpt ckpt path", |
|
default="model_weights/shape_gpt.safetensors", |
|
) |
|
parser.add_argument( |
|
"--shape_ckpt_path", |
|
type=str, |
|
help="Path to the shape ckpt path", |
|
default="model_weights/shape_tokenizer.safetensors", |
|
) |
|
parser.add_argument( |
|
"--save_dir", |
|
type=str, |
|
default="gradio_save_dir", |
|
) |
|
|
|
args = parser.parse_args() |
|
snapshot_download( |
|
repo_id="Roblox/cube3d-v0.1", |
|
local_dir="./model_weights" |
|
) |
|
config_path = args.config_path |
|
gpt_ckpt_path = "./model_weights/shape_gpt.safetensors" |
|
shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors" |
|
engine_fast = EngineFast( |
|
config_path, |
|
gpt_ckpt_path, |
|
shape_ckpt_path, |
|
device=torch.device("cuda"), |
|
) |
|
GLOBAL_STATE["engine_fast"] = engine_fast |
|
GLOBAL_STATE["SAVE_DIR"] = args.save_dir |
|
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True) |
|
|
|
demo = build_interface() |
|
demo.queue(default_concurrency_limit=1) |
|
demo.launch() |
|
|