File size: 9,168 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from submodules.mast3r.dust3r.dust3r.losses import *
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, JaccardIndex, Accuracy
import lpips
from src.utils.gaussian_model import GaussianModel
from src.utils.cuda_splatting import render, DummyPipeline
from einops import rearrange
from src.utils.camera_utils import get_scaled_camera
from torchvision.utils import save_image
from dust3r.inference import make_batch_symmetric
class L2Loss (LLoss):
""" Euclidean distance between 3d points """
def distance(self, a, b):
return torch.norm(a - b, dim=-1) # normalized L2 distance
class L1Loss (LLoss):
""" Manhattan distance between 3d points """
def distance(self, a, b):
return torch.abs(a - b).mean() # L1 distance
L2 = L2Loss()
L1 = L1Loss()
def merge_and_split_predictions(pred1, pred2):
merged = {}
for key in pred1.keys():
merged_pred = torch.stack([pred1[key], pred2[key]], dim=1)
merged_pred = rearrange(merged_pred, 'b v h w ... -> b (v h w) ...')
merged[key] = merged_pred
# Split along the batch dimension
batch_size = next(iter(merged.values())).shape[0]
split = [{key: value[i] for key, value in merged.items()} for i in range(batch_size)]
return split
class GaussianLoss(MultiLoss):
def __init__(self, ssim_weight=0.2):
super().__init__()
self.ssim_weight = ssim_weight
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).cuda()
self.psnr = PeakSignalNoiseRatio(data_range=1.0).cuda()
self.lpips_vgg = lpips.LPIPS(net='vgg').cuda()
self.pipeline = DummyPipeline()
# bg_color
self.register_buffer('bg_color', torch.tensor([0.0, 0.0, 0.0]).cuda())
def get_name(self):
return f'GaussianLoss(ssim_weight={self.ssim_weight})'
# def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model):
# # render images
# # 1. merge predictions
# pred = merge_and_split_predictions(pred1, pred2)
# # 2. calculate optimal scaling
# pred_pts1 = pred1['means']
# pred_pts2 = pred2['means']
# # convert to camera1 coordinates
# # everything is normalized w.r.t. camera of view1
# valid1 = gt1['valid_mask'].clone()
# valid2 = gt2['valid_mask'].clone()
# in_camera1 = inv(gt1['camera_pose'])
# gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) # B,H,W,3
# gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) # B,H,W,3
# scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2)
# # 3. render images(need gaussian model, camera, pipeline)
# rendered_images = []
# rendered_feats = []
# for i in range(len(pred)):
# # get gaussian model
# gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3)
# # get camera
# ref_camera_extrinsics = gt1['camera_pose'][i]
# target_extrinsics = target_view['camera_pose'][i]
# target_intrinsics = target_view['camera_intrinsics'][i]
# image_shape = target_view['true_shape'][i]
# scale = scaling[i]
# camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape)
# # render(image and features)
# rendered_output = render(camera, gaussians, self.pipeline, self.bg_color)
# rendered_images.append(rendered_output['render'])
# rendered_feats.append(rendered_output['feature_map'])
# rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W
# rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W
# rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2
# gt_images = target_view['img'] * 0.5 + 0.5
# gt_feats = model.lseg_feature_extractor.extract_features(target_view['img']) # B, 512, H//2, W//2
# image_loss = torch.abs(rendered_images - gt_images).mean()
# feature_loss = torch.abs(rendered_feats - gt_feats).mean()
# loss = image_loss + 100 * feature_loss
# # # temp
# # gt_logits = model.lseg_feature_extractor.decode_feature(gt_feats, ['wall', 'floor', 'others'])
# # gt_labels = torch.argmax(gt_logits, dim=1, keepdim=True)
# # rendered_logits = model.lseg_feature_extractor.decode_feature(rendered_feats, ['wall', 'floor', 'others'])
# # rendered_labels = torch.argmax(rendered_logits, dim=1, keepdim=True)
# # calculate metric
# with torch.no_grad():
# ssim = self.ssim(rendered_images, gt_images)
# psnr = self.psnr(rendered_images, gt_images)
# lpips = self.lpips_vgg(rendered_images, gt_images).mean()
# return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss}
def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model):
# render images
# 1. merge predictions
pred = merge_and_split_predictions(pred1, pred2)
# 2. calculate optimal scaling
pred_pts1 = pred1['means']
pred_pts2 = pred2['means']
# convert to camera1 coordinates
# everything is normalized w.r.t. camera of view1
valid1 = gt1['valid_mask'].clone()
valid2 = gt2['valid_mask'].clone()
in_camera1 = inv(gt1['camera_pose'])
gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) # B,H,W,3
gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) # B,H,W,3
scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2)
# 3. render images(need gaussian model, camera, pipeline)
rendered_images = []
rendered_feats = []
gt_images = []
for i in range(len(pred)):
# get gaussian model
gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3)
# get camera
ref_camera_extrinsics = gt1['camera_pose'][i]
target_view_list = [gt1, gt2, target_view] # use gt1, gt2, and target_view
for j in range(len(target_view_list)):
target_extrinsics = target_view_list[j]['camera_pose'][i]
target_intrinsics = target_view_list[j]['camera_intrinsics'][i]
image_shape = target_view_list[j]['true_shape'][i]
scale = scaling[i]
camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape)
# render(image and features)
rendered_output = render(camera, gaussians, self.pipeline, self.bg_color)
rendered_images.append(rendered_output['render'])
rendered_feats.append(rendered_output['feature_map'])
gt_images.append(target_view_list[j]['img'][i] * 0.5 + 0.5)
rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W
gt_images = torch.stack(gt_images, dim=0)
rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W
rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2
gt_feats = model.lseg_feature_extractor.extract_features(gt_images) # B, 512, H//2, W//2
image_loss = torch.abs(rendered_images - gt_images).mean()
feature_loss = torch.abs(rendered_feats - gt_feats).mean()
loss = image_loss + feature_loss
# calculate metric
with torch.no_grad():
ssim = self.ssim(rendered_images, gt_images)
psnr = self.psnr(rendered_images, gt_images)
lpips = self.lpips_vgg(rendered_images, gt_images).mean()
return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss}
# loss for one batch
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
view1, view2, target_view = batch
ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng', 'pts3d'])
for view in batch:
for name in view.keys(): # pseudo_focal
if name in ignore_keys:
continue
view[name] = view[name].to(device, non_blocking=True)
if symmetrize_batch:
view1, view2 = make_batch_symmetric(batch)
# Get the actual model if it's distributed
actual_model = model.module if hasattr(model, 'module') else model
with torch.cuda.amp.autocast(enabled=bool(use_amp)):
pred1, pred2 = actual_model(view1, view2)
# loss is supposed to be symmetric
with torch.cuda.amp.autocast(enabled=False):
loss = criterion(view1, view2, target_view, pred1, pred2, actual_model) if criterion is not None else None
result = dict(view1=view1, view2=view2, target_view=target_view, pred1=pred1, pred2=pred2, loss=loss)
return result[ret] if ret else result |