Spaces:
Running
Running
File size: 8,802 Bytes
851751e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import torch
from torch import nn
from . import weights_init, l1, l2, hinge_d_loss, vanilla_d_loss, measure_perplexity, square_dist_loss
from .geometric import GeoConverter
from .discriminator import NLayerDiscriminator, LiDARNLayerDiscriminator, LiDARNLayerDiscriminatorV2
from .perceptual import PerceptualLoss
VERSION2DISC = {'v0': NLayerDiscriminator, 'v1': LiDARNLayerDiscriminator, 'v2': LiDARNLayerDiscriminatorV2}
class VQGeoLPIPSWithDiscriminator(nn.Module):
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
disc_num_layers=3, disc_in_channels=3, disc_out_channels=1, disc_factor=1.0, disc_weight=1.0,
mask_factor=0.0, use_actnorm=False, disc_conditional=False,
disc_ndf=64, disc_loss="hinge", n_classes=None, pixel_loss="l1", disc_version='v1',
geo_factor=1.0, curve_length=4, perceptual_factor=1.0, perceptual_type='rangenet_final',
dataset_config=dict()):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
assert pixel_loss in ["l1", "l2"]
self.codebook_weight = codebook_weight
self.pixel_weight = pixelloss_weight
self.mask_factor = mask_factor
self.geo_factor = geo_factor
# scale of reconstruction loss
self.rec_scale = 1
if mask_factor > 0:
self.rec_scale += 1.
if geo_factor > 0:
self.rec_scale += 1.
if perceptual_factor > 0:
self.rec_scale += 1.
if pixel_loss == "l1":
self.pixel_loss = l1
else:
self.pixel_loss = l2
self.perceptual_factor = perceptual_factor
if perceptual_factor > 0.:
print(f"{self.__class__.__name__}: Running with LPIPS.")
self.perceptual_loss = PerceptualLoss(perceptual_type, dataset_config.depth_scale,
dataset_config.log_scale).eval()
disc_cls = VERSION2DISC[disc_version]
self.discriminator = disc_cls(input_nc=disc_in_channels,
output_nc=disc_out_channels,
n_layers=disc_num_layers,
use_actnorm=use_actnorm,
ndf=disc_ndf).apply(weights_init)
self.discriminator_iter_start = disc_start
if disc_loss == "hinge":
self.disc_loss = hinge_d_loss
elif disc_loss == "vanilla":
self.disc_loss = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
print(f"VQGeoLPIPSWithDiscriminator running with {disc_loss} loss.")
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.disc_conditional = disc_conditional
self.n_classes = n_classes
self.geometry_converter = GeoConverter(curve_length, False, dataset_config) # force converting xyz output
self.geo_loss = square_dist_loss
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None, masks=None):
input_coord = self.geometry_converter(inputs)
rec_coord = self.geometry_converter(reconstructions[:, 0:1].contiguous())
############# Reconstruction #############
# pixel reconstruction loss
if self.mask_factor > 0 and masks is not None:
pixel_rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions[:, 0:1].contiguous())
mask_rec_loss = self.pixel_loss(masks.contiguous(), reconstructions[:, 1:2].contiguous()) * self.mask_factor
else:
pixel_rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
mask_rec_loss = torch.tensor(0.0)
# geometry reconstruction loss (bev)
if self.geo_factor > 0:
geo_rec_loss = self.geo_loss(input_coord[:, :2], rec_coord[:, :2]) * self.geo_factor
else:
geo_rec_loss = torch.tensor(0.0)
# perceptual loss
if self.perceptual_factor > 0:
perceptual_loss = self.perceptual_loss((inputs.contiguous(), input_coord),
(reconstructions[:, 0:1].contiguous(), rec_coord)) * self.perceptual_factor
else:
perceptual_loss = torch.tensor(0.0)
# overall reconstruction loss
rec_loss = (pixel_rec_loss + mask_rec_loss + geo_rec_loss + perceptual_loss) / self.rec_scale
nll_loss = rec_loss
nll_loss = torch.mean(nll_loss)
############# GAN #############
disc_factor = 0. if global_step > self.discriminator_iter_start else self.disc_factor
# update generator (input: img, mask, coord, [cond])
if optimizer_idx == 0:
disc_recons = reconstructions.contiguous()
if self.geo_factor > 0:
disc_recons = torch.cat((disc_recons, rec_coord[:, :2].contiguous()), dim=1)
if cond is not None and self.disc_conditional:
disc_recons = torch.cat((disc_recons, cond), dim=1)
logits_fake = self.discriminator(disc_recons)
# adversarial loss
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/pix_rec_loss".format(split): pixel_rec_loss.detach().mean(),
"{}/geo_rec_loss".format(split): geo_rec_loss.detach().mean(),
"{}/mask_rec_loss".format(split): mask_rec_loss.detach().mean(),
"{}/perceptual_loss".format(split): perceptual_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean()}
if predicted_indices is not None:
assert self.n_classes is not None
with torch.no_grad():
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
log[f"{split}/perplexity"] = perplexity
log[f"{split}/cluster_usage"] = cluster_usage
return loss, log
# update discriminator (input: img, mask, coord, [cond])
if optimizer_idx == 1:
disc_inputs, disc_recons = inputs.contiguous().detach(), reconstructions.contiguous().detach()
if self.mask_factor > 0 and masks is not None:
disc_inputs = torch.cat((disc_inputs, masks.contiguous().detach()), dim=1)
if self.geo_factor > 0:
disc_inputs = torch.cat((disc_inputs, input_coord[:, :2].contiguous()), dim=1)
disc_recons = torch.cat((disc_recons, rec_coord[:, :2].contiguous()), dim=1)
if cond is not None:
disc_inputs = torch.cat((disc_inputs, cond), dim=1)
disc_recons = torch.cat((disc_recons, cond), dim=1)
logits_real = self.discriminator(disc_inputs)
logits_fake = self.discriminator(disc_recons)
# gan loss
d_loss = self.disc_loss(logits_real, logits_fake) * disc_factor
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean()}
return d_loss, log
|