Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,981 Bytes
178f950 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
from PIL import Image
import json
import numpy as np
import pandas as pd
import torch
import utils3d.torch
from ..modules.sparse.basic import SparseTensor
from .components import StandardDatasetBase
class SparseFeat2Render(StandardDatasetBase):
"""
SparseFeat2Render dataset.
Args:
roots (str): paths to the dataset
image_size (int): size of the image
model (str): model name
resolution (int): resolution of the data
min_aesthetic_score (float): minimum aesthetic score
max_num_voxels (int): maximum number of voxels
"""
def __init__(
self,
roots: str,
image_size: int,
model: str = 'dinov2_vitl14_reg',
resolution: int = 64,
min_aesthetic_score: float = 5.0,
max_num_voxels: int = 32768,
):
self.image_size = image_size
self.model = model
self.resolution = resolution
self.min_aesthetic_score = min_aesthetic_score
self.max_num_voxels = max_num_voxels
self.value_range = (0, 1)
super().__init__(roots)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'feature_{self.model}']]
stats['With features'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
return metadata, stats
def _get_image(self, root, instance):
with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
metadata = json.load(f)
n_views = len(metadata['frames'])
view = np.random.randint(n_views)
metadata = metadata['frames'][view]
fov = metadata['camera_angle_x']
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
c2w = torch.tensor(metadata['transform_matrix'])
c2w[:3, 1:3] *= -1
extrinsics = torch.inverse(c2w)
image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
image = Image.open(image_path)
alpha = image.getchannel(3)
image = image.convert('RGB')
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
alpha = torch.tensor(np.array(alpha)).float() / 255.0
return {
'image': image,
'alpha': alpha,
'extrinsics': extrinsics,
'intrinsics': intrinsics,
}
def _get_feat(self, root, instance):
DATA_RESOLUTION = 64
feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz')
feats = np.load(feats_path, allow_pickle=True)
coords = torch.tensor(feats['indices']).int()
feats = torch.tensor(feats['patchtokens']).float()
if self.resolution != DATA_RESOLUTION:
factor = DATA_RESOLUTION // self.resolution
coords = coords // factor
coords, idx = coords.unique(return_inverse=True, dim=0)
feats = torch.scatter_reduce(
torch.zeros(coords.shape[0], feats.shape[1], device=feats.device),
dim=0,
index=idx.unsqueeze(-1).expand(-1, feats.shape[1]),
src=feats,
reduce='mean'
)
return {
'coords': coords,
'feats': feats,
}
@torch.no_grad()
def visualize_sample(self, sample: dict):
return sample['image']
@staticmethod
def collate_fn(batch):
pack = {}
coords = []
for i, b in enumerate(batch):
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
coords = torch.cat(coords)
feats = torch.cat([b['feats'] for b in batch])
pack['feats'] = SparseTensor(
coords=coords,
feats=feats,
)
pack['image'] = torch.stack([b['image'] for b in batch])
pack['alpha'] = torch.stack([b['alpha'] for b in batch])
pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch])
pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch])
return pack
def get_instance(self, root, instance):
image = self._get_image(root, instance)
feat = self._get_feat(root, instance)
return {
**image,
**feat,
}
|