assi1 / ELR_plus /data_loader /webvision.py
uthurumella's picture
Upload 69 files
72fc481 verified
import sys
import os
import numpy as np
from PIL import Image
import torchvision
from torch.utils.data.dataset import Subset
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
import torch
import torch.nn.functional as F
import random
def get_webvision(root, cfg_trainer, num_samples=0, train=True,
transform_train=None, transform_val=None, num_class = 50):
if train:
train_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train, num_class = num_class)
val_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, val=train, transform=transform_val, num_class = num_class)
print(f"Train: {len(train_dataset)} WebVision Val: {len(val_dataset)}")
else:
train_dataset = []
val_dataset = ImagenetVal(root, transform=transform_val, num_class = num_class)
print(f"Imagnet Val: {len(val_dataset)}")
return train_dataset, val_dataset
class ImagenetVal(torch.utils.data.Dataset):
def __init__(self, root, transform, num_class):
self.root = root+'imagenet/'
self.transform = transform
with open(self.root+'imagenet_val.txt') as f:
lines=f.readlines()
self.val_imgs = []
self.val_labels = {}
for line in lines:
img, target = line.split()
target = int(target)
if target<num_class:
self.val_imgs.append(img)
self.val_labels[img]=target
def __getitem__(self, index):
img_path = self.val_imgs[index]
target = self.val_labels[img_path]
image = Image.open(self.root+'val/'+img_path).convert('RGB')
img = self.transform(image)
return img, target, index, target
def __len__(self):
return len(self.val_imgs)
class Webvision(torch.utils.data.Dataset):
def __init__(self, root, cfg_trainer, num_samples=0, train=False, val=False, test=False, transform=None, num_class = 50):
self.cfg_trainer = cfg_trainer
self.root = root
self.transform = transform
self.train_labels = {}
self.test_labels = {}
self.val_labels = {}
self.train = train
self.val = val
self.test = test
if self.val:
with open(self.root+'info/val_filelist.txt') as f:
lines=f.readlines()
self.val_imgs = []
self.val_labels = {}
for line in lines:
img, target = line.split()
target = int(target)
if target<num_class:
self.val_imgs.append(img)
self.val_labels[img]=target
elif self.test:
with open(self.root+'info/val_filelist.txt') as f:
lines=f.readlines()
self.test_imgs = []
self.test_labels = {}
for line in lines:
img, target = line.split()
target = int(target)
if target<num_class:
self.test_imgs.append(img)
self.test_labels[img]=target
else:
with open(self.root+'info/train_filelist_google.txt') as f:
lines=f.readlines()
train_imgs = []
self.train_labels = {}
for line in lines:
img, target = line.split()
target = int(target)
if target<num_class:
train_imgs.append(img)
self.train_labels[img]=target
self.train_imgs = train_imgs
def __getitem__(self, index):
if self.train:
img_path = self.train_imgs[index]
target = self.train_labels[img_path]
image = Image.open(self.root+img_path)
img0 = image.convert('RGB')
img0 = self.transform(img0)
return img0, target, index, target
elif self.val:
img_path = self.val_imgs[index]
target = self.val_labels[img_path]
image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB')
img = self.transform(image)
return img, target, index, target
elif self.test:
img_path = self.test_imgs[index]
target = self.test_labels[img_path]
image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB')
img = self.transform(image)
return img, target, index, target
def __len__(self):
if self.test:
return len(self.test_imgs)
if self.val:
return len(self.val_imgs)
else:
return len(self.train_imgs)