|
import os, subprocess, shlex, sys, gc |
|
import time |
|
import torch |
|
import numpy as np |
|
import shutil |
|
import argparse |
|
import gradio as gr |
|
import uuid |
|
import spaces |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl")) |
|
subprocess.run(shlex.split("pip install wheel/flash_attn-2.6.3+cu123torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) |
|
subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")) |
|
subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl")) |
|
subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl")) |
|
subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86_64.whl")) |
|
|
|
from src.utils.visualization_utils import render_video_from_file |
|
from src.model import LSM_MASt3R |
|
|
|
|
|
repo_id = "Journey9ni/LSM" |
|
remote_dir = "checkpoints/pretrained_models" |
|
local_dir = "checkpoints/pretrained_model" |
|
model_path_map = { |
|
"MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth", |
|
"checkpoint-40.pth":"checkpoint-40.pth", |
|
"demo_e200.ckpt":"lang_seg.ckpt" |
|
} |
|
os.makedirs(local_dir, exist_ok=True) |
|
|
|
snapshot_download(repo_id=repo_id, local_dir='./') |
|
|
|
|
|
for remote_name, local_name in model_path_map.items(): |
|
os.rename(os.path.join(remote_dir, remote_name), os.path.join(local_dir, local_name)) |
|
|
|
|
|
model_path = "checkpoints/pretrained_model/checkpoint-40.pth" |
|
model = LSM_MASt3R.from_pretrained(model_path, device='cuda') |
|
model = model.eval() |
|
|
|
@spaces.GPU(duration=80) |
|
def process(inputfiles, input_path=None): |
|
|
|
cache_dir = os.path.join('outputs', str(uuid.uuid4())) |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
if input_path is not None: |
|
imgs_path = './assets/examples/' + input_path |
|
imgs_names = sorted(os.listdir(imgs_path)) |
|
|
|
inputfiles = [] |
|
for imgs_name in imgs_names: |
|
file_path = os.path.join(imgs_path, imgs_name) |
|
print(file_path) |
|
inputfiles.append(file_path) |
|
print(inputfiles) |
|
|
|
filelist = inputfiles |
|
if len(filelist) != 2: |
|
gr.Warning("Please select 2 images") |
|
shutil.rmtree(cache_dir) |
|
return None, None, None, None, None, None |
|
|
|
ply_path = os.path.join(cache_dir, 'gaussians.ply') |
|
|
|
render_video_from_file(filelist, model, output_path=cache_dir, resolution=512) |
|
|
|
rgb_video_path = os.path.join(cache_dir, 'moved', 'output_images_video.mp4') |
|
depth_video_path = os.path.join(cache_dir, 'moved', 'output_depth_video.mp4') |
|
feature_video_path = os.path.join(cache_dir, 'moved', 'output_fmap_video.mp4') |
|
|
|
return filelist, rgb_video_path, depth_video_path, feature_video_path, ply_path, ply_path |
|
|
|
|
|
_TITLE = 'LargeSpatialModel' |
|
_DESCRIPTION = ''' |
|
<div style="display: flex; justify-content: center; align-items: center;"> |
|
<div style="width: 100%; text-align: center; font-size: 30px;"> |
|
<strong>Large Spatial Model: End-to-end Unposed Images to Semantic 3D</strong> |
|
</div> |
|
</div> |
|
<p></p> |
|
|
|
<div align="center"> |
|
<a style="display:inline-block" href="https://arxiv.org/abs/2410.18956"><img src="https://img.shields.io/badge/ArXiv-2410.18956-b31b1b?logo=arxiv" alt='arxiv'></a> |
|
<a style="display:inline-block" href="https://largespatialmodel.github.io/"><img src='https://img.shields.io/badge/Project_Page-ff7512?logo=lightning'></a> |
|
<a title="Social" href="https://x.com/WayneINR" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> |
|
</a> |
|
|
|
</div> |
|
<p></p> |
|
|
|
* Official demo of: [LargeSpatialModel: End-to-end Unposed Images to Semantic 3D](https://largespatialmodel.github.io/). |
|
* Examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data. |
|
''' |
|
|
|
block = gr.Blocks().queue() |
|
with block: |
|
gr.Markdown(_DESCRIPTION) |
|
|
|
with gr.Column(variant="panel"): |
|
with gr.Tab("Input"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
inputfiles = gr.File(file_count="multiple", label="Load Images") |
|
input_path = gr.Textbox(visible=False, label="example_path") |
|
with gr.Column(scale=1): |
|
image_gallery = gr.Gallery( |
|
label="Gallery", |
|
show_label=False, |
|
elem_id="gallery", |
|
columns=[2], |
|
height=300, |
|
object_fit="cover" |
|
) |
|
|
|
button_gen = gr.Button("Start Reconstruction", elem_id="button_gen") |
|
processing_msg = gr.Markdown("Processing...", visible=False, elem_id="processing_msg") |
|
|
|
|
|
with gr.Column(variant="panel"): |
|
with gr.Tab("Output"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
rgb_video = gr.Video(label="RGB Video", autoplay=True) |
|
with gr.Column(scale=1): |
|
feature_video = gr.Video(label="Feature Video", autoplay=True) |
|
with gr.Column(scale=1): |
|
depth_video = gr.Video(label="Depth Video", autoplay=True) |
|
with gr.Row(): |
|
with gr.Group(): |
|
output_model = gr.Model3D( |
|
label="3D Dense Model under Gaussian Splats Formats, need more time to visualize", |
|
interactive=False, |
|
camera_position=[0.5, 0.5, 1], |
|
height=600, |
|
) |
|
gr.Markdown( |
|
""" |
|
<div class="model-description"> |
|
Use the left mouse button to rotate, the scroll wheel to zoom, and the right mouse button to move. |
|
</div> |
|
""" |
|
) |
|
with gr.Row(): |
|
output_file = gr.File(label="PLY File") |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
"sofa", |
|
], |
|
inputs=[input_path], |
|
outputs=[image_gallery, rgb_video, depth_video, feature_video, output_model, output_file], |
|
fn=lambda x: process(inputfiles=None, input_path=x), |
|
cache_examples=True, |
|
label="Examples" |
|
) |
|
|
|
|
|
button_gen.click( |
|
process, |
|
inputs=[inputfiles], |
|
outputs=[image_gallery, rgb_video, depth_video, feature_video, output_model, output_file], |
|
) |
|
|
|
block.launch(server_name="0.0.0.0", share=False) |
|
|