Spaces:
Configuration error
Configuration error
File size: 3,013 Bytes
72fc481 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import Tuple, Union, Optional
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler
class BaseDataLoader(DataLoader):
"""
Base class for all data loaders
"""
valid_sampler: Optional[SubsetRandomSampler]
sampler: Optional[SubsetRandomSampler]
def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory,
collate_fn=default_collate, val_dataset=None):
self.collate_fn = collate_fn
self.validation_split = validation_split
self.shuffle = shuffle
self.val_dataset = val_dataset
self.batch_idx = 0
self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset)
self.init_kwargs = {
'dataset': train_dataset,
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers,
'pin_memory': pin_memory
}
if val_dataset is None:
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
super().__init__(sampler=self.sampler, **self.init_kwargs)
else:
super().__init__(**self.init_kwargs)
def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]:
if split == 0.0:
return None, None
idx_full = np.arange(self.n_samples)
np.random.seed(0)
np.random.shuffle(idx_full)
if isinstance(split, int):
assert split > 0
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
len_valid = split
else:
len_valid = int(self.n_samples * split)
valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}")
# turn off shuffle option which is mutually exclusive with sampler
self.shuffle = False
self.n_samples = len(train_idx)
return train_sampler, valid_sampler
def split_validation(self, bs = 1000):
if self.val_dataset is not None:
kwargs = {
'dataset': self.val_dataset,
'batch_size': bs,
'shuffle': False,
'collate_fn': self.collate_fn,
'num_workers': self.num_workers
}
return DataLoader(**kwargs)
else:
print('Using sampler to split!')
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
|