Aluren commited on
Commit
9d8c0ee
·
verified ·
1 Parent(s): b3b2a19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -105
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import os
2
  import random
3
  import tempfile
4
- from typing import Any, List
5
 
6
  import spaces
7
  import gradio as gr
8
  import numpy as np
9
  import torch
 
10
  from gradio_litmodel3d import LitModel3D
11
  from huggingface_hub import snapshot_download
12
  from PIL import Image
@@ -21,52 +22,113 @@ MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
22
  DTYPE = torch.bfloat16
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
- REPO_ID = "VAST-AI/DetailGen3D"
25
 
26
  MARKDOWN = """
27
  ## Generating geometry details guided by reference image with [DetailGen3D](https://detailgen3d.github.io/DetailGen3D/)
28
- 1. Upload a detailed image of the frontal view and a coarse model. Then click "Run" to generate the refined result.
29
- 2. If satisfied, download the result using the "Download GLB" button.
30
- 3. Increase CFG strength for better image consistency.
31
  """
32
-
33
  EXAMPLES = [
34
- [
35
- "assets/image/100.png",
36
- "assets/model/100.glb",
37
- 42,
38
- False
39
- ]
40
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  os.makedirs(TMP_DIR, exist_ok=True)
 
43
  local_dir = "pretrained_weights/DetailGen3D"
44
  snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
45
- pipeline = DetailGen3DPipeline.from_pretrained(local_dir).to(DEVICE, dtype=DTYPE)
 
 
 
46
 
47
  def load_mesh(mesh_path, num_pc=20480):
48
- mesh = trimesh.load(mesh_path, force="mesh")
 
49
  center = mesh.bounding_box.centroid
50
  mesh.apply_translation(-center)
51
  scale = max(mesh.bounding_box.extents)
52
  mesh.apply_scale(1.9 / scale)
53
-
54
- surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000)
55
  normal = mesh.face_normals[face_indices]
56
-
57
  rng = np.random.default_rng()
58
  ind = rng.choice(surface.shape[0], num_pc, replace=False)
59
  surface = torch.FloatTensor(surface[ind])
60
  normal = torch.FloatTensor(normal[ind])
61
- return torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
 
 
62
 
63
  @torch.no_grad()
64
  @torch.autocast(device_type=DEVICE)
65
- def run_detailgen3d(pipeline, image, mesh, seed, num_inference_steps, guidance_scale):
 
 
 
 
 
 
 
66
  surface = load_mesh(mesh)
 
67
  batch_size = 1
68
 
69
- # Grid generation
70
  box_min = np.array([-1.005, -1.005, -1.005])
71
  box_max = np.array([1.005, 1.005, 1.005])
