File size: 3,041 Bytes
851751e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from functools import partial

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F


class GeoConverter(nn.Module):
    def __init__(self, curve_length=4, bev_only=False, dataset_config=dict()):
        super().__init__()
        self.curve_length = curve_length
        self.coord_dim = 3 if not bev_only else 2
        self.convert_fn = self.batch_range2bev if bev_only else self.batch_range2xyz

        fov = dataset_config.fov
        self.fov_up = fov[0] / 180.0 * np.pi  # field of view up in rad
        self.fov_down = fov[1] / 180.0 * np.pi  # field of view down in rad
        self.fov_range = abs(self.fov_down) + abs(self.fov_up)  # get field of view total in rad
        self.depth_scale = dataset_config.depth_scale
        self.depth_min, self.depth_max = dataset_config.depth_range
        self.log_scale = dataset_config.log_scale
        self.size = dataset_config['size']
        self.register_conversion()

    def register_conversion(self):
        scan_x, scan_y = np.meshgrid(np.arange(self.size[1]), np.arange(self.size[0]))
        scan_x = scan_x.astype(np.float64) / self.size[1]
        scan_y = scan_y.astype(np.float64) / self.size[0]

        yaw = (np.pi * (scan_x * 2 - 1))
        pitch = ((1.0 - scan_y) * self.fov_range - abs(self.fov_down))

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('cos_yaw', torch.cos(to_torch(yaw)))
        self.register_buffer('sin_yaw', torch.sin(to_torch(yaw)))
        self.register_buffer('cos_pitch', torch.cos(to_torch(pitch)))
        self.register_buffer('sin_pitch', torch.sin(to_torch(pitch)))

    def batch_range2xyz(self, imgs):
        batch_depth = (imgs * 0.5 + 0.5) * self.depth_scale
        if self.log_scale:
            batch_depth = torch.exp2(batch_depth) - 1
        batch_depth = batch_depth.clamp(self.depth_min, self.depth_max)

        batch_x = self.cos_yaw * self.cos_pitch * batch_depth
        batch_y = -self.sin_yaw * self.cos_pitch * batch_depth
        batch_z = self.sin_pitch * batch_depth
        batch_xyz = torch.cat([batch_x, batch_y, batch_z], dim=1)

        return batch_xyz

    def batch_range2bev(self, imgs):
        batch_depth = (imgs * 0.5 + 0.5) * self.depth_scale
        if self.log_scale:
            batch_depth = torch.exp2(batch_depth) - 1
        batch_depth = batch_depth.clamp(self.depth_min, self.depth_max)

        batch_x = self.cos_yaw * self.cos_pitch * batch_depth
        batch_y = -self.sin_yaw * self.cos_pitch * batch_depth
        batch_bev = torch.cat([batch_x, batch_y], dim=1)

        return batch_bev

    def curve_compress(self, batch_coord):
        compressed_batch_coord = F.avg_pool2d(batch_coord, (1, self.curve_length))

        return compressed_batch_coord

    def forward(self, input):
        input = input / 2. + .5  # [-1, 1] -> [0, 1]

        input_coord = self.convert_fn(input)
        if self.curve_length > 1:
            input_coord = self.curve_compress(input_coord)

        return input_coord