File size: 3,710 Bytes
21202e6
37c80b4
 
 
088fc4c
e964e52
 
 
 
 
 
 
 
 
 
 
37c80b4
e964e52
 
 
 
 
 
 
 
 
 
 
 
f328707
 
 
e964e52
 
37c80b4
 
 
 
f328707
 
e964e52
37c80b4
e964e52
 
 
f328707
e964e52
 
 
 
 
f328707
 
37c80b4
21202e6
f328707
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
from src.utils import *

if __name__ == '__main__':
    theme = gr.themes.Soft(primary_hue="emerald", secondary_hue="stone", font=[gr.themes.GoogleFont("Source Sans 3", weights=(400, 600)),'arial'])

with gr.Blocks(theme=theme) as demo:
    with gr.Column(elem_classes="header"):
        gr.Markdown("# 🏔 MESA: Text-Driven Terrain Generation Using Latent Diffusion and Global Copernicus Data")
        gr.Markdown("### Paul Borne–Pons, Mikolaj Czerkawski, Rosalie Martin, Romain Rouffet")
        gr.Markdown('[[Website](https://paulbornep.github.io/mesa-terrain/)] [[GitHub](https://github.com/PaulBorneP/MESA)] [[Model](https://huggingface.co./NewtNewt/MESA)] [[Dataset](https://huggingface.co./datasets/Major-TOM/Core-DEM)]')

    with gr.Column(elem_classes="abstract"):
        gr.Markdown("MESA is a novel generative model based on latent denoising diffusion capable of generating 2.5D representations of terrain based on the text prompt conditioning supplied via natural language. The model produces two co-registered modalities of optical and depth maps.") # Replace with your abstract text
        gr.Markdown("This is a test version of the demo app. Please be aware that MESA supports primarily complex, mountainous terrains as opposed to flat land")
        gr.Markdown("> ⚠️ **The generated image is quite large, so for the larger resolution (768) it might take a while to load the surface**")
        with gr.Row():
            prompt_input = gr.Textbox(lines=2, placeholder="Enter a terrain description...")            

        with gr.Tabs() as output_tabs:
            with gr.Tab("2D View (Fast)"):
                generate_2d_button = gr.Button("Generate Terrain", variant="primary")
                with gr.Row():
                    rgb_output = gr.Image(label="RGB Image")
                    elevation_output = gr.Image(label="Elevation Map")
                    
            with gr.Tab("3D View (Slow)"):
                generate_3d_button = gr.Button("Generate Terrain", variant="primary")
                model_3d_output = gr.Model3D(
                    camera_position=[90, 135, 512],
                    clear_color=[0.0, 0.0, 0.0, 0.0],
                    #display_mode = 'point_cloud'
                )

        with gr.Accordion("Advanced Options", open=False) as advanced_options:
            num_inference_steps_slider = gr.Slider(minimum=10, maximum=1000, step=10, value=50, label="Inference Steps")
            guidance_scale_slider = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale")
            seed_number = gr.Number(value=6378, label="Seed")
            random_seed = gr.Checkbox(value=True, label="Random Seed")
            crop_size_slider = gr.Slider(minimum=128, maximum=768, step=64, value=768, label="(3D Only) Crop Size")
            vertex_count_slider = gr.Slider(minimum=0, maximum=10000, step=0, value=0, label="(3D Only) Vertex Count (Default: 0 - no reduction)")
            prefix_textbox = gr.Textbox(label="Prompt Prefix", value="A Sentinel-2 image of ")
            
        generate_2d_button.click(
            fn=generate_2d_view_output,
            inputs=[prompt_input, num_inference_steps_slider, guidance_scale_slider, seed_number, random_seed, prefix_textbox],
            outputs=[rgb_output, elevation_output],
        )

        generate_3d_button.click(
            fn=generate_3d_view_output,
            inputs=[prompt_input, num_inference_steps_slider, guidance_scale_slider, seed_number, random_seed, crop_size_slider, vertex_count_slider, prefix_textbox],
            outputs=[rgb_output, elevation_output, model_3d_output],
        )
        
        demo.queue().launch()