|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.data as data |
|
from torchvision import transforms |
|
import PIL.Image as Image |
|
|
|
DEVICE = 'cuda' |
|
mse = nn.MSELoss() |
|
|
|
|
|
def calc_histogram_loss(A, B, histogram_block): |
|
input_hist = histogram_block(A) |
|
target_hist = histogram_block(B) |
|
histogram_loss = (1/np.sqrt(2.0) * (torch.sqrt(torch.sum( |
|
torch.pow(torch.sqrt(target_hist) - torch.sqrt(input_hist), 2)))) / |
|
input_hist.shape[0]) |
|
|
|
return histogram_loss |
|
|
|
|
|
def calc_mean_std(feat, eps=1e-5): |
|
|
|
size = feat.size() |
|
assert (len(size) == 4) |
|
N, C = size[:2] |
|
feat_var = feat.view(N, C, -1).var(dim=2) + eps |
|
feat_std = feat_var.sqrt().view(N, C, 1, 1) |
|
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) |
|
return feat_mean, feat_std |
|
|
|
def mean_variance_norm(feat): |
|
size = feat.size() |
|
mean, std = calc_mean_std(feat) |
|
normalized_feat = (feat - mean.expand(size)) / std.expand(size) |
|
return normalized_feat |
|
|
|
def train_transform(): |
|
transform_list = [ |
|
transforms.Resize(size=512), |
|
transforms.RandomCrop(256), |
|
transforms.ToTensor() |
|
] |
|
return transforms.Compose(transform_list) |
|
|
|
def test_transform(): |
|
transform_list = [] |
|
transform_list.append(transforms.Resize(size=(512))) |
|
transform_list.append(transforms.ToTensor()) |
|
transform = transforms.Compose(transform_list) |
|
return transform |
|
|
|
|
|
def plot_grad_flow(named_parameters): |
|
'''Plots the gradients flowing through different layers in the net during training. |
|
Can be used for checking for possible gradient vanishing / exploding problems. |
|
|
|
Usage: Plug this function in Trainer class after loss.backwards() as |
|
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' |
|
ave_grads = [] |
|
max_grads= [] |
|
layers = [] |
|
for n, p in named_parameters: |
|
if(p.requires_grad) and ("bias" not in n): |
|
layers.append(n) |
|
ave_grads.append(p.grad.abs().mean()) |
|
max_grads.append(p.grad.abs().max()) |
|
print('-'*82) |
|
print(n, p.grad.abs().mean(), p.grad.abs().max()) |
|
print('-'*82) |
|
|
|
def InfiniteSampler(n): |
|
|
|
i = n - 1 |
|
order = np.random.permutation(n) |
|
while True: |
|
yield order[i] |
|
i += 1 |
|
if i >= n: |
|
np.random.seed() |
|
order = np.random.permutation(n) |
|
i = 0 |
|
|
|
class InfiniteSamplerWrapper(data.sampler.Sampler): |
|
def __init__(self, data_source): |
|
self.num_samples = len(data_source) |
|
|
|
def __iter__(self): |
|
return iter(InfiniteSampler(self.num_samples)) |
|
|
|
def __len__(self): |
|
return 2 ** 31 |
|
|
|
class FlatFolderDataset(data.Dataset): |
|
def __init__(self, root, transform): |
|
super(FlatFolderDataset, self).__init__() |
|
self.root = root |
|
self.paths = os.listdir(self.root) |
|
self.transform = transform |
|
|
|
def __getitem__(self, index): |
|
path = self.paths[index] |
|
img = Image.open(os.path.join(self.root, path)).convert('RGB') |
|
img = self.transform(img) |
|
return img |
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
def name(self): |
|
return 'FlatFolderDataset' |
|
|
|
def adjust_learning_rate(optimizer, iteration_count, args): |
|
"""Imitating the original implementation""" |
|
lr = args.lr / (1.0 + 5e-5 * iteration_count) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
def cosine_dismat(A, B): |
|
A = A.view(A.shape[0], A.shape[1], -1) |
|
B = B.view(B.shape[0], B.shape[1], -1) |
|
|
|
A_norm = torch.sqrt((A**2).sum(1)) |
|
B_norm = torch.sqrt((B**2).sum(1)) |
|
|
|
A = (A/A_norm.unsqueeze(dim=1).expand(A.shape)).permute(0,2,1) |
|
B = (B/B_norm.unsqueeze(dim=1).expand(B.shape)) |
|
dismat = 1.-torch.bmm(A, B) |
|
|
|
return dismat |
|
|
|
def calc_remd_loss(A, B): |
|
C = cosine_dismat(A, B) |
|
m1, _ = C.min(1) |
|
m2, _ = C.min(2) |
|
|
|
remd = torch.max(m1.mean(), m2.mean()) |
|
|
|
return remd |
|
|
|
def calc_ss_loss(A, B): |
|
MA = cosine_dismat(A, A) |
|
MB = cosine_dismat(B, B) |
|
Lself_similarity = torch.abs(MA-MB).mean() |
|
|
|
return Lself_similarity |
|
|
|
def calc_moment_loss(A, B): |
|
A = A.view(A.shape[0], A.shape[1], -1) |
|
B = B.view(B.shape[0], B.shape[1], -1) |
|
|
|
mu_a = torch.mean(A, 1, keepdim=True) |
|
mu_b = torch.mean(B, 1, keepdim=True) |
|
mu_d = torch.abs(mu_a - mu_b).mean() |
|
|
|
A_c = A - mu_a |
|
B_c = B - mu_b |
|
cov_a = torch.bmm(A_c, A_c.permute(0,2,1)) / (A.shape[2]-1) |
|
cov_b = torch.bmm(B_c, B_c.permute(0,2,1)) / (B.shape[2]-1) |
|
cov_d = torch.abs(cov_a - cov_b).mean() |
|
loss = mu_d + cov_d |
|
return loss |
|
|
|
def calc_mse_loss(A, B): |
|
return mse(A, B) |
|
|
|
|