kairunwen commited on
Commit
d0cb6f1
·
1 Parent(s): 54900f3

Update Model Loading

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -7,6 +7,7 @@ import argparse
7
  import gradio as gr
8
  import uuid
9
  import spaces
 
10
  #
11
 
12
  subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl"))
@@ -19,13 +20,21 @@ subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86
19
  from src.utils.visualization_utils import render_video_from_file
20
  from src.model import LSM_MASt3R
21
 
22
- model = LSM_MASt3R.from_pretrained("checkpoints/pretrained_model/checkpoint-40.pth")
 
 
 
 
 
 
 
 
23
  model = model.eval()
24
 
25
 
26
  @spaces.GPU(duration=80)
27
  def process(inputfiles, input_path=None):
28
- # 创建唯一的缓存目录
29
  cache_dir = os.path.join('outputs', str(uuid.uuid4()))
30
  os.makedirs(cache_dir, exist_ok=True)
31
 
@@ -43,7 +52,7 @@ def process(inputfiles, input_path=None):
43
  filelist = inputfiles
44
  if len(filelist) != 2:
45
  gr.Warning("Please select 2 images")
46
- shutil.rmtree(cache_dir) # 清理缓存目录
47
  return None, None, None, None, None, None
48
 
49
  ply_path = os.path.join(cache_dir, 'gaussians.ply')
@@ -96,8 +105,8 @@ with block:
96
  show_label=False,
97
  elem_id="gallery",
98
  columns=[2],
99
- height=300, # 固定高度
100
- object_fit="cover" # 确保图片填满空间
101
  )
102
 
103
  button_gen = gr.Button("Start Reconstruction", elem_id="button_gen")
@@ -118,7 +127,7 @@ with block:
118
  output_model = gr.Model3D(
119
  label="3D Dense Model under Gaussian Splats Formats, need more time to visualize",
120
  interactive=False,
121
- camera_position=[0.5, 0.5, 1], # 稍微偏移一点,以便更好地查看模型
122
  height=600,
123
  )
124
  gr.Markdown(
 
7
  import gradio as gr
8
  import uuid
9
  import spaces
10
+ from huggingface_hub import hf_hub_download
11
  #
12
 
13
  subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl"))
 
20
  from src.utils.visualization_utils import render_video_from_file
21
  from src.model import LSM_MASt3R
22
 
23
+ # Assuming your model has been uploaded to HuggingFace
24
+ model_repo = "kairunwen/LSM" # Replace with the actual repository name
25
+ model_filename = "checkpoint-40.pth" # Model filename
26
+
27
+ # Download model from HuggingFace
28
+ model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
29
+
30
+ # Load model
31
+ model = LSM_MASt3R.from_pretrained(model_path)
32
  model = model.eval()
33
 
34
 
35
  @spaces.GPU(duration=80)
36
  def process(inputfiles, input_path=None):
37
+ # Create a unique cache directory
38
  cache_dir = os.path.join('outputs', str(uuid.uuid4()))
39
  os.makedirs(cache_dir, exist_ok=True)
40
 
 
52
  filelist = inputfiles
53
  if len(filelist) != 2:
54
  gr.Warning("Please select 2 images")
55
+ shutil.rmtree(cache_dir) # Clean up cache directory
56
  return None, None, None, None, None, None
57
 
58
  ply_path = os.path.join(cache_dir, 'gaussians.ply')
 
105
  show_label=False,
106
  elem_id="gallery",
107
  columns=[2],
108
+ height=300, # Fixed height
109
+ object_fit="cover" # Ensure images fill the space
110
  )
111
 
112
  button_gen = gr.Button("Start Reconstruction", elem_id="button_gen")
 
127
  output_model = gr.Model3D(
128
  label="3D Dense Model under Gaussian Splats Formats, need more time to visualize",
129
  interactive=False,
130
+ camera_position=[0.5, 0.5, 1], # Slight offset for better model viewing
131
  height=600,
132
  )
133
  gr.Markdown(