mikonvergence commited on
Commit
21202e6
·
verified ·
1 Parent(s): c0e84c2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install "huggingface_hub[hf_transfer]"
2
+ !pip install -U "huggingface_hub[cli]"
3
+ !pip install gradio trimesh scipy
4
+ !HF_HUB_ENABLE_HF_TRANSFER=1
5
+ !git clone https://github.com/PaulBorneP/MESA.git
6
+ !cd MESA
7
+ !mkdir weights
8
+ !huggingface-cli download NewtNewt/MESA --local-dir weights
9
+
10
+ import torch
11
+ from MESA.pipeline_terrain import TerrainDiffusionPipeline
12
+ import sys
13
+ import gradio as gr
14
+ import numpy as np
15
+ import trimesh
16
+ import tempfile
17
+ import torch
18
+ from scipy.spatial import Delaunay
19
+
20
+ sys.path.append('MESA/')
21
+
22
+ pipe = TerrainDiffusionPipeline.from_pretrained("./weights", torch_dtype=torch.float16)
23
+ pipe.to("cuda")
24
+
25
+ def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix):
26
+ """Generates terrain data (RGB and elevation) from a text prompt."""
27
+ if prefix and not prefix.endswith(' '):
28
+ prefix += ' ' # Ensure prefix ends with a space
29
+
30
+ full_prompt = prefix + prompt
31
+ generator = torch.Generator("cuda").manual_seed(seed)
32
+ image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)
33
+
34
+ # Center crop the image and dem
35
+ h, w, c = image[0].shape
36
+ start_h = (h - crop_size) // 2
37
+ start_w = (w - crop_size) // 2
38
+ end_h = start_h + crop_size
39
+ end_w = start_w + crop_size
40
+
41
+ cropped_image = image[0][start_h:end_h, start_w:end_w, :]
42
+ cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]
43
+
44
+ return (255 * cropped_image).astype(np.uint8), 500*cropped_dem.mean(-1)
45
+
46
+ def create_3d_mesh(rgb, elevation):
47
+ """Creates a 3D mesh from RGB and elevation data."""
48
+ x, y = np.meshgrid(np.arange(elevation.shape[1]), np.arange(elevation.shape[0]))
49
+ points = np.stack([x.flatten(), y.flatten()], axis=-1)
50
+ tri = Delaunay(points)
51
+
52
+ vertices = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
53
+ faces = tri.simplices
54
+
55
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=rgb.reshape(-1, 3))
56
+
57
+ return mesh
58
+
59
+ def generate_and_display(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix):
60
+ """Generates terrain and displays it as a 3D model."""
61
+ rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix)
62
+ mesh = create_3d_mesh(rgb, elevation)
63
+
64
+ with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
65
+ mesh.export(temp_file.name)
66
+ file_path = temp_file.name
67
+
68
+ return file_path
69
+
70
+ theme = gr.themes.Soft(primary_hue="red", secondary_hue="red", font=['arial'])
71
+
72
+ with gr.Blocks(theme=theme) as demo:
73
+ with gr.Column(elem_classes="header"):
74
+ gr.Markdown("# MESA: Text-Driven Terrain Generation Using Latent Diffusion and Global Copernicus Data")
75
+ gr.Markdown("### Paul Borne–Pons, Mikolaj Czerkawski, Rosalie Martin, Romain Rouffet")
76
+ gr.Markdown('[[GitHub](https://github.com/PaulBorneP/MESA)] [[Model](https://huggingface.co/NewtNewt/MESA)] [[Dataset](https://huggingface.co/datasets/Major-TOM/Core-DEM)]')
77
+
78
+ # Abstract Section
79
+ with gr.Column(elem_classes="abstract"):
80
+ gr.Markdown("MESA is a novel generative model based on latent denoising diffusion capable of generating 2.5D representations of terrain based on the text prompt conditioning supplied via natural language. The model produces two co-registered modalities of optical and depth maps.") # Replace with your abstract text
81
+ gr.Markdown("This is a test version of the demo app. Please be aware that MESA supports primarily complex, mountainous terrains as opposed to flat land")
82
+ gr.Markdown("The generated image is quite large, so for the full resolution (768) it might take a while to load the surface")
83
+
84
+ with gr.Row():
85
+ prompt_input = gr.Textbox(lines=2, placeholder="Enter a terrain description...")
86
+ generate_button = gr.Button("Generate Terrain", variant="primary")
87
+
88
+ model_output = gr.Model3D(
89
+ camera_position=[90, 180, 512]
90
+ )
91
+
92
+ with gr.Accordion("Advanced Options", open=False) as advanced_options:
93
+ num_inference_steps_slider = gr.Slider(minimum=10, maximum=1000, step=10, value=50, label="Inference Steps")
94
+ guidance_scale_slider = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale")
95
+ seed_number = gr.Number(value=6378, label="Seed")
96
+ crop_size_slider = gr.Slider(minimum=128, maximum=768, step=64, value=512, label="Crop Size")
97
+ prefix_textbox = gr.Textbox(label="Prompt Prefix", value="A Sentinel-2 image of ")
98
+
99
+ generate_button.click(
100
+ fn=generate_and_display,
101
+ inputs=[prompt_input, num_inference_steps_slider, guidance_scale_slider, seed_number, crop_size_slider, prefix_textbox],
102
+ outputs=model_output,
103
+ )
104
+
105
+ if __name__ == "__main__":
106
+ demo.launch(debug=True,
107
+ share=True)