MESA / src /utils.py
mikonvergence
first test
37c80b4
raw
history blame
2.25 kB
import numpy as np
import trimesh
import tempfile
import torch
from scipy.spatial import Delaunay
from .build_pipe import *
pipe = build_pipe()
def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix):
"""Generates terrain data (RGB and elevation) from a text prompt."""
if prefix and not prefix.endswith(' '):
prefix += ' ' # Ensure prefix ends with a space
full_prompt = prefix + prompt
generator = torch.Generator("cuda").manual_seed(seed)
image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)
# Center crop the image and dem
h, w, c = image[0].shape
start_h = (h - crop_size) // 2
start_w = (w - crop_size) // 2
end_h = start_h + crop_size
end_w = start_w + crop_size
cropped_image = image[0][start_h:end_h, start_w:end_w, :]
cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]
return (255 * cropped_image).astype(np.uint8), 500*cropped_dem.mean(-1)
def simplify_mesh(mesh, target_face_count):
"""Simplifies a mesh using quadric decimation."""
simplified_mesh = mesh.simplify_quadric_decimation(target_face_count)
return simplified_mesh
def create_3d_mesh(rgb, elevation):
"""Creates a 3D mesh from RGB and elevation data."""
x, y = np.meshgrid(np.arange(elevation.shape[1]), np.arange(elevation.shape[0]))
points = np.stack([x.flatten(), y.flatten()], axis=-1)
tri = Delaunay(points)
vertices = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
faces = tri.simplices
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=rgb.reshape(-1, 3))
#mesh = simplify_mesh(mesh, target_face_count=100)
return mesh
def generate_and_display(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix):
"""Generates terrain and displays it as a 3D model."""
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix)
mesh = create_3d_mesh(rgb, elevation)
with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
mesh.export(temp_file.name)
file_path = temp_file.name
return file_path