assi1 / ELR /data_loader /cifar10.py
uthurumella's picture
Upload 69 files
72fc481 verified
import sys
import numpy as np
from PIL import Image
import torchvision
from torch.utils.data.dataset import Subset
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
import torch
import torch.nn.functional as F
import random
import json
import os
def get_cifar10(root, cfg_trainer, train=True,
transform_train=None, transform_val=None,
download=False, noise_file = ''):
base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download)
if train:
train_idxs, val_idxs = train_val_split(base_dataset.targets)
train_dataset = CIFAR10_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
val_dataset = CIFAR10_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
if cfg_trainer['asym']:
train_dataset.asymmetric_noise()
val_dataset.asymmetric_noise()
else:
train_dataset.symmetric_noise()
val_dataset.symmetric_noise()
print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}") # Train: 45000 Val: 5000
else:
train_dataset = []
val_dataset = CIFAR10_val(root, cfg_trainer, None, train=train, transform=transform_val)
print(f"Test: {len(val_dataset)}")
return train_dataset, val_dataset
def train_val_split(base_dataset: torchvision.datasets.CIFAR10):
num_classes = 10
base_dataset = np.array(base_dataset)
train_n = int(len(base_dataset) * 0.9 / num_classes)
train_idxs = []
val_idxs = []
for i in range(num_classes):
idxs = np.where(base_dataset == i)[0]
np.random.shuffle(idxs)
train_idxs.extend(idxs[:train_n])
val_idxs.extend(idxs[train_n:])
np.random.shuffle(train_idxs)
np.random.shuffle(val_idxs)
return train_idxs, val_idxs
class CIFAR10_train(torchvision.datasets.CIFAR10):
def __init__(self, root, cfg_trainer, indexs, train=True,
transform=None, target_transform=None,
download=False):
super(CIFAR10_train, self).__init__(root, train=train,
transform=transform, target_transform=target_transform,
download=download)
self.num_classes = 10
self.cfg_trainer = cfg_trainer
self.train_data = self.data[indexs]#self.train_data[indexs]
self.train_labels = np.array(self.targets)[indexs]#np.array(self.train_labels)[indexs]
self.indexs = indexs
self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
self.noise_indx = []
def symmetric_noise(self):
self.train_labels_gt = self.train_labels.copy()
#np.random.seed(seed=888)
indices = np.random.permutation(len(self.train_data))
for i, idx in enumerate(indices):
if i < self.cfg_trainer['percent'] * len(self.train_data):
self.noise_indx.append(idx)
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
def asymmetric_noise(self):
self.train_labels_gt = self.train_labels.copy()
for i in range(self.num_classes):
indices = np.where(self.train_labels == i)[0]
np.random.shuffle(indices)
for j, idx in enumerate(indices):
if j < self.cfg_trainer['percent'] * len(indices):
self.noise_indx.append(idx)
# truck -> automobile
if i == 9:
self.train_labels[idx] = 1
# bird -> airplane
elif i == 2:
self.train_labels[idx] = 0
# cat -> dog
elif i == 3:
self.train_labels[idx] = 5
# dog -> cat
elif i == 5:
self.train_labels[idx] = 3
# deer -> horse
elif i == 4:
self.train_labels[idx] = 7
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img,target, index, target_gt
def __len__(self):
return len(self.train_data)
class CIFAR10_val(torchvision.datasets.CIFAR10):
def __init__(self, root, cfg_trainer, indexs, train=True,
transform=None, target_transform=None,
download=False):
super(CIFAR10_val, self).__init__(root, train=train,
transform=transform, target_transform=target_transform,
download=download)
# self.train_data = self.data[indexs]
# self.train_labels = np.array(self.targets)[indexs]
self.num_classes = 10
self.cfg_trainer = cfg_trainer
if train:
self.train_data = self.data[indexs]
self.train_labels = np.array(self.targets)[indexs]
else:
self.train_data = self.data
self.train_labels = np.array(self.targets)
self.train_labels_gt = self.train_labels.copy()
def symmetric_noise(self):
indices = np.random.permutation(len(self.train_data))
for i, idx in enumerate(indices):
if i < self.cfg_trainer['percent'] * len(self.train_data):
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
def asymmetric_noise(self):
for i in range(self.num_classes):
indices = np.where(self.train_labels == i)[0]
np.random.shuffle(indices)
for j, idx in enumerate(indices):
if j < self.cfg_trainer['percent'] * len(indices):
# truck -> automobile
if i == 9:
self.train_labels[idx] = 1
# bird -> airplane
elif i == 2:
self.train_labels[idx] = 0
# cat -> dog
elif i == 3:
self.train_labels[idx] = 5
# dog -> cat
elif i == 5:
self.train_labels[idx] = 3
# deer -> horse
elif i == 4:
self.train_labels[idx] = 7
def __len__(self):
return len(self.train_data)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index, target_gt