Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,388 Bytes
f876753 fc44d4b f876753 fc44d4b f876753 fc44d4b f876753 fc44d4b f876753 fc44d4b f876753 fc44d4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from dataclasses import asdict, is_dataclass
import gc
def config_to_primitive(config):
"""Convert a dataclass config to a dictionary recursively."""
if is_dataclass(config):
config_dict = asdict(config)
return {k: config_to_primitive(v) for k, v in config_dict.items()}
elif isinstance(config, dict):
return {k: config_to_primitive(v) for k, v in config.items()}
elif isinstance(config, list):
return [config_to_primitive(v) for v in config]
elif isinstance(config, tuple):
return tuple(config_to_primitive(v) for v in config)
else:
return config
def scale_tensor(
dat,
inp_scale,
tgt_scale
):
if inp_scale is None:
inp_scale = (0, 1)
if tgt_scale is None:
tgt_scale = (0, 1)
if isinstance(tgt_scale, Tensor):
assert dat.shape[-1] == tgt_scale.shape[-1]
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
return dat
def contract_to_unisphere_custom(
x, bbox, unbounded = False
):
if unbounded:
x = scale_tensor(x, bbox, (-1, 1))
x = x * 2 - 1 # aabb is at [-1, 1]
mag = x.norm(dim=-1, keepdim=True)
mask = mag.squeeze(-1) > 1
x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
x = x / 4 + 0.5 # [-inf, inf] is at [0, 1]
else:
x = scale_tensor(x, bbox, (-1, 1))
return x
# bug fix in https://github.com/NVlabs/eg3d/issues/67
planes = torch.tensor(
[
[
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
],
[
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]
],
[
[0, 0, 1],
[0, 1, 0],
[1, 0, 0]
]
], dtype=torch.float32)
def grid_sample(input, grid):
# if grid.requires_grad and _should_use_custom_op():
# return grid_sample_2d(input, grid, padding_mode='zeros', align_corners=False)
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
def project_onto_planes(planes, coordinates):
"""
Does a projection of a 3D point onto a batch of 2D planes,
returning 2D plane coordinates.
Takes plane axes of shape n_planes, 3, 3
# Takes coordinates of shape N, M, 3
# returns projections of shape N*n_planes, M, 2
"""
N, M, C = coordinates.shape
n_planes, _, _ = planes.shape
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
projections = torch.bmm(coordinates, inv_planes)
return projections[..., :2]
def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=2, interpolate_feat = None):
assert padding_mode == 'zeros'
N, n_planes, C, H, W = plane_features.shape
_, M, _ = coordinates.shape
plane_features = plane_features.view(N*n_planes, C, H, W)
coordinates = (2/box_warp) * coordinates # add specific box bounds
if interpolate_feat in [None, "v1"]:
projected_coordinates = project_onto_planes(planes.to(coordinates), coordinates).unsqueeze(1)
output_features = grid_sample(plane_features, projected_coordinates.float())
output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
output_features = output_features.sum(dim=1, keepdim=True).reshape(N, M, C)
elif interpolate_feat in ["v2"]:
projected_coordinates = project_onto_planes(planes.to(coordinates), coordinates).unsqueeze(1)
output_features = grid_sample(plane_features, projected_coordinates.float())
output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
output_features = output_features.permute(0, 2, 1, 3).reshape(N, M, n_planes*C)
return output_features.contiguous()
def cleanup():
"""Cleanup torch memory."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect() |