|
import torch |
|
from einops import rearrange |
|
|
|
|
|
def merge_points(mast3r_output, view1, view2, grid_size=0.01): |
|
|
|
points1 = mast3r_output[0]['pts3d'].detach() |
|
points2 = mast3r_output[1]['pts3d_in_other_view'].detach() |
|
shape = points1.shape |
|
|
|
colors = torch.stack([view1['img'], view2['img']], dim=1) |
|
colors = rearrange(colors, 'b v c h w -> b (v h w) c') |
|
|
|
points = torch.stack([points1, points2], dim=1) |
|
points = rearrange(points, 'b v h w c -> b (v h w) c') |
|
B, N, _ = points.shape |
|
offset = torch.arange(1, B + 1, device=points.device) * N |
|
|
|
center = torch.mean(points, dim=1, keepdim=True) |
|
points = points - center |
|
scale = torch.max(torch.norm(points, dim=2, keepdim=True), dim=1, keepdim=True)[0] |
|
points = points / scale |
|
|
|
feat = torch.cat([points, colors], dim=-1) |
|
|
|
data_dict = { |
|
'coord': rearrange(points, 'b n c -> (b n) c'), |
|
'color': rearrange(colors, 'b n c -> (b n) c'), |
|
'feat': rearrange(feat, 'b n c -> (b n) c'), |
|
'offset': offset, |
|
'grid_size': grid_size, |
|
'center': center, |
|
'scale': scale, |
|
'shape': shape, |
|
} |
|
|
|
return data_dict |
|
|