Spaces:
Configuration error
Configuration error
from typing import Tuple, Union, Optional | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataloader import default_collate | |
from torch.utils.data.sampler import SubsetRandomSampler | |
class BaseDataLoader(DataLoader): | |
""" | |
Base class for all data loaders | |
""" | |
valid_sampler: Optional[SubsetRandomSampler] | |
sampler: Optional[SubsetRandomSampler] | |
def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory, | |
collate_fn=default_collate, val_dataset=None): | |
self.collate_fn = collate_fn | |
self.validation_split = validation_split | |
self.shuffle = shuffle | |
self.val_dataset = val_dataset | |
self.batch_idx = 0 | |
self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset) | |
self.init_kwargs = { | |
'dataset': train_dataset, | |
'batch_size': batch_size, | |
'shuffle': self.shuffle, | |
'collate_fn': collate_fn, | |
'num_workers': num_workers, | |
'pin_memory': pin_memory | |
} | |
if val_dataset is None: | |
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) | |
super().__init__(sampler=self.sampler, **self.init_kwargs) | |
else: | |
super().__init__(**self.init_kwargs) | |
def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]: | |
if split == 0.0: | |
return None, None | |
idx_full = np.arange(self.n_samples) | |
np.random.seed(0) | |
np.random.shuffle(idx_full) | |
if isinstance(split, int): | |
assert split > 0 | |
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." | |
len_valid = split | |
else: | |
len_valid = int(self.n_samples * split) | |
valid_idx = idx_full[0:len_valid] | |
train_idx = np.delete(idx_full, np.arange(0, len_valid)) | |
train_sampler = SubsetRandomSampler(train_idx) | |
valid_sampler = SubsetRandomSampler(valid_idx) | |
print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}") | |
# turn off shuffle option which is mutually exclusive with sampler | |
self.shuffle = False | |
self.n_samples = len(train_idx) | |
return train_sampler, valid_sampler | |
def split_validation(self, bs = 1000): | |
if self.val_dataset is not None: | |
kwargs = { | |
'dataset': self.val_dataset, | |
'batch_size': bs, | |
'shuffle': False, | |
'collate_fn': self.collate_fn, | |
'num_workers': self.num_workers | |
} | |
return DataLoader(**kwargs) | |
else: | |
print('Using sampler to split!') | |
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) | |