72
  sampled_points, grid_size, bbox_size = generate_dense_grid_points(
@@ -75,27 +137,25 @@ def run_detailgen3d(pipeline, image, mesh, seed, num_inference_steps, guidance_s
75
  sampled_points = torch.FloatTensor(sampled_points).to(DEVICE, dtype=DTYPE)
76
  sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
77
 
78
- # Pipeline execution
79
  sample = pipeline.vae.encode(surface).latent_dist.sample()
80
- occ = pipeline(
81
- image,
82
- latents=sample,
83
- sampled_points=sampled_points,
84
- guidance_scale=guidance_scale,
85
- noise_aug_level=0,
86
- num_inference_steps=num_inference_steps
87
- ).samples[0]
88
-
89
- # Mesh processing
90
  grid_logits = occ.view(grid_size).cpu().numpy()
91
- vertices, faces, normals, _ = measure.marching_cubes(grid_logits, 0, method="lewiner")
 
 
92
  vertices = vertices / grid_size * bbox_size + box_min
93
- return trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
 
94
 
95
  @spaces.GPU(duration=180)
 
 
96
  def run_refinement(
97
- image_path: str,
98
- mesh_path: str,
99
  seed: int,
100
  randomize_seed: bool = False,
101
  num_inference_steps: int = 50,
@@ -103,87 +163,92 @@ def run_refinement(
103
  ):
104
  if randomize_seed:
105
  seed = random.randint(0, MAX_SEED)
106
-
107
- try:
108
- # Validate inputs
109
- if not os.path.exists(image_path):
110
- raise ValueError(f"Image path {image_path} not found")
111
- if not os.path.exists(mesh_path):
112
- raise ValueError(f"Mesh path {mesh_path} not found")
113
-
114
- image = Image.open(image_path).convert("RGB")
115
- scene = run_detailgen3d(
116
- pipeline,
117
- image,
118
- mesh_path,
119
- seed,
120
- num_inference_steps,
121
- guidance_scale,
122
- )
123
 
124
- # Save temporary result
125
- _, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="detailgen3d_", dir=TMP_DIR)
126
- scene.export(tmp_path)
127
-
128
- return tmp_path, tmp_path, seed
129
-
130
- finally:
131
- torch.cuda.empty_cache()
132
 
133
- # Demo interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  with gr.Blocks() as demo:
135
  gr.Markdown(MARKDOWN)
136
 
137
  with gr.Row():
138
  with gr.Column():
139
  with gr.Row():
140
- image_input = gr.Image(
141
- label="Reference Image",
142
- type="filepath",
143
- sources=["upload", "clipboard"],
 
 
 
 
 
 
 
144
  )
145
- mesh_input = gr.Model3D(
146
- label="Input Model",
147
- camera_position=(90, 90, 3)
 
 
 
 
148
  )
149
-
150
- with gr.Accordion("Advanced Settings", open=False):
151
- seed_input = gr.Slider(0, MAX_SEED, value=0, label="Seed")
152
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
153
- steps_input = gr.Slider(1, 100, value=50, step=1, label="Inference Steps")
154
- cfg_scale = gr.Slider(0.0, 20.0, value=4.0, step=0.1, label="CFG Scale")
155
-
156
- run_btn = gr.Button("Generate", variant="primary")
157
-
158
  with gr.Column():
159
- model_output = LitModel3D(
160
- label="Result Preview",
161
- height=500,
162
- camera_position=(90, 90, 3)
163
- )
164
- download_btn = gr.DownloadButton(
165
- "Download GLB",
166
- interactive=False
167
- )
168
-
169
- # Examples section
170
- gr.Examples(
171
- examples=EXAMPLES,
172
- inputs=[image_input, mesh_input, seed_input, randomize_seed],
173
- outputs=[model_output, download_btn, seed_input],
174
- fn=run_refinement,
175
- cache_examples=False,
176
- label="Example Inputs"
177
- )
178
 
179
- # Event handling
180
- run_btn.click(
 
 
 
 
 
 
 
 
181
  run_refinement,
182
- inputs=[image_input, mesh_input, seed_input, randomize_seed, steps_input, cfg_scale],
183
- outputs=[model_output, download_btn, seed_input]
184
- ).then(
185
- lambda: gr.DownloadButton(interactive=True),
186
- outputs=[download_btn]
187
- )
 
 
 
 
 
188
 
189
- demo.launch()
 
1
  import os
2
  import random
3
  import tempfile
4
+ from typing import Any, List, Union
5
 
6
  import spaces
7
  import gradio as gr
8
  import numpy as np
9
  import torch
10
+ # from gradio_image_prompter import ImagePrompter
11
  from gradio_litmodel3d import LitModel3D
12
  from huggingface_hub import snapshot_download
13
  from PIL import Image
 
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
23
  DTYPE = torch.bfloat16
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ REPO_ID = "VAST-AI/DetailGen3D" # 似乎还没有
26
 
27
  MARKDOWN = """
28
  ## Generating geometry details guided by reference image with [DetailGen3D](https://detailgen3d.github.io/DetailGen3D/)
29
+ 1. Upload a detailed image of the frontal view and a coarse model. Then clik "Run " to generate the refined result.
30
+ 2. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
31
+ 3. If you want the refine result to be more consistent with the image, please manually increase the CFG strength.
32
  """
 
33
  EXAMPLES = [
34
+ ["assets/image/100.png","assets/model/100.glb",42,False,]
 
 
 
 
 
35
  ]
36
+ # EXAMPLES = [
37
+ # [
38
+ # # {
39
+ # # "image": "assets/image/100.png",
40
+ # # },
41
+ # "assets/image/100.png",
42
+ # "assets/model/100.glb",
43
+ # 42,
44
+ # False,
45
+ # ],
46
+ # [
47
+ # {
48
+ # "image": "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png",
49
+ # },
50
+ # "assets/image/503d193a-1b9b-4685-b05f-00ac82f93d7b.png",
51
+ # "assets/model/503d193a-1b9b-4685-b05f-00ac82f93d7b.glb",
52
+ # 42,
53
+ # False,
54
+ # ],
55
+ # [
56
+ # {
57
+ # "image": "assets/image/34933195-9c2c-4271-8d31-a28bc5348b7a.png",
58
+ # },
59
+ # "assets/model/34933195-9c2c-4271-8d31-a28bc5348b7a.glb",
60
+ # 42,
61
+ # False,
62
+ # ],
63
+ # [
64
+ # {
65
+ # "image": "assets/image/a5d09c66-1617-465c-aec9-431f48d9a7e1.png",
66
+ # },
67
+ # "assets/model/a5d09c66-1617-465c-aec9-431f48d9a7e1.glb",
68
+ # 42,
69
+ # False,
70
+ # ],
71
+ # [
72
+ # {
73
+ # "image": "assets/image/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.png",
74
+ # },
75
+ # "assets/model/cb7e6c4a-b4dd-483c-9789-3d4887ee7434.glb",
76
+ # 42,
77
+ # False,
78
+ # ],
79
+ # [
80
+ # {
81
+ # "image": "assets/image/e799e6b4-3b47-40e0-befb-b156af8758ad.png",
82
+ # },
83
+ # "assets/model/instant3d/e799e6b4-3b47-40e0-befb-b156af8758ad.glb",
84
+ # 42,
85
+ # False,
86
+ # ],
87
+ # ]
88
 
89
  os.makedirs(TMP_DIR, exist_ok=True)
90
+
91
  local_dir = "pretrained_weights/DetailGen3D"
92
  snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
93
+ pipeline = DetailGen3DPipeline.from_pretrained(
94
+ local_dir
95
+ ).to(DEVICE, dtype=DTYPE)
96
+
97
 
98
  def load_mesh(mesh_path, num_pc=20480):
99
+ mesh = trimesh.load(mesh_path,force="mesh")
100
+
101
  center = mesh.bounding_box.centroid
102
  mesh.apply_translation(-center)
103
  scale = max(mesh.bounding_box.extents)
104
  mesh.apply_scale(1.9 / scale)
105
+
106
+ surface, face_indices = trimesh.sample.sample_surface(mesh, 1000000,)
107
  normal = mesh.face_normals[face_indices]
108
+
109
  rng = np.random.default_rng()
110
  ind = rng.choice(surface.shape[0], num_pc, replace=False)
111
  surface = torch.FloatTensor(surface[ind])
112
  normal = torch.FloatTensor(normal[ind])
113
+ surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
114
+
115
+ return surface
116
 
117
  @torch.no_grad()
118
  @torch.autocast(device_type=DEVICE)
119
+ def run_detailgen3d(
120
+ pipeline,
121
+ image,
122
+ mesh,
123
+ seed,
124
+ num_inference_steps,
125
+ guidance_scale,
126
+ ):
127
  surface = load_mesh(mesh)
128
+
129
  batch_size = 1
130
 
131
+ # sample query points for decoding
132
  box_min = np.array([-1.005, -1.005, -1.005])
133
  box_max = np.array([1.005, 1.005, 1.005])
134
  sampled_points, grid_size, bbox_size = generate_dense_grid_points(
 
137
  sampled_points = torch.FloatTensor(sampled_points).to(DEVICE, dtype=DTYPE)
138
  sampled_points = sampled_points.unsqueeze(0).repeat(batch_size, 1, 1)
139
 
140
+ # inference pipeline
141
  sample = pipeline.vae.encode(surface).latent_dist.sample()
142
+ occ = pipeline(image, latents=sample, sampled_points=sampled_points, guidance_scale=guidance_scale, noise_aug_level=0, num_inference_steps=num_inference_steps).samples[0]
143
+
144
+ # marching cubes
 
 
 
 
 
 
 
145
  grid_logits = occ.view(grid_size).cpu().numpy()
146
+ vertices, faces, normals, _ = measure.marching_cubes(
147
+ grid_logits, 0, method="lewiner"
148
+ )
149
  vertices = vertices / grid_size * bbox_size + box_min
150
+ mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
151
+ return mesh
152
 
153
  @spaces.GPU(duration=180)
154
+ @torch.no_grad()
155
+ @torch.autocast(device_type=DEVICE)
156
  def run_refinement(
157
+ rgb_image: Any,
158
+ mesh: Any,
159
  seed: int,
160
  randomize_seed: bool = False,
161
  num_inference_steps: int = 50,
 
163
  ):
164
  if randomize_seed:
165
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ # print("rgb_image", rgb_image)
168
+ # print("mesh", rgb_image)
169
+
170
+ # if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
171
+ # rgb_image = Image.open(rgb_image["image"]).convert("RGB")
 
 
 
172
 
173
+ rgb_image = Image.open(rgb_image).convert("RGB")
174
+
175
+ scene = run_detailgen3d(
176
+ pipeline,
177
+ rgb_image,
178
+ mesh,
179
+ seed,
180
+ num_inference_steps,
181
+ guidance_scale,
182
+ )
183
+
184
+ _, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="detailgen3d_", dir=TMP_DIR)
185
+ scene.export(tmp_path)
186
+
187
+ torch.cuda.empty_cache()
188
+
189
+ return tmp_path, tmp_path, seed
190
+
191
+ # Demo
192
  with gr.Blocks() as demo:
193
  gr.Markdown(MARKDOWN)
194
 
195
  with gr.Row():
196
  with gr.Column():
197
  with gr.Row():
198
+ # image_prompts = ImagePrompter(label="Input Image", type="pil")
199
+ image_prompts = gr.Image(label="Example Image", type="pil")
200
+ mesh = gr.Model3D(label="Input Coarse Model",camera_position=(90,90,3))
201
+
202
+ with gr.Accordion("Generation Settings", open=False):
203
+ seed = gr.Slider(
204
+ label="Seed",
205
+ minimum=0,
206
+ maximum=MAX_SEED,
207
+ step=1,
208
+ value=0,
209
  )
210
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
211
+ num_inference_steps = gr.Slider(
212
+ label="Number of inference steps",
213
+ minimum=1,
214
+ maximum=50,
215
+ step=1,
216
+ value=50,
217
  )
218
+ guidance_scale = gr.Slider(
219
+ label="CFG scale",
220
+ minimum=0.0,
221
+ maximum=50.0,
222
+ step=0.1,
223
+ value=4.0,
224
+ )
225
+ gen_button = gr.Button("Generate details", variant="primary")
226
+
227
  with gr.Column():
228
+ model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500,camera_position=(90,90,3))
229
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ with gr.Row():
232
+ gr.Examples(
233
+ examples=EXAMPLES,
234
+ fn=run_refinement,
235
+ inputs=[image_prompts, mesh, seed, randomize_seed],
236
+ outputs=[model_output, download_glb, seed],
237
+ cache_examples=False,
238
+ )
239
+
240
+ gen_button.click(
241
  run_refinement,
242
+ inputs=[
243
+ image_prompts,
244
+ mesh,
245
+ seed,
246
+ randomize_seed,
247
+ num_inference_steps,
248
+ guidance_scale,
249
+ ],
250
+ outputs=[model_output, download_glb, seed],
251
+ ).then(lambda: gr.Button(interactive=True), outputs=[download_glb])
252
+
253
 
254
+ demo.launch()