Spaces:
Configuration error
Configuration error
import sys | |
import os | |
import numpy as np | |
from PIL import Image | |
import torchvision | |
from torch.utils.data.dataset import Subset | |
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances | |
import torch | |
import torch.nn.functional as F | |
import random | |
def get_webvision(root, cfg_trainer, num_samples=0, train=True, | |
transform_train=None, transform_val=None, num_class = 50): | |
if train: | |
train_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train, num_class = num_class) | |
val_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, val=train, transform=transform_val, num_class = num_class) | |
print(f"Train: {len(train_dataset)} WebVision Val: {len(val_dataset)}") | |
else: | |
train_dataset = [] | |
val_dataset = ImagenetVal(root, transform=transform_val, num_class = num_class) | |
print(f"Imagnet Val: {len(val_dataset)}") | |
return train_dataset, val_dataset | |
class ImagenetVal(torch.utils.data.Dataset): | |
def __init__(self, root, transform, num_class): | |
self.root = root+'imagenet/' | |
self.transform = transform | |
with open(self.root+'imagenet_val.txt') as f: | |
lines=f.readlines() | |
self.val_imgs = [] | |
self.val_labels = {} | |
for line in lines: | |
img, target = line.split() | |
target = int(target) | |
if target<num_class: | |
self.val_imgs.append(img) | |
self.val_labels[img]=target | |
def __getitem__(self, index): | |
img_path = self.val_imgs[index] | |
target = self.val_labels[img_path] | |
image = Image.open(self.root+'val/'+img_path).convert('RGB') | |
img = self.transform(image) | |
return img, target, index, target | |
def __len__(self): | |
return len(self.val_imgs) | |
class Webvision(torch.utils.data.Dataset): | |
def __init__(self, root, cfg_trainer, num_samples=0, train=False, val=False, test=False, transform=None, num_class = 50): | |
self.cfg_trainer = cfg_trainer | |
self.root = root | |
self.transform = transform | |
self.train_labels = {} | |
self.test_labels = {} | |
self.val_labels = {} | |
self.train = train | |
self.val = val | |
self.test = test | |
if self.val: | |
with open(self.root+'info/val_filelist.txt') as f: | |
lines=f.readlines() | |
self.val_imgs = [] | |
self.val_labels = {} | |
for line in lines: | |
img, target = line.split() | |
target = int(target) | |
if target<num_class: | |
self.val_imgs.append(img) | |
self.val_labels[img]=target | |
elif self.test: | |
with open(self.root+'info/val_filelist.txt') as f: | |
lines=f.readlines() | |
self.test_imgs = [] | |
self.test_labels = {} | |
for line in lines: | |
img, target = line.split() | |
target = int(target) | |
if target<num_class: | |
self.test_imgs.append(img) | |
self.test_labels[img]=target | |
else: | |
with open(self.root+'info/train_filelist_google.txt') as f: | |
lines=f.readlines() | |
train_imgs = [] | |
self.train_labels = {} | |
for line in lines: | |
img, target = line.split() | |
target = int(target) | |
if target<num_class: | |
train_imgs.append(img) | |
self.train_labels[img]=target | |
self.train_imgs = train_imgs | |
def __getitem__(self, index): | |
if self.train: | |
img_path = self.train_imgs[index] | |
target = self.train_labels[img_path] | |
image = Image.open(self.root+img_path) | |
img0 = image.convert('RGB') | |
img0 = self.transform(img0) | |
return img0, target, index, target | |
elif self.val: | |
img_path = self.val_imgs[index] | |
target = self.val_labels[img_path] | |
image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB') | |
img = self.transform(image) | |
return img, target, index, target | |
elif self.test: | |
img_path = self.test_imgs[index] | |
target = self.test_labels[img_path] | |
image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB') | |
img = self.transform(image) | |
return img, target, index, target | |
def __len__(self): | |
if self.test: | |
return len(self.test_imgs) | |
if self.val: | |
return len(self.val_imgs) | |
else: | |
return len(self.train_imgs) | |