File size: 2,250 Bytes
37c80b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import trimesh
import tempfile
import torch
from scipy.spatial import Delaunay
from .build_pipe import *

pipe = build_pipe()

def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix):
    """Generates terrain data (RGB and elevation) from a text prompt."""
    if prefix and not prefix.endswith(' '):
        prefix += ' '  # Ensure prefix ends with a space

    full_prompt = prefix + prompt
    generator = torch.Generator("cuda").manual_seed(seed)
    image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)

    # Center crop the image and dem
    h, w, c = image[0].shape
    start_h = (h - crop_size) // 2
    start_w = (w - crop_size) // 2
    end_h = start_h + crop_size
    end_w = start_w + crop_size

    cropped_image = image[0][start_h:end_h, start_w:end_w, :]
    cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]

    return (255 * cropped_image).astype(np.uint8), 500*cropped_dem.mean(-1)

def simplify_mesh(mesh, target_face_count):
    """Simplifies a mesh using quadric decimation."""
    simplified_mesh = mesh.simplify_quadric_decimation(target_face_count)
    return simplified_mesh

def create_3d_mesh(rgb, elevation):
    """Creates a 3D mesh from RGB and elevation data."""
    x, y = np.meshgrid(np.arange(elevation.shape[1]), np.arange(elevation.shape[0]))
    points = np.stack([x.flatten(), y.flatten()], axis=-1)
    tri = Delaunay(points)

    vertices = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
    faces = tri.simplices

    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=rgb.reshape(-1, 3))

    #mesh = simplify_mesh(mesh, target_face_count=100)
    return mesh

def generate_and_display(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix):
    """Generates terrain and displays it as a 3D model."""
    rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix)
    mesh = create_3d_mesh(rgb, elevation)

    with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
        mesh.export(temp_file.name)
        file_path = temp_file.name

    return file_path