LSM / src /losses.py
kairunwen's picture
Update Code
57746f1
from submodules.mast3r.dust3r.dust3r.losses import *
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, JaccardIndex, Accuracy
import lpips
from src.utils.gaussian_model import GaussianModel
from src.utils.cuda_splatting import render, DummyPipeline
from einops import rearrange
from src.utils.camera_utils import get_scaled_camera
from torchvision.utils import save_image
from dust3r.inference import make_batch_symmetric
class L2Loss (LLoss):
""" Euclidean distance between 3d points """
def distance(self, a, b):
return torch.norm(a - b, dim=-1) # normalized L2 distance
class L1Loss (LLoss):
""" Manhattan distance between 3d points """
def distance(self, a, b):
return torch.abs(a - b).mean() # L1 distance
L2 = L2Loss()
L1 = L1Loss()
def merge_and_split_predictions(pred1, pred2):
merged = {}
for key in pred1.keys():
merged_pred = torch.stack([pred1[key], pred2[key]], dim=1)
merged_pred = rearrange(merged_pred, 'b v h w ... -> b (v h w) ...')
merged[key] = merged_pred
# Split along the batch dimension
batch_size = next(iter(merged.values())).shape[0]
split = [{key: value[i] for key, value in merged.items()} for i in range(batch_size)]
return split
class GaussianLoss(MultiLoss):
def __init__(self, ssim_weight=0.2):
super().__init__()
self.ssim_weight = ssim_weight
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).cuda()
self.psnr = PeakSignalNoiseRatio(data_range=1.0).cuda()
self.lpips_vgg = lpips.LPIPS(net='vgg').cuda()
self.pipeline = DummyPipeline()
# bg_color
self.register_buffer('bg_color', torch.tensor([0.0, 0.0, 0.0]).cuda())
def get_name(self):
return f'GaussianLoss(ssim_weight={self.ssim_weight})'
# def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model):
# # render images
# # 1. merge predictions
# pred = merge_and_split_predictions(pred1, pred2)
# # 2. calculate optimal scaling
# pred_pts1 = pred1['means']
# pred_pts2 = pred2['means']
# # convert to camera1 coordinates
# # everything is normalized w.r.t. camera of view1
# valid1 = gt1['valid_mask'].clone()
# valid2 = gt2['valid_mask'].clone()
# in_camera1 = inv(gt1['camera_pose'])
# gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) # B,H,W,3
# gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) # B,H,W,3
# scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2)
# # 3. render images(need gaussian model, camera, pipeline)
# rendered_images = []
# rendered_feats = []
# for i in range(len(pred)):
# # get gaussian model
# gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3)
# # get camera
# ref_camera_extrinsics = gt1['camera_pose'][i]
# target_extrinsics = target_view['camera_pose'][i]
# target_intrinsics = target_view['camera_intrinsics'][i]
# image_shape = target_view['true_shape'][i]
# scale = scaling[i]
# camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape)
# # render(image and features)
# rendered_output = render(camera, gaussians, self.pipeline, self.bg_color)
# rendered_images.append(rendered_output['render'])
# rendered_feats.append(rendered_output['feature_map'])
# rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W
# rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W
# rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2
# gt_images = target_view['img'] * 0.5 + 0.5
# gt_feats = model.lseg_feature_extractor.extract_features(target_view['img']) # B, 512, H//2, W//2
# image_loss = torch.abs(rendered_images - gt_images).mean()
# feature_loss = torch.abs(rendered_feats - gt_feats).mean()
# loss = image_loss + 100 * feature_loss
# # # temp
# # gt_logits = model.lseg_feature_extractor.decode_feature(gt_feats, ['wall', 'floor', 'others'])
# # gt_labels = torch.argmax(gt_logits, dim=1, keepdim=True)
# # rendered_logits = model.lseg_feature_extractor.decode_feature(rendered_feats, ['wall', 'floor', 'others'])
# # rendered_labels = torch.argmax(rendered_logits, dim=1, keepdim=True)
# # calculate metric
# with torch.no_grad():
# ssim = self.ssim(rendered_images, gt_images)
# psnr = self.psnr(rendered_images, gt_images)
# lpips = self.lpips_vgg(rendered_images, gt_images).mean()
# return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss}
def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model):
# render images
# 1. merge predictions
pred = merge_and_split_predictions(pred1, pred2)
# 2. calculate optimal scaling
pred_pts1 = pred1['means']
pred_pts2 = pred2['means']
# convert to camera1 coordinates
# everything is normalized w.r.t. camera of view1
valid1 = gt1['valid_mask'].clone()
valid2 = gt2['valid_mask'].clone()
in_camera1 = inv(gt1['camera_pose'])
gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) # B,H,W,3
gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) # B,H,W,3
scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2)
# 3. render images(need gaussian model, camera, pipeline)
rendered_images = []
rendered_feats = []
gt_images = []
for i in range(len(pred)):
# get gaussian model
gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3)
# get camera
ref_camera_extrinsics = gt1['camera_pose'][i]
target_view_list = [gt1, gt2, target_view] # use gt1, gt2, and target_view
for j in range(len(target_view_list)):
target_extrinsics = target_view_list[j]['camera_pose'][i]
target_intrinsics = target_view_list[j]['camera_intrinsics'][i]
image_shape = target_view_list[j]['true_shape'][i]
scale = scaling[i]
camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape)
# render(image and features)
rendered_output = render(camera, gaussians, self.pipeline, self.bg_color)
rendered_images.append(rendered_output['render'])
rendered_feats.append(rendered_output['feature_map'])
gt_images.append(target_view_list[j]['img'][i] * 0.5 + 0.5)
rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W
gt_images = torch.stack(gt_images, dim=0)
rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W
rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2
gt_feats = model.lseg_feature_extractor.extract_features(gt_images) # B, 512, H//2, W//2
image_loss = torch.abs(rendered_images - gt_images).mean()
feature_loss = torch.abs(rendered_feats - gt_feats).mean()
loss = image_loss + feature_loss
# calculate metric
with torch.no_grad():
ssim = self.ssim(rendered_images, gt_images)
psnr = self.psnr(rendered_images, gt_images)
lpips = self.lpips_vgg(rendered_images, gt_images).mean()
return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss}
# loss for one batch
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
view1, view2, target_view = batch
ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng', 'pts3d'])
for view in batch:
for name in view.keys(): # pseudo_focal
if name in ignore_keys:
continue
view[name] = view[name].to(device, non_blocking=True)
if symmetrize_batch:
view1, view2 = make_batch_symmetric(batch)
# Get the actual model if it's distributed
actual_model = model.module if hasattr(model, 'module') else model
with torch.cuda.amp.autocast(enabled=bool(use_amp)):
pred1, pred2 = actual_model(view1, view2)
# loss is supposed to be symmetric
with torch.cuda.amp.autocast(enabled=False):
loss = criterion(view1, view2, target_view, pred1, pred2, actual_model) if criterion is not None else None
result = dict(view1=view1, view2=view2, target_view=target_view, pred1=pred1, pred2=pred2, loss=loss)
return result[ret] if ret else result