thewhole's picture
Upload 245 files
2fa4776
import json
import math
import os
import random
from dataclasses import dataclass
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from scipy.spatial.transform import Rotation as Rot
from scipy.spatial.transform import Slerp
from torch.utils.data import DataLoader, Dataset, IterableDataset
from tqdm import tqdm
import threestudio
from threestudio import register
from threestudio.utils.config import parse_structured
from threestudio.utils.ops import get_mvp_matrix, get_ray_directions, get_rays
from threestudio.utils.typing import *
def convert_pose(C2W):
flip_yz = torch.eye(4)
flip_yz[1, 1] = -1
flip_yz[2, 2] = -1
C2W = torch.matmul(C2W, flip_yz)
return C2W
def convert_proj(K, H, W, near, far):
return [
[2 * K[0, 0] / W, -2 * K[0, 1] / W, (W - 2 * K[0, 2]) / W, 0],
[0, -2 * K[1, 1] / H, (H - 2 * K[1, 2]) / H, 0],
[0, 0, (-far - near) / (far - near), -2 * far * near / (far - near)],
[0, 0, -1, 0],
]
def inter_pose(pose_0, pose_1, ratio):
pose_0 = pose_0.detach().cpu().numpy()
pose_1 = pose_1.detach().cpu().numpy()
pose_0 = np.linalg.inv(pose_0)
pose_1 = np.linalg.inv(pose_1)
rot_0 = pose_0[:3, :3]
rot_1 = pose_1[:3, :3]
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
key_times = [0, 1]
slerp = Slerp(key_times, rots)
rot = slerp(ratio)
pose = np.diag([1.0, 1.0, 1.0, 1.0])
pose = pose.astype(np.float32)
pose[:3, :3] = rot.as_matrix()
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
pose = np.linalg.inv(pose)
return pose
@dataclass
class MultiviewsDataModuleConfig:
dataroot: str = ""
train_downsample_resolution: int = 4
eval_downsample_resolution: int = 4
train_data_interval: int = 1
eval_data_interval: int = 1
batch_size: int = 1
eval_batch_size: int = 1
camera_layout: str = "around"
camera_distance: float = -1
eval_interpolation: Optional[Tuple[int, int, int]] = None # (0, 1, 30)
class MultiviewIterableDataset(IterableDataset):
def __init__(self, cfg: Any) -> None:
super().__init__()
self.cfg: MultiviewsDataModuleConfig = cfg
assert self.cfg.batch_size == 1
scale = self.cfg.train_downsample_resolution
camera_dict = json.load(
open(os.path.join(self.cfg.dataroot, "transforms.json"), "r")
)
assert camera_dict["camera_model"] == "OPENCV"
frames = camera_dict["frames"]
frames = frames[:: self.cfg.train_data_interval]
frames_proj = []
frames_c2w = []
frames_position = []
frames_direction = []
frames_img = []
self.frame_w = frames[0]["w"] // scale
self.frame_h = frames[0]["h"] // scale
threestudio.info("Loading frames...")
self.n_frames = len(frames)
c2w_list = []
for frame in tqdm(frames):
extrinsic: Float[Tensor, "4 4"] = torch.as_tensor(
frame["transform_matrix"], dtype=torch.float32
)
c2w = extrinsic
c2w_list.append(c2w)
c2w_list = torch.stack(c2w_list, dim=0)
if self.cfg.camera_layout == "around":
c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0)
elif self.cfg.camera_layout == "front":
assert self.cfg.camera_distance > 0
c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0)
z_vector = torch.zeros(c2w_list.shape[0], 3, 1)
z_vector[:, 2, :] = -1
rot_z_vector = c2w_list[:, :3, :3] @ z_vector
rot_z_vector = torch.mean(rot_z_vector, dim=0).unsqueeze(0)
c2w_list[:, :3, 3] -= rot_z_vector[:, :, 0] * self.cfg.camera_distance
else:
raise ValueError(
f"Unknown camera layout {self.cfg.camera_layout}. Now support only around and front."
)
for idx, frame in tqdm(enumerate(frames)):
intrinsic: Float[Tensor, "4 4"] = torch.eye(4)
intrinsic[0, 0] = frame["fl_x"] / scale
intrinsic[1, 1] = frame["fl_y"] / scale
intrinsic[0, 2] = frame["cx"] / scale
intrinsic[1, 2] = frame["cy"] / scale
frame_path = os.path.join(self.cfg.dataroot, frame["file_path"])
img = cv2.imread(frame_path)[:, :, ::-1].copy()
img = cv2.resize(img, (self.frame_w, self.frame_h))
img: Float[Tensor, "H W 3"] = torch.FloatTensor(img) / 255
frames_img.append(img)
direction: Float[Tensor, "H W 3"] = get_ray_directions(
self.frame_h,
self.frame_w,
(intrinsic[0, 0], intrinsic[1, 1]),
(intrinsic[0, 2], intrinsic[1, 2]),
use_pixel_centers=False,
)
c2w = c2w_list[idx]
camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1)
near = 0.1
far = 1000.0
proj = convert_proj(intrinsic, self.frame_h, self.frame_w, near, far)
proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj)
frames_proj.append(proj)
frames_c2w.append(c2w)
frames_position.append(camera_position)
frames_direction.append(direction)
threestudio.info("Loaded frames.")
self.frames_proj: Float[Tensor, "B 4 4"] = torch.stack(frames_proj, dim=0)
self.frames_c2w: Float[Tensor, "B 4 4"] = torch.stack(frames_c2w, dim=0)
self.frames_position: Float[Tensor, "B 3"] = torch.stack(frames_position, dim=0)
self.frames_direction: Float[Tensor, "B H W 3"] = torch.stack(
frames_direction, dim=0
)
self.frames_img: Float[Tensor, "B H W 3"] = torch.stack(frames_img, dim=0)
self.rays_o, self.rays_d = get_rays(
self.frames_direction, self.frames_c2w, keepdim=True
)
self.mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(
self.frames_c2w, self.frames_proj
)
self.light_positions: Float[Tensor, "B 3"] = torch.zeros_like(
self.frames_position
)
def __iter__(self):
while True:
yield {}
def collate(self, batch):
index = torch.randint(0, self.n_frames, (1,)).item()
return {
"index": index,
"rays_o": self.rays_o[index : index + 1],
"rays_d": self.rays_d[index : index + 1],
"mvp_mtx": self.mvp_mtx[index : index + 1],
"c2w": self.frames_c2w[index : index + 1],
"camera_positions": self.frames_position[index : index + 1],
"light_positions": self.light_positions[index : index + 1],
"gt_rgb": self.frames_img[index : index + 1],
"height": self.frame_h,
"width": self.frame_w,
}
class MultiviewDataset(Dataset):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.cfg: MultiviewsDataModuleConfig = cfg
assert self.cfg.eval_batch_size == 1
scale = self.cfg.eval_downsample_resolution
camera_dict = json.load(
open(os.path.join(self.cfg.dataroot, "transforms.json"), "r")
)
assert camera_dict["camera_model"] == "OPENCV"
frames = camera_dict["frames"]
frames = frames[:: self.cfg.eval_data_interval]
frames_proj = []
frames_c2w = []
frames_position = []
frames_direction = []
frames_img = []
self.frame_w = frames[0]["w"] // scale
self.frame_h = frames[0]["h"] // scale
threestudio.info("Loading frames...")
self.n_frames = len(frames)
c2w_list = []
for frame in tqdm(frames):
extrinsic: Float[Tensor, "4 4"] = torch.as_tensor(
frame["transform_matrix"], dtype=torch.float32
)
c2w = extrinsic
c2w_list.append(c2w)
c2w_list = torch.stack(c2w_list, dim=0)
if self.cfg.camera_layout == "around":
c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0)
elif self.cfg.camera_layout == "front":
assert self.cfg.camera_distance > 0
c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0)
z_vector = torch.zeros(c2w_list.shape[0], 3, 1)
z_vector[:, 2, :] = -1
rot_z_vector = c2w_list[:, :3, :3] @ z_vector
rot_z_vector = torch.mean(rot_z_vector, dim=0).unsqueeze(0)
c2w_list[:, :3, 3] -= rot_z_vector[:, :, 0] * self.cfg.camera_distance
else:
raise ValueError(
f"Unknown camera layout {self.cfg.camera_layout}. Now support only around and front."
)
if not (self.cfg.eval_interpolation is None):
idx0 = self.cfg.eval_interpolation[0]
idx1 = self.cfg.eval_interpolation[1]
eval_nums = self.cfg.eval_interpolation[2]
frame = frames[idx0]
intrinsic: Float[Tensor, "4 4"] = torch.eye(4)
intrinsic[0, 0] = frame["fl_x"] / scale
intrinsic[1, 1] = frame["fl_y"] / scale
intrinsic[0, 2] = frame["cx"] / scale
intrinsic[1, 2] = frame["cy"] / scale
for ratio in np.linspace(0, 1, eval_nums):
img: Float[Tensor, "H W 3"] = torch.zeros(
(self.frame_h, self.frame_w, 3)
)
frames_img.append(img)
direction: Float[Tensor, "H W 3"] = get_ray_directions(
self.frame_h,
self.frame_w,
(intrinsic[0, 0], intrinsic[1, 1]),
(intrinsic[0, 2], intrinsic[1, 2]),
use_pixel_centers=False,
)
c2w = torch.FloatTensor(
inter_pose(c2w_list[idx0], c2w_list[idx1], ratio)
)
camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1)
near = 0.1
far = 1000.0
proj = convert_proj(intrinsic, self.frame_h, self.frame_w, near, far)
proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj)
frames_proj.append(proj)
frames_c2w.append(c2w)
frames_position.append(camera_position)
frames_direction.append(direction)
else:
for idx, frame in tqdm(enumerate(frames)):
intrinsic: Float[Tensor, "4 4"] = torch.eye(4)
intrinsic[0, 0] = frame["fl_x"] / scale
intrinsic[1, 1] = frame["fl_y"] / scale
intrinsic[0, 2] = frame["cx"] / scale
intrinsic[1, 2] = frame["cy"] / scale
frame_path = os.path.join(self.cfg.dataroot, frame["file_path"])
img = cv2.imread(frame_path)[:, :, ::-1].copy()
img = cv2.resize(img, (self.frame_w, self.frame_h))
img: Float[Tensor, "H W 3"] = torch.FloatTensor(img) / 255
frames_img.append(img)
direction: Float[Tensor, "H W 3"] = get_ray_directions(
self.frame_h,
self.frame_w,
(intrinsic[0, 0], intrinsic[1, 1]),
(intrinsic[0, 2], intrinsic[1, 2]),
use_pixel_centers=False,
)
c2w = c2w_list[idx]
camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1)
near = 0.1
far = 1000.0
K = intrinsic
proj = [
[
2 * K[0, 0] / self.frame_w,
-2 * K[0, 1] / self.frame_w,
(self.frame_w - 2 * K[0, 2]) / self.frame_w,
0,
],
[
0,
-2 * K[1, 1] / self.frame_h,
(self.frame_h - 2 * K[1, 2]) / self.frame_h,
0,
],
[
0,
0,
(-far - near) / (far - near),
-2 * far * near / (far - near),
],
[0, 0, -1, 0],
]
proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj)
frames_proj.append(proj)
frames_c2w.append(c2w)
frames_position.append(camera_position)
frames_direction.append(direction)
threestudio.info("Loaded frames.")
self.frames_proj: Float[Tensor, "B 4 4"] = torch.stack(frames_proj, dim=0)
self.frames_c2w: Float[Tensor, "B 4 4"] = torch.stack(frames_c2w, dim=0)
self.frames_position: Float[Tensor, "B 3"] = torch.stack(frames_position, dim=0)
self.frames_direction: Float[Tensor, "B H W 3"] = torch.stack(
frames_direction, dim=0
)
self.frames_img: Float[Tensor, "B H W 3"] = torch.stack(frames_img, dim=0)
self.rays_o, self.rays_d = get_rays(
self.frames_direction, self.frames_c2w, keepdim=True
)
self.mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(
self.frames_c2w, self.frames_proj
)
self.light_positions: Float[Tensor, "B 3"] = torch.zeros_like(
self.frames_position
)
def __len__(self):
return self.frames_proj.shape[0]
def __getitem__(self, index):
return {
"index": index,
"rays_o": self.rays_o[index],
"rays_d": self.rays_d[index],
"mvp_mtx": self.mvp_mtx[index],
"c2w": self.frames_c2w[index],
"camera_positions": self.frames_position[index],
"light_positions": self.light_positions[index],
"gt_rgb": self.frames_img[index],
}
def __iter__(self):
while True:
yield {}
def collate(self, batch):
batch = torch.utils.data.default_collate(batch)
batch.update({"height": self.frame_h, "width": self.frame_w})
return batch
@register("multiview-camera-datamodule")
class MultiviewDataModule(pl.LightningDataModule):
cfg: MultiviewsDataModuleConfig
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(MultiviewsDataModuleConfig, cfg)
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = MultiviewIterableDataset(self.cfg)
if stage in [None, "fit", "validate"]:
self.val_dataset = MultiviewDataset(self.cfg, "val")
if stage in [None, "test", "predict"]:
self.test_dataset = MultiviewDataset(self.cfg, "test")
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
return DataLoader(
dataset,
num_workers=1, # type: ignore
batch_size=batch_size,
collate_fn=collate_fn,
)
def train_dataloader(self) -> DataLoader:
return self.general_loader(
self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate
)
def val_dataloader(self) -> DataLoader:
return self.general_loader(
self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate
)
# return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate)
def test_dataloader(self) -> DataLoader:
return self.general_loader(
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
)
def predict_dataloader(self) -> DataLoader:
return self.general_loader(
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
)