Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import math | |
from dataclasses import dataclass, field | |
import os | |
import imageio | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from tgs.utils.config import parse_structured | |
from tgs.utils.ops import get_intrinsic_from_fov, get_ray_directions, get_rays | |
from tgs.utils.typing import * | |
def _parse_scene_list_single(scene_list_path: str): | |
if scene_list_path.endswith(".json"): | |
with open(scene_list_path) as f: | |
all_scenes = json.loads(f.read()) | |
elif scene_list_path.endswith(".txt"): | |
with open(scene_list_path) as f: | |
all_scenes = [p.strip() for p in f.readlines()] | |
else: | |
all_scenes = [scene_list_path] | |
return all_scenes | |
def _parse_scene_list(scene_list_path: Union[str, List[str]]): | |
all_scenes = [] | |
if isinstance(scene_list_path, str): | |
scene_list_path = [scene_list_path] | |
for scene_list_path_ in scene_list_path: | |
all_scenes += _parse_scene_list_single(scene_list_path_) | |
return all_scenes | |
class CustomImageDataModuleConfig: | |
image_list: Any = "" | |
background_color: Tuple[float, float, float] = field( | |
default_factory=lambda: (1.0, 1.0, 1.0) | |
) | |
relative_pose: bool = False | |
cond_height: int = 512 | |
cond_width: int = 512 | |
cond_camera_distance: float = 1.6 | |
cond_fovy_deg: float = 40.0 | |
cond_elevation_deg: float = 0.0 | |
cond_azimuth_deg: float = 0.0 | |
num_workers: int = 16 | |
eval_height: int = 512 | |
eval_width: int = 512 | |
eval_batch_size: int = 1 | |
eval_elevation_deg: float = 0.0 | |
eval_camera_distance: float = 1.6 | |
eval_fovy_deg: float = 40.0 | |
n_test_views: int = 120 | |
num_views_output: int = 120 | |
only_3dgs: bool = False | |
class CustomImageOrbitDataset(Dataset): | |
def __init__(self, cfg: Any) -> None: | |
super().__init__() | |
self.cfg: CustomImageDataModuleConfig = parse_structured(CustomImageDataModuleConfig, cfg) | |
self.n_views = self.cfg.n_test_views | |
assert self.n_views % self.cfg.num_views_output == 0 | |
self.all_scenes = _parse_scene_list(self.cfg.image_list) | |
azimuth_deg: Float[Tensor, "B"] = torch.linspace(0, 360.0, self.n_views + 1)[ | |
: self.n_views | |
] | |
elevation_deg: Float[Tensor, "B"] = torch.full_like( | |
azimuth_deg, self.cfg.eval_elevation_deg | |
) | |
camera_distances: Float[Tensor, "B"] = torch.full_like( | |
elevation_deg, self.cfg.eval_camera_distance | |
) | |
elevation = elevation_deg * math.pi / 180 | |
azimuth = azimuth_deg * math.pi / 180 | |
# convert spherical coordinates to cartesian coordinates | |
# right hand coordinate system, x back, y right, z up | |
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180) | |
camera_positions: Float[Tensor, "B 3"] = torch.stack( | |
[ | |
camera_distances * torch.cos(elevation) * torch.cos(azimuth), | |
camera_distances * torch.cos(elevation) * torch.sin(azimuth), | |
camera_distances * torch.sin(elevation), | |
], | |
dim=-1, | |
) | |
# default scene center at origin | |
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) | |
# default camera up direction as +z | |
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ | |
None, : | |
].repeat(self.n_views, 1) | |
fovy_deg: Float[Tensor, "B"] = torch.full_like( | |
elevation_deg, self.cfg.eval_fovy_deg | |
) | |
fovy = fovy_deg * math.pi / 180 | |
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) | |
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) | |
up = F.normalize(torch.cross(right, lookat), dim=-1) | |
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( | |
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], | |
dim=-1, | |
) | |
c2w: Float[Tensor, "B 4 4"] = torch.cat( | |
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 | |
) | |
c2w[:, 3, 3] = 1.0 | |
# get directions by dividing directions_unit_focal by focal length | |
focal_length: Float[Tensor, "B"] = ( | |
0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) | |
) | |
directions_unit_focal = get_ray_directions( | |
H=self.cfg.eval_height, | |
W=self.cfg.eval_width, | |
focal=1.0, | |
) | |
directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ | |
None, :, :, : | |
].repeat(self.n_views, 1, 1, 1) | |
directions[:, :, :, :2] = ( | |
directions[:, :, :, :2] / focal_length[:, None, None, None] | |
) | |
# must use normalize=True to normalize directions here | |
rays_o, rays_d = get_rays(directions, c2w, keepdim=True) | |
intrinsic: Float[Tensor, "B 3 3"] = get_intrinsic_from_fov( | |
self.cfg.eval_fovy_deg * math.pi / 180, | |
H=self.cfg.eval_height, | |
W=self.cfg.eval_width, | |
bs=self.n_views, | |
) | |
intrinsic_normed: Float[Tensor, "B 3 3"] = intrinsic.clone() | |
intrinsic_normed[..., 0, 2] /= self.cfg.eval_width | |
intrinsic_normed[..., 1, 2] /= self.cfg.eval_height | |
intrinsic_normed[..., 0, 0] /= self.cfg.eval_width | |
intrinsic_normed[..., 1, 1] /= self.cfg.eval_height | |
self.rays_o, self.rays_d = rays_o, rays_d | |
self.intrinsic = intrinsic | |
self.intrinsic_normed = intrinsic_normed | |
self.c2w = c2w | |
self.camera_positions = camera_positions | |
self.background_color = torch.as_tensor(self.cfg.background_color) | |
# condition | |
self.intrinsic_cond = get_intrinsic_from_fov( | |
np.deg2rad(self.cfg.cond_fovy_deg), | |
H=self.cfg.cond_height, | |
W=self.cfg.cond_width, | |
) | |
self.intrinsic_normed_cond = self.intrinsic_cond.clone() | |
self.intrinsic_normed_cond[..., 0, 2] /= self.cfg.cond_width | |
self.intrinsic_normed_cond[..., 1, 2] /= self.cfg.cond_height | |
self.intrinsic_normed_cond[..., 0, 0] /= self.cfg.cond_width | |
self.intrinsic_normed_cond[..., 1, 1] /= self.cfg.cond_height | |
if self.cfg.relative_pose: | |
self.c2w_cond = torch.as_tensor( | |
[ | |
[0, 0, 1, self.cfg.cond_camera_distance], | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 0, 1], | |
] | |
).float() | |
else: | |
cond_elevation = self.cfg.cond_elevation_deg * math.pi / 180 | |
cond_azimuth = self.cfg.cond_azimuth_deg * math.pi / 180 | |
cond_camera_position: Float[Tensor, "3"] = torch.as_tensor( | |
[ | |
self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.cos(cond_azimuth), | |
self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.sin(cond_azimuth), | |
self.cfg.cond_camera_distance * np.sin(cond_elevation), | |
], dtype=torch.float32 | |
) | |
cond_center: Float[Tensor, "3"] = torch.zeros_like(cond_camera_position) | |
cond_up: Float[Tensor, "3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32) | |
cond_lookat: Float[Tensor, "3"] = F.normalize(cond_center - cond_camera_position, dim=-1) | |
cond_right: Float[Tensor, "3"] = F.normalize(torch.cross(cond_lookat, cond_up), dim=-1) | |
cond_up = F.normalize(torch.cross(cond_right, cond_lookat), dim=-1) | |
cond_c2w3x4: Float[Tensor, "3 4"] = torch.cat( | |
[torch.stack([cond_right, cond_up, -cond_lookat], dim=-1), cond_camera_position[:, None]], | |
dim=-1, | |
) | |
cond_c2w: Float[Tensor, "4 4"] = torch.cat( | |
[cond_c2w3x4, torch.zeros_like(cond_c2w3x4[:1])], dim=0 | |
) | |
cond_c2w[3, 3] = 1.0 | |
self.c2w_cond = cond_c2w | |
def __len__(self): | |
if self.cfg.only_3dgs: | |
return len(self.all_scenes) | |
else: | |
return len(self.all_scenes) * self.n_views // self.cfg.num_views_output | |
def __getitem__(self, index): | |
if self.cfg.only_3dgs: | |
scene_index = index | |
view_index = [0] | |
else: | |
scene_index = index * self.cfg.num_views_output // self.n_views | |
view_start = index % (self.n_views // self.cfg.num_views_output) | |
view_index = list(range(self.n_views))[view_start * self.cfg.num_views_output : | |
(view_start + 1) * self.cfg.num_views_output] | |
img_path = self.all_scenes[scene_index] | |
img_cond = torch.from_numpy( | |
np.asarray( | |
Image.fromarray(imageio.v2.imread(img_path)) | |
.convert("RGBA") | |
.resize((self.cfg.cond_width, self.cfg.cond_height)) | |
) | |
/ 255.0 | |
).float() | |
mask_cond: Float[Tensor, "Hc Wc 1"] = img_cond[:, :, -1:] | |
rgb_cond: Float[Tensor, "Hc Wc 3"] = img_cond[ | |
:, :, :3 | |
] * mask_cond + self.background_color[None, None, :] * (1 - mask_cond) | |
out = { | |
"rgb_cond": rgb_cond.unsqueeze(0), | |
"c2w_cond": self.c2w_cond.unsqueeze(0), | |
"mask_cond": mask_cond.unsqueeze(0), | |
"intrinsic_cond": self.intrinsic_cond.unsqueeze(0), | |
"intrinsic_normed_cond": self.intrinsic_normed_cond.unsqueeze(0), | |
"view_index": torch.as_tensor(view_index), | |
"rays_o": self.rays_o[view_index], | |
"rays_d": self.rays_d[view_index], | |
"intrinsic": self.intrinsic[view_index], | |
"intrinsic_normed": self.intrinsic_normed[view_index], | |
"c2w": self.c2w[view_index], | |
"camera_positions": self.camera_positions[view_index], | |
} | |
out["c2w"][..., :3, 1:3] *= -1 | |
out["c2w_cond"][..., :3, 1:3] *= -1 | |
instance_id = os.path.split(img_path)[-1].split('.')[0] | |
out["index"] = torch.as_tensor(scene_index) | |
out["background_color"] = self.background_color | |
out["instance_id"] = instance_id | |
return out | |
def collate(self, batch): | |
batch = torch.utils.data.default_collate(batch) | |
batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) | |
return batch |