LSM / src /utils /points_process.py
kairunwen's picture
Update Code
57746f1
import torch
from einops import rearrange
# merge points from two views and add color information
def merge_points(mast3r_output, view1, view2, grid_size=0.01):
# get points from mast3r_output
points1 = mast3r_output[0]['pts3d'].detach() # B, H, W, 3
points2 = mast3r_output[1]['pts3d_in_other_view'].detach() # B, H, W, 3
shape = points1.shape
# add color information
colors = torch.stack([view1['img'], view2['img']], dim=1) # B, V, 3, H, W
colors = rearrange(colors, 'b v c h w -> b (v h w) c') # B, V * H * W, 3
# merge points
points = torch.stack([points1, points2], dim=1) # B, V, H, W, 3
points = rearrange(points, 'b v h w c -> b (v h w) c') # B, V * H * W, 3
B, N, _ = points.shape
offset = torch.arange(1, B + 1, device=points.device) * N
# Center and normalize points
center = torch.mean(points, dim=1, keepdim=True)
points = points - center
scale = torch.max(torch.norm(points, dim=2, keepdim=True), dim=1, keepdim=True)[0]
points = points / scale
# concat points and colors
feat = torch.cat([points, colors], dim=-1) # B, V * H * W, 6
data_dict = {
'coord': rearrange(points, 'b n c -> (b n) c'),
'color': rearrange(colors, 'b n c -> (b n) c'),
'feat': rearrange(feat, 'b n c -> (b n) c'),
'offset': offset,
'grid_size': grid_size,
'center': center,
'scale': scale,
'shape': shape,
}
return data_dict