TRELLIS-Texto3D / trellis /datasets /sparse_structure.py
cavargas10's picture
Upload 288 files
178f950 verified
import os
import json
from typing import Union
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import utils3d
from .components import StandardDatasetBase
from ..representations.octree import DfsOctree as Octree
from ..renderers import OctreeRenderer
class SparseStructure(StandardDatasetBase):
"""
Sparse structure dataset
Args:
roots (str): path to the dataset
resolution (int): resolution of the voxel grid
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
"""
def __init__(self,
roots,
resolution: int = 64,
min_aesthetic_score: float = 5.0,
):
self.resolution = resolution
self.min_aesthetic_score = min_aesthetic_score
self.value_range = (0, 1)
super().__init__(roots)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'voxelized']]
stats['Voxelized'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
return {'ss': ss}
@torch.no_grad()
def visualize_sample(self, ss: Union[torch.Tensor, dict]):
ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
renderer = OctreeRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.near = 0.8
renderer.rendering_options.far = 1.6
renderer.rendering_options.bg_color = (0, 0, 0)
renderer.rendering_options.ssaa = 4
renderer.pipe.primitive = 'voxel'
# Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
exts = []
ints = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(30)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
exts.append(extrinsics)
ints.append(intrinsics)
images = []
# Build each representation
ss = ss.cuda()
for i in range(ss.shape[0]):
representation = Octree(
depth=10,
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
device='cuda',
primitive='voxel',
sh_degree=0,
primitive_config={'solid': True},
)
coords = torch.nonzero(ss[i, 0], as_tuple=False)
representation.position = coords.float() / self.resolution
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images.append(image)
return torch.stack(images)