File size: 4,795 Bytes
57746f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import os.path as osp
import sys
sys.path.append("submodules/mast3r/dust3r")
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
import numpy as np
import cv2
from dust3r.utils.image import imread_cv2

class Scannet(BaseStereoViewDataset):
    def __init__(self, *args, ROOT, **kwargs):
        self.ROOT = ROOT
        super().__init__(*args, **kwargs)
        self.num_views = 3 # render third view
        self._load_data()
        
    def _load_data(self):
        # Traverse all the folders in the data_root
        scene_names = [folder for folder in os.listdir(self.ROOT) if os.path.isdir(os.path.join(self.ROOT, folder))]
        # Filter out scenes without scene_data.npz
        valid_scenes = []
        for scene_name in scene_names:
            scene_data_path = osp.join(self.ROOT, scene_name, "scene_data.npz")
            if osp.exists(scene_data_path):
                valid_scenes.append(scene_name)
            else:
                print(f"Skipping {scene_name}: scene_data.npz not found")
        scene_names = valid_scenes
        scene_names.sort()
        if self.split == 'train':
            scene_names = scene_names[:-150]
        else:
            scene_names = scene_names[-150:]
        # merge all pairs and images
        pairs = [] # (scene_name, image_idx1, image_idx2)
        images = {} # (scene_name, image_idx) -> image_path
        for scene_name in scene_names:
            scene_path = osp.join(self.ROOT, scene_name, "scene_data.npz")
            scene_data = np.load(scene_path)
            pairs.extend([(scene_name, *pair) for pair in scene_data['pairs']])
            images.update({(scene_name, idx): path for idx, path in enumerate(scene_data['images'])})
        self.pairs = pairs
        self.images = images
        
    def __len__(self):
        return len(self.pairs)
    
    def _get_views(self, idx, resolution, rng):
        scene_name, image_idx1, image_idx2, _ = self.pairs[idx]
        image_idx1 = int(image_idx1)
        image_idx2 = int(image_idx2)
        image_idx3 = int((image_idx1 + image_idx2) / 2)
        views = []
        for view_idx in [image_idx1, image_idx2, image_idx3]:
            basename = self.images[(scene_name, view_idx)]
            # Load RGB image
            rgb_path = osp.join(self.ROOT, scene_name, 'images', f'{basename}.jpg')
            rgb_image = imread_cv2(rgb_path)
            # Load depthmap
            depthmap_path = osp.join(self.ROOT, scene_name, 'depths', f'{basename}.png')
            depthmap = imread_cv2(depthmap_path, cv2.IMREAD_UNCHANGED)
            depthmap = depthmap.astype(np.float32) / 1000
            depthmap[~np.isfinite(depthmap)] = 0  # invalid
            # Load camera parameters
            meta_path = osp.join(self.ROOT, scene_name, 'images', f'{basename}.npz')
            meta = np.load(meta_path)
            intrinsics = meta['camera_intrinsics']
            camera_pose = meta['camera_pose']
            # crop if necessary
            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)
            views.append(dict(
                img=rgb_image,
                depthmap=depthmap.astype(np.float32),
                camera_pose=camera_pose.astype(np.float32),
                camera_intrinsics=intrinsics.astype(np.float32),
                dataset='ScanNet',
                label=scene_name + '_' + basename,
                instance=f'{str(idx)}_{str(view_idx)}',
            ))
        return views

if __name__ == "__main__":
    from dust3r.datasets.base.base_stereo_view_dataset import view_name
    from dust3r.viz import SceneViz, auto_cam_size
    from dust3r.utils.image import rgb

    dataset = Scannet(split='train', ROOT="data/scannet_processed", resolution=224, aug_crop=16)

    print(len(dataset))

    for idx in np.random.permutation(len(dataset)):
        views = dataset[idx]
        assert len(views) == 3
        print(view_name(views[0]), view_name(views[1]), view_name(views[2]))
        viz = SceneViz()
        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1, 2]]
        cam_size = max(auto_cam_size(poses), 0.001)
        for view_idx in [0, 1, 2]:
            pts3d = views[view_idx]['pts3d']
            valid_mask = views[view_idx]['valid_mask']
            colors = rgb(views[view_idx]['img'])
            viz.add_pointcloud(pts3d, colors, valid_mask)
            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
                           focal=views[view_idx]['camera_intrinsics'][0, 0],
                           color=(idx*255, (1 - idx)*255, 0),
                           image=colors,
                           cam_size=cam_size)
        viz.show()