Spaces:
Configuration error
Configuration error
import sys | |
import os | |
import numpy as np | |
from PIL import Image | |
import torchvision | |
from torch.utils.data.dataset import Subset | |
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances | |
import torch | |
import torch.nn.functional as F | |
import random | |
def get_clothing(root, cfg_trainer, num_samples=0, train=True, | |
transform_train=None, transform_val=None): | |
if train: | |
train_dataset = Clothing(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train) | |
val_dataset = Clothing(root, cfg_trainer, val=train, transform=transform_val) | |
print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}") | |
else: | |
train_dataset = [] | |
val_dataset = Clothing(root, cfg_trainer, test= (not train), transform=transform_val) | |
print(f"Test: {len(val_dataset)}") | |
return train_dataset, val_dataset | |
class Clothing(torch.utils.data.Dataset): | |
def __init__(self, root, cfg_trainer, num_samples=0, train=False, val=False, test=False, transform=None, num_class = 14): | |
self.cfg_trainer = cfg_trainer | |
self.root = root | |
self.transform = transform | |
self.train_labels = {} | |
self.test_labels = {} | |
self.val_labels = {} | |
self.train = train | |
self.val = val | |
self.test = test | |
with open('%s/noisy_label_kv.txt'%self.root,'r') as f: | |
lines = f.read().splitlines() | |
for l in lines: | |
entry = l.split() | |
img_path = '%s/'%self.root+entry[0][7:] | |
self.train_labels[img_path] = int(entry[1]) | |
with open('%s/clean_label_kv.txt'%self.root,'r') as f: | |
lines = f.read().splitlines() | |
for l in lines: | |
entry = l.split() | |
img_path = '%s/'%self.root+entry[0][7:] | |
self.test_labels[img_path] = int(entry[1]) | |
if train: | |
train_imgs=[] | |
with open('%s/noisy_train_key_list.txt'%self.root,'r') as f: | |
lines = f.read().splitlines() | |
for i , l in enumerate(lines): | |
img_path = '%s/'%self.root+l[7:] | |
train_imgs.append((i,img_path)) | |
self.num_raw_example = len(train_imgs) | |
random.shuffle(train_imgs) | |
class_num = torch.zeros(num_class) | |
self.train_imgs = [] | |
for id_raw, impath in train_imgs: | |
label = self.train_labels[impath] | |
if class_num[label]<(num_samples/14) and len(self.train_imgs)<num_samples: | |
self.train_imgs.append((id_raw,impath)) | |
class_num[label]+=1 | |
random.shuffle(self.train_imgs) | |
elif test: | |
self.test_imgs = [] | |
with open('%s/clean_test_key_list.txt'%self.root,'r') as f: | |
lines = f.read().splitlines() | |
for l in lines: | |
img_path = '%s/'%self.root+l[7:] | |
self.test_imgs.append(img_path) | |
elif val: | |
self.val_imgs = [] | |
with open('%s/clean_val_key_list.txt'%self.root,'r') as f: | |
lines = f.read().splitlines() | |
for l in lines: | |
img_path = '%s/'%self.root+l[7:] | |
self.val_imgs.append(img_path) | |
def __getitem__(self, index): | |
if self.train: | |
id_raw, img_path = self.train_imgs[index] | |
target = self.train_labels[img_path] | |
elif self.val: | |
img_path = self.val_imgs[index] | |
target = self.test_labels[img_path] | |
elif self.test: | |
img_path = self.test_imgs[index] | |
target = self.test_labels[img_path] | |
image = Image.open(img_path).convert('RGB') | |
if self.train: | |
img0 = self.transform(image) | |
if self.test or self.val: | |
img = self.transform(image) | |
return img, target, index, target | |
else: | |
return img0, target, id_raw, target | |
def __len__(self): | |
if self.test: | |
return len(self.test_imgs) | |
if self.val: | |
return len(self.val_imgs) | |
else: | |
return len(self.train_imgs) | |
def flist_reader(self, flist): | |
imlist = [] | |
with open(flist, 'r') as rf: | |
for line in rf.readlines(): | |
row = line.split(" ") | |
impath = self.root + row[0] | |
imlabel = float(row[1].replace('\n','')) | |
imlist.append((impath, int(imlabel))) | |
return imlist |