mikonvergence commited on
Commit
f328707
·
1 Parent(s): fbfc18e

minor update (random seed etc)

Browse files
Files changed (2) hide show
  1. app.py +9 -6
  2. src/utils.py +24 -8
app.py CHANGED
@@ -27,27 +27,30 @@ with gr.Blocks(theme=theme) as demo:
27
  with gr.Tab("3D View (Slow)"):
28
  generate_3d_button = gr.Button("Generate Terrain", variant="primary")
29
  model_3d_output = gr.Model3D(
30
- camera_position=[90, 135, 512]
 
 
31
  )
32
 
33
  with gr.Accordion("Advanced Options", open=False) as advanced_options:
34
  num_inference_steps_slider = gr.Slider(minimum=10, maximum=1000, step=10, value=50, label="Inference Steps")
35
  guidance_scale_slider = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale")
36
  seed_number = gr.Number(value=6378, label="Seed")
37
- crop_size_slider = gr.Slider(minimum=128, maximum=768, step=64, value=512, label="(3D Only) Crop Size")
 
38
  vertex_count_slider = gr.Slider(minimum=0, maximum=10000, step=0, value=0, label="(3D Only) Vertex Count (Default: 0 - no reduction)")
39
  prefix_textbox = gr.Textbox(label="Prompt Prefix", value="A Sentinel-2 image of ")
40
 
41
  generate_2d_button.click(
42
  fn=generate_2d_view_output,
43
- inputs=[prompt_input, num_inference_steps_slider, guidance_scale_slider, seed_number, prefix_textbox],
44
  outputs=[rgb_output, elevation_output],
45
  )
46
 
47
  generate_3d_button.click(
48
  fn=generate_3d_view_output,
49
- inputs=[prompt_input, num_inference_steps_slider, guidance_scale_slider, seed_number, crop_size_slider, vertex_count_slider, prefix_textbox],
50
- outputs=[model_3d_output],
51
  )
52
 
53
- demo.queue().launch(share=True)
 
27
  with gr.Tab("3D View (Slow)"):
28
  generate_3d_button = gr.Button("Generate Terrain", variant="primary")
29
  model_3d_output = gr.Model3D(
30
+ camera_position=[90, 135, 512],
31
+ clear_color=[0.0, 0.0, 0.0, 0.0],
32
+ #display_mode = 'point_cloud'
33
  )
34
 
35
  with gr.Accordion("Advanced Options", open=False) as advanced_options:
36
  num_inference_steps_slider = gr.Slider(minimum=10, maximum=1000, step=10, value=50, label="Inference Steps")
37
  guidance_scale_slider = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale")
38
  seed_number = gr.Number(value=6378, label="Seed")
39
+ random_seed = gr.Checkbox(value=True, label="Random Seed")
40
+ crop_size_slider = gr.Slider(minimum=128, maximum=768, step=64, value=768, label="(3D Only) Crop Size")
41
  vertex_count_slider = gr.Slider(minimum=0, maximum=10000, step=0, value=0, label="(3D Only) Vertex Count (Default: 0 - no reduction)")
42
  prefix_textbox = gr.Textbox(label="Prompt Prefix", value="A Sentinel-2 image of ")
43
 
44
  generate_2d_button.click(
45
  fn=generate_2d_view_output,
46
+ inputs=[prompt_input, num_inference_steps_slider, guidance_scale_slider, seed_number, random_seed, prefix_textbox],
47
  outputs=[rgb_output, elevation_output],
48
  )
49
 
50
  generate_3d_button.click(
51
  fn=generate_3d_view_output,
52
+ inputs=[prompt_input, num_inference_steps_slider, guidance_scale_slider, seed_number, random_seed, crop_size_slider, vertex_count_slider, prefix_textbox],
53
+ outputs=[rgb_output, elevation_output, model_3d_output],
54
  )
55
 
56
+ demo.queue().launch()
src/utils.py CHANGED
@@ -10,13 +10,17 @@ import spaces
10
  pipe = build_pipe()
11
 
12
  @spaces.GPU
13
- def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, prefix, crop_size=None):
14
  """Generates terrain data (RGB and elevation) from a text prompt."""
15
  if prefix and not prefix.endswith(' '):
16
  prefix += ' ' # Ensure prefix ends with a space
17
 
18
  full_prompt = prefix + prompt
19
- generator = torch.Generator("cuda").manual_seed(seed)
 
 
 
 
20
  image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)
21
 
22
  if crop_size is not None:
@@ -50,6 +54,7 @@ def create_3d_mesh(rgb, elevation, n_clusters=1000):
50
  if n_clusters <= 0:
51
  # Generate full mesh without clustering
52
  vertices = points_3d
 
53
  try:
54
  tri = Delaunay(points_2d)
55
  faces = tri.simplices
@@ -97,18 +102,29 @@ def create_3d_mesh(rgb, elevation, n_clusters=1000):
97
  mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors)
98
  return mesh
99
 
100
- def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, crop_size, vertex_count, prefix):
101
- rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, prefix, crop_size)
 
 
 
 
 
 
 
 
102
 
103
  mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count)
104
-
105
  with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
106
  mesh.export(temp_file.name)
107
  file_path = temp_file.name
 
 
 
 
108
 
109
- return file_path
110
 
111
- def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, prefix):
112
- rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, prefix)
113
 
114
  return rgb, elevation
 
10
  pipe = build_pipe()
11
 
12
  @spaces.GPU
13
+ def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size=None):
14
  """Generates terrain data (RGB and elevation) from a text prompt."""
15
  if prefix and not prefix.endswith(' '):
16
  prefix += ' ' # Ensure prefix ends with a space
17
 
18
  full_prompt = prefix + prompt
19
+ if random_seed:
20
+ generator = torch.Generator("cuda")
21
+ else:
22
+ generator = torch.Generator("cuda").manual_seed(seed)
23
+
24
  image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)
25
 
26
  if crop_size is not None:
 
54
  if n_clusters <= 0:
55
  # Generate full mesh without clustering
56
  vertices = points_3d
57
+
58
  try:
59
  tri = Delaunay(points_2d)
60
  faces = tri.simplices
 
102
  mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors)
103
  return mesh
104
 
105
+ def create_3d_point_cloud(rgb, elevation):
106
+ height, width = elevation.shape
107
+ x, y = np.meshgrid(np.arange(width), np.arange(height))
108
+ points = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
109
+ colors = rgb.reshape(-1, 3)
110
+
111
+ return trimesh.PointCloud(vertices=points, colors=colors)
112
+
113
+ def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, crop_size, vertex_count, prefix):
114
+ rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size)
115
 
116
  mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count)
 
117
  with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
118
  mesh.export(temp_file.name)
119
  file_path = temp_file.name
120
+ # pc = create_3d_point_cloud(rgb, 500*elevation)
121
+ # with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file:
122
+ # pc.export(temp_file.name, file_type="ply")
123
+ # file_path = temp_file.name
124
 
125
+ return rgb, elevation, file_path
126
 
127
+ def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix):
128
+ rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix)
129
 
130
  return rgb, elevation