import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor import numpy as np from dataclasses import asdict, is_dataclass import gc def config_to_primitive(config): """Convert a dataclass config to a dictionary recursively.""" if is_dataclass(config): config_dict = asdict(config) return {k: config_to_primitive(v) for k, v in config_dict.items()} elif isinstance(config, dict): return {k: config_to_primitive(v) for k, v in config.items()} elif isinstance(config, list): return [config_to_primitive(v) for v in config] elif isinstance(config, tuple): return tuple(config_to_primitive(v) for v in config) else: return config def scale_tensor( dat, inp_scale, tgt_scale ): if inp_scale is None: inp_scale = (0, 1) if tgt_scale is None: tgt_scale = (0, 1) if isinstance(tgt_scale, Tensor): assert dat.shape[-1] == tgt_scale.shape[-1] dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] return dat def contract_to_unisphere_custom( x, bbox, unbounded = False ): if unbounded: x = scale_tensor(x, bbox, (-1, 1)) x = x * 2 - 1 # aabb is at [-1, 1] mag = x.norm(dim=-1, keepdim=True) mask = mag.squeeze(-1) > 1 x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] else: x = scale_tensor(x, bbox, (-1, 1)) return x # bug fix in https://github.com/NVlabs/eg3d/issues/67 planes = torch.tensor( [ [ [1, 0, 0], [0, 1, 0], [0, 0, 1] ], [ [1, 0, 0], [0, 0, 1], [0, 1, 0] ], [ [0, 0, 1], [0, 1, 0], [1, 0, 0] ] ], dtype=torch.float32) def grid_sample(input, grid): # if grid.requires_grad and _should_use_custom_op(): # return grid_sample_2d(input, grid, padding_mode='zeros', align_corners=False) return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) def project_onto_planes(planes, coordinates): """ Does a projection of a 3D point onto a batch of 2D planes, returning 2D plane coordinates. Takes plane axes of shape n_planes, 3, 3 # Takes coordinates of shape N, M, 3 # returns projections of shape N*n_planes, M, 2 """ N, M, C = coordinates.shape n_planes, _, _ = planes.shape coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) projections = torch.bmm(coordinates, inv_planes) return projections[..., :2] def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=2, interpolate_feat = None): assert padding_mode == 'zeros' N, n_planes, C, H, W = plane_features.shape _, M, _ = coordinates.shape plane_features = plane_features.view(N*n_planes, C, H, W) coordinates = (2/box_warp) * coordinates # add specific box bounds if interpolate_feat in [None, "v1"]: projected_coordinates = project_onto_planes(planes.to(coordinates), coordinates).unsqueeze(1) output_features = grid_sample(plane_features, projected_coordinates.float()) output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C) output_features = output_features.sum(dim=1, keepdim=True).reshape(N, M, C) elif interpolate_feat in ["v2"]: projected_coordinates = project_onto_planes(planes.to(coordinates), coordinates).unsqueeze(1) output_features = grid_sample(plane_features, projected_coordinates.float()) output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C) output_features = output_features.permute(0, 2, 1, 3).reshape(N, M, n_planes*C) return output_features.contiguous() def cleanup(): """Cleanup torch memory.""" gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect()