MESA / src /utils.py
mikonvergence
minor update (random seed etc)
f328707
import numpy as np
import trimesh
import tempfile
import torch
from scipy.spatial import Delaunay
from sklearn.cluster import KMeans
from .build_pipe import *
import spaces
pipe = build_pipe()
@spaces.GPU
def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size=None):
"""Generates terrain data (RGB and elevation) from a text prompt."""
if prefix and not prefix.endswith(' '):
prefix += ' ' # Ensure prefix ends with a space
full_prompt = prefix + prompt
if random_seed:
generator = torch.Generator("cuda")
else:
generator = torch.Generator("cuda").manual_seed(seed)
image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)
if crop_size is not None:
# Center crop the image and dem
h, w, c = image[0].shape
start_h = (h - crop_size) // 2
start_w = (w - crop_size) // 2
end_h = start_h + crop_size
end_w = start_w + crop_size
cropped_image = image[0][start_h:end_h, start_w:end_w, :]
cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]
else:
cropped_image = image[0]
cropped_dem = dem[0]
return (255 * cropped_image).astype(np.uint8), cropped_dem.mean(-1)
def create_3d_mesh(rgb, elevation, n_clusters=1000):
"""Creates a 3D mesh from RGB and elevation data.
If n_clusters is 0, generates the full mesh.
Otherwise, generates a simplified mesh using KMeans clustering with distinct colors.
"""
rows, cols = elevation.shape
x, y = np.meshgrid(np.arange(cols), np.arange(rows))
points_2d = np.stack([x.flatten(), y.flatten()], axis=-1)
elevation_flat = elevation.flatten()
points_3d = np.column_stack([points_2d, elevation_flat])
original_colors = rgb.reshape(-1, 3)
if n_clusters <= 0:
# Generate full mesh without clustering
vertices = points_3d
try:
tri = Delaunay(points_2d)
faces = tri.simplices
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=original_colors)
return mesh
except Exception as e:
print(f"Error during Delaunay triangulation (full mesh): {e}")
return None
else:
n_clusters = min(n_clusters, len(elevation_flat))
# Apply KMeans clustering for simplification
kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto')
kmeans.fit(points_3d)
cluster_centers = kmeans.cluster_centers_
cluster_labels = kmeans.labels_
# Use the cluster centers as the simplified vertices
simplified_vertices = cluster_centers
# Perform Delaunay triangulation on the X and Y coordinates of the cluster centers
simplified_points_2d = simplified_vertices[:, :2]
try:
tri = Delaunay(simplified_points_2d)
faces = tri.simplices
# Ensure the number of vertices in faces does not exceed the number of simplified vertices
valid_faces = faces[np.all(faces < len(simplified_vertices), axis=1)]
except Exception as e:
print(f"Error during Delaunay triangulation (clustered mesh): {e}")
return None
# Assign a distinct color to each cluster
unique_labels = np.unique(cluster_labels)
cluster_colors = {}
for label in unique_labels:
cluster_indices = np.where(cluster_labels == label)[0]
if len(cluster_indices) > 0:
avg_color = np.mean(original_colors[cluster_indices], axis=0).astype(np.uint8)
cluster_colors[label] = avg_color
else:
cluster_colors[label] = np.array([255, 0, 0], dtype=np.uint8) # Red
vertex_colors = np.array([cluster_colors[i] for i in range(n_clusters)])
# Create the trimesh object
mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors)
return mesh
def create_3d_point_cloud(rgb, elevation):
height, width = elevation.shape
x, y = np.meshgrid(np.arange(width), np.arange(height))
points = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
colors = rgb.reshape(-1, 3)
return trimesh.PointCloud(vertices=points, colors=colors)
def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, crop_size, vertex_count, prefix):
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size)
mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count)
with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
mesh.export(temp_file.name)
file_path = temp_file.name
# pc = create_3d_point_cloud(rgb, 500*elevation)
# with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file:
# pc.export(temp_file.name, file_type="ply")
# file_path = temp_file.name
return rgb, elevation, file_path
def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix):
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix)
return rgb, elevation