Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import trimesh | |
import tempfile | |
import torch | |
from scipy.spatial import Delaunay | |
from sklearn.cluster import KMeans | |
from .build_pipe import * | |
import spaces | |
pipe = build_pipe() | |
def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size=None): | |
"""Generates terrain data (RGB and elevation) from a text prompt.""" | |
if prefix and not prefix.endswith(' '): | |
prefix += ' ' # Ensure prefix ends with a space | |
full_prompt = prefix + prompt | |
if random_seed: | |
generator = torch.Generator("cuda") | |
else: | |
generator = torch.Generator("cuda").manual_seed(seed) | |
image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) | |
if crop_size is not None: | |
# Center crop the image and dem | |
h, w, c = image[0].shape | |
start_h = (h - crop_size) // 2 | |
start_w = (w - crop_size) // 2 | |
end_h = start_h + crop_size | |
end_w = start_w + crop_size | |
cropped_image = image[0][start_h:end_h, start_w:end_w, :] | |
cropped_dem = dem[0][start_h:end_h, start_w:end_w, :] | |
else: | |
cropped_image = image[0] | |
cropped_dem = dem[0] | |
return (255 * cropped_image).astype(np.uint8), cropped_dem.mean(-1) | |
def create_3d_mesh(rgb, elevation, n_clusters=1000): | |
"""Creates a 3D mesh from RGB and elevation data. | |
If n_clusters is 0, generates the full mesh. | |
Otherwise, generates a simplified mesh using KMeans clustering with distinct colors. | |
""" | |
rows, cols = elevation.shape | |
x, y = np.meshgrid(np.arange(cols), np.arange(rows)) | |
points_2d = np.stack([x.flatten(), y.flatten()], axis=-1) | |
elevation_flat = elevation.flatten() | |
points_3d = np.column_stack([points_2d, elevation_flat]) | |
original_colors = rgb.reshape(-1, 3) | |
if n_clusters <= 0: | |
# Generate full mesh without clustering | |
vertices = points_3d | |
try: | |
tri = Delaunay(points_2d) | |
faces = tri.simplices | |
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=original_colors) | |
return mesh | |
except Exception as e: | |
print(f"Error during Delaunay triangulation (full mesh): {e}") | |
return None | |
else: | |
n_clusters = min(n_clusters, len(elevation_flat)) | |
# Apply KMeans clustering for simplification | |
kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto') | |
kmeans.fit(points_3d) | |
cluster_centers = kmeans.cluster_centers_ | |
cluster_labels = kmeans.labels_ | |
# Use the cluster centers as the simplified vertices | |
simplified_vertices = cluster_centers | |
# Perform Delaunay triangulation on the X and Y coordinates of the cluster centers | |
simplified_points_2d = simplified_vertices[:, :2] | |
try: | |
tri = Delaunay(simplified_points_2d) | |
faces = tri.simplices | |
# Ensure the number of vertices in faces does not exceed the number of simplified vertices | |
valid_faces = faces[np.all(faces < len(simplified_vertices), axis=1)] | |
except Exception as e: | |
print(f"Error during Delaunay triangulation (clustered mesh): {e}") | |
return None | |
# Assign a distinct color to each cluster | |
unique_labels = np.unique(cluster_labels) | |
cluster_colors = {} | |
for label in unique_labels: | |
cluster_indices = np.where(cluster_labels == label)[0] | |
if len(cluster_indices) > 0: | |
avg_color = np.mean(original_colors[cluster_indices], axis=0).astype(np.uint8) | |
cluster_colors[label] = avg_color | |
else: | |
cluster_colors[label] = np.array([255, 0, 0], dtype=np.uint8) # Red | |
vertex_colors = np.array([cluster_colors[i] for i in range(n_clusters)]) | |
# Create the trimesh object | |
mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors) | |
return mesh | |
def create_3d_point_cloud(rgb, elevation): | |
height, width = elevation.shape | |
x, y = np.meshgrid(np.arange(width), np.arange(height)) | |
points = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1) | |
colors = rgb.reshape(-1, 3) | |
return trimesh.PointCloud(vertices=points, colors=colors) | |
def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, crop_size, vertex_count, prefix): | |
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size) | |
mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count) | |
with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file: | |
mesh.export(temp_file.name) | |
file_path = temp_file.name | |
# pc = create_3d_point_cloud(rgb, 500*elevation) | |
# with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: | |
# pc.export(temp_file.name, file_type="ply") | |
# file_path = temp_file.name | |
return rgb, elevation, file_path | |
def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix): | |
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix) | |
return rgb, elevation |