Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
def generator_loss(discriminator, inputs, reconstructions, cond=None): | |
if cond is None: | |
logits_fake = discriminator(reconstructions.contiguous()) | |
else: | |
logits_fake = discriminator( | |
torch.cat((reconstructions.contiguous(), cond), dim=1) | |
) | |
g_loss = -torch.mean(logits_fake) | |
return g_loss | |
def hinge_d_loss(logits_real, logits_fake): | |
loss_real = torch.mean(F.relu(1.0 - logits_real)) | |
loss_fake = torch.mean(F.relu(1.0 + logits_fake)) | |
d_loss = 0.5 * (loss_real + loss_fake) | |
return d_loss | |
def discriminator_loss(discriminator, inputs, reconstructions, cond=None): | |
if cond is None: | |
logits_real = discriminator(inputs.contiguous().detach()) | |
logits_fake = discriminator(reconstructions.contiguous().detach()) | |
else: | |
logits_real = discriminator( | |
torch.cat((inputs.contiguous().detach(), cond), dim=1) | |
) | |
logits_fake = discriminator( | |
torch.cat((reconstructions.contiguous().detach(), cond), dim=1) | |
) | |
d_loss = hinge_d_loss(logits_real, logits_fake).mean() | |
return d_loss | |