uthurumella commited on
Commit
72fc481
·
verified ·
1 Parent(s): 50e7d19

Upload 69 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. ELR/README.md +44 -0
  3. ELR/base/__init__.py +3 -0
  4. ELR/base/base_data_loader.py +83 -0
  5. ELR/base/base_model.py +25 -0
  6. ELR/base/base_trainer.py +195 -0
  7. ELR/config_cifar10.json +75 -0
  8. ELR/config_cifar100.json +75 -0
  9. ELR/config_cifar10_asym.json +75 -0
  10. ELR/data_loader/__pycache__/cifar10.cpython-36.pyc +0 -0
  11. ELR/data_loader/__pycache__/clothing1m.cpython-36.pyc +0 -0
  12. ELR/data_loader/__pycache__/data_loaders.cpython-36.pyc +0 -0
  13. ELR/data_loader/cifar10.py +212 -0
  14. ELR/data_loader/cifar100.py +317 -0
  15. ELR/data_loader/data_loaders.py +70 -0
  16. ELR/logger/__init__.py +2 -0
  17. ELR/logger/logger.py +22 -0
  18. ELR/logger/logger_config.json +32 -0
  19. ELR/logger/visualization.py +154 -0
  20. ELR/model/ResNet_Zoo.py +133 -0
  21. ELR/model/loss.py +30 -0
  22. ELR/model/metric.py +20 -0
  23. ELR/model/model.py +13 -0
  24. ELR/parse_config.py +146 -0
  25. ELR/test.py +82 -0
  26. ELR/train.py +125 -0
  27. ELR/trainer/__init__.py +1 -0
  28. ELR/trainer/trainer.py +278 -0
  29. ELR/utils/__init__.py +1 -0
  30. ELR/utils/util.py +75 -0
  31. ELR_plus/README.md +27 -0
  32. ELR_plus/base/__init__.py +3 -0
  33. ELR_plus/base/base_data_loader.py +83 -0
  34. ELR_plus/base/base_model.py +25 -0
  35. ELR_plus/base/base_trainer.py +341 -0
  36. ELR_plus/config_cifar10.json +105 -0
  37. ELR_plus/config_cifar100.json +104 -0
  38. ELR_plus/config_cifar10_asym.json +105 -0
  39. ELR_plus/config_clothing1m.json +102 -0
  40. ELR_plus/config_webvision.json +103 -0
  41. ELR_plus/data_loader/cifar10.py +214 -0
  42. ELR_plus/data_loader/cifar100.py +307 -0
  43. ELR_plus/data_loader/clothing1m.py +128 -0
  44. ELR_plus/data_loader/data_loaders.py +137 -0
  45. ELR_plus/data_loader/webvision.py +140 -0
  46. ELR_plus/logger/__init__.py +2 -0
  47. ELR_plus/logger/logger.py +22 -0
  48. ELR_plus/logger/logger_config.json +32 -0
  49. ELR_plus/logger/visualization.py +154 -0
  50. ELR_plus/model/InceptionResNetV2.py +314 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/clean_label_simplexheatmap2.gif filter=lfs diff=lfs merge=lfs -text
37
+ images/false_label_simplexheatmap.gif filter=lfs diff=lfs merge=lfs -text
38
+ images/illustration_of_ELR.png filter=lfs diff=lfs merge=lfs -text
39
+ images/simplexheatmap.gif filter=lfs diff=lfs merge=lfs -text
ELR/README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ELR
2
+ This is an official PyTorch implementation of ELR method proposed in [Early-Learning Regularization Prevents Memorization of Noisy Labels](https://arxiv.org/abs/2007.00151).
3
+
4
+
5
+ ## Usage
6
+ Train the network on the Symmmetric Noise CIFAR-10 dataset (noise rate = 0.8):
7
+
8
+ ```
9
+ python train.py -c config_cifar10.json --percent 0.8
10
+ ```
11
+ Train the network on the Asymmmetric Noise CIFAR-10 dataset (noise rate = 0.4):
12
+
13
+ ```
14
+ python train.py -c config_cifar10_asym.json --percent 0.4 --asym 1
15
+ ```
16
+
17
+ Train the network on the Asymmmetric Noise CIFAR-100 dataset (noise rate = 0.4):
18
+
19
+ ```
20
+ python train.py -c config_cifar100.json --percent 0.4 --asym 1
21
+ ```
22
+
23
+ The config files can be modified to adjust hyperparameters and optimization settings.
24
+
25
+ ## Results
26
+ ### CIFAR10
27
+ <center>
28
+
29
+ | Method | 20% | 40% | 60% | 80% | 40% Asym |
30
+ | ---------------------- | ----------- | ----------- | ----------- | ----------- | ----------- |
31
+ | ELR | 91.16% | 89.15% | 86.12% | 73.86% | 90.12% |
32
+ | ELR (cosine annealing) | 91.12% | 91.43% | 88.87% | 80.69% | 90.35% |
33
+
34
+ ### CIAFAR100
35
+
36
+ | Method | 20% | 40% | 60% | 80% | 40% Asym |
37
+ | ---------------------- | ----------- | ----------- | ----------- | ----------- | ----------- |
38
+ | ELR | 74.21% | 68.28% | 59.28% | 29.78% | 73.71% |
39
+ | ELR (cosine annealing) | 74.68% | 68.43% | 60.05% | 30.27% | 73.96% |
40
+
41
+ </center>
42
+
43
+ ## References
44
+ - S. Liu, J. Niles-Weed, N. Razavian and C. Fernandez-Granda "Early-Learning Regularization Prevents Memorization of Noisy Labels", 2020
ELR/base/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base_data_loader import *
2
+ from .base_model import *
3
+ from .base_trainer import *
ELR/base/base_data_loader.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union, Optional
2
+
3
+ import numpy as np
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.data.dataloader import default_collate
6
+ from torch.utils.data.sampler import SubsetRandomSampler
7
+
8
+
9
+ class BaseDataLoader(DataLoader):
10
+ """
11
+ Base class for all data loaders
12
+ """
13
+ valid_sampler: Optional[SubsetRandomSampler]
14
+ sampler: Optional[SubsetRandomSampler]
15
+
16
+ def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory,
17
+ collate_fn=default_collate, val_dataset=None):
18
+ self.collate_fn = collate_fn
19
+ self.validation_split = validation_split
20
+ self.shuffle = shuffle
21
+ self.val_dataset = val_dataset
22
+
23
+ self.batch_idx = 0
24
+ self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset)
25
+ self.init_kwargs = {
26
+ 'dataset': train_dataset,
27
+ 'batch_size': batch_size,
28
+ 'shuffle': self.shuffle,
29
+ 'collate_fn': collate_fn,
30
+ 'num_workers': num_workers,
31
+ 'pin_memory': pin_memory
32
+ }
33
+ if val_dataset is None:
34
+ self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
35
+ super().__init__(sampler=self.sampler, **self.init_kwargs)
36
+ else:
37
+ super().__init__(**self.init_kwargs)
38
+
39
+ def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]:
40
+ if split == 0.0:
41
+ return None, None
42
+
43
+ idx_full = np.arange(self.n_samples)
44
+
45
+ np.random.seed(0)
46
+ np.random.shuffle(idx_full)
47
+
48
+ if isinstance(split, int):
49
+ assert split > 0
50
+ assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
51
+ len_valid = split
52
+ else:
53
+ len_valid = int(self.n_samples * split)
54
+
55
+ valid_idx = idx_full[0:len_valid]
56
+ train_idx = np.delete(idx_full, np.arange(0, len_valid))
57
+
58
+ train_sampler = SubsetRandomSampler(train_idx)
59
+ valid_sampler = SubsetRandomSampler(valid_idx)
60
+ print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}")
61
+
62
+ # turn off shuffle option which is mutually exclusive with sampler
63
+ self.shuffle = False
64
+ self.n_samples = len(train_idx)
65
+
66
+ return train_sampler, valid_sampler
67
+
68
+ def split_validation(self, bs = 1000):
69
+ if self.val_dataset is not None:
70
+ kwargs = {
71
+ 'dataset': self.val_dataset,
72
+ 'batch_size': bs,
73
+ 'shuffle': False,
74
+ 'collate_fn': self.collate_fn,
75
+ 'num_workers': self.num_workers
76
+ }
77
+ return DataLoader(**kwargs)
78
+ else:
79
+ print('Using sampler to split!')
80
+ return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
81
+
82
+
83
+
ELR/base/base_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import numpy as np
3
+ from abc import abstractmethod
4
+
5
+
6
+ class BaseModel(nn.Module):
7
+ """
8
+ Base class for all models
9
+ """
10
+ @abstractmethod
11
+ def forward(self, *inputs):
12
+ """
13
+ Forward pass logic
14
+
15
+ :return: Model output
16
+ """
17
+ raise NotImplementedError
18
+
19
+ def __str__(self):
20
+ """
21
+ Model prints with number of trainable parameters
22
+ """
23
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24
+ params = sum([np.prod(p.size()) for p in model_parameters])
25
+ return super().__str__() + '\nTrainable parameters: {}'.format(params)
ELR/base/base_trainer.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypeVar, List, Tuple
2
+ import torch
3
+ from tqdm import tqdm
4
+ from abc import abstractmethod
5
+ from numpy import inf
6
+ from logger import TensorboardWriter
7
+ import numpy as np
8
+
9
+ class BaseTrainer:
10
+ """
11
+ Base class for all trainers
12
+ """
13
+ def __init__(self, model, train_criterion, metrics, optimizer, config, val_criterion):
14
+ self.config = config
15
+ self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
16
+
17
+ # setup GPU device if available, move model into configured device
18
+ self.device, device_ids = self._prepare_device(config['n_gpu'])
19
+ self.model = model.to(self.device)
20
+
21
+ if len(device_ids) > 1:
22
+ self.model = torch.nn.DataParallel(model, device_ids=device_ids)
23
+
24
+ self.train_criterion = train_criterion.to(self.device)
25
+
26
+
27
+ self.val_criterion = val_criterion
28
+ self.metrics = metrics
29
+
30
+ self.optimizer = optimizer
31
+
32
+ cfg_trainer = config['trainer']
33
+ self.epochs = cfg_trainer['epochs']
34
+ self.save_period = cfg_trainer['save_period']
35
+ self.monitor = cfg_trainer.get('monitor', 'off')
36
+
37
+ # configuration to monitor model performance and save best
38
+ if self.monitor == 'off':
39
+ self.mnt_mode = 'off'
40
+ self.mnt_best = 0
41
+ else:
42
+ self.mnt_mode, self.mnt_metric = self.monitor.split()
43
+ assert self.mnt_mode in ['min', 'max']
44
+
45
+ self.mnt_best = inf if self.mnt_mode == 'min' else -inf
46
+ self.early_stop = cfg_trainer.get('early_stop', inf)
47
+
48
+ self.start_epoch = 1
49
+
50
+ self.checkpoint_dir = config.save_dir
51
+
52
+ # setup visualization writer instance
53
+ self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
54
+
55
+ if config.resume is not None:
56
+ self._resume_checkpoint(config.resume)
57
+
58
+ @abstractmethod
59
+ def _train_epoch(self, epoch):
60
+ """
61
+ Training logic for an epoch
62
+
63
+ :param epoch: Current epochs number
64
+ """
65
+ raise NotImplementedError
66
+
67
+ def train(self):
68
+ """
69
+ Full training logic
70
+ """
71
+ not_improved_count = 0
72
+
73
+ for epoch in tqdm(range(self.start_epoch, self.epochs + 1), desc='Total progress: '):
74
+ if epoch <= self.config['trainer']['warmup']:
75
+ result = self._warmup_epoch(epoch)
76
+ else:
77
+ result= self._train_epoch(epoch)
78
+
79
+
80
+
81
+ # save logged informations into log dict
82
+ log = {'epoch': epoch}
83
+ for key, value in result.items():
84
+ if key == 'metrics':
85
+ log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
86
+ elif key == 'val_metrics':
87
+ log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
88
+ elif key == 'test_metrics':
89
+ log.update({'test_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
90
+ else:
91
+ log[key] = value
92
+
93
+ # print logged informations to the screen
94
+ for key, value in log.items():
95
+ self.logger.info(' {:15s}: {}'.format(str(key), value))
96
+
97
+ # evaluate model performance according to configured metric, save best checkpoint as model_best
98
+ best = False
99
+ if self.mnt_mode != 'off':
100
+ try:
101
+ # check whether model performance improved or not, according to specified metric(mnt_metric)
102
+ improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
103
+ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
104
+ except KeyError:
105
+ self.logger.warning("Warning: Metric '{}' is not found. "
106
+ "Model performance monitoring is disabled.".format(self.mnt_metric))
107
+ self.mnt_mode = 'off'
108
+ improved = False
109
+
110
+ if improved:
111
+ self.mnt_best = log[self.mnt_metric]
112
+ not_improved_count = 0
113
+ best = True
114
+ else:
115
+ not_improved_count += 1
116
+
117
+ if not_improved_count > self.early_stop:
118
+ self.logger.info("Validation performance didn\'t improve for {} epochs. "
119
+ "Training stops.".format(self.early_stop))
120
+ break
121
+
122
+ if epoch % self.save_period == 0:
123
+ self._save_checkpoint(epoch, save_best=best)
124
+
125
+ def _prepare_device(self, n_gpu_use):
126
+ """
127
+ setup GPU device if available, move model into configured device
128
+ """
129
+ n_gpu = torch.cuda.device_count()
130
+ if n_gpu_use > 0 and n_gpu == 0:
131
+ self.logger.warning("Warning: There\'s no GPU available on this machine,"
132
+ "training will be performed on CPU.")
133
+ n_gpu_use = 0
134
+ if n_gpu_use > n_gpu:
135
+ self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
136
+ "on this machine.".format(n_gpu_use, n_gpu))
137
+ n_gpu_use = n_gpu
138
+ device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
139
+ list_ids = list(range(n_gpu_use))
140
+ return device, list_ids
141
+
142
+ def _save_checkpoint(self, epoch, save_best=False):
143
+ """
144
+ Saving checkpoints
145
+
146
+ :param epoch: current epoch number
147
+ :param log: logging information of the epoch
148
+ :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
149
+ """
150
+ arch = type(self.model).__name__
151
+
152
+ state = {
153
+ 'arch': arch,
154
+ 'epoch': epoch,
155
+ 'state_dict': self.model.state_dict(),
156
+ 'optimizer': self.optimizer.state_dict(),
157
+ 'monitor_best': self.mnt_best
158
+ }
159
+ # filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
160
+ # torch.save(state, filename)
161
+ # self.logger.info("Saving checkpoint: {} ...".format(filename))
162
+ if save_best:
163
+ best_path = str(self.checkpoint_dir / 'model_best.pth')
164
+ torch.save(state, best_path)
165
+ self.logger.info("Saving current best: model_best.pth at: {} ...".format(best_path))
166
+
167
+
168
+ def _resume_checkpoint(self, resume_path):
169
+ """
170
+ Resume from saved checkpoints
171
+
172
+ :param resume_path: Checkpoint path to be resumed
173
+ """
174
+ resume_path = str(resume_path)
175
+ self.logger.info("Loading checkpoint: {} ...".format(resume_path))
176
+ checkpoint = torch.load(resume_path)
177
+ self.start_epoch = checkpoint['epoch'] + 1
178
+ self.mnt_best = checkpoint['monitor_best']
179
+
180
+ # load architecture params from checkpoint.
181
+ if checkpoint['config']['arch'] != self.config['arch']:
182
+ self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
183
+ "checkpoint. This may yield an exception while state_dict is being loaded.")
184
+ self.model.load_state_dict(checkpoint['state_dict'])
185
+
186
+ # load optimizer state from checkpoint only when optimizer type is not changed.
187
+ if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
188
+ self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
189
+ "Optimizer parameters not being resumed.")
190
+ else:
191
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
192
+
193
+ self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
194
+
195
+
ELR/config_cifar10.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "cifar10_resnet34_cosine",
3
+ "n_gpu": 1,
4
+ "seed": 123,
5
+
6
+ "arch": {
7
+ "type": "resnet34",
8
+ "args": {"num_classes":10}
9
+ },
10
+
11
+ "num_classes": 10,
12
+
13
+ "data_loader": {
14
+ "type": "CIFAR10DataLoader",
15
+ "args":{
16
+ "data_dir": "/dir_to_data",
17
+ "batch_size": 128,
18
+ "shuffle": true,
19
+ "num_batches": 0,
20
+ "validation_split": 0,
21
+ "num_workers": 8,
22
+ "pin_memory": true
23
+ }
24
+ },
25
+
26
+
27
+ "optimizer": {
28
+ "type": "SGD",
29
+ "args":{
30
+ "lr": 0.02,
31
+ "momentum": 0.9,
32
+ "weight_decay": 1e-3
33
+ }
34
+ },
35
+
36
+ "train_loss": {
37
+ "type": "elr_loss",
38
+ "args":{
39
+ "beta": 0.7,
40
+ "lambda": 3
41
+ }
42
+ },
43
+
44
+ "val_loss": "cross_entropy",
45
+ "metrics": [
46
+ "my_metric", "my_metric2"
47
+ ],
48
+
49
+ "lr_scheduler": {
50
+ "type": "CosineAnnealingWarmRestarts",
51
+ "args": {
52
+ "T_0": 10,
53
+ "eta_min": 0.001
54
+ }
55
+ },
56
+
57
+ "trainer": {
58
+ "epochs": 150,
59
+ "warmup": 0,
60
+ "save_dir": "saved/",
61
+ "save_period": 1,
62
+ "verbosity": 2,
63
+ "label_dir": "saved/",
64
+ "monitor": "max val_my_metric",
65
+ "early_stop": 2000,
66
+ "tensorboard": false,
67
+ "mlflow": true,
68
+ "_percent": "Percentage of noise",
69
+ "percent": 0.8,
70
+ "_begin": "When to begin updating labels",
71
+ "begin": 0,
72
+ "_asym": "symmetric noise if false",
73
+ "asym": false
74
+ }
75
+ }
ELR/config_cifar100.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "cifar100_sy_60_resnet34",
3
+ "n_gpu": 1,
4
+ "seed": 123,
5
+
6
+ "arch": {
7
+ "type": "resnet34",
8
+ "args": {"num_classes":100}
9
+ },
10
+
11
+ "num_classes": 100,
12
+
13
+ "data_loader": {
14
+ "type": "CIFAR100DataLoader",
15
+ "args":{
16
+ "data_dir": "/dir_to_data",
17
+ "batch_size": 128,
18
+ "shuffle": true,
19
+ "num_batches": 0,
20
+ "validation_split": 0,
21
+ "num_workers": 8,
22
+ "pin_memory": true
23
+ }
24
+ },
25
+
26
+
27
+ "optimizer": {
28
+ "type": "SGD",
29
+ "args":{
30
+ "lr": 0.02,
31
+ "momentum": 0.9,
32
+ "weight_decay": 1e-3
33
+ }
34
+ },
35
+
36
+ "train_loss": {
37
+ "type": "elr_loss",
38
+ "args":{
39
+ "beta": 0.9,
40
+ "lambda": 7
41
+ }
42
+ },
43
+
44
+ "val_loss": "cross_entropy",
45
+ "metrics": [
46
+ "my_metric", "my_metric2"
47
+ ],
48
+
49
+ "lr_scheduler": {
50
+ "type": "MultiStepLR",
51
+ "args": {
52
+ "milestones": [80,120],
53
+ "gamma": 0.01
54
+ }
55
+ },
56
+
57
+ "trainer": {
58
+ "epochs": 150,
59
+ "warmup": 0,
60
+ "save_dir": "saved/",
61
+ "save_period": 1,
62
+ "verbosity": 2,
63
+ "label_dir": "saved/",
64
+ "monitor": "max val_my_metric",
65
+ "early_stop": 2000,
66
+ "tensorboard": false,
67
+ "mlflow": true,
68
+ "_percent": "Percentage of noise",
69
+ "percent": 0.6,
70
+ "_begin": "When to begin updating labels",
71
+ "begin": 0,
72
+ "_asym": "symmetric noise if false",
73
+ "asym": false
74
+ }
75
+ }
ELR/config_cifar10_asym.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "cifar10_resnet34_cosine",
3
+ "n_gpu": 1,
4
+ "seed": 123,
5
+
6
+ "arch": {
7
+ "type": "resnet34",
8
+ "args": {"num_classes":10}
9
+ },
10
+
11
+ "num_classes": 10,
12
+
13
+ "data_loader": {
14
+ "type": "CIFAR10DataLoader",
15
+ "args":{
16
+ "data_dir": "/dir_to_data",
17
+ "batch_size": 128,
18
+ "shuffle": true,
19
+ "num_batches": 0,
20
+ "validation_split": 0,
21
+ "num_workers": 8,
22
+ "pin_memory": true
23
+ }
24
+ },
25
+
26
+
27
+ "optimizer": {
28
+ "type": "SGD",
29
+ "args":{
30
+ "lr": 0.02,
31
+ "momentum": 0.9,
32
+ "weight_decay": 1e-3
33
+ }
34
+ },
35
+
36
+ "train_loss": {
37
+ "type": "elr_loss",
38
+ "args":{
39
+ "beta": 0.9,
40
+ "lambda": 1
41
+ }
42
+ },
43
+
44
+ "val_loss": "cross_entropy",
45
+ "metrics": [
46
+ "my_metric", "my_metric2"
47
+ ],
48
+
49
+ "lr_scheduler": {
50
+ "type": "MultiStepLR",
51
+ "args": {
52
+ "milestones": [40,80],
53
+ "gamma": 0.01
54
+ }
55
+ },
56
+
57
+ "trainer": {
58
+ "epochs": 120,
59
+ "warmup": 0,
60
+ "save_dir": "saved/",
61
+ "save_period": 1,
62
+ "verbosity": 2,
63
+ "label_dir": "saved/",
64
+ "monitor": "max val_my_metric",
65
+ "early_stop": 2000,
66
+ "tensorboard": false,
67
+ "mlflow": true,
68
+ "_percent": "Percentage of noise",
69
+ "percent": 0.4,
70
+ "_begin": "When to begin updating labels",
71
+ "begin": 0,
72
+ "_asym": "symmetric noise if false",
73
+ "asym": true
74
+ }
75
+ }
ELR/data_loader/__pycache__/cifar10.cpython-36.pyc ADDED
Binary file (7.87 kB). View file
 
