Aluren commited on
Commit
e470c5b
·
verified ·
1 Parent(s): a85db8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -170
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import os
2
  import random
3
  import tempfile
4
- from typing import Any, List, Union
5
 
6
  import spaces
7
  import gradio as gr
8
  import numpy as np
9
  import torch
10
- # from gradio_image_prompter import ImagePrompter
11
  from gradio_litmodel3d import LitModel3D
12
  from huggingface_hub import snapshot_download
13
  from PIL import Image
@@ -22,113 +21,52 @@ MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
23
  DTYPE = torch.bfloat16
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
- REPO_ID = "VAST-AI/DetailGen3D" # 似乎还没有
26
 
27
  MARKDOWN = """
28
  ## Generating geometry details guided by reference image with [DetailGen3D](https://detailgen3d.github.io/DetailGen3D/)
29
- 1. Upload a detailed image of the frontal view and a coarse model. Then clik "Run " to generate the refined result.
30
- 2. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
31
- 3. If you want the refine result to be more consistent with the image, please manually increase the CFG strength.
32
  """
 
33
  EXAMPLES = [
34
- ["assets/image/100.png","assets/model/100.glb",42,False,]
 
 
 
 
 
35
  ]
36
- # EXAMPLES = [
37
- # [
38
- # # {
39
- # # "image": "assets/image/100.png",
40
- # # },
41
- # "assets/image/100.png",
42
- # "assets/model/100.glb",
43
- # 42,
44
- # False,
45
- # ],
46
- # [
47
- # {
48
- # "image": "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png",
49
- # },
50
- # "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png",
51
- # "assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb",
52
- # 42,
53
- # False,
54
- # ],
55
- # [
56
- # {
57
- # "image": "assets/image/34933195-9c2c-4271-8d31-a28bc5348b7a.png",
58
- # },
59
- # "assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb",
60
- # 42,
61
- # False,
62
- # ],
63
- # [
64
- # {
65
- # "image": "assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png",
66
- # },
67
- # "assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb",
68
- # 42,
69
- # False,
70
- # ],
71
- # [
72
- # {
73
- # "image": "assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png",
74
- # },
75
- # "assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb",
76
- # 42,
77
- # False,
78
- # ],
79
- # [
80
- # {
81
- # "image": "assets/image/e799e6b4-3b47-40e0-befb-b156af8758ad.png",
82
- # },
83
- # "assets/model/instant3d/e799e6b4-3b47-40e0-befb-b156af8758ad.glb",
84
- # 42,
85
- # False,
86
- # ],
87
- # ]
88
 
89
  os.makedirs(TMP_DIR, exist_ok=True)
90
-
91
  local_dir = "pretrained_weights/DetailGen3D"
92
  snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
93
- pipeline = DetailGen3DPipeline.from_pretrained(
94
- local_dir
95
- ).to(DEVICE, dtype=DTYPE)
96
-
97
 
98
  def load_mesh(mesh_path, num_pc=20480):
99
- mesh = trimesh.load(mesh_path,force="mesh")
100
-
101
  center = mesh.bounding_box.centroid
102
  mesh.apply_translation(-center)
103
  scale = max(mesh.bounding_box.extents)
104
  mesh.apply_scale(1.9 / scale)
105
-
106
- surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000,)
107
  normal = mesh.face_normals[face_indices]
108
-
109
  rng = np.random.default_rng()
110
  ind = rng.choice(surface.shape[0], num_pc, replace=False)
111
  surface = torch.FloatTensor(surface[ind])
112
  normal = torch.FloatTensor(normal[ind])
113
- surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
114
-
115
- return surface
116
 
117
  @torch.no_grad()
118
  @torch.autocast(device_type=DEVICE)
119
- def run_detailgen3d(
120
- pipeline,
121
- image,
122
- mesh,
123
- seed,
124
- num_inference_steps,
125
- guidance_scale,
126
- ):
127
  surface = load_mesh(mesh)
128
-
129
  batch_size = 1
130
 
131
- # sample query points for decoding
132
  box_min = np.array([-1.005, -1.005, -1.005])
133
  box_max = np.array([1.005, 1.005, 1.005])
