File size: 9,168 Bytes
57746f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from submodules.mast3r.dust3r.dust3r.losses import *
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, JaccardIndex, Accuracy
import lpips
from src.utils.gaussian_model import GaussianModel
from src.utils.cuda_splatting import render, DummyPipeline
from einops import rearrange
from src.utils.camera_utils import get_scaled_camera
from torchvision.utils import save_image
from dust3r.inference import make_batch_symmetric

class L2Loss (LLoss):
    """ Euclidean distance between 3d points  """

    def distance(self, a, b):
        return torch.norm(a - b, dim=-1)  # normalized L2 distance

class L1Loss (LLoss):
    """ Manhattan distance between 3d points """

    def distance(self, a, b):
        return torch.abs(a - b).mean()  # L1 distance

L2 = L2Loss()
L1 = L1Loss()

def merge_and_split_predictions(pred1, pred2):
    merged = {}
    for key in pred1.keys():
        merged_pred = torch.stack([pred1[key], pred2[key]], dim=1)
        merged_pred = rearrange(merged_pred, 'b v h w ... -> b (v h w) ...')
        merged[key] = merged_pred

    # Split along the batch dimension
    batch_size = next(iter(merged.values())).shape[0]
    split = [{key: value[i] for key, value in merged.items()} for i in range(batch_size)]
    
    return split

class GaussianLoss(MultiLoss):
    def __init__(self, ssim_weight=0.2):
        super().__init__()
        self.ssim_weight = ssim_weight
        self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).cuda()
        self.psnr = PeakSignalNoiseRatio(data_range=1.0).cuda()
        self.lpips_vgg = lpips.LPIPS(net='vgg').cuda()
        self.pipeline = DummyPipeline()
        # bg_color
        self.register_buffer('bg_color', torch.tensor([0.0, 0.0, 0.0]).cuda())
        
    def get_name(self):
        return f'GaussianLoss(ssim_weight={self.ssim_weight})'
    
    # def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model):
    #     # render images
    #     # 1. merge predictions
    #     pred = merge_and_split_predictions(pred1, pred2)
        
    #     # 2. calculate optimal scaling
    #     pred_pts1 = pred1['means']
    #     pred_pts2 = pred2['means']
    #     # convert to camera1 coordinates
    #     # everything is normalized w.r.t. camera of view1
    #     valid1 = gt1['valid_mask'].clone()
    #     valid2 = gt2['valid_mask'].clone()
    #     in_camera1 = inv(gt1['camera_pose'])
    #     gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device))  # B,H,W,3
    #     gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device))  # B,H,W,3
    #     scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2)
        
    #     # 3. render images(need gaussian model, camera, pipeline)
    #     rendered_images = []
    #     rendered_feats = []
    #     for i in range(len(pred)):
    #         # get gaussian model
    #         gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3)
    #         # get camera
    #         ref_camera_extrinsics = gt1['camera_pose'][i]
    #         target_extrinsics = target_view['camera_pose'][i]
    #         target_intrinsics = target_view['camera_intrinsics'][i]
    #         image_shape = target_view['true_shape'][i]
    #         scale = scaling[i]
    #         camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape)
    #         # render(image and features)
    #         rendered_output = render(camera, gaussians, self.pipeline, self.bg_color)
    #         rendered_images.append(rendered_output['render'])
    #         rendered_feats.append(rendered_output['feature_map'])

    #     rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W
    #     rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W
    #     rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2

    #     gt_images = target_view['img'] * 0.5 + 0.5
    #     gt_feats = model.lseg_feature_extractor.extract_features(target_view['img']) # B, 512, H//2, W//2
    #     image_loss = torch.abs(rendered_images - gt_images).mean()
    #     feature_loss = torch.abs(rendered_feats - gt_feats).mean()
    #     loss = image_loss + 100 * feature_loss

    #     # # temp
    #     # gt_logits = model.lseg_feature_extractor.decode_feature(gt_feats, ['wall', 'floor', 'others'])
    #     # gt_labels = torch.argmax(gt_logits, dim=1, keepdim=True)
    #     # rendered_logits = model.lseg_feature_extractor.decode_feature(rendered_feats, ['wall', 'floor', 'others'])
    #     # rendered_labels = torch.argmax(rendered_logits, dim=1, keepdim=True)

    #     # calculate metric
    #     with torch.no_grad():
    #         ssim = self.ssim(rendered_images, gt_images)
    #         psnr = self.psnr(rendered_images, gt_images)
    #         lpips = self.lpips_vgg(rendered_images, gt_images).mean()

    #     return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss}

    def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model):
        # render images
        # 1. merge predictions
        pred = merge_and_split_predictions(pred1, pred2)
        
        # 2. calculate optimal scaling
        pred_pts1 = pred1['means']
        pred_pts2 = pred2['means']
        # convert to camera1 coordinates
        # everything is normalized w.r.t. camera of view1
        valid1 = gt1['valid_mask'].clone()
        valid2 = gt2['valid_mask'].clone()
        in_camera1 = inv(gt1['camera_pose'])
        gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device))  # B,H,W,3
        gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device))  # B,H,W,3
        scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2)
        
        # 3. render images(need gaussian model, camera, pipeline)
        rendered_images = []
        rendered_feats = []
        gt_images = []

        for i in range(len(pred)):
            # get gaussian model
            gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3)
            # get camera
            ref_camera_extrinsics = gt1['camera_pose'][i]
            target_view_list = [gt1, gt2, target_view] # use gt1, gt2, and target_view
            for j in range(len(target_view_list)):
                target_extrinsics = target_view_list[j]['camera_pose'][i]
                target_intrinsics = target_view_list[j]['camera_intrinsics'][i]
                image_shape = target_view_list[j]['true_shape'][i]
                scale = scaling[i]
                camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape)
                # render(image and features)
                rendered_output = render(camera, gaussians, self.pipeline, self.bg_color)
                rendered_images.append(rendered_output['render'])
                rendered_feats.append(rendered_output['feature_map'])
                gt_images.append(target_view_list[j]['img'][i] * 0.5 + 0.5)

        rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W
        gt_images = torch.stack(gt_images, dim=0)
        rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W
        rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2
        gt_feats = model.lseg_feature_extractor.extract_features(gt_images) # B, 512, H//2, W//2
        image_loss = torch.abs(rendered_images - gt_images).mean()
        feature_loss = torch.abs(rendered_feats - gt_feats).mean()
        loss = image_loss + feature_loss

        # calculate metric
        with torch.no_grad():
            ssim = self.ssim(rendered_images, gt_images)
            psnr = self.psnr(rendered_images, gt_images)
            lpips = self.lpips_vgg(rendered_images, gt_images).mean()

        return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss}

# loss for one batch
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
    view1, view2, target_view = batch
    ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng', 'pts3d'])
    for view in batch:
        for name in view.keys():  # pseudo_focal
            if name in ignore_keys:
                continue
            view[name] = view[name].to(device, non_blocking=True)

    if symmetrize_batch:
        view1, view2 = make_batch_symmetric(batch)

    # Get the actual model if it's distributed
    actual_model = model.module if hasattr(model, 'module') else model

    with torch.cuda.amp.autocast(enabled=bool(use_amp)):
        pred1, pred2 = actual_model(view1, view2)

        # loss is supposed to be symmetric
        with torch.cuda.amp.autocast(enabled=False):
            loss = criterion(view1, view2, target_view, pred1, pred2, actual_model) if criterion is not None else None

    result = dict(view1=view1, view2=view2, target_view=target_view, pred1=pred1, pred2=pred2, loss=loss)
    return result[ret] if ret else result