import numpy as np import trimesh import tempfile import torch from scipy.spatial import Delaunay from sklearn.cluster import KMeans from .build_pipe import * import spaces pipe = build_pipe() @spaces.GPU def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size=None): """Generates terrain data (RGB and elevation) from a text prompt.""" if prefix and not prefix.endswith(' '): prefix += ' ' # Ensure prefix ends with a space full_prompt = prefix + prompt if random_seed: generator = torch.Generator("cuda") else: generator = torch.Generator("cuda").manual_seed(seed) image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) if crop_size is not None: # Center crop the image and dem h, w, c = image[0].shape start_h = (h - crop_size) // 2 start_w = (w - crop_size) // 2 end_h = start_h + crop_size end_w = start_w + crop_size cropped_image = image[0][start_h:end_h, start_w:end_w, :] cropped_dem = dem[0][start_h:end_h, start_w:end_w, :] else: cropped_image = image[0] cropped_dem = dem[0] return (255 * cropped_image).astype(np.uint8), cropped_dem.mean(-1) def create_3d_mesh(rgb, elevation, n_clusters=1000): """Creates a 3D mesh from RGB and elevation data. If n_clusters is 0, generates the full mesh. Otherwise, generates a simplified mesh using KMeans clustering with distinct colors. """ rows, cols = elevation.shape x, y = np.meshgrid(np.arange(cols), np.arange(rows)) points_2d = np.stack([x.flatten(), y.flatten()], axis=-1) elevation_flat = elevation.flatten() points_3d = np.column_stack([points_2d, elevation_flat]) original_colors = rgb.reshape(-1, 3) if n_clusters <= 0: # Generate full mesh without clustering vertices = points_3d try: tri = Delaunay(points_2d) faces = tri.simplices mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=original_colors) return mesh except Exception as e: print(f"Error during Delaunay triangulation (full mesh): {e}") return None else: n_clusters = min(n_clusters, len(elevation_flat)) # Apply KMeans clustering for simplification kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init='auto') kmeans.fit(points_3d) cluster_centers = kmeans.cluster_centers_ cluster_labels = kmeans.labels_ # Use the cluster centers as the simplified vertices simplified_vertices = cluster_centers # Perform Delaunay triangulation on the X and Y coordinates of the cluster centers simplified_points_2d = simplified_vertices[:, :2] try: tri = Delaunay(simplified_points_2d) faces = tri.simplices # Ensure the number of vertices in faces does not exceed the number of simplified vertices valid_faces = faces[np.all(faces < len(simplified_vertices), axis=1)] except Exception as e: print(f"Error during Delaunay triangulation (clustered mesh): {e}") return None # Assign a distinct color to each cluster unique_labels = np.unique(cluster_labels) cluster_colors = {} for label in unique_labels: cluster_indices = np.where(cluster_labels == label)[0] if len(cluster_indices) > 0: avg_color = np.mean(original_colors[cluster_indices], axis=0).astype(np.uint8) cluster_colors[label] = avg_color else: cluster_colors[label] = np.array([255, 0, 0], dtype=np.uint8) # Red vertex_colors = np.array([cluster_colors[i] for i in range(n_clusters)]) # Create the trimesh object mesh = trimesh.Trimesh(vertices=simplified_vertices, faces=valid_faces, vertex_colors=vertex_colors) return mesh def create_3d_point_cloud(rgb, elevation): height, width = elevation.shape x, y = np.meshgrid(np.arange(width), np.arange(height)) points = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1) colors = rgb.reshape(-1, 3) return trimesh.PointCloud(vertices=points, colors=colors) def generate_3d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, crop_size, vertex_count, prefix): rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix, crop_size) mesh = create_3d_mesh(rgb, 500*elevation, n_clusters=vertex_count) with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file: mesh.export(temp_file.name) file_path = temp_file.name # pc = create_3d_point_cloud(rgb, 500*elevation) # with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: # pc.export(temp_file.name, file_type="ply") # file_path = temp_file.name return rgb, elevation, file_path def generate_2d_view_output(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix): rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, random_seed, prefix) return rgb, elevation