File size: 3,013 Bytes
72fc481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Tuple, Union, Optional

import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler


class BaseDataLoader(DataLoader):
    """

    Base class for all data loaders

    """
    valid_sampler: Optional[SubsetRandomSampler]
    sampler: Optional[SubsetRandomSampler]

    def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory,

                 collate_fn=default_collate, val_dataset=None):
        self.collate_fn = collate_fn
        self.validation_split = validation_split
        self.shuffle = shuffle
        self.val_dataset = val_dataset

        self.batch_idx = 0
        self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset)
        self.init_kwargs = {
            'dataset': train_dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'collate_fn': collate_fn,
            'num_workers': num_workers,
            'pin_memory': pin_memory
        }
        if val_dataset is None:
            self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
            super().__init__(sampler=self.sampler, **self.init_kwargs)
        else:
            super().__init__(**self.init_kwargs)

    def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]:
        if split == 0.0:
            return None, None

        idx_full = np.arange(self.n_samples)

        np.random.seed(0)
        np.random.shuffle(idx_full)

        if isinstance(split, int):
            assert split > 0
            assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
            len_valid = split
        else:
            len_valid = int(self.n_samples * split)

        valid_idx = idx_full[0:len_valid]
        train_idx = np.delete(idx_full, np.arange(0, len_valid))

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)
        print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}")

        # turn off shuffle option which is mutually exclusive with sampler
        self.shuffle = False
        self.n_samples = len(train_idx)

        return train_sampler, valid_sampler

    def split_validation(self, bs = 1000):
        if self.val_dataset is not None:
            kwargs = {
                'dataset': self.val_dataset,
                'batch_size': bs,
                'shuffle': False,
                'collate_fn': self.collate_fn,
                'num_workers': self.num_workers
            }
            return DataLoader(**kwargs)
        else:
            print('Using sampler to split!')
            return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)