HORT / hort /models /tgs /data.py
zerchen's picture
init test without models
717b269
import json
import math
from dataclasses import dataclass, field
import os
import imageio
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
from tgs.utils.config import parse_structured
from tgs.utils.ops import get_intrinsic_from_fov, get_ray_directions, get_rays
from tgs.utils.typing import *
def _parse_scene_list_single(scene_list_path: str):
if scene_list_path.endswith(".json"):
with open(scene_list_path) as f:
all_scenes = json.loads(f.read())
elif scene_list_path.endswith(".txt"):
with open(scene_list_path) as f:
all_scenes = [p.strip() for p in f.readlines()]
else:
all_scenes = [scene_list_path]
return all_scenes
def _parse_scene_list(scene_list_path: Union[str, List[str]]):
all_scenes = []
if isinstance(scene_list_path, str):
scene_list_path = [scene_list_path]
for scene_list_path_ in scene_list_path:
all_scenes += _parse_scene_list_single(scene_list_path_)
return all_scenes
@dataclass
class CustomImageDataModuleConfig:
image_list: Any = ""
background_color: Tuple[float, float, float] = field(
default_factory=lambda: (1.0, 1.0, 1.0)
)
relative_pose: bool = False
cond_height: int = 512
cond_width: int = 512
cond_camera_distance: float = 1.6
cond_fovy_deg: float = 40.0
cond_elevation_deg: float = 0.0
cond_azimuth_deg: float = 0.0
num_workers: int = 16
eval_height: int = 512
eval_width: int = 512
eval_batch_size: int = 1
eval_elevation_deg: float = 0.0
eval_camera_distance: float = 1.6
eval_fovy_deg: float = 40.0
n_test_views: int = 120
num_views_output: int = 120
only_3dgs: bool = False
class CustomImageOrbitDataset(Dataset):
def __init__(self, cfg: Any) -> None:
super().__init__()
self.cfg: CustomImageDataModuleConfig = parse_structured(CustomImageDataModuleConfig, cfg)
self.n_views = self.cfg.n_test_views
assert self.n_views % self.cfg.num_views_output == 0
self.all_scenes = _parse_scene_list(self.cfg.image_list)
azimuth_deg: Float[Tensor, "B"] = torch.linspace(0, 360.0, self.n_views + 1)[
: self.n_views
]
elevation_deg: Float[Tensor, "B"] = torch.full_like(
azimuth_deg, self.cfg.eval_elevation_deg
)
camera_distances: Float[Tensor, "B"] = torch.full_like(
elevation_deg, self.cfg.eval_camera_distance
)
elevation = elevation_deg * math.pi / 180
azimuth = azimuth_deg * math.pi / 180
# convert spherical coordinates to cartesian coordinates
# right hand coordinate system, x back, y right, z up
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
camera_positions: Float[Tensor, "B 3"] = torch.stack(
[
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
camera_distances * torch.sin(elevation),
],
dim=-1,
)
# default scene center at origin
center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions)
# default camera up direction as +z
up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[
None, :
].repeat(self.n_views, 1)
fovy_deg: Float[Tensor, "B"] = torch.full_like(
elevation_deg, self.cfg.eval_fovy_deg
)
fovy = fovy_deg * math.pi / 180
lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1)
right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
c2w3x4: Float[Tensor, "B 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
dim=-1,
)
c2w: Float[Tensor, "B 4 4"] = torch.cat(
[c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1
)
c2w[:, 3, 3] = 1.0
# get directions by dividing directions_unit_focal by focal length
focal_length: Float[Tensor, "B"] = (
0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy)
)
directions_unit_focal = get_ray_directions(
H=self.cfg.eval_height,
W=self.cfg.eval_width,
focal=1.0,
)
directions: Float[Tensor, "B H W 3"] = directions_unit_focal[
None, :, :, :
].repeat(self.n_views, 1, 1, 1)
directions[:, :, :, :2] = (
directions[:, :, :, :2] / focal_length[:, None, None, None]
)
# must use normalize=True to normalize directions here
rays_o, rays_d = get_rays(directions, c2w, keepdim=True)
intrinsic: Float[Tensor, "B 3 3"] = get_intrinsic_from_fov(
self.cfg.eval_fovy_deg * math.pi / 180,
H=self.cfg.eval_height,
W=self.cfg.eval_width,
bs=self.n_views,
)
intrinsic_normed: Float[Tensor, "B 3 3"] = intrinsic.clone()
intrinsic_normed[..., 0, 2] /= self.cfg.eval_width
intrinsic_normed[..., 1, 2] /= self.cfg.eval_height
intrinsic_normed[..., 0, 0] /= self.cfg.eval_width
intrinsic_normed[..., 1, 1] /= self.cfg.eval_height
self.rays_o, self.rays_d = rays_o, rays_d
self.intrinsic = intrinsic
self.intrinsic_normed = intrinsic_normed
self.c2w = c2w
self.camera_positions = camera_positions
self.background_color = torch.as_tensor(self.cfg.background_color)
# condition
self.intrinsic_cond = get_intrinsic_from_fov(
np.deg2rad(self.cfg.cond_fovy_deg),
H=self.cfg.cond_height,
W=self.cfg.cond_width,
)
self.intrinsic_normed_cond = self.intrinsic_cond.clone()
self.intrinsic_normed_cond[..., 0, 2] /= self.cfg.cond_width
self.intrinsic_normed_cond[..., 1, 2] /= self.cfg.cond_height
self.intrinsic_normed_cond[..., 0, 0] /= self.cfg.cond_width
self.intrinsic_normed_cond[..., 1, 1] /= self.cfg.cond_height
if self.cfg.relative_pose:
self.c2w_cond = torch.as_tensor(
[
[0, 0, 1, self.cfg.cond_camera_distance],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
).float()
else:
cond_elevation = self.cfg.cond_elevation_deg * math.pi / 180
cond_azimuth = self.cfg.cond_azimuth_deg * math.pi / 180
cond_camera_position: Float[Tensor, "3"] = torch.as_tensor(
[
self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.cos(cond_azimuth),
self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.sin(cond_azimuth),
self.cfg.cond_camera_distance * np.sin(cond_elevation),
], dtype=torch.float32
)
cond_center: Float[Tensor, "3"] = torch.zeros_like(cond_camera_position)
cond_up: Float[Tensor, "3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)
cond_lookat: Float[Tensor, "3"] = F.normalize(cond_center - cond_camera_position, dim=-1)
cond_right: Float[Tensor, "3"] = F.normalize(torch.cross(cond_lookat, cond_up), dim=-1)
cond_up = F.normalize(torch.cross(cond_right, cond_lookat), dim=-1)
cond_c2w3x4: Float[Tensor, "3 4"] = torch.cat(
[torch.stack([cond_right, cond_up, -cond_lookat], dim=-1), cond_camera_position[:, None]],
dim=-1,
)
cond_c2w: Float[Tensor, "4 4"] = torch.cat(
[cond_c2w3x4, torch.zeros_like(cond_c2w3x4[:1])], dim=0
)
cond_c2w[3, 3] = 1.0
self.c2w_cond = cond_c2w
def __len__(self):
if self.cfg.only_3dgs:
return len(self.all_scenes)
else:
return len(self.all_scenes) * self.n_views // self.cfg.num_views_output
def __getitem__(self, index):
if self.cfg.only_3dgs:
scene_index = index
view_index = [0]
else:
scene_index = index * self.cfg.num_views_output // self.n_views
view_start = index % (self.n_views // self.cfg.num_views_output)
view_index = list(range(self.n_views))[view_start * self.cfg.num_views_output :
(view_start + 1) * self.cfg.num_views_output]
img_path = self.all_scenes[scene_index]
img_cond = torch.from_numpy(
np.asarray(
Image.fromarray(imageio.v2.imread(img_path))
.convert("RGBA")
.resize((self.cfg.cond_width, self.cfg.cond_height))
)
/ 255.0
).float()
mask_cond: Float[Tensor, "Hc Wc 1"] = img_cond[:, :, -1:]
rgb_cond: Float[Tensor, "Hc Wc 3"] = img_cond[
:, :, :3
] * mask_cond + self.background_color[None, None, :] * (1 - mask_cond)
out = {
"rgb_cond": rgb_cond.unsqueeze(0),
"c2w_cond": self.c2w_cond.unsqueeze(0),
"mask_cond": mask_cond.unsqueeze(0),
"intrinsic_cond": self.intrinsic_cond.unsqueeze(0),
"intrinsic_normed_cond": self.intrinsic_normed_cond.unsqueeze(0),
"view_index": torch.as_tensor(view_index),
"rays_o": self.rays_o[view_index],
"rays_d": self.rays_d[view_index],
"intrinsic": self.intrinsic[view_index],
"intrinsic_normed": self.intrinsic_normed[view_index],
"c2w": self.c2w[view_index],
"camera_positions": self.camera_positions[view_index],
}
out["c2w"][..., :3, 1:3] *= -1
out["c2w_cond"][..., :3, 1:3] *= -1
instance_id = os.path.split(img_path)[-1].split('.')[0]
out["index"] = torch.as_tensor(scene_index)
out["background_color"] = self.background_color
out["instance_id"] = instance_id
return out
def collate(self, batch):
batch = torch.utils.data.default_collate(batch)
batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width})
return batch