CCLAP / utils.py
RobinWZQ's picture
Upload 6 files
c8c90c7
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import transforms
import PIL.Image as Image
DEVICE = 'cuda'
mse = nn.MSELoss()
def calc_histogram_loss(A, B, histogram_block):
input_hist = histogram_block(A)
target_hist = histogram_block(B)
histogram_loss = (1/np.sqrt(2.0) * (torch.sqrt(torch.sum(
torch.pow(torch.sqrt(target_hist) - torch.sqrt(input_hist), 2)))) /
input_hist.shape[0])
return histogram_loss
# B, C, H, W; mean var on HW
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def mean_variance_norm(feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
def train_transform():
transform_list = [
transforms.Resize(size=512),
transforms.RandomCrop(256),
transforms.ToTensor()
]
return transforms.Compose(transform_list)
def test_transform():
transform_list = []
transform_list.append(transforms.Resize(size=(512)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
# https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/7
def plot_grad_flow(named_parameters):
'''Plots the gradients flowing through different layers in the net during training.
Can be used for checking for possible gradient vanishing / exploding problems.
Usage: Plug this function in Trainer class after loss.backwards() as
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
ave_grads = []
max_grads= []
layers = []
for n, p in named_parameters:
if(p.requires_grad) and ("bias" not in n):
layers.append(n)
ave_grads.append(p.grad.abs().mean())
max_grads.append(p.grad.abs().max())
print('-'*82)
print(n, p.grad.abs().mean(), p.grad.abs().max())
print('-'*82)
def InfiniteSampler(n):
# i = 0
i = n - 1
order = np.random.permutation(n)
while True:
yield order[i]
i += 1
if i >= n:
np.random.seed()
order = np.random.permutation(n)
i = 0
class InfiniteSamplerWrapper(data.sampler.Sampler):
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(InfiniteSampler(self.num_samples))
def __len__(self):
return 2 ** 31
class FlatFolderDataset(data.Dataset):
def __init__(self, root, transform):
super(FlatFolderDataset, self).__init__()
self.root = root
self.paths = os.listdir(self.root)
self.transform = transform
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(os.path.join(self.root, path)).convert('RGB')
img = self.transform(img)
return img
def __len__(self):
return len(self.paths)
def name(self):
return 'FlatFolderDataset'
def adjust_learning_rate(optimizer, iteration_count, args):
"""Imitating the original implementation"""
lr = args.lr / (1.0 + 5e-5 * iteration_count)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def cosine_dismat(A, B):
A = A.view(A.shape[0], A.shape[1], -1)
B = B.view(B.shape[0], B.shape[1], -1)
A_norm = torch.sqrt((A**2).sum(1))
B_norm = torch.sqrt((B**2).sum(1))
A = (A/A_norm.unsqueeze(dim=1).expand(A.shape)).permute(0,2,1)
B = (B/B_norm.unsqueeze(dim=1).expand(B.shape))
dismat = 1.-torch.bmm(A, B)
return dismat
def calc_remd_loss(A, B):
C = cosine_dismat(A, B)
m1, _ = C.min(1)
m2, _ = C.min(2)
remd = torch.max(m1.mean(), m2.mean())
return remd
def calc_ss_loss(A, B):
MA = cosine_dismat(A, A)
MB = cosine_dismat(B, B)
Lself_similarity = torch.abs(MA-MB).mean()
return Lself_similarity
def calc_moment_loss(A, B):
A = A.view(A.shape[0], A.shape[1], -1)
B = B.view(B.shape[0], B.shape[1], -1)
mu_a = torch.mean(A, 1, keepdim=True)
mu_b = torch.mean(B, 1, keepdim=True)
mu_d = torch.abs(mu_a - mu_b).mean()
A_c = A - mu_a
B_c = B - mu_b
cov_a = torch.bmm(A_c, A_c.permute(0,2,1)) / (A.shape[2]-1)
cov_b = torch.bmm(B_c, B_c.permute(0,2,1)) / (B.shape[2]-1)
cov_d = torch.abs(cov_a - cov_b).mean()
loss = mu_d + cov_d
return loss
def calc_mse_loss(A, B):
return mse(A, B)