import sys import os import numpy as np from PIL import Image import torchvision from torch.utils.data.dataset import Subset from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances import torch import torch.nn.functional as F import random def get_webvision(root, cfg_trainer, num_samples=0, train=True, transform_train=None, transform_val=None, num_class = 50): if train: train_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train, num_class = num_class) val_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, val=train, transform=transform_val, num_class = num_class) print(f"Train: {len(train_dataset)} WebVision Val: {len(val_dataset)}") else: train_dataset = [] val_dataset = ImagenetVal(root, transform=transform_val, num_class = num_class) print(f"Imagnet Val: {len(val_dataset)}") return train_dataset, val_dataset class ImagenetVal(torch.utils.data.Dataset): def __init__(self, root, transform, num_class): self.root = root+'imagenet/' self.transform = transform with open(self.root+'imagenet_val.txt') as f: lines=f.readlines() self.val_imgs = [] self.val_labels = {} for line in lines: img, target = line.split() target = int(target) if target