ELR/data_loader/__pycache__/clothing1m.cpython-36.pyc ADDED
Binary file (2.19 kB). View file
 
ELR/data_loader/__pycache__/data_loaders.cpython-36.pyc ADDED
Binary file (2.91 kB). View file
 
ELR/data_loader/cifar10.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision
6
+ from torch.utils.data.dataset import Subset
7
+ from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import random
11
+ import json
12
+ import os
13
+
14
+ def get_cifar10(root, cfg_trainer, train=True,
15
+ transform_train=None, transform_val=None,
16
+ download=False, noise_file = ''):
17
+ base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download)
18
+ if train:
19
+ train_idxs, val_idxs = train_val_split(base_dataset.targets)
20
+ train_dataset = CIFAR10_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
21
+ val_dataset = CIFAR10_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
22
+ if cfg_trainer['asym']:
23
+ train_dataset.asymmetric_noise()
24
+ val_dataset.asymmetric_noise()
25
+ else:
26
+ train_dataset.symmetric_noise()
27
+ val_dataset.symmetric_noise()
28
+
29
+ print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}") # Train: 45000 Val: 5000
30
+ else:
31
+ train_dataset = []
32
+ val_dataset = CIFAR10_val(root, cfg_trainer, None, train=train, transform=transform_val)
33
+ print(f"Test: {len(val_dataset)}")
34
+
35
+
36
+
37
+ return train_dataset, val_dataset
38
+
39
+
40
+ def train_val_split(base_dataset: torchvision.datasets.CIFAR10):
41
+ num_classes = 10
42
+ base_dataset = np.array(base_dataset)
43
+ train_n = int(len(base_dataset) * 0.9 / num_classes)
44
+ train_idxs = []
45
+ val_idxs = []
46
+
47
+ for i in range(num_classes):
48
+ idxs = np.where(base_dataset == i)[0]
49
+ np.random.shuffle(idxs)
50
+ train_idxs.extend(idxs[:train_n])
51
+ val_idxs.extend(idxs[train_n:])
52
+ np.random.shuffle(train_idxs)
53
+ np.random.shuffle(val_idxs)
54
+
55
+ return train_idxs, val_idxs
56
+
57
+
58
+ class CIFAR10_train(torchvision.datasets.CIFAR10):
59
+ def __init__(self, root, cfg_trainer, indexs, train=True,
60
+ transform=None, target_transform=None,
61
+ download=False):
62
+ super(CIFAR10_train, self).__init__(root, train=train,
63
+ transform=transform, target_transform=target_transform,
64
+ download=download)
65
+ self.num_classes = 10
66
+ self.cfg_trainer = cfg_trainer
67
+ self.train_data = self.data[indexs]#self.train_data[indexs]
68
+ self.train_labels = np.array(self.targets)[indexs]#np.array(self.train_labels)[indexs]
69
+ self.indexs = indexs
70
+ self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
71
+ self.noise_indx = []
72
+
73
+ def symmetric_noise(self):
74
+ self.train_labels_gt = self.train_labels.copy()
75
+ #np.random.seed(seed=888)
76
+ indices = np.random.permutation(len(self.train_data))
77
+ for i, idx in enumerate(indices):
78
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
79
+ self.noise_indx.append(idx)
80
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
81
+
82
+ def asymmetric_noise(self):
83
+ self.train_labels_gt = self.train_labels.copy()
84
+ for i in range(self.num_classes):
85
+ indices = np.where(self.train_labels == i)[0]
86
+ np.random.shuffle(indices)
87
+ for j, idx in enumerate(indices):
88
+ if j < self.cfg_trainer['percent'] * len(indices):
89
+ self.noise_indx.append(idx)
90
+ # truck -> automobile
91
+ if i == 9:
92
+ self.train_labels[idx] = 1
93
+ # bird -> airplane
94
+ elif i == 2:
95
+ self.train_labels[idx] = 0
96
+ # cat -> dog
97
+ elif i == 3:
98
+ self.train_labels[idx] = 5
99
+ # dog -> cat
100
+ elif i == 5:
101
+ self.train_labels[idx] = 3
102
+ # deer -> horse
103
+ elif i == 4:
104
+ self.train_labels[idx] = 7
105
+
106
+
107
+
108
+ def __getitem__(self, index):
109
+ """
110
+ Args:
111
+ index (int): Index
112
+
113
+ Returns:
114
+ tuple: (image, target) where target is index of the target class.
115
+ """
116
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
117
+
118
+
119
+ # doing this so that it is consistent with all other datasets
120
+ # to return a PIL Image
121
+ img = Image.fromarray(img)
122
+
123
+
124
+ if self.transform is not None:
125
+ img = self.transform(img)
126
+
127
+ if self.target_transform is not None:
128
+ target = self.target_transform(target)
129
+
130
+ return img,target, index, target_gt
131
+
132
+ def __len__(self):
133
+ return len(self.train_data)
134
+
135
+
136
+
137
+ class CIFAR10_val(torchvision.datasets.CIFAR10):
138
+
139
+ def __init__(self, root, cfg_trainer, indexs, train=True,
140
+ transform=None, target_transform=None,
141
+ download=False):
142
+ super(CIFAR10_val, self).__init__(root, train=train,
143
+ transform=transform, target_transform=target_transform,
144
+ download=download)
145
+
146
+ # self.train_data = self.data[indexs]
147
+ # self.train_labels = np.array(self.targets)[indexs]
148
+ self.num_classes = 10
149
+ self.cfg_trainer = cfg_trainer
150
+ if train:
151
+ self.train_data = self.data[indexs]
152
+ self.train_labels = np.array(self.targets)[indexs]
153
+ else:
154
+ self.train_data = self.data
155
+ self.train_labels = np.array(self.targets)
156
+ self.train_labels_gt = self.train_labels.copy()
157
+ def symmetric_noise(self):
158
+
159
+ indices = np.random.permutation(len(self.train_data))
160
+ for i, idx in enumerate(indices):
161
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
162
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
163
+
164
+ def asymmetric_noise(self):
165
+ for i in range(self.num_classes):
166
+ indices = np.where(self.train_labels == i)[0]
167
+ np.random.shuffle(indices)
168
+ for j, idx in enumerate(indices):
169
+ if j < self.cfg_trainer['percent'] * len(indices):
170
+ # truck -> automobile
171
+ if i == 9:
172
+ self.train_labels[idx] = 1
173
+ # bird -> airplane
174
+ elif i == 2:
175
+ self.train_labels[idx] = 0
176
+ # cat -> dog
177
+ elif i == 3:
178
+ self.train_labels[idx] = 5
179
+ # dog -> cat
180
+ elif i == 5:
181
+ self.train_labels[idx] = 3
182
+ # deer -> horse
183
+ elif i == 4:
184
+ self.train_labels[idx] = 7
185
+ def __len__(self):
186
+ return len(self.train_data)
187
+
188
+
189
+ def __getitem__(self, index):
190
+ """
191
+ Args:
192
+ index (int): Index
193
+
194
+ Returns:
195
+ tuple: (image, target) where target is index of the target class.
196
+ """
197
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
198
+
199
+
200
+ # doing this so that it is consistent with all other datasets
201
+ # to return a PIL Image
202
+ img = Image.fromarray(img)
203
+
204
+
205
+ if self.transform is not None:
206
+ img = self.transform(img)
207
+
208
+ if self.target_transform is not None:
209
+ target = self.target_transform(target)
210
+
211
+ return img, target, index, target_gt
212
+
ELR/data_loader/cifar100.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision
6
+ from torch.utils.data.dataset import Subset
7
+ from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import random
11
+ import os
12
+ import json
13
+ from numpy.testing import assert_array_almost_equal
14
+
15
+
16
+
17
+ def get_cifar100(root, cfg_trainer, train=True,
18
+ transform_train=None, transform_val=None,
19
+ download=False, noise_file = ''):
20
+ base_dataset = torchvision.datasets.CIFAR100(root, train=train, download=download)
21
+ if train:
22
+ train_idxs, val_idxs = train_val_split(base_dataset.targets)
23
+ train_dataset = CIFAR100_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
24
+ val_dataset = CIFAR100_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
25
+ if cfg_trainer['asym']:
26
+ train_dataset.asymmetric_noise()
27
+ val_dataset.asymmetric_noise()
28
+ else:
29
+ train_dataset.symmetric_noise()
30
+ val_dataset.symmetric_noise()
31
+
32
+ print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}") # Train: 45000 Val: 5000
33
+ else:
34
+ train_dataset = []
35
+ val_dataset = CIFAR100_val(root, cfg_trainer, None, train=train, transform=transform_val)
36
+ print(f"Test: {len(val_dataset)}")
37
+
38
+
39
+
40
+
41
+ return train_dataset, val_dataset
42
+
43
+
44
+ def train_val_split(base_dataset: torchvision.datasets.CIFAR100):
45
+ num_classes = 100
46
+ base_dataset = np.array(base_dataset)
47
+ train_n = int(len(base_dataset) * 0.9 / num_classes)
48
+ train_idxs = []
49
+ val_idxs = []
50
+
51
+ for i in range(num_classes):
52
+ idxs = np.where(base_dataset == i)[0]
53
+ np.random.shuffle(idxs)
54
+ train_idxs.extend(idxs[:train_n])
55
+ val_idxs.extend(idxs[train_n:])
56
+ np.random.shuffle(train_idxs)
57
+ np.random.shuffle(val_idxs)
58
+
59
+ return train_idxs, val_idxs
60
+
61
+
62
+ class CIFAR100_train(torchvision.datasets.CIFAR100):
63
+ def __init__(self, root, cfg_trainer, indexs, train=True,
64
+ transform=None, target_transform=None,
65
+ download=False):
66
+ super(CIFAR100_train, self).__init__(root, train=train,
67
+ transform=transform, target_transform=target_transform,
68
+ download=download)
69
+ self.num_classes = 100
70
+ self.cfg_trainer = cfg_trainer
71
+ self.train_data = self.data[indexs]
72
+ self.train_labels = np.array(self.targets)[indexs]
73
+ self.indexs = indexs
74
+ self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
75
+ self.noise_indx = []
76
+ #self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
77
+
78
+ self.count = 0
79
+
80
+ def symmetric_noise(self):
81
+ self.train_labels_gt = self.train_labels.copy()
82
+ indices = np.random.permutation(len(self.train_data))
83
+ for i, idx in enumerate(indices):
84
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
85
+ self.noise_indx.append(idx)
86
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
87
+
88
+ def multiclass_noisify(self, y, P, random_state=0):
89
+ """ Flip classes according to transition probability matrix T.
90
+ It expects a number between 0 and the number of classes - 1.
91
+ """
92
+
93
+ assert P.shape[0] == P.shape[1]
94
+ assert np.max(y) < P.shape[0]
95
+
96
+ # row stochastic matrix
97
+ assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
98
+ assert (P >= 0.0).all()
99
+
100
+ m = y.shape[0]
101
+ new_y = y.copy()
102
+ flipper = np.random.RandomState(random_state)
103
+
104
+ for idx in np.arange(m):
105
+ i = y[idx]
106
+ # draw a vector with only an 1
107
+ flipped = flipper.multinomial(1, P[i, :], 1)[0]
108
+ new_y[idx] = np.where(flipped == 1)[0]
109
+
110
+ return new_y
111
+
112
+ # def build_for_cifar100(self, size, noise):
113
+ # """ random flip between two random classes.
114
+ # """
115
+ # assert(noise >= 0.) and (noise <= 1.)
116
+
117
+ # P = np.eye(size)
118
+ # cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
119
+ # P[cls1, cls2] = noise
120
+ # P[cls2, cls1] = noise
121
+ # P[cls1, cls1] = 1.0 - noise
122
+ # P[cls2, cls2] = 1.0 - noise
123
+
124
+ # assert_array_almost_equal(P.sum(axis=1), 1, 1)
125
+ # return P
126
+ def build_for_cifar100(self, size, noise):
127
+ """ The noise matrix flips to the "next" class with probability 'noise'.
128
+ """
129
+
130
+ assert(noise >= 0.) and (noise <= 1.)
131
+
132
+ P = (1. - noise) * np.eye(size)
133
+ for i in np.arange(size - 1):
134
+ P[i, i + 1] = noise
135
+
136
+ # adjust last row
137
+ P[size - 1, 0] = noise
138
+
139
+ assert_array_almost_equal(P.sum(axis=1), 1, 1)
140
+ return P
141
+
142
+
143
+ def asymmetric_noise(self, asym=False, random_shuffle=False):
144
+ self.train_labels_gt = self.train_labels.copy()
145
+ P = np.eye(self.num_classes)
146
+ n = self.cfg_trainer['percent']
147
+ nb_superclasses = 20
148
+ nb_subclasses = 5
149
+
150
+ if n > 0.0:
151
+ for i in np.arange(nb_superclasses):
152
+ init, end = i * nb_subclasses, (i+1) * nb_subclasses
153
+ P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
154
+
155
+ y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
156
+ random_state=0)
157
+ actual_noise = (y_train_noisy != self.train_labels).mean()
158
+ assert actual_noise > 0.0
159
+ self.train_labels = y_train_noisy
160
+
161
+
162
+
163
+ def __getitem__(self, index):
164
+ """
165
+ Args:
166
+ index (int): Index
167
+
168
+ Returns:
169
+ tuple: (image, target) where target is index of the target class.
170
+ """
171
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
172
+
173
+
174
+ # doing this so that it is consistent with all other datasets
175
+ # to return a PIL Image
176
+ img = Image.fromarray(img)
177
+
178
+
179
+ if self.transform is not None:
180
+ img = self.transform(img)
181
+
182
+ if self.target_transform is not None:
183
+ target = self.target_transform(target)
184
+
185
+ return img, target, index, target_gt
186
+
187
+ def __len__(self):
188
+ return len(self.train_data)
189
+
190
+
191
+ class CIFAR100_val(torchvision.datasets.CIFAR100):
192
+
193
+ def __init__(self, root, cfg_trainer, indexs, train=True,
194
+ transform=None, target_transform=None,
195
+ download=False):
196
+ super(CIFAR100_val, self).__init__(root, train=train,
197
+ transform=transform, target_transform=target_transform,
198
+ download=download)
199
+
200
+ # self.train_data = self.data[indexs]
201
+ # self.train_labels = np.array(self.targets)[indexs]
202
+ self.num_classes = 100
203
+ self.cfg_trainer = cfg_trainer
204
+ if train:
205
+ self.train_data = self.data[indexs]
206
+ self.train_labels = np.array(self.targets)[indexs]
207
+ else:
208
+ self.train_data = self.data
209
+ self.train_labels = np.array(self.targets)
210
+ self.train_labels_gt = self.train_labels.copy()
211
+ def symmetric_noise(self):
212
+ indices = np.random.permutation(len(self.train_data))
213
+ for i, idx in enumerate(indices):
214
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
215
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
216
+
217
+ def multiclass_noisify(self, y, P, random_state=0):
218
+ """ Flip classes according to transition probability matrix T.
219
+ It expects a number between 0 and the number of classes - 1.
220
+ """
221
+
222
+ assert P.shape[0] == P.shape[1]
223
+ assert np.max(y) < P.shape[0]
224
+
225
+ # row stochastic matrix
226
+ assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
227
+ assert (P >= 0.0).all()
228
+
229
+ m = y.shape[0]
230
+ new_y = y.copy()
231
+ flipper = np.random.RandomState(random_state)
232
+
233
+ for idx in np.arange(m):
234
+ i = y[idx]
235
+ # draw a vector with only an 1
236
+ flipped = flipper.multinomial(1, P[i, :], 1)[0]
237
+ new_y[idx] = np.where(flipped == 1)[0]
238
+
239
+ return new_y
240
+
241
+ # def build_for_cifar100(self, size, noise):
242
+ # """ random flip between two random classes.
243
+ # """
244
+ # assert(noise >= 0.) and (noise <= 1.)
245
+
246
+ # P = np.eye(size)
247
+ # cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
248
+ # P[cls1, cls2] = noise
249
+ # P[cls2, cls1] = noise
250
+ # P[cls1, cls1] = 1.0 - noise
251
+ # P[cls2, cls2] = 1.0 - noise
252
+
253
+ # assert_array_almost_equal(P.sum(axis=1), 1, 1)
254
+ # return P
255
+ def build_for_cifar100(self, size, noise):
256
+ """ The noise matrix flips to the "next" class with probability 'noise'.
257
+ """
258
+
259
+ assert(noise >= 0.) and (noise <= 1.)
260
+
261
+ P = (1. - noise) * np.eye(size)
262
+ for i in np.arange(size - 1):
263
+ P[i, i + 1] = noise
264
+
265
+ # adjust last row
266
+ P[size - 1, 0] = noise
267
+
268
+ assert_array_almost_equal(P.sum(axis=1), 1, 1)
269
+ return P
270
+
271
+
272
+ def asymmetric_noise(self, asym=False, random_shuffle=False):
273
+ P = np.eye(self.num_classes)
274
+ n = self.cfg_trainer['percent']
275
+ nb_superclasses = 20
276
+ nb_subclasses = 5
277
+
278
+ if n > 0.0:
279
+ for i in np.arange(nb_superclasses):
280
+ init, end = i * nb_subclasses, (i+1) * nb_subclasses
281
+ P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
282
+
283
+ y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
284
+ random_state=0)
285
+ actual_noise = (y_train_noisy != self.train_labels).mean()
286
+ assert actual_noise > 0.0
287
+ self.train_labels = y_train_noisy
288
+ def __len__(self):
289
+ return len(self.train_data)
290
+
291
+
292
+ def __getitem__(self, index):
293
+ """
294
+ Args:
295
+ index (int): Index
296
+
297
+ Returns:
298
+ tuple: (image, target) where target is index of the target class.
299
+ """
300
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
301
+
302
+
303
+ # doing this so that it is consistent with all other datasets
304
+ # to return a PIL Image
305
+ img = Image.fromarray(img)
306
+
307
+
308
+ if self.transform is not None:
309
+ img = self.transform(img)
310
+
311
+ if self.target_transform is not None:
312
+ target = self.target_transform(target)
313
+
314
+ return img, target, index, target_gt
315
+
316
+
317
+
ELR/data_loader/data_loaders.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from torchvision import datasets, transforms
4
+ from base import BaseDataLoader
5
+ from data_loader.cifar10 import get_cifar10
6
+ from data_loader.cifar100 import get_cifar100
7
+ from parse_config import ConfigParser
8
+ from PIL import Image
9
+
10
+
11
+ class CIFAR10DataLoader(BaseDataLoader):
12
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True):
13
+ config = ConfigParser.get_instance()
14
+ cfg_trainer = config['trainer']
15
+
16
+ transform_train = transforms.Compose([
17
+ transforms.RandomCrop(32, padding=4),
18
+ transforms.RandomHorizontalFlip(),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
21
+ ])
22
+ transform_val = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
25
+ ])
26
+ self.data_dir = data_dir
27
+
28
+ noise_file='%sCIFAR10_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
29
+
30
+ self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
31
+ transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
32
+
33
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
34
+ val_dataset = self.val_dataset)
35
+ def run_loader(self, batch_size, shuffle, validation_split, num_workers, pin_memory):
36
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
37
+ val_dataset = self.val_dataset)
38
+
39
+
40
+
41
+ class CIFAR100DataLoader(BaseDataLoader):
42
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True,num_workers=4, pin_memory=True):
43
+ config = ConfigParser.get_instance()
44
+ cfg_trainer = config['trainer']
45
+
46
+ transform_train = transforms.Compose([
47
+ #transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
48
+ transforms.RandomCrop(32, padding=4),
49
+ transforms.RandomHorizontalFlip(),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
52
+ ])
53
+ transform_val = transforms.Compose([
54
+ transforms.ToTensor(),
55
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
56
+ ])
57
+ self.data_dir = data_dir
58
+ config = ConfigParser.get_instance()
59
+ cfg_trainer = config['trainer']
60
+
61
+ noise_file='%sCIFAR100_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
62
+
63
+ self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
64
+ transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
65
+
66
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
67
+ val_dataset = self.val_dataset)
68
+ def run_loader(self, batch_size, shuffle, validation_split, num_workers, pin_memory):
69
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
70
+ val_dataset = self.val_dataset)
ELR/logger/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .logger import *
2
+ from .visualization import *
ELR/logger/logger.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.config
3
+ from pathlib import Path
4
+ from utils import read_json
5
+
6
+
7
+ def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8
+ """
9
+ Setup logging configuration
10
+ """
11
+ log_config = Path(log_config)
12
+ if log_config.is_file():
13
+ config = read_json(log_config)
14
+ # modify logging paths based on run config
15
+ for _, handler in config['handlers'].items():
16
+ if 'filename' in handler:
17
+ handler['filename'] = str(save_dir / handler['filename'])
18
+
19
+ logging.config.dictConfig(config)
20
+ else:
21
+ print("Warning: logging configuration file is not found in {}.".format(log_config))
22
+ logging.basicConfig(level=default_level)
ELR/logger/logger_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ {
3
+ "version": 1,
4
+ "disable_existing_loggers": false,
5
+ "formatters": {
6
+ "simple": {"format": "%(message)s"},
7
+ "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8
+ },
9
+ "handlers": {
10
+ "console": {
11
+ "class": "logging.StreamHandler",
12
+ "level": "DEBUG",
13
+ "formatter": "simple",
14
+ "stream": "ext://sys.stdout"
15
+ },
16
+ "info_file_handler": {
17
+ "class": "logging.handlers.RotatingFileHandler",
18
+ "level": "INFO",
19
+ "formatter": "datetime",
20
+ "filename": "info.log",
21
+ "maxBytes": 10485760,
22
+ "backupCount": 20, "encoding": "utf8"
23
+ }
24
+ },
25
+ "root": {
26
+ "level": "INFO",
27
+ "handlers": [
28
+ "console",
29
+ "info_file_handler"
30
+ ]
31
+ }
32
+ }
ELR/logger/visualization.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from utils import Timer
3
+
4
+
5
+ class MLFlow:
6
+ def __init__(self, log_dir, logger, enabled):
7
+ self.mlflow = None
8
+
9
+ if enabled:
10
+ log_dir = str(log_dir)
11
+
12
+ # Retrieve visualization writer.
13
+ try:
14
+ self.mlflow = importlib.import_module("mlflow")
15
+ succeeded = True
16
+ except ImportError:
17
+ succeeded = False
18
+
19
+ if not succeeded:
20
+ message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \
21
+ "this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \
22
+ "the 'config.json' file."
23
+ logger.warning(message)
24
+
25
+ self.step = 0
26
+ self.mode = ''
27
+
28
+ self.mlflow_ftns_with_tag_and_value = {
29
+ 'log_param', 'log_metric'
30
+ }
31
+ self.mlflow_ftns = {
32
+ 'start_run'
33
+ }
34
+ # self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
35
+
36
+ # self.timer = Timer()
37
+
38
+ # def set_step(self, step, mode='train'):
39
+ # self.mode = mode
40
+ # self.step = step
41
+ # if step == 0:
42
+ # self.timer.reset()
43
+ # else:
44
+ # duration = self.timer.check()
45
+ # self.add_scalar('steps_per_sec', 1 / duration)
46
+
47
+ def __getattr__(self, name):
48
+ """
49
+ If visualization is configured to use:
50
+ return add_data() methods of tensorboard with additional information (step, tag) added.
51
+ Otherwise:
52
+ return a blank function handle that does nothing
53
+ """
54
+ if name in self.mlflow_ftns_with_tag_and_value:
55
+ add_data = getattr(self.mlflow, name, None)
56
+
57
+ def wrapper(tag, data, *args, **kwargs):
58
+ if add_data is not None:
59
+ # add mode(train/valid) tag
60
+ if name not in self.tag_mode_exceptions:
61
+ tag = '{}/{}'.format(tag, self.mode)
62
+ add_data(tag, data, *args, **kwargs)
63
+
64
+ return wrapper
65
+ elif name in self.mlflow_ftns:
66
+ add_data = getattr(self.mlflow, name, None)
67
+
68
+ def wrapper(*args, **kwargs):
69
+ if add_data is not None:
70
+ # add mode(train/valid) tag
71
+ # if name not in self.tag_mode_exceptions:
72
+ # tag = '{}/{}'.format(tag, self.mode)
73
+ add_data(*args, **kwargs)
74
+
75
+ return wrapper
76
+ else:
77
+ # default action for returning methods defined in this class, set_step() for instance.
78
+ try:
79
+ attr = object.__getattr__(name)
80
+ except AttributeError:
81
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
82
+ return attr
83
+
84
+
85
+ class TensorboardWriter:
86
+ def __init__(self, log_dir, logger, enabled):
87
+ self.writer = None
88
+ self.selected_module = ""
89
+
90
+ if enabled:
91
+ log_dir = str(log_dir)
92
+
93
+ # Retrieve vizualization writer.
94
+ succeeded = False
95
+ for module in ["torch.utils.tensorboard", "tensorboardX"]:
96
+ try:
97
+ self.writer = importlib.import_module(module).SummaryWriter(log_dir)
98
+ succeeded = True
99
+ break
100
+ except ImportError:
101
+ succeeded = False
102
+ self.selected_module = module
103
+
104
+ if not succeeded:
105
+ message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
106
+ "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
107
+ "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
108
+ "the 'config.json' file."
109
+ logger.warning(message)
110
+
111
+ self.step = 0
112
+ self.mode = ''
113
+
114
+ self.tb_writer_ftns = {
115
+ 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
116
+ 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
117
+ }
118
+ self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
119
+
120
+ self.timer = Timer()
121
+
122
+ def set_step(self, step, mode='train'):
123
+ self.mode = mode
124
+ self.step = step
125
+ if step == 0:
126
+ self.timer.reset()
127
+ else:
128
+ duration = self.timer.check()
129
+ self.add_scalar('steps_per_sec', 1 / duration)
130
+
131
+ def __getattr__(self, name):
132
+ """
133
+ If visualization is configured to use:
134
+ return add_data() methods of tensorboard with additional information (step, tag) added.
135
+ Otherwise:
136
+ return a blank function handle that does nothing
137
+ """
138
+ if name in self.tb_writer_ftns:
139
+ add_data = getattr(self.writer, name, None)
140
+
141
+ def wrapper(tag, data, *args, **kwargs):
142
+ if add_data is not None:
143
+ # add mode(train/valid) tag
144
+ if name not in self.tag_mode_exceptions:
145
+ tag = '{}/{}'.format(tag, self.mode)
146
+ add_data(tag, data, self.step, *args, **kwargs)
147
+ return wrapper
148
+ else:
149
+ # default action for returning methods defined in this class, set_step() for instance.
150
+ try:
151
+ attr = object.__getattr__(name)
152
+ except AttributeError:
153
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
154
+ return attr
ELR/model/ResNet_Zoo.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''ResNet in PyTorch.
2
+ For Pre-activation ResNet, see 'preact_resnet.py'.
3
+ Reference:
4
+ [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
6
+ '''
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class BasicBlock(nn.Module):
13
+ expansion = 1
14
+
15
+ def __init__(self, in_planes, planes, stride=1):
16
+ super(BasicBlock, self).__init__()
17
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(planes)
21
+
22
+ self.shortcut = nn.Sequential()
23
+ if stride != 1 or in_planes != self.expansion*planes:
24
+ self.shortcut = nn.Sequential(
25
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
26
+ nn.BatchNorm2d(self.expansion*planes)
27
+ )
28
+
29
+ def forward(self, x):
30
+ out = F.relu(self.bn1(self.conv1(x)))
31
+ out = self.bn2(self.conv2(out))
32
+ out += self.shortcut(x)
33
+ out = F.relu(out)
34
+ return out
35
+
36
+
37
+
38
+ class Bottleneck(nn.Module):
39
+ expansion = 4
40
+
41
+ def __init__(self, in_planes, planes, stride=1):
42
+ super(Bottleneck, self).__init__()
43
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
44
+ self.bn1 = nn.BatchNorm2d(planes)
45
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
46
+ self.bn2 = nn.BatchNorm2d(planes)
47
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
48
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
49
+
50
+ self.shortcut = nn.Sequential()
51
+ if stride != 1 or in_planes != self.expansion*planes:
52
+ self.shortcut = nn.Sequential(
53
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
54
+ nn.BatchNorm2d(self.expansion*planes)
55
+ )
56
+
57
+ def forward(self, x):
58
+ out = F.relu(self.bn1(self.conv1(x)))
59
+ out = F.relu(self.bn2(self.conv2(out)))
60
+ out = self.bn3(self.conv3(out))
61
+ out += self.shortcut(x)
62
+ out = F.relu(out)
63
+ return out
64
+
65
+
66
+ class ResNet(nn.Module):
67
+ def __init__(self, block, num_blocks, num_classes=10):
68
+ super(ResNet, self).__init__()
69
+ self.in_planes = 64
70
+
71
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
72
+ self.bn1 = nn.BatchNorm2d(64)
73
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
74
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
75
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
76
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
77
+ self.linear = nn.Linear(512*block.expansion, num_classes)
78
+
79
+
80
+ self.gradients = None
81
+
82
+
83
+
84
+ def _make_layer(self, block, planes, num_blocks, stride):
85
+ strides = [stride] + [1]*(num_blocks-1)
86
+ layers = []
87
+ for stride in strides:
88
+ layers.append(block(self.in_planes, planes, stride))
89
+ self.in_planes = planes * block.expansion
90
+ return nn.Sequential(*layers)
91
+
92
+ def activations_hook(self, grad):
93
+ self.gradients = grad
94
+
95
+ def forward(self, x):
96
+ out = F.relu(self.bn1(self.conv1(x)))
97
+ out = self.layer1(out)
98
+ out = self.layer2(out)
99
+ out = self.layer3(out)
100
+ out = self.layer4(out)
101
+ out = F.avg_pool2d(out, 4)
102
+ y = out.view(out.size(0), -1)
103
+ out = self.linear(y)
104
+ if out.requires_grad:
105
+ out.register_hook(self.activations_hook)
106
+ return out
107
+
108
+ def get_activations_gradient(self):
109
+ return self.gradients
110
+
111
+
112
+ def ResNet18():
113
+ return ResNet(BasicBlock, [2,2,2,2])
114
+
115
+ def ResNet34():
116
+ return ResNet(BasicBlock, [3,4,6,3])
117
+
118
+ def ResNet50():
119
+ return ResNet(Bottleneck, [3,4,6,3])
120
+
121
+ def ResNet101():
122
+ return ResNet(Bottleneck, [3,4,23,3])
123
+
124
+ def ResNet152():
125
+ return ResNet(Bottleneck, [3,8,36,3])
126
+
127
+
128
+ def test():
129
+ net = ResNet18()
130
+ y = net(torch.randn(1,3,32,32))
131
+ print(y.size())
132
+
133
+ # test()
ELR/model/loss.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ from parse_config import ConfigParser
4
+ import torch.nn as nn
5
+
6
+
7
+ def cross_entropy(output, target):
8
+ return F.cross_entropy(output, target)
9
+
10
+
11
+ class elr_loss(nn.Module):
12
+ def __init__(self, num_examp, num_classes=10, beta=0.3):
13
+ super(elr_loss, self).__init__()
14
+ self.num_classes = num_classes
15
+ self.config = ConfigParser.get_instance()
16
+ self.USE_CUDA = torch.cuda.is_available()
17
+ self.target = torch.zeros(num_examp, self.num_classes).cuda() if self.USE_CUDA else torch.zeros(num_examp, self.num_classes)
18
+ self.beta = beta
19
+
20
+
21
+ def forward(self, index, output, label):
22
+ y_pred = F.softmax(output,dim=1)
23
+ y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
24
+ y_pred_ = y_pred.data.detach()
25
+ self.target[index] = self.beta * self.target[index] + (1-self.beta) * ((y_pred_)/(y_pred_).sum(dim=1,keepdim=True))
26
+ ce_loss = F.cross_entropy(output, label)
27
+ elr_reg = ((1-(self.target[index] * y_pred).sum(dim=1)).log()).mean()
28
+ final_loss = ce_loss + self.config['train_loss']['args']['lambda']*elr_reg
29
+ return final_loss
30
+
ELR/model/metric.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def my_metric(output, target):
5
+ with torch.no_grad():
6
+ pred = torch.argmax(output, dim=1)
7
+ assert pred.shape[0] == len(target)
8
+ correct = 0
9
+ correct += torch.sum(pred == target).item()
10
+ return correct / len(target)
11
+
12
+
13
+ def my_metric2(output, target, k=5):
14
+ with torch.no_grad():
15
+ pred = torch.topk(output, k, dim=1)[1]
16
+ assert pred.shape[0] == len(target)
17
+ correct = 0
18
+ for i in range(k):
19
+ correct += torch.sum(pred[:, i] == target).item()
20
+ return correct / len(target)
ELR/model/model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from base import BaseModel
4
+ from .ResNet_Zoo import ResNet, BasicBlock
5
+
6
+
7
+
8
+ def resnet34(num_classes=10):
9
+ return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
10
+
11
+
12
+
13
+
ELR/parse_config.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from pathlib import Path
4
+ from functools import reduce
5
+ from operator import getitem
6
+ from datetime import datetime
7
+ from logger import setup_logging
8
+ from utils import read_json, write_json
9
+
10
+
11
+ class ConfigParser:
12
+
13
+ __instance = None
14
+
15
+ def __new__(cls, args, options='', timestamp=True):
16
+ raise NotImplementedError('Cannot initialize via Constructor')
17
+
18
+ @classmethod
19
+ def __internal_new__(cls):
20
+ return super().__new__(cls)
21
+
22
+ @classmethod
23
+ def get_instance(cls, args=None, options='', timestamp=True):
24
+ if not cls.__instance:
25
+ if args is None:
26
+ NotImplementedError('Cannot initialize without args')
27
+ cls.__instance = cls.__internal_new__()
28
+ cls.__instance.__init__(args, options)
29
+
30
+ return cls.__instance
31
+
32
+ def __init__(self, args, options='', timestamp=True):
33
+ # parse default and custom cli options
34
+ for opt in options:
35
+ args.add_argument(*opt.flags, default=None, type=opt.type)
36
+ args = args.parse_args()
37
+
38
+ if args.device:
39
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
40
+ if args.resume is None:
41
+ msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
42
+ assert args.config is not None, msg_no_cfg
43
+ self.cfg_fname = Path(args.config)
44
+ config = read_json(self.cfg_fname)
45
+ self.resume = None
46
+ else:
47
+ self.resume = Path(args.resume)
48
+ resume_cfg_fname = self.resume.parent / 'config.json'
49
+ config = read_json(resume_cfg_fname)
50
+ if args.config is not None:
51
+ config.update(read_json(Path(args.config)))
52
+
53
+ # load config file and apply custom cli options
54
+ self._config = _update_config(config, options, args)
55
+
56
+ # set save_dir where trained model and log will be saved.
57
+ save_dir = Path(self.config['trainer']['save_dir'])
58
+ timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
59
+
60
+
61
+ if self.config['trainer']['asym']:
62
+ exper_name = self.config['name'] + '_asym_' + str(int(self.config['trainer']['percent']*100))
63
+ else:
64
+ exper_name = self.config['name'] + '_sym_' + str(int(self.config['trainer']['percent']*100))
65
+ self._save_dir = save_dir / 'models' / exper_name / timestamp
66
+ self._log_dir = save_dir / 'log' / exper_name / timestamp
67
+
68
+ self.save_dir.mkdir(parents=True, exist_ok=True)
69
+ self.log_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ # save updated config file to the checkpoint dir
72
+ write_json(self.config, self.save_dir / 'config.json')
73
+
74
+ # configure logging module
75
+ setup_logging(self.log_dir)
76
+ self.log_levels = {
77
+ 0: logging.WARNING,
78
+ 1: logging.INFO,
79
+ 2: logging.DEBUG
80
+ }
81
+
82
+ def initialize(self, name, module, *args, **kwargs):
83
+ """
84
+ finds a function handle with the name given as 'type' in config, and returns the
85
+ instance initialized with corresponding keyword args given as 'args'.
86
+ """
87
+ module_name = self[name]['type']
88
+ module_args = dict(self[name]['args'])
89
+ assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
90
+ module_args.update(kwargs)
91
+ return getattr(module, module_name)(*args, **module_args)
92
+
93
+ def __getitem__(self, name):
94
+ return self.config[name]
95
+
96
+ def get_logger(self, name, verbosity=2):
97
+ msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
98
+ self.log_levels.keys())
99
+ assert verbosity in self.log_levels, msg_verbosity
100
+ logger = logging.getLogger(name)
101
+ logger.setLevel(self.log_levels[verbosity])
102
+ return logger
103
+
104
+ # setting read-only attributes
105
+ @property
106
+ def config(self):
107
+ return self._config
108
+
109
+ @property
110
+ def save_dir(self):
111
+ return self._save_dir
112
+
113
+ @property
114
+ def log_dir(self):
115
+ return self._log_dir
116
+
117
+
118
+ # helper functions used to update config dict with custom cli options
119
+ def _update_config(config, options, args):
120
+ for opt in options:
121
+ value = getattr(args, _get_opt_name(opt.flags))
122
+ if value is not None:
123
+ _set_by_path(config, opt.target, value)
124
+ if 'target2' in opt._fields:
125
+ _set_by_path(config, opt.target2, value)
126
+ if 'target3' in opt._fields:
127
+ _set_by_path(config, opt.target3, value)
128
+
129
+ return config
130
+
131
+
132
+ def _get_opt_name(flags):
133
+ for flg in flags:
134
+ if flg.startswith('--'):
135
+ return flg.replace('--', '')
136
+ return flags[0].replace('--', '')
137
+
138
+
139
+ def _set_by_path(tree, keys, value):
140
+ """Set a value in a nested object in tree by sequence of keys."""
141
+ _get_by_path(tree, keys[:-1])[keys[-1]] = value
142
+
143
+
144
+ def _get_by_path(tree, keys):
145
+ """Access a nested object in tree by sequence of keys."""
146
+ return reduce(getitem, keys, tree)
ELR/test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from tqdm import tqdm
4
+ import data_loader.data_loaders as module_data
5
+ import model.loss as module_loss
6
+ import model.metric as module_metric
7
+ import model.model as module_arch
8
+ from parse_config import ConfigParser
9
+
10
+
11
+ def main(config):
12
+ logger = config.get_logger('test')
13
+
14
+ # setup data_loader instances
15
+ data_loader = getattr(module_data, config['data_loader']['type'])(
16
+ config['data_loader']['args']['data_dir'],
17
+ batch_size=512,
18
+ shuffle=False,
19
+ validation_split=0.0,
20
+ training=False,
21
+ num_workers=2
22
+ ).split_validation()
23
+
24
+ # build model architecture
25
+ model = config.initialize('arch', module_arch)
26
+ logger.info(model)
27
+
28
+ # get function handles of loss and metrics
29
+ loss_fn = getattr(module_loss, config['val_loss'])
30
+ metric_fns = [getattr(module_metric, met) for met in config['metrics']]
31
+
32
+ logger.info('Loading checkpoint: {} ...'.format(config.resume))
33
+ checkpoint = torch.load(config.resume,map_location='cpu')
34
+ state_dict = checkpoint['state_dict']
35
+ if config['n_gpu'] > 1:
36
+ model = torch.nn.DataParallel(model)
37
+ model.load_state_dict(state_dict)
38
+
39
+ # prepare model for testing
40
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ model = model.to(device)
42
+ model.eval()
43
+
44
+ total_loss = 0.0
45
+ total_metrics = torch.zeros(len(metric_fns))
46
+
47
+ with torch.no_grad():
48
+ for i, (data, target,_,_) in enumerate(tqdm(data_loader)):
49
+ data, target = data.to(device), target.to(device)
50
+ output = model(data)
51
+
52
+ #
53
+ # save sample images, or do something with output here
54
+ #
55
+
56
+ # computing loss, metrics on test set
57
+ loss = loss_fn(output, target)
58
+ batch_size = data.shape[0]
59
+ total_loss += loss.item() * batch_size
60
+ for i, metric in enumerate(metric_fns):
61
+ total_metrics[i] += metric(output, target) * batch_size
62
+
63
+ n_samples = len(data_loader.sampler)
64
+ log = {'loss': total_loss / n_samples}
65
+ log.update({
66
+ met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
67
+ })
68
+ logger.info(log)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ args = argparse.ArgumentParser(description='PyTorch Template')
73
+
74
+ args.add_argument('-c', '--config', default=None, type=str,
75
+ help='config file path (default: None)')
76
+ args.add_argument('-r', '--resume', default=None, type=str,
77
+ help='path to latest checkpoint (default: None)')
78
+ args.add_argument('-d', '--device', default=None, type=str,
79
+ help='indices of GPUs to enable (default: all)')
80
+ config = ConfigParser.get_instance(args, '')
81
+ #config = ConfigParser(args)
82
+ main(config)
ELR/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import sys
4
+ import requests
5
+ import socket
6
+ import torch
7
+ import mlflow
8
+ import mlflow.pytorch
9
+ import data_loader.data_loaders as module_data
10
+ import model.loss as module_loss
11
+ import model.metric as module_metric
12
+ import model.model as module_arch
13
+ from parse_config import ConfigParser
14
+ from trainer import Trainer
15
+ from collections import OrderedDict
16
+ import random
17
+
18
+
19
+
20
+ def log_params(conf: OrderedDict, parent_key: str = None):
21
+ for key, value in conf.items():
22
+ if parent_key is not None:
23
+ combined_key = f'{parent_key}-{key}'
24
+ else:
25
+ combined_key = key
26
+
27
+ if not isinstance(value, OrderedDict):
28
+ mlflow.log_param(combined_key, value)
29
+ else:
30
+ log_params(value, combined_key)
31
+
32
+
33
+ def main(config: ConfigParser):
34
+
35
+ logger = config.get_logger('train')
36
+
37
+ data_loader = getattr(module_data, config['data_loader']['type'])(
38
+ config['data_loader']['args']['data_dir'],
39
+ batch_size= config['data_loader']['args']['batch_size'],
40
+ shuffle=config['data_loader']['args']['shuffle'],
41
+ validation_split=config['data_loader']['args']['validation_split'],
42
+ num_batches=config['data_loader']['args']['num_batches'],
43
+ training=True,
44
+ num_workers=config['data_loader']['args']['num_workers'],
45
+ pin_memory=config['data_loader']['args']['pin_memory']
46
+ )
47
+
48
+
49
+ valid_data_loader = data_loader.split_validation()
50
+
51
+ # test_data_loader = None
52
+
53
+ test_data_loader = getattr(module_data, config['data_loader']['type'])(
54
+ config['data_loader']['args']['data_dir'],
55
+ batch_size=128,
56
+ shuffle=False,
57
+ validation_split=0.0,
58
+ training=False,
59
+ num_workers=2
60
+ ).split_validation()
61
+
62
+
63
+ # build model architecture, then print to console
64
+ model = config.initialize('arch', module_arch)
65
+
66
+ # get function handles of loss and metrics
67
+ logger.info(config.config)
68
+ if hasattr(data_loader.dataset, 'num_raw_example'):
69
+ num_examp = data_loader.dataset.num_raw_example
70
+ else:
71
+ num_examp = len(data_loader.dataset)
72
+
73
+ train_loss = getattr(module_loss, config['train_loss']['type'])(num_examp=num_examp, num_classes=config['num_classes'],
74
+ beta=config['train_loss']['args']['beta'])
75
+
76
+ val_loss = getattr(module_loss, config['val_loss'])
77
+ metrics = [getattr(module_metric, met) for met in config['metrics']]
78
+
79
+ # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
80
+ trainable_params = filter(lambda p: p.requires_grad, model.parameters())
81
+
82
+ optimizer = config.initialize('optimizer', torch.optim, [{'params': trainable_params}])
83
+
84
+ lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer)
85
+
86
+ trainer = Trainer(model, train_loss, metrics, optimizer,
87
+ config=config,
88
+ data_loader=data_loader,
89
+ valid_data_loader=valid_data_loader,
90
+ test_data_loader=test_data_loader,
91
+ lr_scheduler=lr_scheduler,
92
+ val_criterion=val_loss)
93
+
94
+ trainer.train()
95
+ logger = config.get_logger('trainer', config['trainer']['verbosity'])
96
+ cfg_trainer = config['trainer']
97
+
98
+
99
+ if __name__ == '__main__':
100
+ args = argparse.ArgumentParser(description='PyTorch Template')
101
+ args.add_argument('-c', '--config', default=None, type=str,
102
+ help='config file path (default: None)')
103
+ args.add_argument('-r', '--resume', default=None, type=str,
104
+ help='path to latest checkpoint (default: None)')
105
+ args.add_argument('-d', '--device', default=None, type=str,
106
+ help='indices of GPUs to enable (default: all)')
107
+
108
+ # custom cli options to modify configuration from default values given in json file.
109
+ CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
110
+ options = [
111
+ CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
112
+ CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
113
+ CustomArgs(['--lamb', '--lamb'], type=float, target=('train_loss', 'args', 'lambda')),
114
+ CustomArgs(['--beta', '--beta'], type=float, target=('train_loss', 'args', 'beta')),
115
+ CustomArgs(['--percent', '--percent'], type=float, target=('trainer', 'percent')),
116
+ CustomArgs(['--asym', '--asym'], type=bool, target=('trainer', 'asym')),
117
+ CustomArgs(['--name', '--exp_name'], type=str, target=('name',)),
118
+ CustomArgs(['--seed', '--seed'], type=int, target=('seed',))
119
+ ]
120
+ config = ConfigParser.get_instance(args, options)
121
+
122
+ random.seed(config['seed'])
123
+ torch.manual_seed(config['seed'])
124
+ torch.cuda.manual_seed_all(config['seed'])
125
+ main(config)
ELR/trainer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .trainer import *
ELR/trainer/trainer.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ from typing import List
5
+ from torchvision.utils import make_grid
6
+ from base import BaseTrainer
7
+ from utils import inf_loop
8
+ import sys
9
+ from sklearn.mixture import GaussianMixture
10
+
11
+ class Trainer(BaseTrainer):
12
+ """
13
+ Trainer class
14
+
15
+ Note:
16
+ Inherited from BaseTrainer.
17
+ """
18
+ def __init__(self, model, train_criterion, metrics, optimizer, config, data_loader,
19
+ valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None, val_criterion=None):
20
+ super().__init__(model, train_criterion, metrics, optimizer, config, val_criterion)
21
+ self.config = config
22
+ self.data_loader = data_loader
23
+ if len_epoch is None:
24
+ # epoch-based training
25
+ self.len_epoch = len(self.data_loader)
26
+ else:
27
+ # iteration-based training
28
+ self.data_loader = inf_loop(data_loader)
29
+ self.len_epoch = len_epoch
30
+ self.valid_data_loader = valid_data_loader
31
+
32
+ self.test_data_loader = test_data_loader
33
+ self.do_validation = self.valid_data_loader is not None
34
+ self.do_test = self.test_data_loader is not None
35
+ self.lr_scheduler = lr_scheduler
36
+ self.log_step = int(np.sqrt(data_loader.batch_size))
37
+ self.train_loss_list: List[float] = []
38
+ self.val_loss_list: List[float] = []
39
+ self.test_loss_list: List[float] = []
40
+ #Visdom visualization
41
+
42
+
43
+ def _eval_metrics(self, output, label):
44
+ acc_metrics = np.zeros(len(self.metrics))
45
+ for i, metric in enumerate(self.metrics):
46
+ acc_metrics[i] += metric(output, label)
47
+ self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i])
48
+ return acc_metrics
49
+
50
+ def _train_epoch(self, epoch):
51
+ """
52
+ Training logic for an epoch
53
+
54
+ :param epoch: Current training epoch.
55
+ :return: A log that contains all information you want to save.
56
+
57
+ Note:
58
+ If you have additional information to record, for example:
59
+ > additional_log = {"x": x, "y": y}
60
+ merge it with log before return. i.e.
61
+ > log = {**log, **additional_log}
62
+ > return log
63
+
64
+ The metrics in log must have the key 'metrics'.
65
+ """
66
+ self.model.train()
67
+
68
+ total_loss = 0
69
+ total_metrics = np.zeros(len(self.metrics))
70
+
71
+ with tqdm(self.data_loader) as progress:
72
+ for batch_idx, (data, label, indexs, _) in enumerate(progress):
73
+ progress.set_description_str(f'Train epoch {epoch}')
74
+
75
+ data, label = data.to(self.device), label.long().to(self.device)
76
+
77
+ output = self.model(data)
78
+
79
+ loss = self.train_criterion(indexs.cpu().detach().numpy().tolist(), output, label)
80
+ self.optimizer.zero_grad()
81
+ loss.backward()
82
+
83
+
84
+
85
+
86
+ self.optimizer.step()
87
+
88
+ self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
89
+ self.writer.add_scalar('loss', loss.item())
90
+ self.train_loss_list.append(loss.item())
91
+ total_loss += loss.item()
92
+ total_metrics += self._eval_metrics(output, label)
93
+
94
+
95
+ if batch_idx % self.log_step == 0:
96
+ progress.set_postfix_str(' {} Loss: {:.6f}'.format(
97
+ self._progress(batch_idx),
98
+ loss.item()))
99
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
100
+
101
+ if batch_idx == self.len_epoch:
102
+ break
103
+ # if hasattr(self.data_loader, 'run'):
104
+ # self.data_loader.run()
105
+
106
+ log = {
107
+ 'loss': total_loss / self.len_epoch,
108
+ 'metrics': (total_metrics / self.len_epoch).tolist(),
109
+ 'learning rate': self.lr_scheduler.get_lr()
110
+ }
111
+
112
+
113
+ if self.do_validation:
114
+ val_log = self._valid_epoch(epoch)
115
+ log.update(val_log)
116
+ if self.do_test:
117
+ test_log, test_meta = self._test_epoch(epoch)
118
+ log.update(test_log)
119
+ else:
120
+ test_meta = [0,0]
121
+
122
+
123
+ if self.lr_scheduler is not None:
124
+ self.lr_scheduler.step()
125
+
126
+ return log
127
+
128
+
129
+ def _valid_epoch(self, epoch):
130
+ """
131
+ Validate after training an epoch
132
+
133
+ :return: A log that contains information about validation
134
+
135
+ Note:
136
+ The validation metrics in log must have the key 'val_metrics'.
137
+ """
138
+ self.model.eval()
139
+
140
+ total_val_loss = 0
141
+ total_val_metrics = np.zeros(len(self.metrics))
142
+ with torch.no_grad():
143
+ with tqdm(self.valid_data_loader) as progress:
144
+ for batch_idx, (data, label, _, _) in enumerate(progress):
145
+ progress.set_description_str(f'Valid epoch {epoch}')
146
+ data, label = data.to(self.device), label.to(self.device)
147
+ output = self.model(data)
148
+ loss = self.val_criterion(output, label)
149
+
150
+ self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
151
+ self.writer.add_scalar('loss', loss.item())
152
+ self.val_loss_list.append(loss.item())
153
+ total_val_loss += loss.item()
154
+ total_val_metrics += self._eval_metrics(output, label)
155
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
156
+
157
+ # add histogram of model parameters to the tensorboard
158
+ for name, p in self.model.named_parameters():
159
+ self.writer.add_histogram(name, p, bins='auto')
160
+
161
+ return {
162
+ 'val_loss': total_val_loss / len(self.valid_data_loader),
163
+ 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
164
+ }
165
+
166
+ def _test_epoch(self, epoch):
167
+ """
168
+ Test after training an epoch
169
+
170
+ :return: A log that contains information about test
171
+
172
+ Note:
173
+ The Test metrics in log must have the key 'val_metrics'.
174
+ """
175
+ self.model.eval()
176
+ total_test_loss = 0
177
+ total_test_metrics = np.zeros(len(self.metrics))
178
+ results = np.zeros((len(self.test_data_loader.dataset), self.config['num_classes']), dtype=np.float32)
179
+ tar_ = np.zeros((len(self.test_data_loader.dataset),), dtype=np.float32)
180
+ with torch.no_grad():
181
+ with tqdm(self.test_data_loader) as progress:
182
+ for batch_idx, (data, label,indexs,_) in enumerate(progress):
183
+ progress.set_description_str(f'Test epoch {epoch}')
184
+ data, label = data.to(self.device), label.to(self.device)
185
+ output = self.model(data)
186
+
187
+ loss = self.val_criterion(output, label)
188
+
189
+ self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test')
190
+ self.writer.add_scalar('loss', loss.item())
191
+ self.test_loss_list.append(loss.item())
192
+ total_test_loss += loss.item()
193
+ total_test_metrics += self._eval_metrics(output, label)
194
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
195
+
196
+ results[indexs.cpu().detach().numpy().tolist()] = output.cpu().detach().numpy().tolist()
197
+ tar_[indexs.cpu().detach().numpy().tolist()] = label.cpu().detach().numpy().tolist()
198
+
199
+ # add histogram of model parameters to the tensorboard
200
+ for name, p in self.model.named_parameters():
201
+ self.writer.add_histogram(name, p, bins='auto')
202
+
203
+ return {
204
+ 'test_loss': total_test_loss / len(self.test_data_loader),
205
+ 'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
206
+ },[results,tar_]
207
+
208
+
209
+ def _warmup_epoch(self, epoch):
210
+ total_loss = 0
211
+ total_metrics = np.zeros(len(self.metrics))
212
+ self.model.train()
213
+
214
+ data_loader = self.data_loader#self.loader.run('warmup')
215
+
216
+
217
+ with tqdm(data_loader) as progress:
218
+ for batch_idx, (data, label, _, indexs , _) in enumerate(progress):
219
+ progress.set_description_str(f'Warm up epoch {epoch}')
220
+
221
+ data, label = data.to(self.device), label.long().to(self.device)
222
+
223
+ self.optimizer.zero_grad()
224
+ output = self.model(data)
225
+ out_prob = torch.nn.functional.softmax(output).data.detach()
226
+
227
+ self.train_criterion.update_hist(indexs.cpu().detach().numpy().tolist(), out_prob)
228
+
229
+ loss = torch.nn.functional.cross_entropy(output, label)
230
+
231
+ loss.backward()
232
+ self.optimizer.step()
233
+
234
+ self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
235
+ self.writer.add_scalar('loss', loss.item())
236
+ self.train_loss_list.append(loss.item())
237
+ total_loss += loss.item()
238
+ total_metrics += self._eval_metrics(output, label)
239
+
240
+
241
+ if batch_idx % self.log_step == 0:
242
+ progress.set_postfix_str(' {} Loss: {:.6f}'.format(
243
+ self._progress(batch_idx),
244
+ loss.item()))
245
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
246
+
247
+ if batch_idx == self.len_epoch:
248
+ break
249
+ if hasattr(self.data_loader, 'run'):
250
+ self.data_loader.run()
251
+ log = {
252
+ 'loss': total_loss / self.len_epoch,
253
+ 'noise detection rate' : 0.0,
254
+ 'metrics': (total_metrics / self.len_epoch).tolist(),
255
+ 'learning rate': self.lr_scheduler.get_lr()
256
+ }
257
+
258
+ if self.do_validation:
259
+ val_log = self._valid_epoch(epoch)
260
+ log.update(val_log)
261
+ if self.do_test:
262
+ test_log, test_meta = self._test_epoch(epoch)
263
+ log.update(test_log)
264
+ else:
265
+ test_meta = [0,0]
266
+
267
+ return log
268
+
269
+
270
+ def _progress(self, batch_idx):
271
+ base = '[{}/{} ({:.0f}%)]'
272
+ if hasattr(self.data_loader, 'n_samples'):
273
+ current = batch_idx * self.data_loader.batch_size
274
+ total = self.data_loader.n_samples
275
+ else:
276
+ current = batch_idx
277
+ total = self.len_epoch
278
+ return base.format(current, total, 100.0 * current / total)
ELR/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .util import *
ELR/utils/util.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from datetime import datetime
4
+ from itertools import repeat
5
+ from collections import OrderedDict
6
+ import numpy as np
7
+
8
+ def ensure_dir(dirname):
9
+ dirname = Path(dirname)
10
+ if not dirname.is_dir():
11
+ dirname.mkdir(parents=True, exist_ok=False)
12
+
13
+
14
+ def read_json(fname):
15
+ with fname.open('rt') as handle:
16
+ return json.load(handle, object_hook=OrderedDict)
17
+
18
+
19
+ def write_json(content, fname):
20
+ with fname.open('wt') as handle:
21
+ json.dump(content, handle, indent=4, sort_keys=False)
22
+
23
+
24
+ def inf_loop(data_loader):
25
+ ''' wrapper function for endless data loader. '''
26
+ for loader in repeat(data_loader):
27
+ yield from loader
28
+
29
+
30
+ class Timer:
31
+ def __init__(self):
32
+ self.cache = datetime.now()
33
+
34
+ def check(self):
35
+ now = datetime.now()
36
+ duration = now - self.cache
37
+ self.cache = now
38
+ return duration.total_seconds()
39
+
40
+ def reset(self):
41
+ self.cache = datetime.now()
42
+
43
+
44
+
45
+ def sigmoid_rampup(current, rampup_length):
46
+ """Exponential rampup from 2"""
47
+ if rampup_length == 0:
48
+ return 1.0
49
+ else:
50
+ current = np.clip(current, 0.0, rampup_length)
51
+ phase = 1.0 - current / rampup_length
52
+ return float(np.exp(-5.0 * phase * phase))
53
+
54
+
55
+ def linear_rampup(current, rampup_length):
56
+ """Linear rampup"""
57
+ assert current >= 0 and rampup_length >= 0
58
+ if current >= rampup_length:
59
+ return 1.0
60
+ else:
61
+ return current / rampup_length
62
+
63
+
64
+ def cosine_rampdown(current, rampdown_length):
65
+ """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
66
+ current = np.clip(current, 0.0, rampdown_length)
67
+ return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
68
+
69
+
70
+ def cosine_rampup(current, rampup_length):
71
+ """Cosine rampup"""
72
+ current = np.clip(current, 0.0, rampup_length)
73
+ return float(-.5 * (np.cos(np.pi * current / rampup_length) - 1))
74
+
75
+
ELR_plus/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ELR+
2
+ This is an official PyTorch implementation of ELR+ method proposed in [Early-Learning Regularization Prevents Memorization of Noisy Labels](https://arxiv.org/abs/2007.00151).
3
+
4
+
5
+ ## Usage
6
+ Train the network on the Symmmetric Noise CIFAR-10 dataset (noise rate = 0.8):
7
+
8
+ ```
9
+ python train.py -c config_cifar10.json --percent 0.8
10
+ ```
11
+ Train the network on the Asymmmetric Noise CIFAR-10 dataset (noise rate = 0.4):
12
+
13
+ ```
14
+ python train.py -c config_cifar10_asym.json --percent 0.4
15
+ ```
16
+
17
+ Train the network on the Asymmmetric Noise CIFAR-100 dataset (noise rate = 0.4):
18
+
19
+ ```
20
+ python train.py -c config_cifar100.json --percent 0.4 --asym 1
21
+ ```
22
+
23
+ The config files can be modified to adjust hyperparameters and optimization settings.
24
+
25
+
26
+ ## References
27
+ - S. Liu, J. Niles-Weed, N. Razavian and C. Fernandez-Granda "Early-Learning Regularization Prevents Memorization of Noisy Labels", 2020
ELR_plus/base/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base_data_loader import *
2
+ from .base_model import *
3
+ from .base_trainer import *
ELR_plus/base/base_data_loader.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union, Optional
2
+
3
+ import numpy as np
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.data.dataloader import default_collate
6
+ from torch.utils.data.sampler import SubsetRandomSampler
7
+
8
+
9
+ class BaseDataLoader(DataLoader):
10
+ """
11
+ Base class for all data loaders
12
+ """
13
+ valid_sampler: Optional[SubsetRandomSampler]
14
+ sampler: Optional[SubsetRandomSampler]
15
+
16
+ def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory,
17
+ collate_fn=default_collate, val_dataset=None):
18
+ self.collate_fn = collate_fn
19
+ self.validation_split = validation_split
20
+ self.shuffle = shuffle
21
+ self.val_dataset = val_dataset
22
+
23
+ self.batch_idx = 0
24
+ self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset)
25
+ self.init_kwargs = {
26
+ 'dataset': train_dataset,
27
+ 'batch_size': batch_size,
28
+ 'shuffle': self.shuffle,
29
+ 'collate_fn': collate_fn,
30
+ 'num_workers': num_workers,
31
+ 'pin_memory': pin_memory
32
+ }
33
+ if val_dataset is None:
34
+ self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
35
+ super().__init__(sampler=self.sampler, **self.init_kwargs)
36
+ else:
37
+ super().__init__(**self.init_kwargs)
38
+
39
+ def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]:
40
+ if split == 0.0:
41
+ return None, None
42
+
43
+ idx_full = np.arange(self.n_samples)
44
+
45
+ np.random.seed(0)
46
+ np.random.shuffle(idx_full)
47
+
48
+ if isinstance(split, int):
49
+ assert split > 0
50
+ assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
51
+ len_valid = split
52
+ else:
53
+ len_valid = int(self.n_samples * split)
54
+
55
+ valid_idx = idx_full[0:len_valid]
56
+ train_idx = np.delete(idx_full, np.arange(0, len_valid))
57
+
58
+ train_sampler = SubsetRandomSampler(train_idx)
59
+ valid_sampler = SubsetRandomSampler(valid_idx)
60
+ print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}")
61
+
62
+ # turn off shuffle option which is mutually exclusive with sampler
63
+ self.shuffle = False
64
+ self.n_samples = len(train_idx)
65
+
66
+ return train_sampler, valid_sampler
67
+
68
+ def split_validation(self, bs = 1000):
69
+ if self.val_dataset is not None:
70
+ kwargs = {
71
+ 'dataset': self.val_dataset,
72
+ 'batch_size': bs,
73
+ 'shuffle': False,
74
+ 'collate_fn': self.collate_fn,
75
+ 'num_workers': self.num_workers
76
+ }
77
+ return DataLoader(**kwargs)
78
+ else:
79
+ print('Using sampler to split!')
80
+ return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
81
+
82
+
83
+
ELR_plus/base/base_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import numpy as np
3
+ from abc import abstractmethod
4
+
5
+
6
+ class BaseModel(nn.Module):
7
+ """
8
+ Base class for all models
9
+ """
10
+ @abstractmethod
11
+ def forward(self, *inputs):
12
+ """
13
+ Forward pass logic
14
+
15
+ :return: Model output
16
+ """
17
+ raise NotImplementedError
18
+
19
+ def __str__(self):
20
+ """
21
+ Model prints with number of trainable parameters
22
+ """
23
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24
+ params = sum([np.prod(p.size()) for p in model_parameters])
25
+ return super().__str__() + '\nTrainable parameters: {}'.format(params)
ELR_plus/base/base_trainer.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypeVar, List, Tuple
2
+ import torch
3
+ from tqdm import tqdm
4
+ from abc import abstractmethod
5
+ from numpy import inf
6
+ from logger import TensorboardWriter
7
+ import numpy as np
8
+
9
+
10
+ class BaseTrainer:
11
+ """
12
+ Base class for all trainers
13
+ """
14
+ def __init__(self, model1, model2, model_ema1, model_ema2, train_criterion1,
15
+ train_criterion2, metrics, optimizer1, optimizer2, config, val_criterion,
16
+ model_ema1_copy, model_ema2_copy):
17
+ self.config = config.config
18
+
19
+ self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
20
+
21
+
22
+ # setup GPU device if available, move model into configured device
23
+ self.device, self.device_ids = self._prepare_device(config['n_gpu'])
24
+
25
+ if len(self.device_ids) > 1:
26
+ print('Using Multi-Processing!')
27
+
28
+ self.model1 = model1.to(self.device+str(self.device_ids[0]))
29
+ self.model2 = model2.to(self.device+str(self.device_ids[-1]))
30
+
31
+ if model_ema1 is not None:
32
+ self.model_ema1 = model_ema1.to(self.device+str(self.device_ids[0]))
33
+ self.model_ema2_copy = model_ema2_copy.to(self.device+str(self.device_ids[0]))
34
+ else:
35
+ self.model_ema1 = None
36
+ self.model_ema2_copy = None
37
+
38
+ if model_ema2 is not None:
39
+ self.model_ema2 = model_ema2.to(self.device+str(self.device_ids[-1]))
40
+ self.model_ema1_copy = model_ema1_copy.to(self.device+str(self.device_ids[-1]))
41
+ else:
42
+ self.model_ema2 = None
43
+ self.model_ema1_copy = None
44
+
45
+ if self.model_ema1 is not None:
46
+ for param in self.model_ema1.parameters():
47
+ param.detach_()
48
+
49
+ for param in self.model_ema2_copy.parameters():
50
+ param.detach_()
51
+
52
+ if self.model_ema2 is not None:
53
+ for param in self.model_ema2.parameters():
54
+ param.detach_()
55
+
56
+ for param in self.model_ema1_copy.parameters():
57
+ param.detach_()
58
+
59
+
60
+ self.train_criterion1 = train_criterion1.to(self.device+str(self.device_ids[0]))
61
+ self.train_criterion2 = train_criterion2.to(self.device+str(self.device_ids[-1]))
62
+
63
+ self.val_criterion = val_criterion
64
+
65
+ self.metrics = metrics
66
+
67
+ self.optimizer1 = optimizer1
68
+ self.optimizer2 = optimizer2
69
+
70
+ cfg_trainer = config['trainer']
71
+ self.epochs = cfg_trainer['epochs']
72
+ self.save_period = cfg_trainer['save_period']
73
+ self.monitor = cfg_trainer.get('monitor', 'off')
74
+
75
+ # configuration to monitor model performance and save best
76
+ if self.monitor == 'off':
77
+ self.mnt_mode = 'off'
78
+ self.mnt_best = 0
79
+ else:
80
+ self.mnt_mode, self.mnt_metric = self.monitor.split()
81
+ assert self.mnt_mode in ['min', 'max']
82
+
83
+ self.mnt_best = inf if self.mnt_mode == 'min' else -inf
84
+ self.early_stop = cfg_trainer.get('early_stop', inf)
85
+
86
+ self.start_epoch = 1
87
+
88
+ self.global_step = 0
89
+
90
+ self.checkpoint_dir = config.save_dir
91
+
92
+ # setup visualization writer instance
93
+ self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
94
+
95
+ if config.resume is not None:
96
+ self._resume_checkpoint(config.resume)
97
+
98
+
99
+
100
+ @abstractmethod
101
+ def _train_epoch(self, epoch):
102
+ """
103
+ Training logic for an epoch
104
+
105
+ :param epoch: Current epochs number
106
+ """
107
+ raise NotImplementedError
108
+
109
+
110
+
111
+ def train(self):
112
+ """
113
+ Full training logic
114
+ """
115
+
116
+ if len(self.device_ids) > 1:
117
+ import torch.multiprocessing as mp
118
+ mp.set_start_method('spawn', force =True)
119
+
120
+ not_improved_count = 0
121
+
122
+ for epoch in tqdm(range(self.start_epoch, self.epochs + 1), desc='Total progress: '):
123
+ if epoch <= self.config['trainer']['warmup']:
124
+ if len(self.device_ids) > 1:
125
+ q1 = mp.Queue()
126
+ q2 = mp.Queue()
127
+ p1 = mp.Process(target=self._warmup_epoch, args=(epoch, self.model1, self.data_loader1, self.optimizer1, self.train_criterion1, self.lr_scheduler1, self.device+str(self.device_ids[0]), q1 ))
128
+ p2 = mp.Process(target=self._warmup_epoch, args=(epoch, self.model2, self.data_loader2, self.optimizer2, self.train_criterion2, self.lr_scheduler2, self.device+str(self.device_ids[-1]), q2))
129
+ p1.start()
130
+ p2.start()
131
+ result1 = q1.get()
132
+ result2 = q2.get()
133
+ p1.join()
134
+ p2.join()
135
+ else:
136
+ result1 = self._warmup_epoch(epoch, self.model1, self.data_loader1, self.optimizer1, self.train_criterion1, self.lr_scheduler1, self.device+str(self.device_ids[0]))
137
+ result2 = self._warmup_epoch(epoch, self.model2, self.data_loader2, self.optimizer2, self.train_criterion2, self.lr_scheduler2, self.device+str(self.device_ids[-1]))
138
+
139
+ if len(self.device_ids) > 1:
140
+ self.model_ema1_copy.load_state_dict(self.model_ema1.state_dict())
141
+ self.model_ema2_copy.load_state_dict(self.model_ema2.state_dict())
142
+ if self.do_validation:
143
+ q1 = mp.Queue()
144
+ p1 = mp.Process(target=self._valid_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q1))
145
+
146
+ if self.do_test:
147
+ q2 = mp.Queue()
148
+ p2 = mp.Process(target=self._test_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q2))
149
+ p1.start()
150
+ p2.start()
151
+ val_log = q1.get()
152
+ test_log, test_meta = q2.get()
153
+ result1.update(val_log)
154
+ result2.update(val_log)
155
+ result1.update(test_log)
156
+ result2.update(test_log)
157
+ p1.join()
158
+ p2.join()
159
+ else:
160
+ if self.do_validation:
161
+ val_log = self._valid_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
162
+ result1.update(val_log)
163
+ result2.update(val_log)
164
+ if self.do_test:
165
+ test_log, test_meta = self._test_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
166
+ result1.update(test_log)
167
+ result2.update(test_log)
168
+ else:
169
+ test_meta = [0,0]
170
+
171
+ else:
172
+ if len(self.device_ids) > 1:
173
+ q1 = mp.Queue()
174
+ q2 = mp.Queue()
175
+ p1 = mp.Process(target=self._train_epoch, args=(epoch, self.model1, self.model_ema1, self.model_ema2_copy, self.data_loader1, self.train_criterion1, self.optimizer1, self.lr_scheduler1, self.device+str(self.device_ids[0]), q1 ))
176
+ p2 = mp.Process(target=self._train_epoch, args=(epoch, self.model2, self.model_ema2, self.model_ema1_copy, self.data_loader2, self.train_criterion2, self.optimizer2, self.lr_scheduler2, self.device+str(self.device_ids[-1]), q2 ))
177
+ p1.start()
178
+ p2.start()
179
+ result1 = q1.get()
180
+ result2 = q2.get()
181
+ p1.join()
182
+ p2.join()
183
+ else:
184
+ result1 = self._train_epoch(epoch, self.model1, self.model_ema1, self.model_ema2, self.data_loader1, self.train_criterion1, self.optimizer1, self.lr_scheduler1, self.device+str(self.device_ids[0]))
185
+ result2 = self._train_epoch(epoch, self.model2, self.model_ema2, self.model_ema1, self.data_loader2, self.train_criterion2, self.optimizer2, self.lr_scheduler2, self.device+str(self.device_ids[-1]))
186
+
187
+
188
+ self.global_step += result1['local_step']
189
+ if len(self.device_ids) > 1:
190
+ self.model_ema1_copy.load_state_dict(self.model_ema1.state_dict())
191
+ self.model_ema2_copy.load_state_dict(self.model_ema2.state_dict())
192
+ if self.do_validation:
193
+ q1 = mp.Queue()
194
+ p1 = mp.Process(target=self._valid_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q1))
195
+
196
+ if self.do_test:
197
+ q2 = mp.Queue()
198
+ p2 = mp.Process(target=self._test_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q2))
199
+ p1.start()
200
+ p2.start()
201
+ val_log = q1.get()
202
+ test_log = q2.get()
203
+ result1.update(val_log)
204
+ result2.update(val_log)
205
+ result1.update(test_log)
206
+ result2.update(test_log)
207
+ p1.join()
208
+ p2.join()
209
+ else:
210
+ if self.do_validation:
211
+ val_log = self._valid_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
212
+ result1.update(val_log)
213
+ result2.update(val_log)
214
+ if self.do_test:
215
+ test_log = self._test_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
216
+ result1.update(test_log)
217
+ result2.update(test_log)
218
+
219
+
220
+
221
+ # save logged informations into log dict
222
+ log = {'epoch': epoch}
223
+ for key, value in result1.items():
224
+ if key == 'metrics':
225
+ log.update({'Net1' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
226
+ log.update({'Net2' + mtr.__name__: result2[key][i] for i, mtr in enumerate(self.metrics)})
227
+ elif key == 'val_metrics':
228
+ log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
229
+ elif key == 'test_metrics':
230
+ log.update({'test_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
231
+ else:
232
+ log['Net1'+key] = value
233
+ log['Net2'+key] = result2[key]
234
+
235
+ # print logged informations to the screen
236
+ for key, value in log.items():
237
+ self.logger.info(' {:15s}: {}'.format(str(key), value))
238
+
239
+ # evaluate model performance according to configured metric, save best checkpoint as model_best
240
+ best = False
241
+ if self.mnt_mode != 'off':
242
+ try:
243
+ # check whether model performance improved or not, according to specified metric(mnt_metric)
244
+ improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
245
+ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
246
+ except KeyError:
247
+ self.logger.warning("Warning: Metric '{}' is not found. "
248
+ "Model performance monitoring is disabled.".format(self.mnt_metric))
249
+ self.mnt_mode = 'off'
250
+ improved = False
251
+
252
+ if improved:
253
+ self.mnt_best = log[self.mnt_metric]
254
+ not_improved_count = 0
255
+ best = True
256
+ else:
257
+ not_improved_count += 1
258
+
259
+ if not_improved_count > self.early_stop:
260
+ self.logger.info("Validation performance didn\'t improve for {} epochs. "
261
+ "Training stops.".format(self.early_stop))
262
+ break
263
+
264
+ if epoch % self.save_period == 0:
265
+ self._save_checkpoint(epoch, save_best=best)
266
+
267
+
268
+ def _prepare_device(self, n_gpu_use):
269
+ """
270
+ setup GPU device if available, move model into configured device
271
+ """
272
+ n_gpu = torch.cuda.device_count()
273
+ if n_gpu_use > 0 and n_gpu == 0:
274
+ self.logger.warning("Warning: There\'s no GPU available on this machine,"
275
+ "training will be performed on CPU.")
276
+ n_gpu_use = 0
277
+ if n_gpu_use > n_gpu:
278
+ self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
279
+ "on this machine.".format(n_gpu_use, n_gpu))
280
+ n_gpu_use = n_gpu
281
+ device = 'cuda:'#torch.device('cuda:' if n_gpu_use > 0 else 'cpu')
282
+ list_ids = list(range(n_gpu_use))
283
+ return device, list_ids
284
+
285
+ def _save_checkpoint(self, epoch, save_best=False):
286
+ """
287
+ Saving checkpoints
288
+
289
+ :param epoch: current epoch number
290
+ :param log: logging information of the epoch
291
+ :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
292
+ """
293
+ arch = type(self.model1).__name__
294
+
295
+ state = {
296
+ 'arch': arch,
297
+ 'epoch': epoch,
298
+ 'state_dict1': self.model1.state_dict(),
299
+ 'state_dict2': self.model2.state_dict(),
300
+ 'optimizer1': self.optimizer1.state_dict(),
301
+ 'optimizer2': self.optimizer2.state_dict(),
302
+ 'monitor_best': self.mnt_best
303
+ #'config': self.config
304
+ }
305
+ filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
306
+ torch.save(state, filename)
307
+ self.logger.info("Saving checkpoint: {} ...".format(filename))
308
+ if save_best:
309
+ best_path = str(self.checkpoint_dir / 'model_best.pth')
310
+ torch.save(state, best_path)
311
+ self.logger.info("Saving current best: model_best.pth at: {} ...".format(best_path))
312
+
313
+
314
+
315
+ def _resume_checkpoint(self, resume_path):
316
+ """
317
+ Resume from saved checkpoints
318
+
319
+ :param resume_path: Checkpoint path to be resumed
320
+ """
321
+ resume_path = str(resume_path)
322
+ self.logger.info("Loading checkpoint: {} ...".format(resume_path))
323
+ checkpoint = torch.load(resume_path)
324
+ self.start_epoch = checkpoint['epoch'] + 1
325
+ self.mnt_best = checkpoint['monitor_best']
326
+
327
+ # load architecture params from checkpoint.
328
+ if checkpoint['config']['arch'] != self.config['arch1']:
329
+ self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
330
+ "checkpoint. This may yield an exception while state_dict is being loaded.")
331
+ self.model.load_state_dict(checkpoint['state_dict'])
332
+
333
+ # load optimizer state from checkpoint only when optimizer type is not changed.
334
+ if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
335
+ self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
336
+ "Optimizer parameters not being resumed.")
337
+ else:
338
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
339
+
340
+ self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
341
+
ELR_plus/config_cifar10.json ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "cifar10_ELR_plus_PreActResNet18",
3
+ "n_gpu": 1,
4
+ "seed":123,
5
+
6
+ "arch": {
7
+ "args": {"num_classes":10}
8
+ },
9
+
10
+ "arch1": {
11
+ "type": "PreActResNet18",
12
+ "args": {"num_classes":10}
13
+ },
14
+
15
+ "arch2": {
16
+ "type": "PreActResNet18",
17
+ "args": {"num_classes":10}
18
+ },
19
+
20
+ "mixup_alpha": 1,
21
+ "coef_step": 0,
22
+ "num_classes": 10,
23
+ "ema_alpha": 0.997,
24
+ "ema_update": true,
25
+ "ema_step": 40000,
26
+
27
+
28
+ "data_loader": {
29
+ "type": "CIFAR10DataLoader",
30
+ "args":{
31
+ "data_dir": "/dir/to/data",
32
+ "batch_size": 128,
33
+ "batch_size2": 128,
34
+ "num_batches": 0,
35
+ "shuffle": true,
36
+ "validation_split": 0,
37
+ "num_workers": 8,
38
+ "pin_memory": true
39
+ }
40
+ },
41
+
42
+
43
+ "optimizer1": {
44
+ "type": "SGD",
45
+ "args":{
46
+ "lr": 0.02,
47
+ "momentum": 0.9,
48
+ "weight_decay": 5e-4
49
+ }
50
+ },
51
+
52
+ "optimizer2": {
53
+ "type": "SGD",
54
+ "args":{
55
+ "lr": 0.02,
56
+ "momentum": 0.9,
57
+ "weight_decay": 5e-4
58
+ }
59
+ },
60
+
61
+
62
+
63
+ "train_loss": {
64
+ "type": "elr_plus_loss",
65
+ "args":{
66
+ "beta": 0.7,
67
+ "lambda": 3
68
+ }
69
+ },
70
+
71
+ "val_loss": "cross_entropy",
72
+ "metrics": [
73
+ "my_metric", "my_metric2"
74
+ ],
75
+
76
+ "lr_scheduler": {
77
+ "type": "MultiStepLR",
78
+ "args": {
79
+ "milestones": [150],
80
+ "gamma": 0.1
81
+ }
82
+ },
83
+
84
+ "trainer": {
85
+ "epochs": 200,
86
+ "warmup": 0,
87
+ "save_dir": "dir/to/model",
88
+ "save_period": 1,
89
+ "verbosity": 2,
90
+ "label_dir": "saved/",
91
+
92
+ "monitor": "max val_my_metric",
93
+ "early_stop": 2000,
94
+
95
+ "tensorboard": false,
96
+ "mlflow": true,
97
+
98
+ "_percent": "Percentage of noise",
99
+ "percent": 0.8,
100
+ "_begin": "When to begin updating labels",
101
+ "begin": 0,
102
+ "_asym": "symmetric noise if false",
103
+ "asym": false
104
+ }
105
+ }
ELR_plus/config_cifar100.json ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "cifar100_ELR_plus_PreActResNet18",
3
+ "n_gpu": 1,
4
+ "seed":123,
5
+
6
+ "arch": {
7
+ "args": {"num_classes":100}
8
+ },
9
+
10
+ "arch1": {
11
+ "type": "PreActResNet18",
12
+ "args": {"num_classes":100}
13
+ },
14
+
15
+ "arch2": {
16
+ "type": "PreActResNet18",
17
+ "args": {"num_classes":100}
18
+ },
19
+
20
+ "mixup_alpha": 1,
21
+ "coef_step": 40000,
22
+ "num_classes": 100,
23
+ "ema_alpha": 0.997,
24
+ "ema_update": true,
25
+ "ema_step": 40000,
26
+
27
+
28
+ "data_loader": {
29
+ "type": "CIFAR100DataLoader",
30
+ "args":{
31
+ "data_dir": "/gpfs/scratch/sl5924/noisy/data/",
32
+ "batch_size": 128,
33
+ "batch_size2": 128,
34
+ "num_batches": 0,
35
+ "shuffle": true,
36
+ "validation_split": 0,
37
+ "num_workers": 8,
38
+ "pin_memory": true
39
+ }
40
+ },
41
+
42
+ "optimizer1": {
43
+ "type": "SGD",
44
+ "args":{
45
+ "lr": 0.02,
46
+ "momentum": 0.9,
47
+ "weight_decay": 5e-4
48
+ }
49
+ },
50
+
51
+ "optimizer2": {
52
+ "type": "SGD",
53
+ "args":{
54
+ "lr": 0.02,
55
+ "momentum": 0.9,
56
+ "weight_decay": 5e-4
57
+ }
58
+ },
59
+
60
+
61
+
62
+ "train_loss": {
63
+ "type": "elr_plus_loss",
64
+ "args":{
65
+ "beta": 0.9,
66
+ "lambda": 7
67
+ }
68
+ },
69
+
70
+ "val_loss": "cross_entropy",
71
+ "metrics": [
72
+ "my_metric", "my_metric2"
73
+ ],
74
+
75
+ "lr_scheduler": {
76
+ "type": "MultiStepLR",
77
+ "args": {
78
+ "milestones": [200],
79
+ "gamma": 0.1
80
+ }
81
+ },
82
+
83
+ "trainer": {
84
+ "epochs": 250,
85
+ "warmup": 0,
86
+ "save_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/saved/",
87
+ "save_period": 1,
88
+ "verbosity": 2,
89
+ "label_dir": "saved/",
90
+
91
+ "monitor": "max val_my_metric",
92
+ "early_stop": 2000,
93
+
94
+ "tensorboard": false,
95
+ "mlflow": true,
96
+
97
+ "_percent": "Percentage of noise",
98
+ "percent": 0.8,
99
+ "_begin": "When to begin updating labels",
100
+ "begin": 0,
101
+ "_asym": "symmetric noise if false",
102
+ "asym": false
103
+ }
104
+ }
ELR_plus/config_cifar10_asym.json ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "cifar10_ELR_plus_PreActResNet18",
3
+ "n_gpu": 1,
4
+ "seed":123,
5
+
6
+ "arch": {
7
+ "args": {"num_classes":10}
8
+ },
9
+
10
+ "arch1": {
11
+ "type": "PreActResNet18",
12
+ "args": {"num_classes":10}
13
+ },
14
+
15
+ "arch2": {
16
+ "type": "PreActResNet18",
17
+ "args": {"num_classes":10}
18
+ },
19
+
20
+ "mixup_alpha": 1,
21
+ "coef_step": 0,
22
+ "num_classes": 10,
23
+ "ema_alpha": 0.997,
24
+ "ema_update": true,
25
+ "ema_step": 40000,
26
+
27
+
28
+ "data_loader": {
29
+ "type": "CIFAR10DataLoader",
30
+ "args":{
31
+ "data_dir": "dir/to/data",
32
+ "batch_size": 128,
33
+ "batch_size2": 128,
34
+ "num_batches": 0,
35
+ "shuffle": true,
36
+ "validation_split": 0,
37
+ "num_workers": 8,
38
+ "pin_memory": true
39
+ }
40
+ },
41
+
42
+
43
+ "optimizer1": {
44
+ "type": "SGD",
45
+ "args":{
46
+ "lr": 0.02,
47
+ "momentum": 0.9,
48
+ "weight_decay": 5e-4
49
+ }
50
+ },
51
+
52
+ "optimizer2": {
53
+ "type": "SGD",
54
+ "args":{
55
+ "lr": 0.02,
56
+ "momentum": 0.9,
57
+ "weight_decay": 5e-4
58
+ }
59
+ },
60
+
61
+
62
+
63
+ "train_loss": {
64
+ "type": "elr_plus_loss",
65
+ "args":{
66
+ "beta": 0.9,
67
+ "lambda": 1
68
+ }
69
+ },
70
+
71
+ "val_loss": "cross_entropy",
72
+ "metrics": [
73
+ "my_metric", "my_metric2"
74
+ ],
75
+
76
+ "lr_scheduler": {
77
+ "type": "MultiStepLR",
78
+ "args": {
79
+ "milestones": [150],
80
+ "gamma": 0.1
81
+ }
82
+ },
83
+
84
+ "trainer": {
85
+ "epochs": 200,
86
+ "warmup": 0,
87
+ "save_dir": "dir/to/model",
88
+ "save_period": 1,
89
+ "verbosity": 2,
90
+ "label_dir": "saved/",
91
+
92
+ "monitor": "max val_my_metric",
93
+ "early_stop": 2000,
94
+
95
+ "tensorboard": false,
96
+ "mlflow": true,
97
+
98
+ "_percent": "Percentage of noise",
99
+ "percent": 0.4,
100
+ "_begin": "When to begin updating labels",
101
+ "begin": 0,
102
+ "_asym": "symmetric noise if false",
103
+ "asym": true
104
+ }
105
+ }
ELR_plus/config_clothing1m.json ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "clothing1M_ELR_plus_resnet50",
3
+ "n_gpu": 1,
4
+ "seed":123,
5
+
6
+
7
+ "arch1": {
8
+ "type": "resnet50",
9
+ "args": {"num_classes":14}
10
+ },
11
+
12
+ "arch2": {
13
+ "type": "resnet50",
14
+ "args": {"num_classes":14}
15
+ },
16
+
17
+ "mixup_alpha": 1,
18
+ "coef_step": 0,
19
+ "num_classes": 14,
20
+ "ema_alpha": 0.9999,
21
+ "ema_update": false,
22
+ "ema_step": -1,
23
+
24
+
25
+ "data_loader": {
26
+ "type": "Clothing1MDataLoader",
27
+ "args":{
28
+ "data_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/clothing1M/images",
29
+ "batch_size": 64,
30
+ "batch_size2": 64,
31
+ "num_batches": 3000,
32
+ "shuffle": true,
33
+ "validation_split": 0,
34
+ "num_workers": 8,
35
+ "pin_memory": true
36
+ }
37
+ },
38
+
39
+ "optimizer1": {
40
+ "type": "SGD",
41
+ "args":{
42
+ "lr": 0.002,
43
+ "momentum": 0.9,
44
+ "weight_decay": 1e-3
45
+ }
46
+ },
47
+
48
+ "optimizer2": {
49
+ "type": "SGD",
50
+ "args":{
51
+ "lr": 0.002,
52
+ "momentum": 0.9,
53
+ "weight_decay": 1e-3
54
+ }
55
+ },
56
+
57
+
58
+
59
+ "train_loss": {
60
+ "type": "elr_plus_loss",
61
+ "args":{
62
+ "beta": 0.7,
63
+ "lambda": 3
64
+ }
65
+ },
66
+
67
+ "val_loss": "cross_entropy",
68
+ "metrics": [
69
+ "my_metric", "my_metric2"
70
+ ],
71
+
72
+ "lr_scheduler": {
73
+ "type": "MultiStepLR",
74
+ "args": {
75
+ "milestones": [7],
76
+ "gamma": 0.1
77
+ }
78
+ },
79
+
80
+ "trainer": {
81
+ "epochs": 15,
82
+ "warmup": 0,
83
+ "save_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/saved/",
84
+ "save_period": 1,
85
+ "verbosity": 2,
86
+ "label_dir": "saved/",
87
+
88
+ "monitor": "max val_my_metric",
89
+ "early_stop": 2000,
90
+
91
+ "tensorboard": true,
92
+ "mlflow": true,
93
+
94
+ "_percent": "Percentage of noise",
95
+ "percent": 0.8,
96
+ "_begin": "When to begin updating labels",
97
+ "begin": 0,
98
+ "_asym": "symmetric noise if false",
99
+ "asym": false
100
+ }
101
+ }
102
+
ELR_plus/config_webvision.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Webvision_ELR_plus_InceptionResNetV2",
3
+ "n_gpu": 2,
4
+ "seed": 123,
5
+
6
+
7
+ "arch": {
8
+ "args": {"num_classes":50}
9
+ },
10
+
11
+ "arch1": {
12
+ "type": "InceptionResNetV2",
13
+ "args": {"num_classes":50}
14
+ },
15
+
16
+ "arch2": {
17
+ "type": "InceptionResNetV2",
18
+ "args": {"num_classes":50}
19
+ },
20
+
21
+
22
+ "mixup_alpha": 1.5,
23
+ "mixup_ramp": false,
24
+ "num_classes": 50,
25
+ "ema_alpha": 0.997,
26
+ "ema_update": false,
27
+ "ema_step": 40000,
28
+
29
+
30
+
31
+ "data_loader": {
32
+ "type": "WebvisionDataLoader",
33
+ "args":{
34
+ "data_dir": "/dir/to/data",
35
+ "batch_size": 32,
36
+ "batch_size2": 32,
37
+ "shuffle": true,
38
+ "num_batches": 0,
39
+ "validation_split": 0,
40
+ "num_workers": 8,
41
+ "pin_memory": true
42
+ }
43
+ },
44
+
45
+ "optimizer1": {
46
+ "type": "SGD",
47
+ "args":{
48
+ "lr": 0.02,
49
+ "momentum": 0.9,
50
+ "weight_decay": 5e-4
51
+ }
52
+ },
53
+
54
+ "optimizer2": {
55
+ "type": "SGD",
56
+ "args":{
57
+ "lr": 0.02,
58
+ "momentum": 0.9,
59
+ "weight_decay": 5e-4
60
+ }
61
+ },
62
+
63
+
64
+ "train_loss": {
65
+ "type": "elr_plus_loss",
66
+ "args":{
67
+ "beta": 0.7,
68
+ "lambda": 3
69
+ }
70
+ },
71
+ "val_loss": "cross_entropy",
72
+ "metrics": [
73
+ "my_metric", "my_metric2"
74
+ ],
75
+ "lr_scheduler": {
76
+ "type": "MultiStepLR",
77
+ "args": {
78
+ "milestones": [50],
79
+ "gamma": 0.1
80
+ }
81
+ },
82
+ "trainer": {
83
+ "epochs": 100,
84
+ "warmup": 0,
85
+ "save_dir": "/dir/to/data",
86
+ "save_period": 1,
87
+ "verbosity": 2,
88
+ "label_dir": "saved/",
89
+
90
+ "monitor": "max val_my_metric",
91
+ "early_stop": 2000,
92
+
93
+ "tensorboard": false,
94
+ "mlflow": true,
95
+
96
+ "_percent": "Percentage of noise",
97
+ "percent": 0.9,
98
+ "_begin": "When to begin updating labels",
99
+ "begin": 0,
100
+ "_asym": "symmetric noise if false",
101
+ "asym": false
102
+ }
103
+ }
ELR_plus/data_loader/cifar10.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision
6
+ from torch.utils.data.dataset import Subset
7
+ from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import random
11
+ import json
12
+ import os
13
+
14
+
15
+ def get_cifar10(root, cfg_trainer, train=True,
16
+ transform_train=None, transform_val=None,
17
+ download=False, noise_file = ''):
18
+ base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download)
19
+ if train:
20
+ train_idxs, val_idxs = train_val_split(base_dataset.targets)
21
+ train_dataset = CIFAR10_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
22
+ val_dataset = CIFAR10_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
23
+ if cfg_trainer['asym']:
24
+ train_dataset.asymmetric_noise()
25
+ val_dataset.asymmetric_noise()
26
+ else:
27
+ train_dataset.symmetric_noise()
28
+ val_dataset.symmetric_noise()
29
+
30
+ print(f"Train: {len(train_idxs)} Val: {len(val_idxs)}") # Train: 45000 Val: 5000
31
+ else:
32
+ train_dataset = []
33
+ val_dataset = CIFAR10_val(root, cfg_trainer, None, train=train, transform=transform_val)
34
+ print(f"Test: {len(val_dataset)}")
35
+
36
+ return train_dataset, val_dataset
37
+
38
+
39
+ def train_val_split(base_dataset: torchvision.datasets.CIFAR10):
40
+ num_classes = 10
41
+ base_dataset = np.array(base_dataset)
42
+ train_n = int(len(base_dataset) * 0.9 / num_classes)
43
+ train_idxs = []
44
+ val_idxs = []
45
+
46
+ for i in range(num_classes):
47
+ idxs = np.where(base_dataset == i)[0]
48
+ np.random.shuffle(idxs)
49
+ train_idxs.extend(idxs[:train_n])
50
+ val_idxs.extend(idxs[train_n:])
51
+ np.random.shuffle(train_idxs)
52
+ np.random.shuffle(val_idxs)
53
+
54
+ return train_idxs, val_idxs
55
+
56
+
57
+ class CIFAR10_train(torchvision.datasets.CIFAR10):
58
+ def __init__(self, root, cfg_trainer, indexs, train=True,
59
+ transform=None, target_transform=None,
60
+ download=False):
61
+ super(CIFAR10_train, self).__init__(root, train=train,
62
+ transform=transform, target_transform=target_transform,
63
+ download=download)
64
+ self.num_classes = 10
65
+ self.cfg_trainer = cfg_trainer
66
+ self.train_data = self.data[indexs]#self.train_data[indexs]
67
+ self.train_labels = np.array(self.targets)[indexs]#np.array(self.train_labels)[indexs]
68
+ self.indexs = indexs
69
+ self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
70
+ self.noise_indx = []
71
+ #self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
72
+
73
+ def symmetric_noise(self):
74
+ self.train_labels_gt = self.train_labels.copy()
75
+ #np.random.seed(seed=888)
76
+ indices = np.random.permutation(len(self.train_data))
77
+ for i, idx in enumerate(indices):
78
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
79
+ self.noise_indx.append(idx)
80
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
81
+
82
+ def asymmetric_noise(self):
83
+ self.train_labels_gt = self.train_labels.copy()
84
+ for i in range(self.num_classes):
85
+ indices = np.where(self.train_labels == i)[0]
86
+ np.random.shuffle(indices)
87
+ for j, idx in enumerate(indices):
88
+ if j < self.cfg_trainer['percent'] * len(indices):
89
+ self.noise_indx.append(idx)
90
+ # truck -> automobile
91
+ if i == 9:
92
+ self.train_labels[idx] = 1
93
+ # bird -> airplane
94
+ elif i == 2:
95
+ self.train_labels[idx] = 0
96
+ # cat -> dog
97
+ elif i == 3:
98
+ self.train_labels[idx] = 5
99
+ # dog -> cat
100
+ elif i == 5:
101
+ self.train_labels[idx] = 3
102
+ # deer -> horse
103
+ elif i == 4:
104
+ self.train_labels[idx] = 7
105
+
106
+
107
+
108
+
109
+
110
+ def __getitem__(self, index):
111
+ """
112
+ Args:
113
+ index (int): Index
114
+
115
+ Returns:
116
+ tuple: (image, target) where target is index of the target class.
117
+ """
118
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
119
+
120
+
121
+ # doing this so that it is consistent with all other datasets
122
+ # to return a PIL Image
123
+ img = Image.fromarray(img)
124
+
125
+
126
+ if self.transform is not None:
127
+ img = self.transform(img)
128
+
129
+ if self.target_transform is not None:
130
+ target = self.target_transform(target)
131
+
132
+ return img,target, index, target_gt
133
+
134
+ def __len__(self):
135
+ return len(self.train_data)
136
+
137
+
138
+
139
+ class CIFAR10_val(torchvision.datasets.CIFAR10):
140
+
141
+ def __init__(self, root, cfg_trainer, indexs, train=True,
142
+ transform=None, target_transform=None,
143
+ download=False):
144
+ super(CIFAR10_val, self).__init__(root, train=train,
145
+ transform=transform, target_transform=target_transform,
146
+ download=download)
147
+
148
+ # self.train_data = self.data[indexs]
149
+ # self.train_labels = np.array(self.targets)[indexs]
150
+ self.num_classes = 10
151
+ self.cfg_trainer = cfg_trainer
152
+ if train:
153
+ self.train_data = self.data[indexs]
154
+ self.train_labels = np.array(self.targets)[indexs]
155
+ else:
156
+ self.train_data = self.data
157
+ self.train_labels = np.array(self.targets)
158
+ self.train_labels_gt = self.train_labels.copy()
159
+ def symmetric_noise(self):
160
+
161
+ indices = np.random.permutation(len(self.train_data))
162
+ for i, idx in enumerate(indices):
163
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
164
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
165
+
166
+ def asymmetric_noise(self):
167
+ for i in range(self.num_classes):
168
+ indices = np.where(self.train_labels == i)[0]
169
+ np.random.shuffle(indices)
170
+ for j, idx in enumerate(indices):
171
+ if j < self.cfg_trainer['percent'] * len(indices):
172
+ # truck -> automobile
173
+ if i == 9:
174
+ self.train_labels[idx] = 1
175
+ # bird -> airplane
176
+ elif i == 2:
177
+ self.train_labels[idx] = 0
178
+ # cat -> dog
179
+ elif i == 3:
180
+ self.train_labels[idx] = 5
181
+ # dog -> cat
182
+ elif i == 5:
183
+ self.train_labels[idx] = 3
184
+ # deer -> horse
185
+ elif i == 4:
186
+ self.train_labels[idx] = 7
187
+ def __len__(self):
188
+ return len(self.train_data)
189
+
190
+
191
+ def __getitem__(self, index):
192
+ """
193
+ Args:
194
+ index (int): Index
195
+
196
+ Returns:
197
+ tuple: (image, target) where target is index of the target class.
198
+ """
199
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
200
+
201
+
202
+ # doing this so that it is consistent with all other datasets
203
+ # to return a PIL Image
204
+ img = Image.fromarray(img)
205
+
206
+
207
+ if self.transform is not None:
208
+ img = self.transform(img)
209
+
210
+ if self.target_transform is not None:
211
+ target = self.target_transform(target)
212
+
213
+ return img, target, index, target_gt
214
+
ELR_plus/data_loader/cifar100.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision
6
+ from torch.utils.data.dataset import Subset
7
+ from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import random
11
+ from numpy.testing import assert_array_almost_equal
12
+ import os
13
+ import json
14
+
15
+ import warnings
16
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
17
+
18
+ def get_cifar100(root, cfg_trainer, train=True,
19
+ transform_train=None, transform_val=None,
20
+ download=False, noise_file = ''):
21
+ base_dataset = torchvision.datasets.CIFAR100(root, train=train, download=download)
22
+ if train:
23
+ train_idxs, val_idxs = train_val_split(base_dataset.targets)
24
+ train_dataset = CIFAR100_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
25
+ val_dataset = CIFAR100_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
26
+ if cfg_trainer['asym']:
27
+ train_dataset.asymmetric_noise()
28
+ val_dataset.asymmetric_noise()
29
+ else:
30
+ train_dataset.symmetric_noise()
31
+ val_dataset.symmetric_noise()
32
+
33
+ print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}\n") # Train: 45000 Val: 5000
34
+ else:
35
+ train_dataset = []
36
+ val_dataset = CIFAR100_val(root, cfg_trainer, None, train=train, transform=transform_val)
37
+ print(f"Test: {len(val_dataset)}\n")
38
+
39
+
40
+
41
+
42
+ return train_dataset, val_dataset
43
+
44
+
45
+ def train_val_split(base_dataset: torchvision.datasets.CIFAR100):
46
+ num_classes = 100
47
+ base_dataset = np.array(base_dataset)
48
+ train_n = int(len(base_dataset) * 0.9 / num_classes)
49
+ train_idxs = []
50
+ val_idxs = []
51
+
52
+ for i in range(num_classes):
53
+ idxs = np.where(base_dataset == i)[0]
54
+ np.random.shuffle(idxs)
55
+ train_idxs.extend(idxs[:train_n])
56
+ val_idxs.extend(idxs[train_n:])
57
+ np.random.shuffle(train_idxs)
58
+ np.random.shuffle(val_idxs)
59
+
60
+ return train_idxs, val_idxs
61
+
62
+
63
+ class CIFAR100_train(torchvision.datasets.CIFAR100):
64
+ def __init__(self, root, cfg_trainer, indexs, train=True,
65
+ transform=None, target_transform=None,
66
+ download=False):
67
+ super(CIFAR100_train, self).__init__(root, train=train,
68
+ transform=transform, target_transform=target_transform,
69
+ download=download)
70
+ self.num_classes = 100
71
+ self.cfg_trainer = cfg_trainer
72
+ self.train_data = self.data[indexs]
73
+ self.train_labels = np.array(self.targets)[indexs]
74
+ self.indexs = indexs
75
+ self.soft_labels = np.zeros((len(self.train_data), self.num_classes), dtype=np.float32)
76
+ self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
77
+ self.noise_indx = []
78
+ #self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
79
+
80
+ self.count = 0
81
+
82
+ def symmetric_noise(self):
83
+ self.train_labels_gt = self.train_labels.copy()
84
+ np.random.seed(seed=888)
85
+ indices = np.random.permutation(len(self.train_data))
86
+ for i, idx in enumerate(indices):
87
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
88
+ self.noise_indx.append(idx)
89
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
90
+ self.soft_labels[idx][self.train_labels[idx]] = 1.
91
+
92
+ def multiclass_noisify(self, y, P, random_state=0):
93
+ """ Flip classes according to transition probability matrix T.
94
+ It expects a number between 0 and the number of classes - 1.
95
+ """
96
+
97
+ assert P.shape[0] == P.shape[1]
98
+ assert np.max(y) < P.shape[0]
99
+
100
+ # row stochastic matrix
101
+ assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
102
+ assert (P >= 0.0).all()
103
+
104
+ m = y.shape[0]
105
+ new_y = y.copy()
106
+ flipper = np.random.RandomState(random_state)
107
+
108
+ for idx in np.arange(m):
109
+ i = y[idx]
110
+ # draw a vector with only an 1
111
+ flipped = flipper.multinomial(1, P[i, :], 1)[0]
112
+ new_y[idx] = np.where(flipped == 1)[0]
113
+
114
+ return new_y
115
+
116
+ def build_for_cifar100(self, size, noise):
117
+ """ The noise matrix flips to the "next" class with probability 'noise'.
118
+ """
119
+
120
+ assert(noise >= 0.) and (noise <= 1.)
121
+
122
+ P = (1. - noise) * np.eye(size)
123
+ for i in np.arange(size - 1):
124
+ P[i, i + 1] = noise
125
+
126
+ # adjust last row
127
+ P[size - 1, 0] = noise
128
+
129
+ assert_array_almost_equal(P.sum(axis=1), 1, 1)
130
+ return P
131
+
132
+
133
+ def asymmetric_noise(self, asym=False, random_shuffle=False):
134
+ self.train_labels_gt = self.train_labels.copy()
135
+ P = np.eye(self.num_classes)
136
+ n = self.cfg_trainer['percent']
137
+ nb_superclasses = 20
138
+ nb_subclasses = 5
139
+
140
+ if n > 0.0:
141
+ for i in np.arange(nb_superclasses):
142
+ init, end = i * nb_subclasses, (i+1) * nb_subclasses
143
+ P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
144
+
145
+ y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
146
+ random_state=0)
147
+ actual_noise = (y_train_noisy != self.train_labels).mean()
148
+ assert actual_noise > 0.0
149
+ self.train_labels = y_train_noisy
150
+ #np.save(P_file, P)
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+ def __getitem__(self, index):
159
+ """
160
+ Args:
161
+ index (int): Index
162
+
163
+ Returns:
164
+ tuple: (image, target) where target is index of the target class.
165
+ """
166
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
167
+
168
+
169
+ # doing this so that it is consistent with all other datasets
170
+ # to return a PIL Image
171
+
172
+ img = Image.fromarray(img)
173
+ if self.transform is not None:
174
+ img = self.transform(img)
175
+ if self.target_transform is not None:
176
+ target = self.target_transform(target)
177
+ return img, target, index, target_gt
178
+
179
+ def __len__(self):
180
+ return len(self.train_data)
181
+
182
+ def rotate_img(self, img, rot):
183
+ if rot == 0: # 0 degrees rotation
184
+ return img
185
+ elif rot == 90: # 90 degrees rotation
186
+ return np.flipud(np.transpose(img, (1,0,2)))
187
+ elif rot == 180: # 90 degrees rotation
188
+ return np.fliplr(np.flipud(img))
189
+ elif rot == 270: # 270 degrees rotation / or -90
190
+ return np.transpose(np.flipud(img), (1,0,2))
191
+ else:
192
+ raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
193
+
194
+ def __len__(self):
195
+ return len(self.train_data)
196
+
197
+
198
+ class CIFAR100_val(torchvision.datasets.CIFAR100):
199
+
200
+ def __init__(self, root, cfg_trainer, indexs, train=True,
201
+ transform=None, target_transform=None,
202
+ download=False):
203
+ super(CIFAR100_val, self).__init__(root, train=train,
204
+ transform=transform, target_transform=target_transform,
205
+ download=download)
206
+
207
+ # self.train_data = self.data[indexs]
208
+ # self.train_labels = np.array(self.targets)[indexs]
209
+ self.num_classes = 100
210
+ self.cfg_trainer = cfg_trainer
211
+ if train:
212
+ self.train_data = self.data[indexs]
213
+ self.train_labels = np.array(self.targets)[indexs]
214
+ else:
215
+ self.train_data = self.data
216
+ self.train_labels = np.array(self.targets)
217
+ self.train_labels_gt = self.train_labels.copy()
218
+ def symmetric_noise(self):
219
+ indices = np.random.permutation(len(self.train_data))
220
+ for i, idx in enumerate(indices):
221
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
222
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
223
+
224
+ def multiclass_noisify(self, y, P, random_state=0):
225
+ """ Flip classes according to transition probability matrix T.
226
+ It expects a number between 0 and the number of classes - 1.
227
+ """
228
+
229
+ assert P.shape[0] == P.shape[1]
230
+ assert np.max(y) < P.shape[0]
231
+
232
+ # row stochastic matrix
233
+ assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
234
+ assert (P >= 0.0).all()
235
+
236
+ m = y.shape[0]
237
+ new_y = y.copy()
238
+ flipper = np.random.RandomState(random_state)
239
+
240
+ for idx in np.arange(m):
241
+ i = y[idx]
242
+ # draw a vector with only an 1
243
+ flipped = flipper.multinomial(1, P[i, :], 1)[0]
244
+ new_y[idx] = np.where(flipped == 1)[0]
245
+
246
+ return new_y
247
+
248
+ def build_for_cifar100(self, size, noise):
249
+ """ The noise matrix flips to the "next" class with probability 'noise'.
250
+ """
251
+
252
+ assert(noise >= 0.) and (noise <= 1.)
253
+
254
+ P = (1. - noise) * np.eye(size)
255
+ for i in np.arange(size - 1):
256
+ P[i, i + 1] = noise
257
+
258
+ # adjust last row
259
+ P[size - 1, 0] = noise
260
+
261
+ assert_array_almost_equal(P.sum(axis=1), 1, 1)
262
+ return P
263
+
264
+
265
+ def asymmetric_noise(self, asym=False, random_shuffle=False):
266
+ P = np.eye(self.num_classes)
267
+ n = self.cfg_trainer['percent']
268
+ nb_superclasses = 20
269
+ nb_subclasses = 5
270
+
271
+ if n > 0.0:
272
+ for i in np.arange(nb_superclasses):
273
+ init, end = i * nb_subclasses, (i+1) * nb_subclasses
274
+ P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
275
+
276
+ y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
277
+ random_state=0)
278
+ actual_noise = (y_train_noisy != self.train_labels).mean()
279
+ assert actual_noise > 0.0
280
+ self.train_labels = y_train_noisy
281
+ def __len__(self):
282
+ return len(self.train_data)
283
+
284
+
285
+ def __getitem__(self, index):
286
+ """
287
+ Args:
288
+ index (int): Index
289
+
290
+ Returns:
291
+ tuple: (image, target) where target is index of the target class.
292
+ """
293
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
294
+
295
+
296
+ # doing this so that it is consistent with all other datasets
297
+ # to return a PIL Image
298
+ img = Image.fromarray(img)
299
+
300
+
301
+ if self.transform is not None:
302
+ img = self.transform(img)
303
+
304
+ if self.target_transform is not None:
305
+ target = self.target_transform(target)
306
+
307
+ return img, target, index, target_gt
ELR_plus/data_loader/clothing1m.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision
6
+ from torch.utils.data.dataset import Subset
7
+ from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import random
11
+
12
+ def get_clothing(root, cfg_trainer, num_samples=0, train=True,
13
+ transform_train=None, transform_val=None):
14
+
15
+ if train:
16
+ train_dataset = Clothing(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train)
17
+ val_dataset = Clothing(root, cfg_trainer, val=train, transform=transform_val)
18
+ print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}")
19
+
20
+ else:
21
+ train_dataset = []
22
+ val_dataset = Clothing(root, cfg_trainer, test= (not train), transform=transform_val)
23
+ print(f"Test: {len(val_dataset)}")
24
+
25
+ return train_dataset, val_dataset
26
+
27
+ class Clothing(torch.utils.data.Dataset):
28
+
29
+ def __init__(self, root, cfg_trainer, num_samples=0, train=False, val=False, test=False, transform=None, num_class = 14):
30
+ self.cfg_trainer = cfg_trainer
31
+ self.root = root
32
+ self.transform = transform
33
+ self.train_labels = {}
34
+ self.test_labels = {}
35
+ self.val_labels = {}
36
+
37
+ self.train = train
38
+ self.val = val
39
+ self.test = test
40
+
41
+ with open('%s/noisy_label_kv.txt'%self.root,'r') as f:
42
+ lines = f.read().splitlines()
43
+ for l in lines:
44
+ entry = l.split()
45
+ img_path = '%s/'%self.root+entry[0][7:]
46
+ self.train_labels[img_path] = int(entry[1])
47
+ with open('%s/clean_label_kv.txt'%self.root,'r') as f:
48
+ lines = f.read().splitlines()
49
+ for l in lines:
50
+ entry = l.split()
51
+ img_path = '%s/'%self.root+entry[0][7:]
52
+ self.test_labels[img_path] = int(entry[1])
53
+
54
+ if train:
55
+ train_imgs=[]
56
+ with open('%s/noisy_train_key_list.txt'%self.root,'r') as f:
57
+ lines = f.read().splitlines()
58
+ for i , l in enumerate(lines):
59
+ img_path = '%s/'%self.root+l[7:]
60
+ train_imgs.append((i,img_path))
61
+ self.num_raw_example = len(train_imgs)
62
+ random.shuffle(train_imgs)
63
+ class_num = torch.zeros(num_class)
64
+ self.train_imgs = []
65
+ for id_raw, impath in train_imgs:
66
+ label = self.train_labels[impath]
67
+ if class_num[label]<(num_samples/14) and len(self.train_imgs)<num_samples:
68
+ self.train_imgs.append((id_raw,impath))
69
+ class_num[label]+=1
70
+ random.shuffle(self.train_imgs)
71
+
72
+ elif test:
73
+ self.test_imgs = []
74
+ with open('%s/clean_test_key_list.txt'%self.root,'r') as f:
75
+ lines = f.read().splitlines()
76
+ for l in lines:
77
+ img_path = '%s/'%self.root+l[7:]
78
+ self.test_imgs.append(img_path)
79
+ elif val:
80
+ self.val_imgs = []
81
+ with open('%s/clean_val_key_list.txt'%self.root,'r') as f:
82
+ lines = f.read().splitlines()
83
+ for l in lines:
84
+ img_path = '%s/'%self.root+l[7:]
85
+ self.val_imgs.append(img_path)
86
+
87
+
88
+
89
+ def __getitem__(self, index):
90
+ if self.train:
91
+ id_raw, img_path = self.train_imgs[index]
92
+ target = self.train_labels[img_path]
93
+ elif self.val:
94
+ img_path = self.val_imgs[index]
95
+ target = self.test_labels[img_path]
96
+ elif self.test:
97
+ img_path = self.test_imgs[index]
98
+ target = self.test_labels[img_path]
99
+ image = Image.open(img_path).convert('RGB')
100
+ if self.train:
101
+ img0 = self.transform(image)
102
+
103
+ if self.test or self.val:
104
+ img = self.transform(image)
105
+ return img, target, index, target
106
+ else:
107
+ return img0, target, id_raw, target
108
+
109
+
110
+
111
+ def __len__(self):
112
+ if self.test:
113
+ return len(self.test_imgs)
114
+ if self.val:
115
+ return len(self.val_imgs)
116
+ else:
117
+ return len(self.train_imgs)
118
+
119
+
120
+ def flist_reader(self, flist):
121
+ imlist = []
122
+ with open(flist, 'r') as rf:
123
+ for line in rf.readlines():
124
+ row = line.split(" ")
125
+ impath = self.root + row[0]
126
+ imlabel = float(row[1].replace('\n',''))
127
+ imlist.append((impath, int(imlabel)))
128
+ return imlist
ELR_plus/data_loader/data_loaders.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from torchvision import datasets, transforms
4
+ from base import BaseDataLoader
5
+ from data_loader.cifar10 import get_cifar10
6
+ from data_loader.cifar100 import get_cifar100
7
+ from data_loader.clothing1m import get_clothing
8
+ from data_loader.webvision import get_webvision
9
+ from parse_config import ConfigParser
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from PIL import Image
14
+
15
+
16
+
17
+ class CIFAR10DataLoader(BaseDataLoader):
18
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True):
19
+ config = ConfigParser.get_instance()
20
+ cfg_trainer = config['trainer']
21
+
22
+ transform_train = transforms.Compose([
23
+ transforms.RandomCrop(32, padding=4),
24
+ transforms.RandomHorizontalFlip(),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
27
+ ])
28
+ transform_val = transforms.Compose([
29
+ transforms.ToTensor(),
30
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
31
+ ])
32
+ self.data_dir = data_dir
33
+
34
+ noise_file='%sCIFAR10_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
35
+
36
+ self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
37
+ transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
38
+
39
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
40
+ val_dataset = self.val_dataset)
41
+
42
+
43
+ class CIFAR100DataLoader(BaseDataLoader):
44
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True,num_workers=4, pin_memory=True):
45
+ config = ConfigParser.get_instance()
46
+ cfg_trainer = config['trainer']
47
+
48
+ transform_train = transforms.Compose([
49
+ #transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
50
+ transforms.RandomCrop(32, padding=4),
51
+ transforms.RandomHorizontalFlip(),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
54
+ ])
55
+ transform_val = transforms.Compose([
56
+ transforms.ToTensor(),
57
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
58
+ ])
59
+ self.data_dir = data_dir
60
+
61
+ noise_file='%sCIFAR100_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
62
+
63
+ self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
64
+ transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
65
+
66
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
67
+ val_dataset = self.val_dataset)
68
+
69
+ self.batch_size_ = int(batch_size)
70
+
71
+
72
+ class Clothing1MDataLoader(BaseDataLoader):
73
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True):
74
+
75
+ self.batch_size = batch_size
76
+ self.num_workers = num_workers
77
+ self.num_batches = num_batches
78
+ self.training = training
79
+
80
+ self.transform_train = transforms.Compose([
81
+ transforms.Resize(256),
82
+ transforms.RandomCrop(224),
83
+ transforms.RandomHorizontalFlip(),
84
+ transforms.ToTensor(),
85
+ transforms.Normalize((0.6959, 0.6537, 0.6371),(0.3113, 0.3192, 0.3214)),
86
+ ])
87
+ self.transform_val = transforms.Compose([
88
+ transforms.Resize(256),
89
+ transforms.CenterCrop(224),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize((0.6959, 0.6537, 0.6371),(0.3113, 0.3192, 0.3214)),
92
+ ])
93
+
94
+ self.data_dir = data_dir
95
+ config = ConfigParser.get_instance()
96
+ cfg_trainer = config['trainer']
97
+ self.train_dataset, self.val_dataset = get_clothing(config['data_loader']['args']['data_dir'], cfg_trainer, num_samples=self.num_batches*self.batch_size, train=training,
98
+ transform_train=self.transform_train, transform_val=self.transform_val)
99
+
100
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
101
+ val_dataset = self.val_dataset)
102
+
103
+ class WebvisionDataLoader(BaseDataLoader):
104
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True, num_class = 50):
105
+
106
+ self.batch_size = batch_size
107
+ self.num_workers = num_workers
108
+ self.num_batches = num_batches
109
+ self.training = training
110
+
111
+ self.transform_train = transforms.Compose([
112
+ transforms.RandomCrop(227),
113
+ transforms.RandomHorizontalFlip(),
114
+ transforms.ToTensor(),
115
+ transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)),
116
+ ])
117
+ self.transform_val = transforms.Compose([
118
+ transforms.CenterCrop(227),
119
+ transforms.ToTensor(),
120
+ transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)),
121
+ ])
122
+ self.transform_imagenet = transforms.Compose([
123
+ transforms.Resize(256),
124
+ transforms.CenterCrop(227),
125
+ transforms.ToTensor(),
126
+ transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)),
127
+ ])
128
+
129
+ self.data_dir = data_dir
130
+ config = ConfigParser.get_instance()
131
+ cfg_trainer = config['trainer']
132
+ self.train_dataset, self.val_dataset = get_webvision(config['data_loader']['args']['data_dir'], cfg_trainer, num_samples=self.num_batches*self.batch_size, train=training,
133
+ transform_train=self.transform_train, transform_val=self.transform_val, num_class = num_class)
134
+
135
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
136
+ val_dataset = self.val_dataset)
137
+
ELR_plus/data_loader/webvision.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision
6
+ from torch.utils.data.dataset import Subset
7
+ from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import random
11
+
12
+ def get_webvision(root, cfg_trainer, num_samples=0, train=True,
13
+ transform_train=None, transform_val=None, num_class = 50):
14
+
15
+ if train:
16
+ train_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train, num_class = num_class)
17
+ val_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, val=train, transform=transform_val, num_class = num_class)
18
+ print(f"Train: {len(train_dataset)} WebVision Val: {len(val_dataset)}")
19
+
20
+ else:
21
+ train_dataset = []
22
+ val_dataset = ImagenetVal(root, transform=transform_val, num_class = num_class)
23
+ print(f"Imagnet Val: {len(val_dataset)}")
24
+
25
+ return train_dataset, val_dataset
26
+
27
+
28
+
29
+ class ImagenetVal(torch.utils.data.Dataset):
30
+ def __init__(self, root, transform, num_class):
31
+ self.root = root+'imagenet/'
32
+ self.transform = transform
33
+
34
+
35
+ with open(self.root+'imagenet_val.txt') as f:
36
+ lines=f.readlines()
37
+ self.val_imgs = []
38
+ self.val_labels = {}
39
+ for line in lines:
40
+ img, target = line.split()
41
+ target = int(target)
42
+ if target<num_class:
43
+ self.val_imgs.append(img)
44
+ self.val_labels[img]=target
45
+
46
+ def __getitem__(self, index):
47
+
48
+ img_path = self.val_imgs[index]
49
+ target = self.val_labels[img_path]
50
+ image = Image.open(self.root+'val/'+img_path).convert('RGB')
51
+ img = self.transform(image)
52
+
53
+ return img, target, index, target
54
+
55
+
56
+ def __len__(self):
57
+ return len(self.val_imgs)
58
+
59
+
60
+ class Webvision(torch.utils.data.Dataset):
61
+
62
+ def __init__(self, root, cfg_trainer, num_samples=0, train=False, val=False, test=False, transform=None, num_class = 50):
63
+ self.cfg_trainer = cfg_trainer
64
+ self.root = root
65
+ self.transform = transform
66
+ self.train_labels = {}
67
+ self.test_labels = {}
68
+ self.val_labels = {}
69
+
70
+ self.train = train
71
+ self.val = val
72
+ self.test = test
73
+
74
+ if self.val:
75
+ with open(self.root+'info/val_filelist.txt') as f:
76
+ lines=f.readlines()
77
+ self.val_imgs = []
78
+ self.val_labels = {}
79
+ for line in lines:
80
+ img, target = line.split()
81
+ target = int(target)
82
+ if target<num_class:
83
+ self.val_imgs.append(img)
84
+ self.val_labels[img]=target
85
+ elif self.test:
86
+ with open(self.root+'info/val_filelist.txt') as f:
87
+ lines=f.readlines()
88
+ self.test_imgs = []
89
+ self.test_labels = {}
90
+ for line in lines:
91
+ img, target = line.split()
92
+ target = int(target)
93
+ if target<num_class:
94
+ self.test_imgs.append(img)
95
+ self.test_labels[img]=target
96
+ else:
97
+ with open(self.root+'info/train_filelist_google.txt') as f:
98
+ lines=f.readlines()
99
+ train_imgs = []
100
+ self.train_labels = {}
101
+ for line in lines:
102
+ img, target = line.split()
103
+ target = int(target)
104
+ if target<num_class:
105
+ train_imgs.append(img)
106
+ self.train_labels[img]=target
107
+
108
+ self.train_imgs = train_imgs
109
+
110
+ def __getitem__(self, index):
111
+
112
+ if self.train:
113
+ img_path = self.train_imgs[index]
114
+ target = self.train_labels[img_path]
115
+ image = Image.open(self.root+img_path)
116
+ img0 = image.convert('RGB')
117
+ img0 = self.transform(img0)
118
+ return img0, target, index, target
119
+ elif self.val:
120
+ img_path = self.val_imgs[index]
121
+ target = self.val_labels[img_path]
122
+ image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB')
123
+ img = self.transform(image)
124
+ return img, target, index, target
125
+ elif self.test:
126
+ img_path = self.test_imgs[index]
127
+ target = self.test_labels[img_path]
128
+ image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB')
129
+ img = self.transform(image)
130
+ return img, target, index, target
131
+
132
+
133
+
134
+ def __len__(self):
135
+ if self.test:
136
+ return len(self.test_imgs)
137
+ if self.val:
138
+ return len(self.val_imgs)
139
+ else:
140
+ return len(self.train_imgs)
ELR_plus/logger/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .logger import *
2
+ from .visualization import *
ELR_plus/logger/logger.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.config
3
+ from pathlib import Path
4
+ from utils import read_json
5
+
6
+
7
+ def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8
+ """
9
+ Setup logging configuration
10
+ """
11
+ log_config = Path(log_config)
12
+ if log_config.is_file():
13
+ config = read_json(log_config)
14
+ # modify logging paths based on run config
15
+ for _, handler in config['handlers'].items():
16
+ if 'filename' in handler:
17
+ handler['filename'] = str(save_dir / handler['filename'])
18
+
19
+ logging.config.dictConfig(config)
20
+ else:
21
+ print("Warning: logging configuration file is not found in {}.".format(log_config))
22
+ logging.basicConfig(level=default_level)
ELR_plus/logger/logger_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ {
3
+ "version": 1,
4
+ "disable_existing_loggers": false,
5
+ "formatters": {
6
+ "simple": {"format": "%(message)s"},
7
+ "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8
+ },
9
+ "handlers": {
10
+ "console": {
11
+ "class": "logging.StreamHandler",
12
+ "level": "DEBUG",
13
+ "formatter": "simple",
14
+ "stream": "ext://sys.stdout"
15
+ },
16
+ "info_file_handler": {
17
+ "class": "logging.handlers.RotatingFileHandler",
18
+ "level": "INFO",
19
+ "formatter": "datetime",
20
+ "filename": "info.log",
21
+ "maxBytes": 10485760,
22
+ "backupCount": 20, "encoding": "utf8"
23
+ }
24
+ },
25
+ "root": {
26
+ "level": "INFO",
27
+ "handlers": [
28
+ "console",
29
+ "info_file_handler"
30
+ ]
31
+ }
32
+ }
ELR_plus/logger/visualization.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from utils import Timer
3
+
4
+
5
+ class MLFlow:
6
+ def __init__(self, log_dir, logger, enabled):
7
+ self.mlflow = None
8
+
9
+ if enabled:
10
+ log_dir = str(log_dir)
11
+
12
+ # Retrieve visualization writer.
13
+ try:
14
+ self.mlflow = importlib.import_module("mlflow")
15
+ succeeded = True
16
+ except ImportError:
17
+ succeeded = False
18
+
19
+ if not succeeded:
20
+ message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \
21
+ "this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \
22
+ "the 'config.json' file."
23
+ logger.warning(message)
24
+
25
+ self.step = 0
26
+ self.mode = ''
27
+
28
+ self.mlflow_ftns_with_tag_and_value = {
29
+ 'log_param', 'log_metric'
30
+ }
31
+ self.mlflow_ftns = {
32
+ 'start_run'
33
+ }
34
+ # self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
35
+
36
+ # self.timer = Timer()
37
+
38
+ # def set_step(self, step, mode='train'):
39
+ # self.mode = mode
40
+ # self.step = step
41
+ # if step == 0:
42
+ # self.timer.reset()
43
+ # else:
44
+ # duration = self.timer.check()
45
+ # self.add_scalar('steps_per_sec', 1 / duration)
46
+
47
+ def __getattr__(self, name):
48
+ """
49
+ If visualization is configured to use:
50
+ return add_data() methods of tensorboard with additional information (step, tag) added.
51
+ Otherwise:
52
+ return a blank function handle that does nothing
53
+ """
54
+ if name in self.mlflow_ftns_with_tag_and_value:
55
+ add_data = getattr(self.mlflow, name, None)
56
+
57
+ def wrapper(tag, data, *args, **kwargs):
58
+ if add_data is not None:
59
+ # add mode(train/valid) tag
60
+ if name not in self.tag_mode_exceptions:
61
+ tag = '{}/{}'.format(tag, self.mode)
62
+ add_data(tag, data, *args, **kwargs)
63
+
64
+ return wrapper
65
+ elif name in self.mlflow_ftns:
66
+ add_data = getattr(self.mlflow, name, None)
67
+
68
+ def wrapper(*args, **kwargs):
69
+ if add_data is not None:
70
+ # add mode(train/valid) tag
71
+ # if name not in self.tag_mode_exceptions:
72
+ # tag = '{}/{}'.format(tag, self.mode)
73
+ add_data(*args, **kwargs)
74
+
75
+ return wrapper
76
+ else:
77
+ # default action for returning methods defined in this class, set_step() for instance.
78
+ try:
79
+ attr = object.__getattr__(name)
80
+ except AttributeError:
81
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
82
+ return attr
83
+
84
+
85
+ class TensorboardWriter:
86
+ def __init__(self, log_dir, logger, enabled):
87
+ self.writer = None
88
+ self.selected_module = ""
89
+
90
+ if enabled:
91
+ log_dir = str(log_dir)
92
+
93
+ # Retrieve vizualization writer.
94
+ succeeded = False
95
+ for module in ["torch.utils.tensorboard", "tensorboardX"]:
96
+ try:
97
+ self.writer = importlib.import_module(module).SummaryWriter(log_dir)
98
+ succeeded = True
99
+ break
100
+ except ImportError:
101
+ succeeded = False
102
+ self.selected_module = module
103
+
104
+ if not succeeded:
105
+ message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
106
+ "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
107
+ "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
108
+ "the 'config.json' file."
109
+ logger.warning(message)
110
+
111
+ self.step = 0
112
+ self.mode = ''
113
+
114
+ self.tb_writer_ftns = {
115
+ 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
116
+ 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
117
+ }
118
+ self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
119
+
120
+ self.timer = Timer()
121
+
122
+ def set_step(self, step, mode='train'):
123
+ self.mode = mode
124
+ self.step = step
125
+ if step == 0:
126
+ self.timer.reset()
127
+ else:
128
+ duration = self.timer.check()
129
+ self.add_scalar('steps_per_sec', 1 / duration)
130
+
131
+ def __getattr__(self, name):
132
+ """
133
+ If visualization is configured to use:
134
+ return add_data() methods of tensorboard with additional information (step, tag) added.
135
+ Otherwise:
136
+ return a blank function handle that does nothing
137
+ """
138
+ if name in self.tb_writer_ftns:
139
+ add_data = getattr(self.writer, name, None)
140
+
141
+ def wrapper(tag, data, *args, **kwargs):
142
+ if add_data is not None:
143
+ # add mode(train/valid) tag
144
+ if name not in self.tag_mode_exceptions:
145
+ tag = '{}/{}'.format(tag, self.mode)
146
+ add_data(tag, data, self.step, *args, **kwargs)
147
+ return wrapper
148
+ else:
149
+ # default action for returning methods defined in this class, set_step() for instance.
150
+ try:
151
+ attr = object.__getattr__(name)
152
+ except AttributeError:
153
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
154
+ return attr
ELR_plus/model/InceptionResNetV2.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division, absolute_import
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ import sys
6
+
7
+
8
+ class BasicConv2d(nn.Module):
9
+
10
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
11
+ super(BasicConv2d, self).__init__()
12
+ self.conv = nn.Conv2d(in_planes, out_planes,
13
+ kernel_size=kernel_size, stride=stride,
14
+ padding=padding, bias=False) # verify bias false
15
+ self.bn = nn.BatchNorm2d(out_planes,
16
+ eps=0.001, # value found in tensorflow
17
+ momentum=0.1, # default pytorch value
18
+ affine=True)
19
+ self.relu = nn.ReLU(inplace=False)
20
+
21
+ def forward(self, x):
22
+ x = self.conv(x)
23
+ x = self.bn(x)
24
+ x = self.relu(x)
25
+ return x
26
+
27
+
28
+ class Mixed_5b(nn.Module):
29
+
30
+ def __init__(self):
31
+ super(Mixed_5b, self).__init__()
32
+
33
+ self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
34
+
35
+ self.branch1 = nn.Sequential(
36
+ BasicConv2d(192, 48, kernel_size=1, stride=1),
37
+ BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
38
+ )
39
+
40
+ self.branch2 = nn.Sequential(
41
+ BasicConv2d(192, 64, kernel_size=1, stride=1),
42
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
43
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
44
+ )
45
+
46
+ self.branch3 = nn.Sequential(
47
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
48
+ BasicConv2d(192, 64, kernel_size=1, stride=1)
49
+ )
50
+
51
+ def forward(self, x):
52
+ x0 = self.branch0(x)
53
+ x1 = self.branch1(x)
54
+ x2 = self.branch2(x)
55
+ x3 = self.branch3(x)
56
+ out = torch.cat((x0, x1, x2, x3), 1)
57
+ return out
58
+
59
+
60
+ class Block35(nn.Module):
61
+
62
+ def __init__(self, scale=1.0):
63
+ super(Block35, self).__init__()
64
+
65
+ self.scale = scale
66
+
67
+ self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
68
+
69
+ self.branch1 = nn.Sequential(
70
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
71
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
72
+ )
73
+
74
+ self.branch2 = nn.Sequential(
75
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
76
+ BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
77
+ BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
78
+ )
79
+
80
+ self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
81
+ self.relu = nn.ReLU(inplace=False)
82
+
83
+ def forward(self, x):
84
+ x0 = self.branch0(x)
85
+ x1 = self.branch1(x)
86
+ x2 = self.branch2(x)
87
+ out = torch.cat((x0, x1, x2), 1)
88
+ out = self.conv2d(out)
89
+ out = out * self.scale + x
90
+ out = self.relu(out)
91
+ return out
92
+
93
+
94
+ class Mixed_6a(nn.Module):
95
+
96
+ def __init__(self):
97
+ super(Mixed_6a, self).__init__()
98
+
99
+ self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
100
+
101
+ self.branch1 = nn.Sequential(
102
+ BasicConv2d(320, 256, kernel_size=1, stride=1),
103
+ BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
104
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
105
+ )
106
+
107
+ self.branch2 = nn.MaxPool2d(3, stride=2)
108
+
109
+ def forward(self, x):
110
+ x0 = self.branch0(x)
111
+ x1 = self.branch1(x)
112
+ x2 = self.branch2(x)
113
+ out = torch.cat((x0, x1, x2), 1)
114
+ return out
115
+
116
+
117
+ class Block17(nn.Module):
118
+
119
+ def __init__(self, scale=1.0):
120
+ super(Block17, self).__init__()
121
+
122
+ self.scale = scale
123
+
124
+ self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
125
+
126
+ self.branch1 = nn.Sequential(
127
+ BasicConv2d(1088, 128, kernel_size=1, stride=1),
128
+ BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
129
+ BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0))
130
+ )
131
+
132
+ self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
133
+ self.relu = nn.ReLU(inplace=False)
134
+
135
+ def forward(self, x):
136
+ x0 = self.branch0(x)
137
+ x1 = self.branch1(x)
138
+ out = torch.cat((x0, x1), 1)
139
+ out = self.conv2d(out)
140
+ out = out * self.scale + x
141
+ out = self.relu(out)
142
+ return out
143
+
144
+
145
+ class Mixed_7a(nn.Module):
146
+
147
+ def __init__(self):
148
+ super(Mixed_7a, self).__init__()
149
+
150
+ self.branch0 = nn.Sequential(
151
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
152
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
153
+ )
154
+
155
+ self.branch1 = nn.Sequential(
156
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
157
+ BasicConv2d(256, 288, kernel_size=3, stride=2)
158
+ )
159
+
160
+ self.branch2 = nn.Sequential(
161
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
162
+ BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
163
+ BasicConv2d(288, 320, kernel_size=3, stride=2)
164
+ )
165
+
166
+ self.branch3 = nn.MaxPool2d(3, stride=2)
167
+
168
+ def forward(self, x):
169
+ x0 = self.branch0(x)
170
+ x1 = self.branch1(x)
171
+ x2 = self.branch2(x)
172
+ x3 = self.branch3(x)
173
+ out = torch.cat((x0, x1, x2, x3), 1)
174
+ return out
175
+
176
+
177
+ class Block8(nn.Module):
178
+
179
+ def __init__(self, scale=1.0, noReLU=False):
180
+ super(Block8, self).__init__()
181
+
182
+ self.scale = scale
183
+ self.noReLU = noReLU
184
+
185
+ self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
186
+
187
+ self.branch1 = nn.Sequential(
188
+ BasicConv2d(2080, 192, kernel_size=1, stride=1),
189
+ BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
190
+ BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0))
191
+ )
192
+
193
+ self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
194
+ if not self.noReLU:
195
+ self.relu = nn.ReLU(inplace=False)
196
+
197
+ def forward(self, x):
198
+ x0 = self.branch0(x)
199
+ x1 = self.branch1(x)
200
+ out = torch.cat((x0, x1), 1)
201
+ out = self.conv2d(out)
202
+ out = out * self.scale + x
203
+ if not self.noReLU:
204
+ out = self.relu(out)
205
+ return out
206
+
207
+
208
+ class InceptionResNetV2(nn.Module):
209
+
210
+ def __init__(self, num_classes=50):
211
+ super(InceptionResNetV2, self).__init__()
212
+ # Special attributs
213
+ self.input_space = None
214
+ self.input_size = (299, 299, 3)
215
+ self.mean = None
216
+ self.std = None
217
+ # Modules
218
+ self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
219
+ self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
220
+ self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
221
+ self.maxpool_3a = nn.MaxPool2d(3, stride=2)
222
+ self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
223
+ self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
224
+ self.maxpool_5a = nn.MaxPool2d(3, stride=2)
225
+ self.mixed_5b = Mixed_5b()
226
+ self.repeat = nn.Sequential(
227
+ Block35(scale=0.17),
228
+ Block35(scale=0.17),
229
+ Block35(scale=0.17),
230
+ Block35(scale=0.17),
231
+ Block35(scale=0.17),
232
+ Block35(scale=0.17),
233
+ Block35(scale=0.17),
234
+ Block35(scale=0.17),
235
+ Block35(scale=0.17),
236
+ Block35(scale=0.17)
237
+ )
238
+ self.mixed_6a = Mixed_6a()
239
+ self.repeat_1 = nn.Sequential(
240
+ Block17(scale=0.10),
241
+ Block17(scale=0.10),
242
+ Block17(scale=0.10),
243
+ Block17(scale=0.10),
244
+ Block17(scale=0.10),
245
+ Block17(scale=0.10),
246
+ Block17(scale=0.10),
247
+ Block17(scale=0.10),
248
+ Block17(scale=0.10),
249
+ Block17(scale=0.10),
250
+ Block17(scale=0.10),
251
+ Block17(scale=0.10),
252
+ Block17(scale=0.10),
253
+ Block17(scale=0.10),
254
+ Block17(scale=0.10),
255
+ Block17(scale=0.10),
256
+ Block17(scale=0.10),
257
+ Block17(scale=0.10),
258
+ Block17(scale=0.10),
259
+ Block17(scale=0.10)
260
+ )
261
+ self.mixed_7a = Mixed_7a()
262
+ self.repeat_2 = nn.Sequential(
263
+ Block8(scale=0.20),
264
+ Block8(scale=0.20),
265
+ Block8(scale=0.20),
266
+ Block8(scale=0.20),
267
+ Block8(scale=0.20),
268
+ Block8(scale=0.20),
269
+ Block8(scale=0.20),
270
+ Block8(scale=0.20),
271
+ Block8(scale=0.20)
272
+ )
273
+ self.block8 = Block8(noReLU=True)
274
+ self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
275
+ self.avgpool_1a = nn.AdaptiveAvgPool2d((1, 1))#nn.AvgPool2d(8, count_include_pad=False)
276
+ self.last_linear = nn.Linear(1536, num_classes)
277
+
278
+
279
+ def features(self, input):
280
+ x = self.conv2d_1a(input)
281
+ x = self.conv2d_2a(x)
282
+ x = self.conv2d_2b(x)
283
+ x = self.maxpool_3a(x)
284
+ x = self.conv2d_3b(x)
285
+ x = self.conv2d_4a(x)
286
+ x = self.maxpool_5a(x)
287
+ x = self.mixed_5b(x)
288
+ x = self.repeat(x)
289
+ x = self.mixed_6a(x)
290
+ x = self.repeat_1(x)
291
+ x = self.mixed_7a(x)
292
+ x = self.repeat_2(x)
293
+ x = self.block8(x)
294
+ x = self.conv2d_7b(x)
295
+ return x
296
+
297
+ def logits(self, features):
298
+ x = self.avgpool_1a(features)
299
+ x = x.view(x.size(0), -1)
300
+ out = self.last_linear(x)
301
+ return out
302
+
303
+
304
+ def forward(self, input):
305
+ x = self.features(input)
306
+ out = self.logits(x)
307
+ return out
308
+
309
+
310
+ def test():
311
+ net = InceptionResNetV2().cuda()
312
+ y = net(torch.randn(1,3,227,227).cuda())
313
+ print(y.size())
314
+ #test()