|
from submodules.mast3r.dust3r.dust3r.losses import * |
|
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, JaccardIndex, Accuracy |
|
import lpips |
|
from src.utils.gaussian_model import GaussianModel |
|
from src.utils.cuda_splatting import render, DummyPipeline |
|
from einops import rearrange |
|
from src.utils.camera_utils import get_scaled_camera |
|
from torchvision.utils import save_image |
|
from dust3r.inference import make_batch_symmetric |
|
|
|
class L2Loss (LLoss): |
|
""" Euclidean distance between 3d points """ |
|
|
|
def distance(self, a, b): |
|
return torch.norm(a - b, dim=-1) |
|
|
|
class L1Loss (LLoss): |
|
""" Manhattan distance between 3d points """ |
|
|
|
def distance(self, a, b): |
|
return torch.abs(a - b).mean() |
|
|
|
L2 = L2Loss() |
|
L1 = L1Loss() |
|
|
|
def merge_and_split_predictions(pred1, pred2): |
|
merged = {} |
|
for key in pred1.keys(): |
|
merged_pred = torch.stack([pred1[key], pred2[key]], dim=1) |
|
merged_pred = rearrange(merged_pred, 'b v h w ... -> b (v h w) ...') |
|
merged[key] = merged_pred |
|
|
|
|
|
batch_size = next(iter(merged.values())).shape[0] |
|
split = [{key: value[i] for key, value in merged.items()} for i in range(batch_size)] |
|
|
|
return split |
|
|
|
class GaussianLoss(MultiLoss): |
|
def __init__(self, ssim_weight=0.2): |
|
super().__init__() |
|
self.ssim_weight = ssim_weight |
|
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).cuda() |
|
self.psnr = PeakSignalNoiseRatio(data_range=1.0).cuda() |
|
self.lpips_vgg = lpips.LPIPS(net='vgg').cuda() |
|
self.pipeline = DummyPipeline() |
|
|
|
self.register_buffer('bg_color', torch.tensor([0.0, 0.0, 0.0]).cuda()) |
|
|
|
def get_name(self): |
|
return f'GaussianLoss(ssim_weight={self.ssim_weight})' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model): |
|
|
|
|
|
pred = merge_and_split_predictions(pred1, pred2) |
|
|
|
|
|
pred_pts1 = pred1['means'] |
|
pred_pts2 = pred2['means'] |
|
|
|
|
|
valid1 = gt1['valid_mask'].clone() |
|
valid2 = gt2['valid_mask'].clone() |
|
in_camera1 = inv(gt1['camera_pose']) |
|
gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) |
|
gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) |
|
scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2) |
|
|
|
|
|
rendered_images = [] |
|
rendered_feats = [] |
|
gt_images = [] |
|
|
|
for i in range(len(pred)): |
|
|
|
gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3) |
|
|
|
ref_camera_extrinsics = gt1['camera_pose'][i] |
|
target_view_list = [gt1, gt2, target_view] |
|
for j in range(len(target_view_list)): |
|
target_extrinsics = target_view_list[j]['camera_pose'][i] |
|
target_intrinsics = target_view_list[j]['camera_intrinsics'][i] |
|
image_shape = target_view_list[j]['true_shape'][i] |
|
scale = scaling[i] |
|
camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape) |
|
|
|
rendered_output = render(camera, gaussians, self.pipeline, self.bg_color) |
|
rendered_images.append(rendered_output['render']) |
|
rendered_feats.append(rendered_output['feature_map']) |
|
gt_images.append(target_view_list[j]['img'][i] * 0.5 + 0.5) |
|
|
|
rendered_images = torch.stack(rendered_images, dim=0) |
|
gt_images = torch.stack(gt_images, dim=0) |
|
rendered_feats = torch.stack(rendered_feats, dim=0) |
|
rendered_feats = model.feature_expansion(rendered_feats) |
|
gt_feats = model.lseg_feature_extractor.extract_features(gt_images) |
|
image_loss = torch.abs(rendered_images - gt_images).mean() |
|
feature_loss = torch.abs(rendered_feats - gt_feats).mean() |
|
loss = image_loss + feature_loss |
|
|
|
|
|
with torch.no_grad(): |
|
ssim = self.ssim(rendered_images, gt_images) |
|
psnr = self.psnr(rendered_images, gt_images) |
|
lpips = self.lpips_vgg(rendered_images, gt_images).mean() |
|
|
|
return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss} |
|
|
|
|
|
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): |
|
view1, view2, target_view = batch |
|
ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng', 'pts3d']) |
|
for view in batch: |
|
for name in view.keys(): |
|
if name in ignore_keys: |
|
continue |
|
view[name] = view[name].to(device, non_blocking=True) |
|
|
|
if symmetrize_batch: |
|
view1, view2 = make_batch_symmetric(batch) |
|
|
|
|
|
actual_model = model.module if hasattr(model, 'module') else model |
|
|
|
with torch.cuda.amp.autocast(enabled=bool(use_amp)): |
|
pred1, pred2 = actual_model(view1, view2) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
loss = criterion(view1, view2, target_view, pred1, pred2, actual_model) if criterion is not None else None |
|
|
|
result = dict(view1=view1, view2=view2, target_view=target_view, pred1=pred1, pred2=pred2, loss=loss) |
|
return result[ret] if ret else result |