File size: 4,388 Bytes
f876753
 
 
 
fc44d4b
 
 
f876753
fc44d4b
 
 
 
 
 
 
 
 
 
 
 
 
f876753
 
fc44d4b
 
 
f876753
 
 
 
 
 
 
 
 
 
 
 
fc44d4b
 
f876753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc44d4b
f876753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc44d4b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from dataclasses import asdict, is_dataclass
import gc

def config_to_primitive(config):
    """Convert a dataclass config to a dictionary recursively."""
    if is_dataclass(config):
        config_dict = asdict(config)
        return {k: config_to_primitive(v) for k, v in config_dict.items()}
    elif isinstance(config, dict):
        return {k: config_to_primitive(v) for k, v in config.items()}
    elif isinstance(config, list):
        return [config_to_primitive(v) for v in config]
    elif isinstance(config, tuple):
        return tuple(config_to_primitive(v) for v in config)
    else:
        return config

def scale_tensor(
    dat, 
    inp_scale, 
    tgt_scale
):
    if inp_scale is None:
        inp_scale = (0, 1)
    if tgt_scale is None:
        tgt_scale = (0, 1)
    if isinstance(tgt_scale, Tensor):
        assert dat.shape[-1] == tgt_scale.shape[-1]
    dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
    dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
    return dat

def contract_to_unisphere_custom(
    x, bbox, unbounded = False
):
    if unbounded:
        x = scale_tensor(x, bbox, (-1, 1))
        x = x * 2 - 1  # aabb is at [-1, 1]
        mag = x.norm(dim=-1, keepdim=True)
        mask = mag.squeeze(-1) > 1
        x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
        x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
    else:
        x = scale_tensor(x, bbox, (-1, 1))
    return x

# bug fix in https://github.com/NVlabs/eg3d/issues/67
planes =  torch.tensor(
            [
                [
                    [1, 0, 0],
                    [0, 1, 0],
                    [0, 0, 1]
                ],
                [
                    [1, 0, 0],
                    [0, 0, 1],
                    [0, 1, 0]
                ],
                [
                    [0, 0, 1],
                    [0, 1, 0],
                    [1, 0, 0]
                ]
            ], dtype=torch.float32)


def grid_sample(input, grid):
    # if grid.requires_grad and _should_use_custom_op():
    #     return grid_sample_2d(input, grid, padding_mode='zeros', align_corners=False)
    return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)


def project_onto_planes(planes, coordinates):
    """
    Does a projection of a 3D point onto a batch of 2D planes,
    returning 2D plane coordinates.

    Takes plane axes of shape n_planes, 3, 3
    # Takes coordinates of shape N, M, 3
    # returns projections of shape N*n_planes, M, 2
    """
    N, M, C = coordinates.shape
    n_planes, _, _ = planes.shape
    coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
    inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
    projections = torch.bmm(coordinates, inv_planes)
    return projections[..., :2]

def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=2, interpolate_feat = None):
    assert padding_mode == 'zeros'
    N, n_planes, C, H, W = plane_features.shape
    _, M, _ = coordinates.shape
    plane_features = plane_features.view(N*n_planes, C, H, W)

    coordinates = (2/box_warp) * coordinates # add specific box bounds

    if interpolate_feat in [None, "v1"]:
        projected_coordinates = project_onto_planes(planes.to(coordinates), coordinates).unsqueeze(1)
        output_features = grid_sample(plane_features, projected_coordinates.float())
        output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
        output_features = output_features.sum(dim=1, keepdim=True).reshape(N, M, C)

    elif interpolate_feat in ["v2"]:
        projected_coordinates = project_onto_planes(planes.to(coordinates), coordinates).unsqueeze(1)
        output_features = grid_sample(plane_features, projected_coordinates.float())
        output_features = output_features.permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
        output_features = output_features.permute(0, 2, 1, 3).reshape(N, M, n_planes*C)        

    return output_features.contiguous()

def cleanup():
    """Cleanup torch memory."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()