134
  sampled_points, grid_size, bbox_size = generate_dense_grid_points(
@@ -137,25 +75,27 @@ def run_detailgen3d(
137
  sampled_points = torch.FloatTensor(sampled_points).to(DEVICE, dtype=DTYPE)
138
  sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
139
 
140
- # inference pipeline
141
  sample = pipeline.vae.encode(surface).latent_dist.sample()
142
- occ = pipeline(image, latents=sample, sampled_points=sampled_points, guidance_scale=guidance_scale, noise_aug_level=0, num_inference_steps=num_inference_steps).samples[0]
143
-
144
- # marching cubes
 
 
 
 
 
 
 
145
  grid_logits = occ.view(grid_size).cpu().numpy()
146
- vertices, faces, normals, _ = measure.marching_cubes(
147
- grid_logits, 0, method="lewiner"
148
- )
149
  vertices = vertices / grid_size * bbox_size + box_min
150
- mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
151
- return mesh
152
 
153
  @spaces.GPU(duration=180)
154
- @torch.no_grad()
155
- @torch.autocast(device_type=DEVICE)
156
  def run_refinement(
157
- rgb_image: Any,
158
- mesh: Any,
159
  seed: int,
160
  randomize_seed: bool = False,
161
  num_inference_steps: int = 50,
@@ -163,92 +103,88 @@ def run_refinement(
163
  ):
164
  if randomize_seed:
165
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- # print("rgb_image", rgb_image)
168
- # print("mesh", rgb_image)
169
-
170
- # if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
171
- # rgb_image = Image.open(rgb_image["image"]).convert("RGB")
172
-
173
- rgb_image = Image.open(rgb_image).convert("RGB")
174
-
175
- scene = run_detailgen3d(
176
- pipeline,
177
- rgb_image,
178
- mesh,
179
- seed,
180
- num_inference_steps,
181
- guidance_scale,
182
- )
183
-
184
- _, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="detailgen3d_", dir=TMP_DIR)
185
- scene.export(tmp_path)
186
-
187
- torch.cuda.empty_cache()
188
-
189
- return tmp_path, tmp_path, seed
190
 
191
- # Demo
192
  with gr.Blocks() as demo:
193
  gr.Markdown(MARKDOWN)
194
 
195
  with gr.Row():
196
  with gr.Column():
197
  with gr.Row():
198
- # image_prompts = ImagePrompter(label="Input Image", type="pil")
199
- image_prompts = gr.Image(label="Example Image", type="filepath")
200
- mesh = gr.Model3D(label="Input Coarse Model",camera_position=(90,90,3))
201
-
202
- with gr.Accordion("Generation Settings", open=False):
203
- seed = gr.Slider(
204
- label="Seed",
205
- minimum=0,
206
- maximum=MAX_SEED,
207
- step=1,
208
- value=0,
209
- )
210
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
211
- num_inference_steps = gr.Slider(
212
- label="Number of inference steps",
213
- minimum=1,
214
- maximum=50,
215
- step=1,
216
- value=50,
217
  )
218
- guidance_scale = gr.Slider(
219
- label="CFG scale",
220
- minimum=0.0,
221
- maximum=50.0,
222
- step=0.1,
223
- value=4.0,
224
  )
225
- gen_button = gr.Button("Generate details", variant="primary")
226
-
227
- with gr.Column():
228
- model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500,camera_position=(90,90,3))
229
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
230
 
231
- with gr.Row():
232
- gr.Examples(
233
- examples=EXAMPLES,
234
- fn=run_refinement,
235
- inputs=[image_prompts, mesh, seed, randomize_seed],
236
- outputs=[model_output, download_glb, seed],
237
- cache_examples=False,
238
- )
239
 
240
- gen_button.click(
241
- run_refinement,
242
- inputs=[
243
- image_prompts,
244
- mesh,
245
- seed,
246
- randomize_seed,
247
- num_inference_steps,
248
- guidance_scale,
249
- ],
250
- outputs=[model_output, download_glb, seed],
251
- ).then(lambda: gr.Button(interactive=True), outputs=[download_glb])
 
 
 
 
 
 
 
 
 
252
 
 
 
 
 
 
 
 
 
 
253
 
254
- demo.launch()
 
1
  import os
2
  import random
3
  import tempfile
4
+ from typing import Any, List
5
 
6
  import spaces
7
  import gradio as gr
8
  import numpy as np
9
  import torch
 
10
  from gradio_litmodel3d import LitModel3D
11
  from huggingface_hub import snapshot_download
12
  from PIL import Image
 
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
22
  DTYPE = torch.bfloat16
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ REPO_ID = "VAST-AI/DetailGen3D"
25
 
26
  MARKDOWN = """
27
  ## Generating geometry details guided by reference image with [DetailGen3D](https://detailgen3d.github.io/DetailGen3D/)
28
+ 1. Upload a detailed image of the frontal view and a coarse model. Then click "Run" to generate the refined result.
29
+ 2. If satisfied, download the result using the "Download GLB" button.
30
+ 3. Increase CFG strength for better image consistency.
31
  """
32
+
33
  EXAMPLES = [
34
+ [
35
+ "assets/image/100.png",
36
+ "assets/model/100.glb",
37
+ 42,
38
+ False
39
+ ]
40
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  os.makedirs(TMP_DIR, exist_ok=True)
 
43
  local_dir = "pretrained_weights/DetailGen3D"
44
  snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
45
+ pipeline = DetailGen3DPipeline.from_pretrained(local_dir).to(DEVICE, dtype=DTYPE)
 
 
 
46
 
47
  def load_mesh(mesh_path, num_pc=20480):
48
+ mesh = trimesh.load(mesh_path, force="mesh")
 
49
  center = mesh.bounding_box.centroid
50
  mesh.apply_translation(-center)
51
  scale = max(mesh.bounding_box.extents)
52
  mesh.apply_scale(1.9 / scale)
53
+
54
+ surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000)
55
  normal = mesh.face_normals[face_indices]
56
+
57
  rng = np.random.default_rng()
58
  ind = rng.choice(surface.shape[0], num_pc, replace=False)
59
  surface = torch.FloatTensor(surface[ind])
60
  normal = torch.FloatTensor(normal[ind])
61
+ return torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
 
 
62
 
63
  @torch.no_grad()
64
  @torch.autocast(device_type=DEVICE)
65
+ def run_detailgen3d(pipeline, image, mesh, seed, num_inference_steps, guidance_scale):
 
 
 
 
 
 
 
66
  surface = load_mesh(mesh)
 
67
  batch_size = 1
68
 
69
+ # Grid generation
70
  box_min = np.array([-1.005, -1.005, -1.005])
71
  box_max = np.array([1.005, 1.005, 1.005])
72
  sampled_points, grid_size, bbox_size = generate_dense_grid_points(
 
75
  sampled_points = torch.FloatTensor(sampled_points).to(DEVICE, dtype=DTYPE)
76
  sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
77
 
78
+ # Pipeline execution
79
  sample = pipeline.vae.encode(surface).latent_dist.sample()
80
+ occ = pipeline(
81
+ image,
82
+ latents=sample,
83
+ sampled_points=sampled_points,
84
+ guidance_scale=guidance_scale,
85
+ noise_aug_level=0,
86
+ num_inference_steps=num_inference_steps
87
+ ).samples[0]
88
+
89
+ # Mesh processing
90
  grid_logits = occ.view(grid_size).cpu().numpy()
91
+ vertices, faces, normals, _ = measure.marching_cubes(grid_logits, 0, method="lewiner")
 
 
92
  vertices = vertices / grid_size * bbox_size + box_min
93
+ return trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
 
94
 
95
  @spaces.GPU(duration=180)
 
 
96
  def run_refinement(
97
+ image_path: str,
98
+ mesh_path: str,
99
  seed: int,
100
  randomize_seed: bool = False,
101
  num_inference_steps: int = 50,
 
103
  ):
104
  if randomize_seed:
105
  seed = random.randint(0, MAX_SEED)
106
+
107
+ try:
108
+ # Validate inputs
109
+ if not os.path.exists(image_path):
110
+ raise ValueError(f"Image path {image_path} not found")
111
+ if not os.path.exists(mesh_path):
112
+ raise ValueError(f"Mesh path {mesh_path} not found")
113
+
114
+ image = Image.open(image_path).convert("RGB")
115
+ scene = run_detailgen3d(
116
+ pipeline,
117
+ image,
118
+ mesh_path,
119
+ seed,
120
+ num_inference_steps,
121
+ guidance_scale,
122
+ )
123
 
124
+ # Save temporary result
125
+ _, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="detailgen3d_", dir=TMP_DIR)
126
+ scene.export(tmp_path)
127
+
128
+ return tmp_path, tmp_path, seed
129
+
130
+ finally:
131
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # Demo interface
134
  with gr.Blocks() as demo:
135
  gr.Markdown(MARKDOWN)
136
 
137
  with gr.Row():
138
  with gr.Column():
139
  with gr.Row():
140
+ image_input = gr.Image(
141
+ label="Reference Image",
142
+ type="filepath",
143
+ sources=["upload", "clipboard"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
+ mesh_input = gr.Model3D(
146
+ label="Input Model",
147
+ camera_position=(90, 90, 3)
 
 
 
148
  )
 
 
 
 
 
149
 
150
+ with gr.Accordion("Advanced Settings", open=False):
151
+ seed_input = gr.Slider(0, MAX_SEED, value=0, label="Seed")
152
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
153
+ steps_input = gr.Slider(1, 100, value=50, step=1, label="Inference Steps")
154
+ cfg_scale = gr.Slider(0.0, 20.0, value=4.0, step=0.1, label="CFG Scale")
155
+
156
+ run_btn = gr.Button("Generate", variant="primary")
 
157
 
158
+ with gr.Column():
159
+ model_output = LitModel3D(
160
+ label="Result Preview",
161
+ height=500,
162
+ camera_position=(90, 90, 3)
163
+ )
164
+ download_btn = gr.DownloadButton(
165
+ "Download GLB",
166
+ file_count="multiple",
167
+ interactive=False
168
+ )
169
+
170
+ # Examples section
171
+ gr.Examples(
172
+ examples=EXAMPLES,
173
+ inputs=[image_input, mesh_input, seed_input, randomize_seed],
174
+ outputs=[model_output, download_btn, seed_input],
175
+ fn=run_refinement,
176
+ cache_examples=False,
177
+ label="Example Inputs"
178
+ )
179
 
180
+ # Event handling
181
+ run_btn.click(
182
+ run_refinement,
183
+ inputs=[image_input, mesh_input, seed_input, randomize_seed, steps_input, cfg_scale],
184
+ outputs=[model_output, download_btn, seed_input]
185
+ ).then(
186
+ lambda: gr.DownloadButton(interactive=True),
187
+ outputs=[download_btn]
188
+ )
189
 
190
+ demo.launch()