thewhole's picture
Upload 245 files
2fa4776
import gzip
import json
import os
import warnings
from dataclasses import dataclass, field
from typing import List
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import DataLoader, Dataset, IterableDataset
from threestudio import register
from threestudio.data.uncond import (
RandomCameraDataModuleConfig,
RandomCameraDataset,
RandomCameraIterableDataset,
)
from threestudio.utils.config import parse_structured
from threestudio.utils.misc import get_rank
from threestudio.utils.ops import (
get_mvp_matrix,
get_projection_matrix,
get_ray_directions,
get_rays,
)
from threestudio.utils.typing import *
def _load_16big_png_depth(depth_png) -> np.ndarray:
with Image.open(depth_png) as depth_pil:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth
def _load_depth(path, scale_adjustment) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth file name "%s"' % path)
d = _load_16big_png_depth(path) * scale_adjustment
d[~np.isfinite(d)] = 0.0
return d[None] # fake feature channel
# Code adapted from https://github.com/eldar/snes/blob/473ff2b1f6/3rdparty/co3d/dataset/co3d_dataset.py
def _get_1d_bounds(arr):
nz = np.flatnonzero(arr)
return nz[0], nz[-1]
def get_bbox_from_mask(mask, thr, decrease_quant=0.05):
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
masks_for_box = (mask > thr).astype(np.float32)
thr -= decrease_quant
if thr <= 0.0:
warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
return x0, y0, x1 - x0, y1 - y0
def get_clamp_bbox(bbox, box_crop_context=0.0, impath=""):
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.astype(np.float32)
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c
if (bbox[2:] <= 1.0).any():
warnings.warn(f"squashed image {impath}!!")
return None
# bbox[2:] = np.clip(bbox[2:], 2, )
bbox[2:] = np.maximum(bbox[2:], 2)
bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax]
# +1 because upper bound is not inclusive
return bbox
def crop_around_box(tensor, bbox, impath=""):
bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0.0, tensor.shape[-2])
bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0.0, tensor.shape[-3])
bbox = bbox.round().astype(np.longlong)
return tensor[bbox[1] : bbox[3], bbox[0] : bbox[2], ...]
def resize_image(image, height, width, mode="bilinear"):
if image.shape[:2] == (height, width):
return image, 1.0, np.ones_like(image[..., :1])
image = torch.from_numpy(image).permute(2, 0, 1)
minscale = min(height / image.shape[-2], width / image.shape[-1])
imre = torch.nn.functional.interpolate(
image[None],
scale_factor=minscale,
mode=mode,
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
# pyre-fixme[19]: Expected 1 positional argument.
imre_ = torch.zeros(image.shape[0], height, width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
mask = torch.zeros(1, height, width)
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
return imre_.permute(1, 2, 0).numpy(), minscale, mask.permute(1, 2, 0).numpy()
# Code adapted from https://github.com/POSTECH-CVLab/PeRFception/data_util/co3d.py
def similarity_from_cameras(c2w, fix_rot=False, radius=1.0):
"""
Get a similarity transform to normalize dataset
from c2w (OpenCV convention) cameras
:param c2w: (N, 4)
:return T (4,4) , scale (float)
"""
t = c2w[:, :3, 3]
R = c2w[:, :3, :3]
# (1) Rotate the world so that z+ is the up axis
# we estimate the up axis by averaging the camera up axes
ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
world_up = np.mean(ups, axis=0)
world_up /= np.linalg.norm(world_up)
up_camspace = np.array([0.0, 0.0, 1.0])
c = (up_camspace * world_up).sum()
cross = np.cross(world_up, up_camspace)
skew = np.array(
[
[0.0, -cross[2], cross[1]],
[cross[2], 0.0, -cross[0]],
[-cross[1], cross[0], 0.0],
]
)
if c > -1:
R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
else:
# In the unlikely case the original data has y+ up axis,
# rotate 180-deg about x axis
R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
if fix_rot:
R_align = np.eye(3)
R = np.eye(3)
else:
R = R_align @ R
fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
t = (R_align @ t[..., None])[..., 0]
# (2) Recenter the scene using camera center rays
# find the closest point to the origin for each camera's center ray
nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
# median for more robustness
translate = -np.median(nearest, axis=0)
# translate = -np.mean(t, axis=0) # DEBUG
transform = np.eye(4)
transform[:3, 3] = translate
transform[:3, :3] = R_align
# (3) Rescale the scene using camera distances
scale = radius / np.median(np.linalg.norm(t + translate, axis=-1))
return transform, scale
@dataclass
class Co3dDataModuleConfig:
root_dir: str = ""
batch_size: int = 1
height: int = 256
width: int = 256
load_preprocessed: bool = False
cam_scale_factor: float = 0.95
max_num_frames: int = 300
v2_mode: bool = True
use_mask: bool = True
box_crop: bool = True
box_crop_mask_thr: float = 0.4
box_crop_context: float = 0.3
train_num_rays: int = -1
train_views: Optional[list] = None
train_split: str = "train"
val_split: str = "val"
test_split: str = "test"
scale_radius: float = 1.0
use_random_camera: bool = True
random_camera: dict = field(default_factory=dict)
rays_noise_scale: float = 0.0
render_path: str = "circle"
class Co3dDatasetBase:
def setup(self, cfg, split):
self.split = split
self.rank = get_rank()
self.cfg: Co3dDataModuleConfig = cfg
if self.cfg.use_random_camera:
random_camera_cfg = parse_structured(
RandomCameraDataModuleConfig, self.cfg.get("random_camera", {})
)
if split == "train":
self.random_pose_generator = RandomCameraIterableDataset(
random_camera_cfg
)
else:
self.random_pose_generator = RandomCameraDataset(
random_camera_cfg, split
)
self.use_mask = self.cfg.use_mask
cam_scale_factor = self.cfg.cam_scale_factor
assert os.path.exists(self.cfg.root_dir), f"{self.cfg.root_dir} doesn't exist!"
cam_trans = np.diag(np.array([-1, -1, 1, 1], dtype=np.float32))
scene_number = self.cfg.root_dir.split("/")[-1]
json_path = os.path.join(self.cfg.root_dir, "..", "frame_annotations.jgz")
with gzip.open(json_path, "r") as fp:
all_frames_data = json.load(fp)
frame_data, images, intrinsics, extrinsics, image_sizes = [], [], [], [], []
masks = []
depths = []
for temporal_data in all_frames_data:
if temporal_data["sequence_name"] == scene_number:
frame_data.append(temporal_data)
self.all_directions = []
self.all_fg_masks = []
for frame in frame_data:
if "unseen" in frame["meta"]["frame_type"]:
continue
img = cv2.imread(
os.path.join(self.cfg.root_dir, "..", "..", frame["image"]["path"])
)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
# TODO: use estimated depth
depth = _load_depth(
os.path.join(self.cfg.root_dir, "..", "..", frame["depth"]["path"]),
frame["depth"]["scale_adjustment"],
)[0]
H, W = frame["image"]["size"]
image_size = np.array([H, W])
fxy = np.array(frame["viewpoint"]["focal_length"])
cxy = np.array(frame["viewpoint"]["principal_point"])
R = np.array(frame["viewpoint"]["R"])
T = np.array(frame["viewpoint"]["T"])
if self.cfg.v2_mode:
min_HW = min(W, H)
image_size_half = np.array([W * 0.5, H * 0.5], dtype=np.float32)
scale_arr = np.array([min_HW * 0.5, min_HW * 0.5], dtype=np.float32)
fxy_x = fxy * scale_arr
prp_x = np.array([W * 0.5, H * 0.5], dtype=np.float32) - cxy * scale_arr
cxy = (image_size_half - prp_x) / image_size_half
fxy = fxy_x / image_size_half
scale_arr = np.array([W * 0.5, H * 0.5], dtype=np.float32)
focal = fxy * scale_arr
prp = -1.0 * (cxy - 1.0) * scale_arr
pose = np.eye(4)
pose[:3, :3] = R
pose[:3, 3:] = -R @ T[..., None]
# original camera: x left, y up, z in (Pytorch3D)
# transformed camera: x right, y down, z in (OpenCV)
pose = pose @ cam_trans
intrinsic = np.array(
[
[focal[0], 0.0, prp[0], 0.0],
[0.0, focal[1], prp[1], 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
)
if any([np.all(pose == _pose) for _pose in extrinsics]):
continue
image_sizes.append(image_size)
intrinsics.append(intrinsic)
extrinsics.append(pose)
images.append(img)
depths.append(depth)
self.all_directions.append(get_ray_directions(W, H, focal, prp))
# vis_utils.vis_depth_pcd([depth], [pose], intrinsic, [(img * 255).astype(np.uint8)])
if self.use_mask:
mask = np.array(
Image.open(
os.path.join(
self.cfg.root_dir, "..", "..", frame["mask"]["path"]
)
)
)
mask = mask.astype(np.float32) / 255.0 # (h, w)
else:
mask = torch.ones_like(img[..., 0])
self.all_fg_masks.append(mask)
intrinsics = np.stack(intrinsics)
extrinsics = np.stack(extrinsics)
image_sizes = np.stack(image_sizes)
self.all_directions = torch.stack(self.all_directions, dim=0)
self.all_fg_masks = np.stack(self.all_fg_masks, 0)
H_median, W_median = np.median(
np.stack([image_size for image_size in image_sizes]), axis=0
)
H_inlier = np.abs(image_sizes[:, 0] - H_median) / H_median < 0.1
W_inlier = np.abs(image_sizes[:, 1] - W_median) / W_median < 0.1
inlier = np.logical_and(H_inlier, W_inlier)
dists = np.linalg.norm(
extrinsics[:, :3, 3] - np.median(extrinsics[:, :3, 3], axis=0), axis=-1
)
med = np.median(dists)
good_mask = dists < (med * 5.0)
inlier = np.logical_and(inlier, good_mask)
if inlier.sum() != 0:
intrinsics = intrinsics[inlier]
extrinsics = extrinsics[inlier]
image_sizes = image_sizes[inlier]
images = [images[i] for i in range(len(inlier)) if inlier[i]]
depths = [depths[i] for i in range(len(inlier)) if inlier[i]]
self.all_directions = self.all_directions[inlier]
self.all_fg_masks = self.all_fg_masks[inlier]
extrinsics = np.stack(extrinsics)
T, sscale = similarity_from_cameras(extrinsics, radius=self.cfg.scale_radius)
extrinsics = T @ extrinsics
extrinsics[:, :3, 3] *= sscale * cam_scale_factor
depths = [depth * sscale * cam_scale_factor for depth in depths]
num_frames = len(extrinsics)
if self.cfg.max_num_frames < num_frames:
num_frames = self.cfg.max_num_frames
extrinsics = extrinsics[:num_frames]
intrinsics = intrinsics[:num_frames]
image_sizes = image_sizes[:num_frames]
images = images[:num_frames]
depths = depths[:num_frames]
self.all_directions = self.all_directions[:num_frames]
self.all_fg_masks = self.all_fg_masks[:num_frames]
if self.cfg.box_crop:
print("cropping...")
crop_masks = []
crop_imgs = []
crop_depths = []
crop_directions = []
crop_xywhs = []
max_sl = 0
for i in range(num_frames):
bbox_xywh = np.array(
get_bbox_from_mask(self.all_fg_masks[i], self.cfg.box_crop_mask_thr)
)
clamp_bbox_xywh = get_clamp_bbox(bbox_xywh, self.cfg.box_crop_context)
max_sl = max(clamp_bbox_xywh[2] - clamp_bbox_xywh[0], max_sl)
max_sl = max(clamp_bbox_xywh[3] - clamp_bbox_xywh[1], max_sl)
mask = crop_around_box(self.all_fg_masks[i][..., None], clamp_bbox_xywh)
img = crop_around_box(images[i], clamp_bbox_xywh)
depth = crop_around_box(depths[i][..., None], clamp_bbox_xywh)
# resize to the same shape
mask, _, _ = resize_image(mask, self.cfg.height, self.cfg.width)
depth, _, _ = resize_image(depth, self.cfg.height, self.cfg.width)
img, scale, _ = resize_image(img, self.cfg.height, self.cfg.width)
fx, fy, cx, cy = (
intrinsics[i][0, 0],
intrinsics[i][1, 1],
intrinsics[i][0, 2],
intrinsics[i][1, 2],
)
crop_masks.append(mask)
crop_imgs.append(img)
crop_depths.append(depth)
crop_xywhs.append(clamp_bbox_xywh)
crop_directions.append(
get_ray_directions(
self.cfg.height,
self.cfg.width,
(fx * scale, fy * scale),
(
(cx - clamp_bbox_xywh[0]) * scale,
(cy - clamp_bbox_xywh[1]) * scale,
),
)
)
# # pad all images to the same shape
# for i in range(num_frames):
# uh = (max_sl - crop_imgs[i].shape[0]) // 2 # h
# dh = max_sl - crop_imgs[i].shape[0] - uh
# lw = (max_sl - crop_imgs[i].shape[1]) // 2
# rw = max_sl - crop_imgs[i].shape[1] - lw
# crop_masks[i] = np.pad(crop_masks[i], pad_width=((uh, dh), (lw, rw), (0, 0)), mode='constant', constant_values=0.)
# crop_imgs[i] = np.pad(crop_imgs[i], pad_width=((uh, dh), (lw, rw), (0, 0)), mode='constant', constant_values=1.)
# crop_depths[i] = np.pad(crop_depths[i], pad_width=((uh, dh), (lw, rw), (0, 0)), mode='constant', constant_values=0.)
# fx, fy, cx, cy = intrinsics[i][0, 0], intrinsics[i][1, 1], intrinsics[i][0, 2], intrinsics[i][1, 2]
# crop_directions.append(get_ray_directions(max_sl, max_sl, (fx, fy), (cx - crop_xywhs[i][0] + lw, cy - crop_xywhs[i][1] + uh)))
# self.w, self.h = max_sl, max_sl
images = crop_imgs
depths = crop_depths
self.all_fg_masks = np.stack(crop_masks, 0)
self.all_directions = torch.from_numpy(np.stack(crop_directions, 0))
# self.width, self.height = self.w, self.h
self.all_c2w = torch.from_numpy(
(
extrinsics
@ np.diag(np.array([1, -1, -1, 1], dtype=np.float32))[None, ...]
)[..., :3, :4]
)
self.all_images = torch.from_numpy(np.stack(images, axis=0))
self.all_depths = torch.from_numpy(np.stack(depths, axis=0))
# self.all_c2w = []
# self.all_images = []
# for i in range(num_frames):
# # convert to: x right, y up, z back (OpenGL)
# c2w = torch.from_numpy(extrinsics[i] @ np.diag(np.array([1, -1, -1, 1], dtype=np.float32)))[:3, :4]
# self.all_c2w.append(c2w)
# img = torch.from_numpy(images[i])
# self.all_images.append(img)
# TODO: save data for fast loading next time
if self.cfg.load_preprocessed and os.path.exists(
self.cfg.root_dir, "nerf_preprocessed.npy"
):
pass
i_all = np.arange(num_frames)
if self.cfg.train_views is None:
i_test = i_all[::10]
i_val = i_test
i_train = np.array([i for i in i_all if not i in i_test])
else:
# use provided views
i_train = self.cfg.train_views
i_test = np.array([i for i in i_all if not i in i_train])
i_val = i_test
if self.split == "train":
print("[INFO] num of train views: ", len(i_train))
print("[INFO] train view ids = ", i_train)
i_split = {"train": i_train, "val": i_val, "test": i_all}
# if self.split == 'test':
# self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.cfg.n_test_traj_steps)
# self.all_images = torch.zeros((self.cfg.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
# self.all_fg_masks = torch.zeros((self.cfg.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
# self.directions = self.directions[0].to(self.rank)
# else:
self.all_images, self.all_c2w = (
self.all_images[i_split[self.split]],
self.all_c2w[i_split[self.split]],
)
self.all_directions = self.all_directions[i_split[self.split]].to(self.rank)
self.all_fg_masks = torch.from_numpy(self.all_fg_masks)[i_split[self.split]]
self.all_depths = self.all_depths[i_split[self.split]]
# if render_random_pose:
# render_poses = random_pose(extrinsics[i_all], 50)
# elif render_scene_interp:
# render_poses = pose_interp(extrinsics[i_all], interp_fac)
# render_poses = spherical_poses(sscale * cam_scale_factor * np.eye(4))
# near, far = 0., 1.
# ndc_coeffs = (-1., -1.)
self.all_c2w, self.all_images, self.all_fg_masks = (
self.all_c2w.float().to(self.rank),
self.all_images.float().to(self.rank),
self.all_fg_masks.float().to(self.rank),
)
# self.all_c2w, self.all_images, self.all_fg_masks = \
# self.all_c2w.float(), \
# self.all_images.float(), \
# self.all_fg_masks.float()
self.all_depths = self.all_depths.float().to(self.rank)
def get_all_images(self):
return self.all_images
class Co3dDataset(Dataset, Co3dDatasetBase):
def __init__(self, cfg, split):
self.setup(cfg, split)
def __len__(self):
if self.split == "test":
if self.cfg.render_path == "circle":
return len(self.random_pose_generator)
else:
return len(self.all_images)
else:
return len(self.random_pose_generator)
# return len(self.all_images)
def prepare_data(self, index):
# prepare batch data here
c2w = self.all_c2w[index]
light_positions = c2w[..., :3, -1]
directions = self.all_directions[index]
rays_o, rays_d = get_rays(
directions, c2w, keepdim=True, noise_scale=self.cfg.rays_noise_scale
)
rgb = self.all_images[index]
depth = self.all_depths[index]
mask = self.all_fg_masks[index]
# TODO: get projection matrix and mvp matrix
# proj_mtx = get_projection_matrix()
batch = {
"rays_o": rays_o,
"rays_d": rays_d,
"mvp_mtx": 0,
"camera_positions": c2w[..., :3, -1],
"light_positions": light_positions,
"elevation": 0,
"azimuth": 0,
"camera_distances": 0,
"rgb": rgb,
"depth": depth,
"mask": mask,
}
# c2w = self.all_c2w[index]
# return {
# 'index': index,
# 'c2w': c2w,
# 'light_positions': c2w[:3, -1],
# 'H': self.h,
# 'W': self.w
# }
return batch
def __getitem__(self, index):
if self.split == "test":
if self.cfg.render_path == "circle":
return self.random_pose_generator[index]
else:
return self.prepare_data(index)
else:
return self.random_pose_generator[index]
class Co3dIterableDataset(IterableDataset, Co3dDatasetBase):
def __init__(self, cfg, split):
self.setup(cfg, split)
self.idx = 0
self.image_perm = torch.randperm(len(self.all_images))
def __iter__(self):
while True:
yield {}
def collate(self, batch) -> Dict[str, Any]:
idx = self.image_perm[self.idx]
# prepare batch data here
c2w = self.all_c2w[idx][None]
light_positions = c2w[..., :3, -1]
directions = self.all_directions[idx][None]
rays_o, rays_d = get_rays(
directions, c2w, keepdim=True, noise_scale=self.cfg.rays_noise_scale
)
rgb = self.all_images[idx][None]
depth = self.all_depths[idx][None]
mask = self.all_fg_masks[idx][None]
if (
self.cfg.train_num_rays != -1
and self.cfg.train_num_rays < self.cfg.height * self.cfg.width
):
_, height, width, _ = rays_o.shape
x = torch.randint(
0, width, size=(self.cfg.train_num_rays,), device=rays_o.device
)
y = torch.randint(
0, height, size=(self.cfg.train_num_rays,), device=rays_o.device
)
rays_o = rays_o[:, y, x].unsqueeze(-2)
rays_d = rays_d[:, y, x].unsqueeze(-2)
directions = directions[:, y, x].unsqueeze(-2)
rgb = rgb[:, y, x].unsqueeze(-2)
mask = mask[:, y, x].unsqueeze(-2)
depth = depth[:, y, x].unsqueeze(-2)
# TODO: get projection matrix and mvp matrix
# proj_mtx = get_projection_matrix()
batch = {
"rays_o": rays_o,
"rays_d": rays_d,
"mvp_mtx": None,
"camera_positions": c2w[..., :3, -1],
"light_positions": light_positions,
"elevation": None,
"azimuth": None,
"camera_distances": None,
"rgb": rgb,
"depth": depth,
"mask": mask,
}
if self.cfg.use_random_camera:
batch["random_camera"] = self.random_pose_generator.collate(None)
# prepare batch data in system
# c2w = self.all_c2w[idx][None]
# batch = {
# 'index': torch.tensor([idx]),
# 'c2w': c2w,
# 'light_positions': c2w[..., :3, -1],
# 'H': self.h,
# 'W': self.w
# }
self.idx += 1
if self.idx == len(self.all_images):
self.idx = 0
self.image_perm = torch.randperm(len(self.all_images))
# self.idx = (self.idx + 1) % len(self.all_images)
return batch
@register("co3d-datamodule")
class Co3dDataModule(pl.LightningDataModule):
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(Co3dDataModuleConfig, cfg)
def setup(self, stage=None):
if stage in [None, "fit"]:
self.train_dataset = Co3dIterableDataset(self.cfg, self.cfg.train_split)
if stage in [None, "fit", "validate"]:
self.val_dataset = Co3dDataset(self.cfg, self.cfg.val_split)
if stage in [None, "test", "predict"]:
self.test_dataset = Co3dDataset(self.cfg, self.cfg.test_split)
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
sampler = None
return DataLoader(
dataset,
num_workers=0,
batch_size=batch_size,
# pin_memory=True,
collate_fn=collate_fn,
)
def train_dataloader(self):
return self.general_loader(
self.train_dataset, batch_size=1, collate_fn=self.train_dataset.collate
)
def val_dataloader(self):
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self):
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self):
return self.general_loader(self.test_dataset, batch_size=1)