Spaces:
Configuration error
Configuration error
Upload 69 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- ELR/README.md +44 -0
- ELR/base/__init__.py +3 -0
- ELR/base/base_data_loader.py +83 -0
- ELR/base/base_model.py +25 -0
- ELR/base/base_trainer.py +195 -0
- ELR/config_cifar10.json +75 -0
- ELR/config_cifar100.json +75 -0
- ELR/config_cifar10_asym.json +75 -0
- ELR/data_loader/__pycache__/cifar10.cpython-36.pyc +0 -0
- ELR/data_loader/__pycache__/clothing1m.cpython-36.pyc +0 -0
- ELR/data_loader/__pycache__/data_loaders.cpython-36.pyc +0 -0
- ELR/data_loader/cifar10.py +212 -0
- ELR/data_loader/cifar100.py +317 -0
- ELR/data_loader/data_loaders.py +70 -0
- ELR/logger/__init__.py +2 -0
- ELR/logger/logger.py +22 -0
- ELR/logger/logger_config.json +32 -0
- ELR/logger/visualization.py +154 -0
- ELR/model/ResNet_Zoo.py +133 -0
- ELR/model/loss.py +30 -0
- ELR/model/metric.py +20 -0
- ELR/model/model.py +13 -0
- ELR/parse_config.py +146 -0
- ELR/test.py +82 -0
- ELR/train.py +125 -0
- ELR/trainer/__init__.py +1 -0
- ELR/trainer/trainer.py +278 -0
- ELR/utils/__init__.py +1 -0
- ELR/utils/util.py +75 -0
- ELR_plus/README.md +27 -0
- ELR_plus/base/__init__.py +3 -0
- ELR_plus/base/base_data_loader.py +83 -0
- ELR_plus/base/base_model.py +25 -0
- ELR_plus/base/base_trainer.py +341 -0
- ELR_plus/config_cifar10.json +105 -0
- ELR_plus/config_cifar100.json +104 -0
- ELR_plus/config_cifar10_asym.json +105 -0
- ELR_plus/config_clothing1m.json +102 -0
- ELR_plus/config_webvision.json +103 -0
- ELR_plus/data_loader/cifar10.py +214 -0
- ELR_plus/data_loader/cifar100.py +307 -0
- ELR_plus/data_loader/clothing1m.py +128 -0
- ELR_plus/data_loader/data_loaders.py +137 -0
- ELR_plus/data_loader/webvision.py +140 -0
- ELR_plus/logger/__init__.py +2 -0
- ELR_plus/logger/logger.py +22 -0
- ELR_plus/logger/logger_config.json +32 -0
- ELR_plus/logger/visualization.py +154 -0
- ELR_plus/model/InceptionResNetV2.py +314 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/clean_label_simplexheatmap2.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
images/false_label_simplexheatmap.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
images/illustration_of_ELR.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
images/simplexheatmap.gif filter=lfs diff=lfs merge=lfs -text
|
ELR/README.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ELR
|
2 |
+
This is an official PyTorch implementation of ELR method proposed in [Early-Learning Regularization Prevents Memorization of Noisy Labels](https://arxiv.org/abs/2007.00151).
|
3 |
+
|
4 |
+
|
5 |
+
## Usage
|
6 |
+
Train the network on the Symmmetric Noise CIFAR-10 dataset (noise rate = 0.8):
|
7 |
+
|
8 |
+
```
|
9 |
+
python train.py -c config_cifar10.json --percent 0.8
|
10 |
+
```
|
11 |
+
Train the network on the Asymmmetric Noise CIFAR-10 dataset (noise rate = 0.4):
|
12 |
+
|
13 |
+
```
|
14 |
+
python train.py -c config_cifar10_asym.json --percent 0.4 --asym 1
|
15 |
+
```
|
16 |
+
|
17 |
+
Train the network on the Asymmmetric Noise CIFAR-100 dataset (noise rate = 0.4):
|
18 |
+
|
19 |
+
```
|
20 |
+
python train.py -c config_cifar100.json --percent 0.4 --asym 1
|
21 |
+
```
|
22 |
+
|
23 |
+
The config files can be modified to adjust hyperparameters and optimization settings.
|
24 |
+
|
25 |
+
## Results
|
26 |
+
### CIFAR10
|
27 |
+
<center>
|
28 |
+
|
29 |
+
| Method | 20% | 40% | 60% | 80% | 40% Asym |
|
30 |
+
| ---------------------- | ----------- | ----------- | ----------- | ----------- | ----------- |
|
31 |
+
| ELR | 91.16% | 89.15% | 86.12% | 73.86% | 90.12% |
|
32 |
+
| ELR (cosine annealing) | 91.12% | 91.43% | 88.87% | 80.69% | 90.35% |
|
33 |
+
|
34 |
+
### CIAFAR100
|
35 |
+
|
36 |
+
| Method | 20% | 40% | 60% | 80% | 40% Asym |
|
37 |
+
| ---------------------- | ----------- | ----------- | ----------- | ----------- | ----------- |
|
38 |
+
| ELR | 74.21% | 68.28% | 59.28% | 29.78% | 73.71% |
|
39 |
+
| ELR (cosine annealing) | 74.68% | 68.43% | 60.05% | 30.27% | 73.96% |
|
40 |
+
|
41 |
+
</center>
|
42 |
+
|
43 |
+
## References
|
44 |
+
- S. Liu, J. Niles-Weed, N. Razavian and C. Fernandez-Granda "Early-Learning Regularization Prevents Memorization of Noisy Labels", 2020
|
ELR/base/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .base_data_loader import *
|
2 |
+
from .base_model import *
|
3 |
+
from .base_trainer import *
|
ELR/base/base_data_loader.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union, Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torch.utils.data.dataloader import default_collate
|
6 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
7 |
+
|
8 |
+
|
9 |
+
class BaseDataLoader(DataLoader):
|
10 |
+
"""
|
11 |
+
Base class for all data loaders
|
12 |
+
"""
|
13 |
+
valid_sampler: Optional[SubsetRandomSampler]
|
14 |
+
sampler: Optional[SubsetRandomSampler]
|
15 |
+
|
16 |
+
def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory,
|
17 |
+
collate_fn=default_collate, val_dataset=None):
|
18 |
+
self.collate_fn = collate_fn
|
19 |
+
self.validation_split = validation_split
|
20 |
+
self.shuffle = shuffle
|
21 |
+
self.val_dataset = val_dataset
|
22 |
+
|
23 |
+
self.batch_idx = 0
|
24 |
+
self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset)
|
25 |
+
self.init_kwargs = {
|
26 |
+
'dataset': train_dataset,
|
27 |
+
'batch_size': batch_size,
|
28 |
+
'shuffle': self.shuffle,
|
29 |
+
'collate_fn': collate_fn,
|
30 |
+
'num_workers': num_workers,
|
31 |
+
'pin_memory': pin_memory
|
32 |
+
}
|
33 |
+
if val_dataset is None:
|
34 |
+
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
|
35 |
+
super().__init__(sampler=self.sampler, **self.init_kwargs)
|
36 |
+
else:
|
37 |
+
super().__init__(**self.init_kwargs)
|
38 |
+
|
39 |
+
def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]:
|
40 |
+
if split == 0.0:
|
41 |
+
return None, None
|
42 |
+
|
43 |
+
idx_full = np.arange(self.n_samples)
|
44 |
+
|
45 |
+
np.random.seed(0)
|
46 |
+
np.random.shuffle(idx_full)
|
47 |
+
|
48 |
+
if isinstance(split, int):
|
49 |
+
assert split > 0
|
50 |
+
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
|
51 |
+
len_valid = split
|
52 |
+
else:
|
53 |
+
len_valid = int(self.n_samples * split)
|
54 |
+
|
55 |
+
valid_idx = idx_full[0:len_valid]
|
56 |
+
train_idx = np.delete(idx_full, np.arange(0, len_valid))
|
57 |
+
|
58 |
+
train_sampler = SubsetRandomSampler(train_idx)
|
59 |
+
valid_sampler = SubsetRandomSampler(valid_idx)
|
60 |
+
print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}")
|
61 |
+
|
62 |
+
# turn off shuffle option which is mutually exclusive with sampler
|
63 |
+
self.shuffle = False
|
64 |
+
self.n_samples = len(train_idx)
|
65 |
+
|
66 |
+
return train_sampler, valid_sampler
|
67 |
+
|
68 |
+
def split_validation(self, bs = 1000):
|
69 |
+
if self.val_dataset is not None:
|
70 |
+
kwargs = {
|
71 |
+
'dataset': self.val_dataset,
|
72 |
+
'batch_size': bs,
|
73 |
+
'shuffle': False,
|
74 |
+
'collate_fn': self.collate_fn,
|
75 |
+
'num_workers': self.num_workers
|
76 |
+
}
|
77 |
+
return DataLoader(**kwargs)
|
78 |
+
else:
|
79 |
+
print('Using sampler to split!')
|
80 |
+
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
|
81 |
+
|
82 |
+
|
83 |
+
|
ELR/base/base_model.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import numpy as np
|
3 |
+
from abc import abstractmethod
|
4 |
+
|
5 |
+
|
6 |
+
class BaseModel(nn.Module):
|
7 |
+
"""
|
8 |
+
Base class for all models
|
9 |
+
"""
|
10 |
+
@abstractmethod
|
11 |
+
def forward(self, *inputs):
|
12 |
+
"""
|
13 |
+
Forward pass logic
|
14 |
+
|
15 |
+
:return: Model output
|
16 |
+
"""
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def __str__(self):
|
20 |
+
"""
|
21 |
+
Model prints with number of trainable parameters
|
22 |
+
"""
|
23 |
+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
24 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
25 |
+
return super().__str__() + '\nTrainable parameters: {}'.format(params)
|
ELR/base/base_trainer.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypeVar, List, Tuple
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from abc import abstractmethod
|
5 |
+
from numpy import inf
|
6 |
+
from logger import TensorboardWriter
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class BaseTrainer:
|
10 |
+
"""
|
11 |
+
Base class for all trainers
|
12 |
+
"""
|
13 |
+
def __init__(self, model, train_criterion, metrics, optimizer, config, val_criterion):
|
14 |
+
self.config = config
|
15 |
+
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
|
16 |
+
|
17 |
+
# setup GPU device if available, move model into configured device
|
18 |
+
self.device, device_ids = self._prepare_device(config['n_gpu'])
|
19 |
+
self.model = model.to(self.device)
|
20 |
+
|
21 |
+
if len(device_ids) > 1:
|
22 |
+
self.model = torch.nn.DataParallel(model, device_ids=device_ids)
|
23 |
+
|
24 |
+
self.train_criterion = train_criterion.to(self.device)
|
25 |
+
|
26 |
+
|
27 |
+
self.val_criterion = val_criterion
|
28 |
+
self.metrics = metrics
|
29 |
+
|
30 |
+
self.optimizer = optimizer
|
31 |
+
|
32 |
+
cfg_trainer = config['trainer']
|
33 |
+
self.epochs = cfg_trainer['epochs']
|
34 |
+
self.save_period = cfg_trainer['save_period']
|
35 |
+
self.monitor = cfg_trainer.get('monitor', 'off')
|
36 |
+
|
37 |
+
# configuration to monitor model performance and save best
|
38 |
+
if self.monitor == 'off':
|
39 |
+
self.mnt_mode = 'off'
|
40 |
+
self.mnt_best = 0
|
41 |
+
else:
|
42 |
+
self.mnt_mode, self.mnt_metric = self.monitor.split()
|
43 |
+
assert self.mnt_mode in ['min', 'max']
|
44 |
+
|
45 |
+
self.mnt_best = inf if self.mnt_mode == 'min' else -inf
|
46 |
+
self.early_stop = cfg_trainer.get('early_stop', inf)
|
47 |
+
|
48 |
+
self.start_epoch = 1
|
49 |
+
|
50 |
+
self.checkpoint_dir = config.save_dir
|
51 |
+
|
52 |
+
# setup visualization writer instance
|
53 |
+
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
|
54 |
+
|
55 |
+
if config.resume is not None:
|
56 |
+
self._resume_checkpoint(config.resume)
|
57 |
+
|
58 |
+
@abstractmethod
|
59 |
+
def _train_epoch(self, epoch):
|
60 |
+
"""
|
61 |
+
Training logic for an epoch
|
62 |
+
|
63 |
+
:param epoch: Current epochs number
|
64 |
+
"""
|
65 |
+
raise NotImplementedError
|
66 |
+
|
67 |
+
def train(self):
|
68 |
+
"""
|
69 |
+
Full training logic
|
70 |
+
"""
|
71 |
+
not_improved_count = 0
|
72 |
+
|
73 |
+
for epoch in tqdm(range(self.start_epoch, self.epochs + 1), desc='Total progress: '):
|
74 |
+
if epoch <= self.config['trainer']['warmup']:
|
75 |
+
result = self._warmup_epoch(epoch)
|
76 |
+
else:
|
77 |
+
result= self._train_epoch(epoch)
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
# save logged informations into log dict
|
82 |
+
log = {'epoch': epoch}
|
83 |
+
for key, value in result.items():
|
84 |
+
if key == 'metrics':
|
85 |
+
log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
|
86 |
+
elif key == 'val_metrics':
|
87 |
+
log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
|
88 |
+
elif key == 'test_metrics':
|
89 |
+
log.update({'test_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
|
90 |
+
else:
|
91 |
+
log[key] = value
|
92 |
+
|
93 |
+
# print logged informations to the screen
|
94 |
+
for key, value in log.items():
|
95 |
+
self.logger.info(' {:15s}: {}'.format(str(key), value))
|
96 |
+
|
97 |
+
# evaluate model performance according to configured metric, save best checkpoint as model_best
|
98 |
+
best = False
|
99 |
+
if self.mnt_mode != 'off':
|
100 |
+
try:
|
101 |
+
# check whether model performance improved or not, according to specified metric(mnt_metric)
|
102 |
+
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
|
103 |
+
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
|
104 |
+
except KeyError:
|
105 |
+
self.logger.warning("Warning: Metric '{}' is not found. "
|
106 |
+
"Model performance monitoring is disabled.".format(self.mnt_metric))
|
107 |
+
self.mnt_mode = 'off'
|
108 |
+
improved = False
|
109 |
+
|
110 |
+
if improved:
|
111 |
+
self.mnt_best = log[self.mnt_metric]
|
112 |
+
not_improved_count = 0
|
113 |
+
best = True
|
114 |
+
else:
|
115 |
+
not_improved_count += 1
|
116 |
+
|
117 |
+
if not_improved_count > self.early_stop:
|
118 |
+
self.logger.info("Validation performance didn\'t improve for {} epochs. "
|
119 |
+
"Training stops.".format(self.early_stop))
|
120 |
+
break
|
121 |
+
|
122 |
+
if epoch % self.save_period == 0:
|
123 |
+
self._save_checkpoint(epoch, save_best=best)
|
124 |
+
|
125 |
+
def _prepare_device(self, n_gpu_use):
|
126 |
+
"""
|
127 |
+
setup GPU device if available, move model into configured device
|
128 |
+
"""
|
129 |
+
n_gpu = torch.cuda.device_count()
|
130 |
+
if n_gpu_use > 0 and n_gpu == 0:
|
131 |
+
self.logger.warning("Warning: There\'s no GPU available on this machine,"
|
132 |
+
"training will be performed on CPU.")
|
133 |
+
n_gpu_use = 0
|
134 |
+
if n_gpu_use > n_gpu:
|
135 |
+
self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
|
136 |
+
"on this machine.".format(n_gpu_use, n_gpu))
|
137 |
+
n_gpu_use = n_gpu
|
138 |
+
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
|
139 |
+
list_ids = list(range(n_gpu_use))
|
140 |
+
return device, list_ids
|
141 |
+
|
142 |
+
def _save_checkpoint(self, epoch, save_best=False):
|
143 |
+
"""
|
144 |
+
Saving checkpoints
|
145 |
+
|
146 |
+
:param epoch: current epoch number
|
147 |
+
:param log: logging information of the epoch
|
148 |
+
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
|
149 |
+
"""
|
150 |
+
arch = type(self.model).__name__
|
151 |
+
|
152 |
+
state = {
|
153 |
+
'arch': arch,
|
154 |
+
'epoch': epoch,
|
155 |
+
'state_dict': self.model.state_dict(),
|
156 |
+
'optimizer': self.optimizer.state_dict(),
|
157 |
+
'monitor_best': self.mnt_best
|
158 |
+
}
|
159 |
+
# filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
|
160 |
+
# torch.save(state, filename)
|
161 |
+
# self.logger.info("Saving checkpoint: {} ...".format(filename))
|
162 |
+
if save_best:
|
163 |
+
best_path = str(self.checkpoint_dir / 'model_best.pth')
|
164 |
+
torch.save(state, best_path)
|
165 |
+
self.logger.info("Saving current best: model_best.pth at: {} ...".format(best_path))
|
166 |
+
|
167 |
+
|
168 |
+
def _resume_checkpoint(self, resume_path):
|
169 |
+
"""
|
170 |
+
Resume from saved checkpoints
|
171 |
+
|
172 |
+
:param resume_path: Checkpoint path to be resumed
|
173 |
+
"""
|
174 |
+
resume_path = str(resume_path)
|
175 |
+
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
|
176 |
+
checkpoint = torch.load(resume_path)
|
177 |
+
self.start_epoch = checkpoint['epoch'] + 1
|
178 |
+
self.mnt_best = checkpoint['monitor_best']
|
179 |
+
|
180 |
+
# load architecture params from checkpoint.
|
181 |
+
if checkpoint['config']['arch'] != self.config['arch']:
|
182 |
+
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
|
183 |
+
"checkpoint. This may yield an exception while state_dict is being loaded.")
|
184 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
185 |
+
|
186 |
+
# load optimizer state from checkpoint only when optimizer type is not changed.
|
187 |
+
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
|
188 |
+
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
|
189 |
+
"Optimizer parameters not being resumed.")
|
190 |
+
else:
|
191 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
192 |
+
|
193 |
+
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
|
194 |
+
|
195 |
+
|
ELR/config_cifar10.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "cifar10_resnet34_cosine",
|
3 |
+
"n_gpu": 1,
|
4 |
+
"seed": 123,
|
5 |
+
|
6 |
+
"arch": {
|
7 |
+
"type": "resnet34",
|
8 |
+
"args": {"num_classes":10}
|
9 |
+
},
|
10 |
+
|
11 |
+
"num_classes": 10,
|
12 |
+
|
13 |
+
"data_loader": {
|
14 |
+
"type": "CIFAR10DataLoader",
|
15 |
+
"args":{
|
16 |
+
"data_dir": "/dir_to_data",
|
17 |
+
"batch_size": 128,
|
18 |
+
"shuffle": true,
|
19 |
+
"num_batches": 0,
|
20 |
+
"validation_split": 0,
|
21 |
+
"num_workers": 8,
|
22 |
+
"pin_memory": true
|
23 |
+
}
|
24 |
+
},
|
25 |
+
|
26 |
+
|
27 |
+
"optimizer": {
|
28 |
+
"type": "SGD",
|
29 |
+
"args":{
|
30 |
+
"lr": 0.02,
|
31 |
+
"momentum": 0.9,
|
32 |
+
"weight_decay": 1e-3
|
33 |
+
}
|
34 |
+
},
|
35 |
+
|
36 |
+
"train_loss": {
|
37 |
+
"type": "elr_loss",
|
38 |
+
"args":{
|
39 |
+
"beta": 0.7,
|
40 |
+
"lambda": 3
|
41 |
+
}
|
42 |
+
},
|
43 |
+
|
44 |
+
"val_loss": "cross_entropy",
|
45 |
+
"metrics": [
|
46 |
+
"my_metric", "my_metric2"
|
47 |
+
],
|
48 |
+
|
49 |
+
"lr_scheduler": {
|
50 |
+
"type": "CosineAnnealingWarmRestarts",
|
51 |
+
"args": {
|
52 |
+
"T_0": 10,
|
53 |
+
"eta_min": 0.001
|
54 |
+
}
|
55 |
+
},
|
56 |
+
|
57 |
+
"trainer": {
|
58 |
+
"epochs": 150,
|
59 |
+
"warmup": 0,
|
60 |
+
"save_dir": "saved/",
|
61 |
+
"save_period": 1,
|
62 |
+
"verbosity": 2,
|
63 |
+
"label_dir": "saved/",
|
64 |
+
"monitor": "max val_my_metric",
|
65 |
+
"early_stop": 2000,
|
66 |
+
"tensorboard": false,
|
67 |
+
"mlflow": true,
|
68 |
+
"_percent": "Percentage of noise",
|
69 |
+
"percent": 0.8,
|
70 |
+
"_begin": "When to begin updating labels",
|
71 |
+
"begin": 0,
|
72 |
+
"_asym": "symmetric noise if false",
|
73 |
+
"asym": false
|
74 |
+
}
|
75 |
+
}
|
ELR/config_cifar100.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "cifar100_sy_60_resnet34",
|
3 |
+
"n_gpu": 1,
|
4 |
+
"seed": 123,
|
5 |
+
|
6 |
+
"arch": {
|
7 |
+
"type": "resnet34",
|
8 |
+
"args": {"num_classes":100}
|
9 |
+
},
|
10 |
+
|
11 |
+
"num_classes": 100,
|
12 |
+
|
13 |
+
"data_loader": {
|
14 |
+
"type": "CIFAR100DataLoader",
|
15 |
+
"args":{
|
16 |
+
"data_dir": "/dir_to_data",
|
17 |
+
"batch_size": 128,
|
18 |
+
"shuffle": true,
|
19 |
+
"num_batches": 0,
|
20 |
+
"validation_split": 0,
|
21 |
+
"num_workers": 8,
|
22 |
+
"pin_memory": true
|
23 |
+
}
|
24 |
+
},
|
25 |
+
|
26 |
+
|
27 |
+
"optimizer": {
|
28 |
+
"type": "SGD",
|
29 |
+
"args":{
|
30 |
+
"lr": 0.02,
|
31 |
+
"momentum": 0.9,
|
32 |
+
"weight_decay": 1e-3
|
33 |
+
}
|
34 |
+
},
|
35 |
+
|
36 |
+
"train_loss": {
|
37 |
+
"type": "elr_loss",
|
38 |
+
"args":{
|
39 |
+
"beta": 0.9,
|
40 |
+
"lambda": 7
|
41 |
+
}
|
42 |
+
},
|
43 |
+
|
44 |
+
"val_loss": "cross_entropy",
|
45 |
+
"metrics": [
|
46 |
+
"my_metric", "my_metric2"
|
47 |
+
],
|
48 |
+
|
49 |
+
"lr_scheduler": {
|
50 |
+
"type": "MultiStepLR",
|
51 |
+
"args": {
|
52 |
+
"milestones": [80,120],
|
53 |
+
"gamma": 0.01
|
54 |
+
}
|
55 |
+
},
|
56 |
+
|
57 |
+
"trainer": {
|
58 |
+
"epochs": 150,
|
59 |
+
"warmup": 0,
|
60 |
+
"save_dir": "saved/",
|
61 |
+
"save_period": 1,
|
62 |
+
"verbosity": 2,
|
63 |
+
"label_dir": "saved/",
|
64 |
+
"monitor": "max val_my_metric",
|
65 |
+
"early_stop": 2000,
|
66 |
+
"tensorboard": false,
|
67 |
+
"mlflow": true,
|
68 |
+
"_percent": "Percentage of noise",
|
69 |
+
"percent": 0.6,
|
70 |
+
"_begin": "When to begin updating labels",
|
71 |
+
"begin": 0,
|
72 |
+
"_asym": "symmetric noise if false",
|
73 |
+
"asym": false
|
74 |
+
}
|
75 |
+
}
|
ELR/config_cifar10_asym.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "cifar10_resnet34_cosine",
|
3 |
+
"n_gpu": 1,
|
4 |
+
"seed": 123,
|
5 |
+
|
6 |
+
"arch": {
|
7 |
+
"type": "resnet34",
|
8 |
+
"args": {"num_classes":10}
|
9 |
+
},
|
10 |
+
|
11 |
+
"num_classes": 10,
|
12 |
+
|
13 |
+
"data_loader": {
|
14 |
+
"type": "CIFAR10DataLoader",
|
15 |
+
"args":{
|
16 |
+
"data_dir": "/dir_to_data",
|
17 |
+
"batch_size": 128,
|
18 |
+
"shuffle": true,
|
19 |
+
"num_batches": 0,
|
20 |
+
"validation_split": 0,
|
21 |
+
"num_workers": 8,
|
22 |
+
"pin_memory": true
|
23 |
+
}
|
24 |
+
},
|
25 |
+
|
26 |
+
|
27 |
+
"optimizer": {
|
28 |
+
"type": "SGD",
|
29 |
+
"args":{
|
30 |
+
"lr": 0.02,
|
31 |
+
"momentum": 0.9,
|
32 |
+
"weight_decay": 1e-3
|
33 |
+
}
|
34 |
+
},
|
35 |
+
|
36 |
+
"train_loss": {
|
37 |
+
"type": "elr_loss",
|
38 |
+
"args":{
|
39 |
+
"beta": 0.9,
|
40 |
+
"lambda": 1
|
41 |
+
}
|
42 |
+
},
|
43 |
+
|
44 |
+
"val_loss": "cross_entropy",
|
45 |
+
"metrics": [
|
46 |
+
"my_metric", "my_metric2"
|
47 |
+
],
|
48 |
+
|
49 |
+
"lr_scheduler": {
|
50 |
+
"type": "MultiStepLR",
|
51 |
+
"args": {
|
52 |
+
"milestones": [40,80],
|
53 |
+
"gamma": 0.01
|
54 |
+
}
|
55 |
+
},
|
56 |
+
|
57 |
+
"trainer": {
|
58 |
+
"epochs": 120,
|
59 |
+
"warmup": 0,
|
60 |
+
"save_dir": "saved/",
|
61 |
+
"save_period": 1,
|
62 |
+
"verbosity": 2,
|
63 |
+
"label_dir": "saved/",
|
64 |
+
"monitor": "max val_my_metric",
|
65 |
+
"early_stop": 2000,
|
66 |
+
"tensorboard": false,
|
67 |
+
"mlflow": true,
|
68 |
+
"_percent": "Percentage of noise",
|
69 |
+
"percent": 0.4,
|
70 |
+
"_begin": "When to begin updating labels",
|
71 |
+
"begin": 0,
|
72 |
+
"_asym": "symmetric noise if false",
|
73 |
+
"asym": true
|
74 |
+
}
|
75 |
+
}
|
ELR/data_loader/__pycache__/cifar10.cpython-36.pyc
ADDED
Binary file (7.87 kB). View file
|
|
ELR/data_loader/__pycache__/clothing1m.cpython-36.pyc
ADDED
Binary file (2.19 kB). View file
|
|
ELR/data_loader/__pycache__/data_loaders.cpython-36.pyc
ADDED
Binary file (2.91 kB). View file
|
|
ELR/data_loader/cifar10.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision
|
6 |
+
from torch.utils.data.dataset import Subset
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import random
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
|
14 |
+
def get_cifar10(root, cfg_trainer, train=True,
|
15 |
+
transform_train=None, transform_val=None,
|
16 |
+
download=False, noise_file = ''):
|
17 |
+
base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download)
|
18 |
+
if train:
|
19 |
+
train_idxs, val_idxs = train_val_split(base_dataset.targets)
|
20 |
+
train_dataset = CIFAR10_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
|
21 |
+
val_dataset = CIFAR10_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
|
22 |
+
if cfg_trainer['asym']:
|
23 |
+
train_dataset.asymmetric_noise()
|
24 |
+
val_dataset.asymmetric_noise()
|
25 |
+
else:
|
26 |
+
train_dataset.symmetric_noise()
|
27 |
+
val_dataset.symmetric_noise()
|
28 |
+
|
29 |
+
print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}") # Train: 45000 Val: 5000
|
30 |
+
else:
|
31 |
+
train_dataset = []
|
32 |
+
val_dataset = CIFAR10_val(root, cfg_trainer, None, train=train, transform=transform_val)
|
33 |
+
print(f"Test: {len(val_dataset)}")
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
return train_dataset, val_dataset
|
38 |
+
|
39 |
+
|
40 |
+
def train_val_split(base_dataset: torchvision.datasets.CIFAR10):
|
41 |
+
num_classes = 10
|
42 |
+
base_dataset = np.array(base_dataset)
|
43 |
+
train_n = int(len(base_dataset) * 0.9 / num_classes)
|
44 |
+
train_idxs = []
|
45 |
+
val_idxs = []
|
46 |
+
|
47 |
+
for i in range(num_classes):
|
48 |
+
idxs = np.where(base_dataset == i)[0]
|
49 |
+
np.random.shuffle(idxs)
|
50 |
+
train_idxs.extend(idxs[:train_n])
|
51 |
+
val_idxs.extend(idxs[train_n:])
|
52 |
+
np.random.shuffle(train_idxs)
|
53 |
+
np.random.shuffle(val_idxs)
|
54 |
+
|
55 |
+
return train_idxs, val_idxs
|
56 |
+
|
57 |
+
|
58 |
+
class CIFAR10_train(torchvision.datasets.CIFAR10):
|
59 |
+
def __init__(self, root, cfg_trainer, indexs, train=True,
|
60 |
+
transform=None, target_transform=None,
|
61 |
+
download=False):
|
62 |
+
super(CIFAR10_train, self).__init__(root, train=train,
|
63 |
+
transform=transform, target_transform=target_transform,
|
64 |
+
download=download)
|
65 |
+
self.num_classes = 10
|
66 |
+
self.cfg_trainer = cfg_trainer
|
67 |
+
self.train_data = self.data[indexs]#self.train_data[indexs]
|
68 |
+
self.train_labels = np.array(self.targets)[indexs]#np.array(self.train_labels)[indexs]
|
69 |
+
self.indexs = indexs
|
70 |
+
self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
|
71 |
+
self.noise_indx = []
|
72 |
+
|
73 |
+
def symmetric_noise(self):
|
74 |
+
self.train_labels_gt = self.train_labels.copy()
|
75 |
+
#np.random.seed(seed=888)
|
76 |
+
indices = np.random.permutation(len(self.train_data))
|
77 |
+
for i, idx in enumerate(indices):
|
78 |
+
if i < self.cfg_trainer['percent'] * len(self.train_data):
|
79 |
+
self.noise_indx.append(idx)
|
80 |
+
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
|
81 |
+
|
82 |
+
def asymmetric_noise(self):
|
83 |
+
self.train_labels_gt = self.train_labels.copy()
|
84 |
+
for i in range(self.num_classes):
|
85 |
+
indices = np.where(self.train_labels == i)[0]
|
86 |
+
np.random.shuffle(indices)
|
87 |
+
for j, idx in enumerate(indices):
|
88 |
+
if j < self.cfg_trainer['percent'] * len(indices):
|
89 |
+
self.noise_indx.append(idx)
|
90 |
+
# truck -> automobile
|
91 |
+
if i == 9:
|
92 |
+
self.train_labels[idx] = 1
|
93 |
+
# bird -> airplane
|
94 |
+
elif i == 2:
|
95 |
+
self.train_labels[idx] = 0
|
96 |
+
# cat -> dog
|
97 |
+
elif i == 3:
|
98 |
+
self.train_labels[idx] = 5
|
99 |
+
# dog -> cat
|
100 |
+
elif i == 5:
|
101 |
+
self.train_labels[idx] = 3
|
102 |
+
# deer -> horse
|
103 |
+
elif i == 4:
|
104 |
+
self.train_labels[idx] = 7
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
def __getitem__(self, index):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
index (int): Index
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
tuple: (image, target) where target is index of the target class.
|
115 |
+
"""
|
116 |
+
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
|
117 |
+
|
118 |
+
|
119 |
+
# doing this so that it is consistent with all other datasets
|
120 |
+
# to return a PIL Image
|
121 |
+
img = Image.fromarray(img)
|
122 |
+
|
123 |
+
|
124 |
+
if self.transform is not None:
|
125 |
+
img = self.transform(img)
|
126 |
+
|
127 |
+
if self.target_transform is not None:
|
128 |
+
target = self.target_transform(target)
|
129 |
+
|
130 |
+
return img,target, index, target_gt
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return len(self.train_data)
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
class CIFAR10_val(torchvision.datasets.CIFAR10):
|
138 |
+
|
139 |
+
def __init__(self, root, cfg_trainer, indexs, train=True,
|
140 |
+
transform=None, target_transform=None,
|
141 |
+
download=False):
|
142 |
+
super(CIFAR10_val, self).__init__(root, train=train,
|
143 |
+
transform=transform, target_transform=target_transform,
|
144 |
+
download=download)
|
145 |
+
|
146 |
+
# self.train_data = self.data[indexs]
|
147 |
+
# self.train_labels = np.array(self.targets)[indexs]
|
148 |
+
self.num_classes = 10
|
149 |
+
self.cfg_trainer = cfg_trainer
|
150 |
+
if train:
|
151 |
+
self.train_data = self.data[indexs]
|
152 |
+
self.train_labels = np.array(self.targets)[indexs]
|
153 |
+
else:
|
154 |
+
self.train_data = self.data
|
155 |
+
self.train_labels = np.array(self.targets)
|
156 |
+
self.train_labels_gt = self.train_labels.copy()
|
157 |
+
def symmetric_noise(self):
|
158 |
+
|
159 |
+
indices = np.random.permutation(len(self.train_data))
|
160 |
+
for i, idx in enumerate(indices):
|
161 |
+
if i < self.cfg_trainer['percent'] * len(self.train_data):
|
162 |
+
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
|
163 |
+
|
164 |
+
def asymmetric_noise(self):
|
165 |
+
for i in range(self.num_classes):
|
166 |
+
indices = np.where(self.train_labels == i)[0]
|
167 |
+
np.random.shuffle(indices)
|
168 |
+
for j, idx in enumerate(indices):
|
169 |
+
if j < self.cfg_trainer['percent'] * len(indices):
|
170 |
+
# truck -> automobile
|
171 |
+
if i == 9:
|
172 |
+
self.train_labels[idx] = 1
|
173 |
+
# bird -> airplane
|
174 |
+
elif i == 2:
|
175 |
+
self.train_labels[idx] = 0
|
176 |
+
# cat -> dog
|
177 |
+
elif i == 3:
|
178 |
+
self.train_labels[idx] = 5
|
179 |
+
# dog -> cat
|
180 |
+
elif i == 5:
|
181 |
+
self.train_labels[idx] = 3
|
182 |
+
# deer -> horse
|
183 |
+
elif i == 4:
|
184 |
+
self.train_labels[idx] = 7
|
185 |
+
def __len__(self):
|
186 |
+
return len(self.train_data)
|
187 |
+
|
188 |
+
|
189 |
+
def __getitem__(self, index):
|
190 |
+
"""
|
191 |
+
Args:
|
192 |
+
index (int): Index
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
tuple: (image, target) where target is index of the target class.
|
196 |
+
"""
|
197 |
+
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
|
198 |
+
|
199 |
+
|
200 |
+
# doing this so that it is consistent with all other datasets
|
201 |
+
# to return a PIL Image
|
202 |
+
img = Image.fromarray(img)
|
203 |
+
|
204 |
+
|
205 |
+
if self.transform is not None:
|
206 |
+
img = self.transform(img)
|
207 |
+
|
208 |
+
if self.target_transform is not None:
|
209 |
+
target = self.target_transform(target)
|
210 |
+
|
211 |
+
return img, target, index, target_gt
|
212 |
+
|
ELR/data_loader/cifar100.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision
|
6 |
+
from torch.utils.data.dataset import Subset
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import random
|
11 |
+
import os
|
12 |
+
import json
|
13 |
+
from numpy.testing import assert_array_almost_equal
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def get_cifar100(root, cfg_trainer, train=True,
|
18 |
+
transform_train=None, transform_val=None,
|
19 |
+
download=False, noise_file = ''):
|
20 |
+
base_dataset = torchvision.datasets.CIFAR100(root, train=train, download=download)
|
21 |
+
if train:
|
22 |
+
train_idxs, val_idxs = train_val_split(base_dataset.targets)
|
23 |
+
train_dataset = CIFAR100_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
|
24 |
+
val_dataset = CIFAR100_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
|
25 |
+
if cfg_trainer['asym']:
|
26 |
+
train_dataset.asymmetric_noise()
|
27 |
+
val_dataset.asymmetric_noise()
|
28 |
+
else:
|
29 |
+
train_dataset.symmetric_noise()
|
30 |
+
val_dataset.symmetric_noise()
|
31 |
+
|
32 |
+
print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}") # Train: 45000 Val: 5000
|
33 |
+
else:
|
34 |
+
train_dataset = []
|
35 |
+
val_dataset = CIFAR100_val(root, cfg_trainer, None, train=train, transform=transform_val)
|
36 |
+
print(f"Test: {len(val_dataset)}")
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
return train_dataset, val_dataset
|
42 |
+
|
43 |
+
|
44 |
+
def train_val_split(base_dataset: torchvision.datasets.CIFAR100):
|
45 |
+
num_classes = 100
|
46 |
+
base_dataset = np.array(base_dataset)
|
47 |
+
train_n = int(len(base_dataset) * 0.9 / num_classes)
|
48 |
+
train_idxs = []
|
49 |
+
val_idxs = []
|
50 |
+
|
51 |
+
for i in range(num_classes):
|
52 |
+
idxs = np.where(base_dataset == i)[0]
|
53 |
+
np.random.shuffle(idxs)
|
54 |
+
train_idxs.extend(idxs[:train_n])
|
55 |
+
val_idxs.extend(idxs[train_n:])
|
56 |
+
np.random.shuffle(train_idxs)
|
57 |
+
np.random.shuffle(val_idxs)
|
58 |
+
|
59 |
+
return train_idxs, val_idxs
|
60 |
+
|
61 |
+
|
62 |
+
class CIFAR100_train(torchvision.datasets.CIFAR100):
|
63 |
+
def __init__(self, root, cfg_trainer, indexs, train=True,
|
64 |
+
transform=None, target_transform=None,
|
65 |
+
download=False):
|
66 |
+
super(CIFAR100_train, self).__init__(root, train=train,
|
67 |
+
transform=transform, target_transform=target_transform,
|
68 |
+
download=download)
|
69 |
+
self.num_classes = 100
|
70 |
+
self.cfg_trainer = cfg_trainer
|
71 |
+
self.train_data = self.data[indexs]
|
72 |
+
self.train_labels = np.array(self.targets)[indexs]
|
73 |
+
self.indexs = indexs
|
74 |
+
self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
|
75 |
+
self.noise_indx = []
|
76 |
+
#self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
|
77 |
+
|
78 |
+
self.count = 0
|
79 |
+
|
80 |
+
def symmetric_noise(self):
|
81 |
+
self.train_labels_gt = self.train_labels.copy()
|
82 |
+
indices = np.random.permutation(len(self.train_data))
|
83 |
+
for i, idx in enumerate(indices):
|
84 |
+
if i < self.cfg_trainer['percent'] * len(self.train_data):
|
85 |
+
self.noise_indx.append(idx)
|
86 |
+
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
|
87 |
+
|
88 |
+
def multiclass_noisify(self, y, P, random_state=0):
|
89 |
+
""" Flip classes according to transition probability matrix T.
|
90 |
+
It expects a number between 0 and the number of classes - 1.
|
91 |
+
"""
|
92 |
+
|
93 |
+
assert P.shape[0] == P.shape[1]
|
94 |
+
assert np.max(y) < P.shape[0]
|
95 |
+
|
96 |
+
# row stochastic matrix
|
97 |
+
assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
|
98 |
+
assert (P >= 0.0).all()
|
99 |
+
|
100 |
+
m = y.shape[0]
|
101 |
+
new_y = y.copy()
|
102 |
+
flipper = np.random.RandomState(random_state)
|
103 |
+
|
104 |
+
for idx in np.arange(m):
|
105 |
+
i = y[idx]
|
106 |
+
# draw a vector with only an 1
|
107 |
+
flipped = flipper.multinomial(1, P[i, :], 1)[0]
|
108 |
+
new_y[idx] = np.where(flipped == 1)[0]
|
109 |
+
|
110 |
+
return new_y
|
111 |
+
|
112 |
+
# def build_for_cifar100(self, size, noise):
|
113 |
+
# """ random flip between two random classes.
|
114 |
+
# """
|
115 |
+
# assert(noise >= 0.) and (noise <= 1.)
|
116 |
+
|
117 |
+
# P = np.eye(size)
|
118 |
+
# cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
|
119 |
+
# P[cls1, cls2] = noise
|
120 |
+
# P[cls2, cls1] = noise
|
121 |
+
# P[cls1, cls1] = 1.0 - noise
|
122 |
+
# P[cls2, cls2] = 1.0 - noise
|
123 |
+
|
124 |
+
# assert_array_almost_equal(P.sum(axis=1), 1, 1)
|
125 |
+
# return P
|
126 |
+
def build_for_cifar100(self, size, noise):
|
127 |
+
""" The noise matrix flips to the "next" class with probability 'noise'.
|
128 |
+
"""
|
129 |
+
|
130 |
+
assert(noise >= 0.) and (noise <= 1.)
|
131 |
+
|
132 |
+
P = (1. - noise) * np.eye(size)
|
133 |
+
for i in np.arange(size - 1):
|
134 |
+
P[i, i + 1] = noise
|
135 |
+
|
136 |
+
# adjust last row
|
137 |
+
P[size - 1, 0] = noise
|
138 |
+
|
139 |
+
assert_array_almost_equal(P.sum(axis=1), 1, 1)
|
140 |
+
return P
|
141 |
+
|
142 |
+
|
143 |
+
def asymmetric_noise(self, asym=False, random_shuffle=False):
|
144 |
+
self.train_labels_gt = self.train_labels.copy()
|
145 |
+
P = np.eye(self.num_classes)
|
146 |
+
n = self.cfg_trainer['percent']
|
147 |
+
nb_superclasses = 20
|
148 |
+
nb_subclasses = 5
|
149 |
+
|
150 |
+
if n > 0.0:
|
151 |
+
for i in np.arange(nb_superclasses):
|
152 |
+
init, end = i * nb_subclasses, (i+1) * nb_subclasses
|
153 |
+
P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
|
154 |
+
|
155 |
+
y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
|
156 |
+
random_state=0)
|
157 |
+
actual_noise = (y_train_noisy != self.train_labels).mean()
|
158 |
+
assert actual_noise > 0.0
|
159 |
+
self.train_labels = y_train_noisy
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
def __getitem__(self, index):
|
164 |
+
"""
|
165 |
+
Args:
|
166 |
+
index (int): Index
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
tuple: (image, target) where target is index of the target class.
|
170 |
+
"""
|
171 |
+
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
|
172 |
+
|
173 |
+
|
174 |
+
# doing this so that it is consistent with all other datasets
|
175 |
+
# to return a PIL Image
|
176 |
+
img = Image.fromarray(img)
|
177 |
+
|
178 |
+
|
179 |
+
if self.transform is not None:
|
180 |
+
img = self.transform(img)
|
181 |
+
|
182 |
+
if self.target_transform is not None:
|
183 |
+
target = self.target_transform(target)
|
184 |
+
|
185 |
+
return img, target, index, target_gt
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return len(self.train_data)
|
189 |
+
|
190 |
+
|
191 |
+
class CIFAR100_val(torchvision.datasets.CIFAR100):
|
192 |
+
|
193 |
+
def __init__(self, root, cfg_trainer, indexs, train=True,
|
194 |
+
transform=None, target_transform=None,
|
195 |
+
download=False):
|
196 |
+
super(CIFAR100_val, self).__init__(root, train=train,
|
197 |
+
transform=transform, target_transform=target_transform,
|
198 |
+
download=download)
|
199 |
+
|
200 |
+
# self.train_data = self.data[indexs]
|
201 |
+
# self.train_labels = np.array(self.targets)[indexs]
|
202 |
+
self.num_classes = 100
|
203 |
+
self.cfg_trainer = cfg_trainer
|
204 |
+
if train:
|
205 |
+
self.train_data = self.data[indexs]
|
206 |
+
self.train_labels = np.array(self.targets)[indexs]
|
207 |
+
else:
|
208 |
+
self.train_data = self.data
|
209 |
+
self.train_labels = np.array(self.targets)
|
210 |
+
self.train_labels_gt = self.train_labels.copy()
|
211 |
+
def symmetric_noise(self):
|
212 |
+
indices = np.random.permutation(len(self.train_data))
|
213 |
+
for i, idx in enumerate(indices):
|
214 |
+
if i < self.cfg_trainer['percent'] * len(self.train_data):
|
215 |
+
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
|
216 |
+
|
217 |
+
def multiclass_noisify(self, y, P, random_state=0):
|
218 |
+
""" Flip classes according to transition probability matrix T.
|
219 |
+
It expects a number between 0 and the number of classes - 1.
|
220 |
+
"""
|
221 |
+
|
222 |
+
assert P.shape[0] == P.shape[1]
|
223 |
+
assert np.max(y) < P.shape[0]
|
224 |
+
|
225 |
+
# row stochastic matrix
|
226 |
+
assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
|
227 |
+
assert (P >= 0.0).all()
|
228 |
+
|
229 |
+
m = y.shape[0]
|
230 |
+
new_y = y.copy()
|
231 |
+
flipper = np.random.RandomState(random_state)
|
232 |
+
|
233 |
+
for idx in np.arange(m):
|
234 |
+
i = y[idx]
|
235 |
+
# draw a vector with only an 1
|
236 |
+
flipped = flipper.multinomial(1, P[i, :], 1)[0]
|
237 |
+
new_y[idx] = np.where(flipped == 1)[0]
|
238 |
+
|
239 |
+
return new_y
|
240 |
+
|
241 |
+
# def build_for_cifar100(self, size, noise):
|
242 |
+
# """ random flip between two random classes.
|
243 |
+
# """
|
244 |
+
# assert(noise >= 0.) and (noise <= 1.)
|
245 |
+
|
246 |
+
# P = np.eye(size)
|
247 |
+
# cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
|
248 |
+
# P[cls1, cls2] = noise
|
249 |
+
# P[cls2, cls1] = noise
|
250 |
+
# P[cls1, cls1] = 1.0 - noise
|
251 |
+
# P[cls2, cls2] = 1.0 - noise
|
252 |
+
|
253 |
+
# assert_array_almost_equal(P.sum(axis=1), 1, 1)
|
254 |
+
# return P
|
255 |
+
def build_for_cifar100(self, size, noise):
|
256 |
+
""" The noise matrix flips to the "next" class with probability 'noise'.
|
257 |
+
"""
|
258 |
+
|
259 |
+
assert(noise >= 0.) and (noise <= 1.)
|
260 |
+
|
261 |
+
P = (1. - noise) * np.eye(size)
|
262 |
+
for i in np.arange(size - 1):
|
263 |
+
P[i, i + 1] = noise
|
264 |
+
|
265 |
+
# adjust last row
|
266 |
+
P[size - 1, 0] = noise
|
267 |
+
|
268 |
+
assert_array_almost_equal(P.sum(axis=1), 1, 1)
|
269 |
+
return P
|
270 |
+
|
271 |
+
|
272 |
+
def asymmetric_noise(self, asym=False, random_shuffle=False):
|
273 |
+
P = np.eye(self.num_classes)
|
274 |
+
n = self.cfg_trainer['percent']
|
275 |
+
nb_superclasses = 20
|
276 |
+
nb_subclasses = 5
|
277 |
+
|
278 |
+
if n > 0.0:
|
279 |
+
for i in np.arange(nb_superclasses):
|
280 |
+
init, end = i * nb_subclasses, (i+1) * nb_subclasses
|
281 |
+
P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
|
282 |
+
|
283 |
+
y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
|
284 |
+
random_state=0)
|
285 |
+
actual_noise = (y_train_noisy != self.train_labels).mean()
|
286 |
+
assert actual_noise > 0.0
|
287 |
+
self.train_labels = y_train_noisy
|
288 |
+
def __len__(self):
|
289 |
+
return len(self.train_data)
|
290 |
+
|
291 |
+
|
292 |
+
def __getitem__(self, index):
|
293 |
+
"""
|
294 |
+
Args:
|
295 |
+
index (int): Index
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
tuple: (image, target) where target is index of the target class.
|
299 |
+
"""
|
300 |
+
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
|
301 |
+
|
302 |
+
|
303 |
+
# doing this so that it is consistent with all other datasets
|
304 |
+
# to return a PIL Image
|
305 |
+
img = Image.fromarray(img)
|
306 |
+
|
307 |
+
|
308 |
+
if self.transform is not None:
|
309 |
+
img = self.transform(img)
|
310 |
+
|
311 |
+
if self.target_transform is not None:
|
312 |
+
target = self.target_transform(target)
|
313 |
+
|
314 |
+
return img, target, index, target_gt
|
315 |
+
|
316 |
+
|
317 |
+
|
ELR/data_loader/data_loaders.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
from torchvision import datasets, transforms
|
4 |
+
from base import BaseDataLoader
|
5 |
+
from data_loader.cifar10 import get_cifar10
|
6 |
+
from data_loader.cifar100 import get_cifar100
|
7 |
+
from parse_config import ConfigParser
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
class CIFAR10DataLoader(BaseDataLoader):
|
12 |
+
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True):
|
13 |
+
config = ConfigParser.get_instance()
|
14 |
+
cfg_trainer = config['trainer']
|
15 |
+
|
16 |
+
transform_train = transforms.Compose([
|
17 |
+
transforms.RandomCrop(32, padding=4),
|
18 |
+
transforms.RandomHorizontalFlip(),
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
21 |
+
])
|
22 |
+
transform_val = transforms.Compose([
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
25 |
+
])
|
26 |
+
self.data_dir = data_dir
|
27 |
+
|
28 |
+
noise_file='%sCIFAR10_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
|
29 |
+
|
30 |
+
self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
|
31 |
+
transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
|
32 |
+
|
33 |
+
super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
|
34 |
+
val_dataset = self.val_dataset)
|
35 |
+
def run_loader(self, batch_size, shuffle, validation_split, num_workers, pin_memory):
|
36 |
+
super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
|
37 |
+
val_dataset = self.val_dataset)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
class CIFAR100DataLoader(BaseDataLoader):
|
42 |
+
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True,num_workers=4, pin_memory=True):
|
43 |
+
config = ConfigParser.get_instance()
|
44 |
+
cfg_trainer = config['trainer']
|
45 |
+
|
46 |
+
transform_train = transforms.Compose([
|
47 |
+
#transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
|
48 |
+
transforms.RandomCrop(32, padding=4),
|
49 |
+
transforms.RandomHorizontalFlip(),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
|
52 |
+
])
|
53 |
+
transform_val = transforms.Compose([
|
54 |
+
transforms.ToTensor(),
|
55 |
+
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
|
56 |
+
])
|
57 |
+
self.data_dir = data_dir
|
58 |
+
config = ConfigParser.get_instance()
|
59 |
+
cfg_trainer = config['trainer']
|
60 |
+
|
61 |
+
noise_file='%sCIFAR100_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
|
62 |
+
|
63 |
+
self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
|
64 |
+
transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
|
65 |
+
|
66 |
+
super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
|
67 |
+
val_dataset = self.val_dataset)
|
68 |
+
def run_loader(self, batch_size, shuffle, validation_split, num_workers, pin_memory):
|
69 |
+
super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
|
70 |
+
val_dataset = self.val_dataset)
|
ELR/logger/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .logger import *
|
2 |
+
from .visualization import *
|
ELR/logger/logger.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import logging.config
|
3 |
+
from pathlib import Path
|
4 |
+
from utils import read_json
|
5 |
+
|
6 |
+
|
7 |
+
def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
|
8 |
+
"""
|
9 |
+
Setup logging configuration
|
10 |
+
"""
|
11 |
+
log_config = Path(log_config)
|
12 |
+
if log_config.is_file():
|
13 |
+
config = read_json(log_config)
|
14 |
+
# modify logging paths based on run config
|
15 |
+
for _, handler in config['handlers'].items():
|
16 |
+
if 'filename' in handler:
|
17 |
+
handler['filename'] = str(save_dir / handler['filename'])
|
18 |
+
|
19 |
+
logging.config.dictConfig(config)
|
20 |
+
else:
|
21 |
+
print("Warning: logging configuration file is not found in {}.".format(log_config))
|
22 |
+
logging.basicConfig(level=default_level)
|
ELR/logger/logger_config.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
{
|
3 |
+
"version": 1,
|
4 |
+
"disable_existing_loggers": false,
|
5 |
+
"formatters": {
|
6 |
+
"simple": {"format": "%(message)s"},
|
7 |
+
"datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
|
8 |
+
},
|
9 |
+
"handlers": {
|
10 |
+
"console": {
|
11 |
+
"class": "logging.StreamHandler",
|
12 |
+
"level": "DEBUG",
|
13 |
+
"formatter": "simple",
|
14 |
+
"stream": "ext://sys.stdout"
|
15 |
+
},
|
16 |
+
"info_file_handler": {
|
17 |
+
"class": "logging.handlers.RotatingFileHandler",
|
18 |
+
"level": "INFO",
|
19 |
+
"formatter": "datetime",
|
20 |
+
"filename": "info.log",
|
21 |
+
"maxBytes": 10485760,
|
22 |
+
"backupCount": 20, "encoding": "utf8"
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"root": {
|
26 |
+
"level": "INFO",
|
27 |
+
"handlers": [
|
28 |
+
"console",
|
29 |
+
"info_file_handler"
|
30 |
+
]
|
31 |
+
}
|
32 |
+
}
|
ELR/logger/visualization.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from utils import Timer
|
3 |
+
|
4 |
+
|
5 |
+
class MLFlow:
|
6 |
+
def __init__(self, log_dir, logger, enabled):
|
7 |
+
self.mlflow = None
|
8 |
+
|
9 |
+
if enabled:
|
10 |
+
log_dir = str(log_dir)
|
11 |
+
|
12 |
+
# Retrieve visualization writer.
|
13 |
+
try:
|
14 |
+
self.mlflow = importlib.import_module("mlflow")
|
15 |
+
succeeded = True
|
16 |
+
except ImportError:
|
17 |
+
succeeded = False
|
18 |
+
|
19 |
+
if not succeeded:
|
20 |
+
message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \
|
21 |
+
"this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \
|
22 |
+
"the 'config.json' file."
|
23 |
+
logger.warning(message)
|
24 |
+
|
25 |
+
self.step = 0
|
26 |
+
self.mode = ''
|
27 |
+
|
28 |
+
self.mlflow_ftns_with_tag_and_value = {
|
29 |
+
'log_param', 'log_metric'
|
30 |
+
}
|
31 |
+
self.mlflow_ftns = {
|
32 |
+
'start_run'
|
33 |
+
}
|
34 |
+
# self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
|
35 |
+
|
36 |
+
# self.timer = Timer()
|
37 |
+
|
38 |
+
# def set_step(self, step, mode='train'):
|
39 |
+
# self.mode = mode
|
40 |
+
# self.step = step
|
41 |
+
# if step == 0:
|
42 |
+
# self.timer.reset()
|
43 |
+
# else:
|
44 |
+
# duration = self.timer.check()
|
45 |
+
# self.add_scalar('steps_per_sec', 1 / duration)
|
46 |
+
|
47 |
+
def __getattr__(self, name):
|
48 |
+
"""
|
49 |
+
If visualization is configured to use:
|
50 |
+
return add_data() methods of tensorboard with additional information (step, tag) added.
|
51 |
+
Otherwise:
|
52 |
+
return a blank function handle that does nothing
|
53 |
+
"""
|
54 |
+
if name in self.mlflow_ftns_with_tag_and_value:
|
55 |
+
add_data = getattr(self.mlflow, name, None)
|
56 |
+
|
57 |
+
def wrapper(tag, data, *args, **kwargs):
|
58 |
+
if add_data is not None:
|
59 |
+
# add mode(train/valid) tag
|
60 |
+
if name not in self.tag_mode_exceptions:
|
61 |
+
tag = '{}/{}'.format(tag, self.mode)
|
62 |
+
add_data(tag, data, *args, **kwargs)
|
63 |
+
|
64 |
+
return wrapper
|
65 |
+
elif name in self.mlflow_ftns:
|
66 |
+
add_data = getattr(self.mlflow, name, None)
|
67 |
+
|
68 |
+
def wrapper(*args, **kwargs):
|
69 |
+
if add_data is not None:
|
70 |
+
# add mode(train/valid) tag
|
71 |
+
# if name not in self.tag_mode_exceptions:
|
72 |
+
# tag = '{}/{}'.format(tag, self.mode)
|
73 |
+
add_data(*args, **kwargs)
|
74 |
+
|
75 |
+
return wrapper
|
76 |
+
else:
|
77 |
+
# default action for returning methods defined in this class, set_step() for instance.
|
78 |
+
try:
|
79 |
+
attr = object.__getattr__(name)
|
80 |
+
except AttributeError:
|
81 |
+
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
|
82 |
+
return attr
|
83 |
+
|
84 |
+
|
85 |
+
class TensorboardWriter:
|
86 |
+
def __init__(self, log_dir, logger, enabled):
|
87 |
+
self.writer = None
|
88 |
+
self.selected_module = ""
|
89 |
+
|
90 |
+
if enabled:
|
91 |
+
log_dir = str(log_dir)
|
92 |
+
|
93 |
+
# Retrieve vizualization writer.
|
94 |
+
succeeded = False
|
95 |
+
for module in ["torch.utils.tensorboard", "tensorboardX"]:
|
96 |
+
try:
|
97 |
+
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
|
98 |
+
succeeded = True
|
99 |
+
break
|
100 |
+
except ImportError:
|
101 |
+
succeeded = False
|
102 |
+
self.selected_module = module
|
103 |
+
|
104 |
+
if not succeeded:
|
105 |
+
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
|
106 |
+
"this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
|
107 |
+
"PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
|
108 |
+
"the 'config.json' file."
|
109 |
+
logger.warning(message)
|
110 |
+
|
111 |
+
self.step = 0
|
112 |
+
self.mode = ''
|
113 |
+
|
114 |
+
self.tb_writer_ftns = {
|
115 |
+
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
|
116 |
+
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
|
117 |
+
}
|
118 |
+
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
|
119 |
+
|
120 |
+
self.timer = Timer()
|
121 |
+
|
122 |
+
def set_step(self, step, mode='train'):
|
123 |
+
self.mode = mode
|
124 |
+
self.step = step
|
125 |
+
if step == 0:
|
126 |
+
self.timer.reset()
|
127 |
+
else:
|
128 |
+
duration = self.timer.check()
|
129 |
+
self.add_scalar('steps_per_sec', 1 / duration)
|
130 |
+
|
131 |
+
def __getattr__(self, name):
|
132 |
+
"""
|
133 |
+
If visualization is configured to use:
|
134 |
+
return add_data() methods of tensorboard with additional information (step, tag) added.
|
135 |
+
Otherwise:
|
136 |
+
return a blank function handle that does nothing
|
137 |
+
"""
|
138 |
+
if name in self.tb_writer_ftns:
|
139 |
+
add_data = getattr(self.writer, name, None)
|
140 |
+
|
141 |
+
def wrapper(tag, data, *args, **kwargs):
|
142 |
+
if add_data is not None:
|
143 |
+
# add mode(train/valid) tag
|
144 |
+
if name not in self.tag_mode_exceptions:
|
145 |
+
tag = '{}/{}'.format(tag, self.mode)
|
146 |
+
add_data(tag, data, self.step, *args, **kwargs)
|
147 |
+
return wrapper
|
148 |
+
else:
|
149 |
+
# default action for returning methods defined in this class, set_step() for instance.
|
150 |
+
try:
|
151 |
+
attr = object.__getattr__(name)
|
152 |
+
except AttributeError:
|
153 |
+
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
|
154 |
+
return attr
|
ELR/model/ResNet_Zoo.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''ResNet in PyTorch.
|
2 |
+
For Pre-activation ResNet, see 'preact_resnet.py'.
|
3 |
+
Reference:
|
4 |
+
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
5 |
+
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
6 |
+
'''
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class BasicBlock(nn.Module):
|
13 |
+
expansion = 1
|
14 |
+
|
15 |
+
def __init__(self, in_planes, planes, stride=1):
|
16 |
+
super(BasicBlock, self).__init__()
|
17 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
20 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
21 |
+
|
22 |
+
self.shortcut = nn.Sequential()
|
23 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
24 |
+
self.shortcut = nn.Sequential(
|
25 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
26 |
+
nn.BatchNorm2d(self.expansion*planes)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
31 |
+
out = self.bn2(self.conv2(out))
|
32 |
+
out += self.shortcut(x)
|
33 |
+
out = F.relu(out)
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
class Bottleneck(nn.Module):
|
39 |
+
expansion = 4
|
40 |
+
|
41 |
+
def __init__(self, in_planes, planes, stride=1):
|
42 |
+
super(Bottleneck, self).__init__()
|
43 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
44 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
45 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
46 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
47 |
+
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
48 |
+
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
49 |
+
|
50 |
+
self.shortcut = nn.Sequential()
|
51 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
52 |
+
self.shortcut = nn.Sequential(
|
53 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
54 |
+
nn.BatchNorm2d(self.expansion*planes)
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
59 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
60 |
+
out = self.bn3(self.conv3(out))
|
61 |
+
out += self.shortcut(x)
|
62 |
+
out = F.relu(out)
|
63 |
+
return out
|
64 |
+
|
65 |
+
|
66 |
+
class ResNet(nn.Module):
|
67 |
+
def __init__(self, block, num_blocks, num_classes=10):
|
68 |
+
super(ResNet, self).__init__()
|
69 |
+
self.in_planes = 64
|
70 |
+
|
71 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
72 |
+
self.bn1 = nn.BatchNorm2d(64)
|
73 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
74 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
75 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
76 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
77 |
+
self.linear = nn.Linear(512*block.expansion, num_classes)
|
78 |
+
|
79 |
+
|
80 |
+
self.gradients = None
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
85 |
+
strides = [stride] + [1]*(num_blocks-1)
|
86 |
+
layers = []
|
87 |
+
for stride in strides:
|
88 |
+
layers.append(block(self.in_planes, planes, stride))
|
89 |
+
self.in_planes = planes * block.expansion
|
90 |
+
return nn.Sequential(*layers)
|
91 |
+
|
92 |
+
def activations_hook(self, grad):
|
93 |
+
self.gradients = grad
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
97 |
+
out = self.layer1(out)
|
98 |
+
out = self.layer2(out)
|
99 |
+
out = self.layer3(out)
|
100 |
+
out = self.layer4(out)
|
101 |
+
out = F.avg_pool2d(out, 4)
|
102 |
+
y = out.view(out.size(0), -1)
|
103 |
+
out = self.linear(y)
|
104 |
+
if out.requires_grad:
|
105 |
+
out.register_hook(self.activations_hook)
|
106 |
+
return out
|
107 |
+
|
108 |
+
def get_activations_gradient(self):
|
109 |
+
return self.gradients
|
110 |
+
|
111 |
+
|
112 |
+
def ResNet18():
|
113 |
+
return ResNet(BasicBlock, [2,2,2,2])
|
114 |
+
|
115 |
+
def ResNet34():
|
116 |
+
return ResNet(BasicBlock, [3,4,6,3])
|
117 |
+
|
118 |
+
def ResNet50():
|
119 |
+
return ResNet(Bottleneck, [3,4,6,3])
|
120 |
+
|
121 |
+
def ResNet101():
|
122 |
+
return ResNet(Bottleneck, [3,4,23,3])
|
123 |
+
|
124 |
+
def ResNet152():
|
125 |
+
return ResNet(Bottleneck, [3,8,36,3])
|
126 |
+
|
127 |
+
|
128 |
+
def test():
|
129 |
+
net = ResNet18()
|
130 |
+
y = net(torch.randn(1,3,32,32))
|
131 |
+
print(y.size())
|
132 |
+
|
133 |
+
# test()
|
ELR/model/loss.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
import torch
|
3 |
+
from parse_config import ConfigParser
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def cross_entropy(output, target):
|
8 |
+
return F.cross_entropy(output, target)
|
9 |
+
|
10 |
+
|
11 |
+
class elr_loss(nn.Module):
|
12 |
+
def __init__(self, num_examp, num_classes=10, beta=0.3):
|
13 |
+
super(elr_loss, self).__init__()
|
14 |
+
self.num_classes = num_classes
|
15 |
+
self.config = ConfigParser.get_instance()
|
16 |
+
self.USE_CUDA = torch.cuda.is_available()
|
17 |
+
self.target = torch.zeros(num_examp, self.num_classes).cuda() if self.USE_CUDA else torch.zeros(num_examp, self.num_classes)
|
18 |
+
self.beta = beta
|
19 |
+
|
20 |
+
|
21 |
+
def forward(self, index, output, label):
|
22 |
+
y_pred = F.softmax(output,dim=1)
|
23 |
+
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
|
24 |
+
y_pred_ = y_pred.data.detach()
|
25 |
+
self.target[index] = self.beta * self.target[index] + (1-self.beta) * ((y_pred_)/(y_pred_).sum(dim=1,keepdim=True))
|
26 |
+
ce_loss = F.cross_entropy(output, label)
|
27 |
+
elr_reg = ((1-(self.target[index] * y_pred).sum(dim=1)).log()).mean()
|
28 |
+
final_loss = ce_loss + self.config['train_loss']['args']['lambda']*elr_reg
|
29 |
+
return final_loss
|
30 |
+
|
ELR/model/metric.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def my_metric(output, target):
|
5 |
+
with torch.no_grad():
|
6 |
+
pred = torch.argmax(output, dim=1)
|
7 |
+
assert pred.shape[0] == len(target)
|
8 |
+
correct = 0
|
9 |
+
correct += torch.sum(pred == target).item()
|
10 |
+
return correct / len(target)
|
11 |
+
|
12 |
+
|
13 |
+
def my_metric2(output, target, k=5):
|
14 |
+
with torch.no_grad():
|
15 |
+
pred = torch.topk(output, k, dim=1)[1]
|
16 |
+
assert pred.shape[0] == len(target)
|
17 |
+
correct = 0
|
18 |
+
for i in range(k):
|
19 |
+
correct += torch.sum(pred[:, i] == target).item()
|
20 |
+
return correct / len(target)
|
ELR/model/model.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from base import BaseModel
|
4 |
+
from .ResNet_Zoo import ResNet, BasicBlock
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def resnet34(num_classes=10):
|
9 |
+
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
ELR/parse_config.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from functools import reduce
|
5 |
+
from operator import getitem
|
6 |
+
from datetime import datetime
|
7 |
+
from logger import setup_logging
|
8 |
+
from utils import read_json, write_json
|
9 |
+
|
10 |
+
|
11 |
+
class ConfigParser:
|
12 |
+
|
13 |
+
__instance = None
|
14 |
+
|
15 |
+
def __new__(cls, args, options='', timestamp=True):
|
16 |
+
raise NotImplementedError('Cannot initialize via Constructor')
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def __internal_new__(cls):
|
20 |
+
return super().__new__(cls)
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def get_instance(cls, args=None, options='', timestamp=True):
|
24 |
+
if not cls.__instance:
|
25 |
+
if args is None:
|
26 |
+
NotImplementedError('Cannot initialize without args')
|
27 |
+
cls.__instance = cls.__internal_new__()
|
28 |
+
cls.__instance.__init__(args, options)
|
29 |
+
|
30 |
+
return cls.__instance
|
31 |
+
|
32 |
+
def __init__(self, args, options='', timestamp=True):
|
33 |
+
# parse default and custom cli options
|
34 |
+
for opt in options:
|
35 |
+
args.add_argument(*opt.flags, default=None, type=opt.type)
|
36 |
+
args = args.parse_args()
|
37 |
+
|
38 |
+
if args.device:
|
39 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
40 |
+
if args.resume is None:
|
41 |
+
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
|
42 |
+
assert args.config is not None, msg_no_cfg
|
43 |
+
self.cfg_fname = Path(args.config)
|
44 |
+
config = read_json(self.cfg_fname)
|
45 |
+
self.resume = None
|
46 |
+
else:
|
47 |
+
self.resume = Path(args.resume)
|
48 |
+
resume_cfg_fname = self.resume.parent / 'config.json'
|
49 |
+
config = read_json(resume_cfg_fname)
|
50 |
+
if args.config is not None:
|
51 |
+
config.update(read_json(Path(args.config)))
|
52 |
+
|
53 |
+
# load config file and apply custom cli options
|
54 |
+
self._config = _update_config(config, options, args)
|
55 |
+
|
56 |
+
# set save_dir where trained model and log will be saved.
|
57 |
+
save_dir = Path(self.config['trainer']['save_dir'])
|
58 |
+
timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
|
59 |
+
|
60 |
+
|
61 |
+
if self.config['trainer']['asym']:
|
62 |
+
exper_name = self.config['name'] + '_asym_' + str(int(self.config['trainer']['percent']*100))
|
63 |
+
else:
|
64 |
+
exper_name = self.config['name'] + '_sym_' + str(int(self.config['trainer']['percent']*100))
|
65 |
+
self._save_dir = save_dir / 'models' / exper_name / timestamp
|
66 |
+
self._log_dir = save_dir / 'log' / exper_name / timestamp
|
67 |
+
|
68 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
69 |
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
70 |
+
|
71 |
+
# save updated config file to the checkpoint dir
|
72 |
+
write_json(self.config, self.save_dir / 'config.json')
|
73 |
+
|
74 |
+
# configure logging module
|
75 |
+
setup_logging(self.log_dir)
|
76 |
+
self.log_levels = {
|
77 |
+
0: logging.WARNING,
|
78 |
+
1: logging.INFO,
|
79 |
+
2: logging.DEBUG
|
80 |
+
}
|
81 |
+
|
82 |
+
def initialize(self, name, module, *args, **kwargs):
|
83 |
+
"""
|
84 |
+
finds a function handle with the name given as 'type' in config, and returns the
|
85 |
+
instance initialized with corresponding keyword args given as 'args'.
|
86 |
+
"""
|
87 |
+
module_name = self[name]['type']
|
88 |
+
module_args = dict(self[name]['args'])
|
89 |
+
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
|
90 |
+
module_args.update(kwargs)
|
91 |
+
return getattr(module, module_name)(*args, **module_args)
|
92 |
+
|
93 |
+
def __getitem__(self, name):
|
94 |
+
return self.config[name]
|
95 |
+
|
96 |
+
def get_logger(self, name, verbosity=2):
|
97 |
+
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
|
98 |
+
self.log_levels.keys())
|
99 |
+
assert verbosity in self.log_levels, msg_verbosity
|
100 |
+
logger = logging.getLogger(name)
|
101 |
+
logger.setLevel(self.log_levels[verbosity])
|
102 |
+
return logger
|
103 |
+
|
104 |
+
# setting read-only attributes
|
105 |
+
@property
|
106 |
+
def config(self):
|
107 |
+
return self._config
|
108 |
+
|
109 |
+
@property
|
110 |
+
def save_dir(self):
|
111 |
+
return self._save_dir
|
112 |
+
|
113 |
+
@property
|
114 |
+
def log_dir(self):
|
115 |
+
return self._log_dir
|
116 |
+
|
117 |
+
|
118 |
+
# helper functions used to update config dict with custom cli options
|
119 |
+
def _update_config(config, options, args):
|
120 |
+
for opt in options:
|
121 |
+
value = getattr(args, _get_opt_name(opt.flags))
|
122 |
+
if value is not None:
|
123 |
+
_set_by_path(config, opt.target, value)
|
124 |
+
if 'target2' in opt._fields:
|
125 |
+
_set_by_path(config, opt.target2, value)
|
126 |
+
if 'target3' in opt._fields:
|
127 |
+
_set_by_path(config, opt.target3, value)
|
128 |
+
|
129 |
+
return config
|
130 |
+
|
131 |
+
|
132 |
+
def _get_opt_name(flags):
|
133 |
+
for flg in flags:
|
134 |
+
if flg.startswith('--'):
|
135 |
+
return flg.replace('--', '')
|
136 |
+
return flags[0].replace('--', '')
|
137 |
+
|
138 |
+
|
139 |
+
def _set_by_path(tree, keys, value):
|
140 |
+
"""Set a value in a nested object in tree by sequence of keys."""
|
141 |
+
_get_by_path(tree, keys[:-1])[keys[-1]] = value
|
142 |
+
|
143 |
+
|
144 |
+
def _get_by_path(tree, keys):
|
145 |
+
"""Access a nested object in tree by sequence of keys."""
|
146 |
+
return reduce(getitem, keys, tree)
|
ELR/test.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import data_loader.data_loaders as module_data
|
5 |
+
import model.loss as module_loss
|
6 |
+
import model.metric as module_metric
|
7 |
+
import model.model as module_arch
|
8 |
+
from parse_config import ConfigParser
|
9 |
+
|
10 |
+
|
11 |
+
def main(config):
|
12 |
+
logger = config.get_logger('test')
|
13 |
+
|
14 |
+
# setup data_loader instances
|
15 |
+
data_loader = getattr(module_data, config['data_loader']['type'])(
|
16 |
+
config['data_loader']['args']['data_dir'],
|
17 |
+
batch_size=512,
|
18 |
+
shuffle=False,
|
19 |
+
validation_split=0.0,
|
20 |
+
training=False,
|
21 |
+
num_workers=2
|
22 |
+
).split_validation()
|
23 |
+
|
24 |
+
# build model architecture
|
25 |
+
model = config.initialize('arch', module_arch)
|
26 |
+
logger.info(model)
|
27 |
+
|
28 |
+
# get function handles of loss and metrics
|
29 |
+
loss_fn = getattr(module_loss, config['val_loss'])
|
30 |
+
metric_fns = [getattr(module_metric, met) for met in config['metrics']]
|
31 |
+
|
32 |
+
logger.info('Loading checkpoint: {} ...'.format(config.resume))
|
33 |
+
checkpoint = torch.load(config.resume,map_location='cpu')
|
34 |
+
state_dict = checkpoint['state_dict']
|
35 |
+
if config['n_gpu'] > 1:
|
36 |
+
model = torch.nn.DataParallel(model)
|
37 |
+
model.load_state_dict(state_dict)
|
38 |
+
|
39 |
+
# prepare model for testing
|
40 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
41 |
+
model = model.to(device)
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
total_loss = 0.0
|
45 |
+
total_metrics = torch.zeros(len(metric_fns))
|
46 |
+
|
47 |
+
with torch.no_grad():
|
48 |
+
for i, (data, target,_,_) in enumerate(tqdm(data_loader)):
|
49 |
+
data, target = data.to(device), target.to(device)
|
50 |
+
output = model(data)
|
51 |
+
|
52 |
+
#
|
53 |
+
# save sample images, or do something with output here
|
54 |
+
#
|
55 |
+
|
56 |
+
# computing loss, metrics on test set
|
57 |
+
loss = loss_fn(output, target)
|
58 |
+
batch_size = data.shape[0]
|
59 |
+
total_loss += loss.item() * batch_size
|
60 |
+
for i, metric in enumerate(metric_fns):
|
61 |
+
total_metrics[i] += metric(output, target) * batch_size
|
62 |
+
|
63 |
+
n_samples = len(data_loader.sampler)
|
64 |
+
log = {'loss': total_loss / n_samples}
|
65 |
+
log.update({
|
66 |
+
met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
|
67 |
+
})
|
68 |
+
logger.info(log)
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
args = argparse.ArgumentParser(description='PyTorch Template')
|
73 |
+
|
74 |
+
args.add_argument('-c', '--config', default=None, type=str,
|
75 |
+
help='config file path (default: None)')
|
76 |
+
args.add_argument('-r', '--resume', default=None, type=str,
|
77 |
+
help='path to latest checkpoint (default: None)')
|
78 |
+
args.add_argument('-d', '--device', default=None, type=str,
|
79 |
+
help='indices of GPUs to enable (default: all)')
|
80 |
+
config = ConfigParser.get_instance(args, '')
|
81 |
+
#config = ConfigParser(args)
|
82 |
+
main(config)
|
ELR/train.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import collections
|
3 |
+
import sys
|
4 |
+
import requests
|
5 |
+
import socket
|
6 |
+
import torch
|
7 |
+
import mlflow
|
8 |
+
import mlflow.pytorch
|
9 |
+
import data_loader.data_loaders as module_data
|
10 |
+
import model.loss as module_loss
|
11 |
+
import model.metric as module_metric
|
12 |
+
import model.model as module_arch
|
13 |
+
from parse_config import ConfigParser
|
14 |
+
from trainer import Trainer
|
15 |
+
from collections import OrderedDict
|
16 |
+
import random
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def log_params(conf: OrderedDict, parent_key: str = None):
|
21 |
+
for key, value in conf.items():
|
22 |
+
if parent_key is not None:
|
23 |
+
combined_key = f'{parent_key}-{key}'
|
24 |
+
else:
|
25 |
+
combined_key = key
|
26 |
+
|
27 |
+
if not isinstance(value, OrderedDict):
|
28 |
+
mlflow.log_param(combined_key, value)
|
29 |
+
else:
|
30 |
+
log_params(value, combined_key)
|
31 |
+
|
32 |
+
|
33 |
+
def main(config: ConfigParser):
|
34 |
+
|
35 |
+
logger = config.get_logger('train')
|
36 |
+
|
37 |
+
data_loader = getattr(module_data, config['data_loader']['type'])(
|
38 |
+
config['data_loader']['args']['data_dir'],
|
39 |
+
batch_size= config['data_loader']['args']['batch_size'],
|
40 |
+
shuffle=config['data_loader']['args']['shuffle'],
|
41 |
+
validation_split=config['data_loader']['args']['validation_split'],
|
42 |
+
num_batches=config['data_loader']['args']['num_batches'],
|
43 |
+
training=True,
|
44 |
+
num_workers=config['data_loader']['args']['num_workers'],
|
45 |
+
pin_memory=config['data_loader']['args']['pin_memory']
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
valid_data_loader = data_loader.split_validation()
|
50 |
+
|
51 |
+
# test_data_loader = None
|
52 |
+
|
53 |
+
test_data_loader = getattr(module_data, config['data_loader']['type'])(
|
54 |
+
config['data_loader']['args']['data_dir'],
|
55 |
+
batch_size=128,
|
56 |
+
shuffle=False,
|
57 |
+
validation_split=0.0,
|
58 |
+
training=False,
|
59 |
+
num_workers=2
|
60 |
+
).split_validation()
|
61 |
+
|
62 |
+
|
63 |
+
# build model architecture, then print to console
|
64 |
+
model = config.initialize('arch', module_arch)
|
65 |
+
|
66 |
+
# get function handles of loss and metrics
|
67 |
+
logger.info(config.config)
|
68 |
+
if hasattr(data_loader.dataset, 'num_raw_example'):
|
69 |
+
num_examp = data_loader.dataset.num_raw_example
|
70 |
+
else:
|
71 |
+
num_examp = len(data_loader.dataset)
|
72 |
+
|
73 |
+
train_loss = getattr(module_loss, config['train_loss']['type'])(num_examp=num_examp, num_classes=config['num_classes'],
|
74 |
+
beta=config['train_loss']['args']['beta'])
|
75 |
+
|
76 |
+
val_loss = getattr(module_loss, config['val_loss'])
|
77 |
+
metrics = [getattr(module_metric, met) for met in config['metrics']]
|
78 |
+
|
79 |
+
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
|
80 |
+
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
81 |
+
|
82 |
+
optimizer = config.initialize('optimizer', torch.optim, [{'params': trainable_params}])
|
83 |
+
|
84 |
+
lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer)
|
85 |
+
|
86 |
+
trainer = Trainer(model, train_loss, metrics, optimizer,
|
87 |
+
config=config,
|
88 |
+
data_loader=data_loader,
|
89 |
+
valid_data_loader=valid_data_loader,
|
90 |
+
test_data_loader=test_data_loader,
|
91 |
+
lr_scheduler=lr_scheduler,
|
92 |
+
val_criterion=val_loss)
|
93 |
+
|
94 |
+
trainer.train()
|
95 |
+
logger = config.get_logger('trainer', config['trainer']['verbosity'])
|
96 |
+
cfg_trainer = config['trainer']
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
args = argparse.ArgumentParser(description='PyTorch Template')
|
101 |
+
args.add_argument('-c', '--config', default=None, type=str,
|
102 |
+
help='config file path (default: None)')
|
103 |
+
args.add_argument('-r', '--resume', default=None, type=str,
|
104 |
+
help='path to latest checkpoint (default: None)')
|
105 |
+
args.add_argument('-d', '--device', default=None, type=str,
|
106 |
+
help='indices of GPUs to enable (default: all)')
|
107 |
+
|
108 |
+
# custom cli options to modify configuration from default values given in json file.
|
109 |
+
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
|
110 |
+
options = [
|
111 |
+
CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
|
112 |
+
CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
|
113 |
+
CustomArgs(['--lamb', '--lamb'], type=float, target=('train_loss', 'args', 'lambda')),
|
114 |
+
CustomArgs(['--beta', '--beta'], type=float, target=('train_loss', 'args', 'beta')),
|
115 |
+
CustomArgs(['--percent', '--percent'], type=float, target=('trainer', 'percent')),
|
116 |
+
CustomArgs(['--asym', '--asym'], type=bool, target=('trainer', 'asym')),
|
117 |
+
CustomArgs(['--name', '--exp_name'], type=str, target=('name',)),
|
118 |
+
CustomArgs(['--seed', '--seed'], type=int, target=('seed',))
|
119 |
+
]
|
120 |
+
config = ConfigParser.get_instance(args, options)
|
121 |
+
|
122 |
+
random.seed(config['seed'])
|
123 |
+
torch.manual_seed(config['seed'])
|
124 |
+
torch.cuda.manual_seed_all(config['seed'])
|
125 |
+
main(config)
|
ELR/trainer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .trainer import *
|
ELR/trainer/trainer.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from typing import List
|
5 |
+
from torchvision.utils import make_grid
|
6 |
+
from base import BaseTrainer
|
7 |
+
from utils import inf_loop
|
8 |
+
import sys
|
9 |
+
from sklearn.mixture import GaussianMixture
|
10 |
+
|
11 |
+
class Trainer(BaseTrainer):
|
12 |
+
"""
|
13 |
+
Trainer class
|
14 |
+
|
15 |
+
Note:
|
16 |
+
Inherited from BaseTrainer.
|
17 |
+
"""
|
18 |
+
def __init__(self, model, train_criterion, metrics, optimizer, config, data_loader,
|
19 |
+
valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None, val_criterion=None):
|
20 |
+
super().__init__(model, train_criterion, metrics, optimizer, config, val_criterion)
|
21 |
+
self.config = config
|
22 |
+
self.data_loader = data_loader
|
23 |
+
if len_epoch is None:
|
24 |
+
# epoch-based training
|
25 |
+
self.len_epoch = len(self.data_loader)
|
26 |
+
else:
|
27 |
+
# iteration-based training
|
28 |
+
self.data_loader = inf_loop(data_loader)
|
29 |
+
self.len_epoch = len_epoch
|
30 |
+
self.valid_data_loader = valid_data_loader
|
31 |
+
|
32 |
+
self.test_data_loader = test_data_loader
|
33 |
+
self.do_validation = self.valid_data_loader is not None
|
34 |
+
self.do_test = self.test_data_loader is not None
|
35 |
+
self.lr_scheduler = lr_scheduler
|
36 |
+
self.log_step = int(np.sqrt(data_loader.batch_size))
|
37 |
+
self.train_loss_list: List[float] = []
|
38 |
+
self.val_loss_list: List[float] = []
|
39 |
+
self.test_loss_list: List[float] = []
|
40 |
+
#Visdom visualization
|
41 |
+
|
42 |
+
|
43 |
+
def _eval_metrics(self, output, label):
|
44 |
+
acc_metrics = np.zeros(len(self.metrics))
|
45 |
+
for i, metric in enumerate(self.metrics):
|
46 |
+
acc_metrics[i] += metric(output, label)
|
47 |
+
self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i])
|
48 |
+
return acc_metrics
|
49 |
+
|
50 |
+
def _train_epoch(self, epoch):
|
51 |
+
"""
|
52 |
+
Training logic for an epoch
|
53 |
+
|
54 |
+
:param epoch: Current training epoch.
|
55 |
+
:return: A log that contains all information you want to save.
|
56 |
+
|
57 |
+
Note:
|
58 |
+
If you have additional information to record, for example:
|
59 |
+
> additional_log = {"x": x, "y": y}
|
60 |
+
merge it with log before return. i.e.
|
61 |
+
> log = {**log, **additional_log}
|
62 |
+
> return log
|
63 |
+
|
64 |
+
The metrics in log must have the key 'metrics'.
|
65 |
+
"""
|
66 |
+
self.model.train()
|
67 |
+
|
68 |
+
total_loss = 0
|
69 |
+
total_metrics = np.zeros(len(self.metrics))
|
70 |
+
|
71 |
+
with tqdm(self.data_loader) as progress:
|
72 |
+
for batch_idx, (data, label, indexs, _) in enumerate(progress):
|
73 |
+
progress.set_description_str(f'Train epoch {epoch}')
|
74 |
+
|
75 |
+
data, label = data.to(self.device), label.long().to(self.device)
|
76 |
+
|
77 |
+
output = self.model(data)
|
78 |
+
|
79 |
+
loss = self.train_criterion(indexs.cpu().detach().numpy().tolist(), output, label)
|
80 |
+
self.optimizer.zero_grad()
|
81 |
+
loss.backward()
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
self.optimizer.step()
|
87 |
+
|
88 |
+
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
|
89 |
+
self.writer.add_scalar('loss', loss.item())
|
90 |
+
self.train_loss_list.append(loss.item())
|
91 |
+
total_loss += loss.item()
|
92 |
+
total_metrics += self._eval_metrics(output, label)
|
93 |
+
|
94 |
+
|
95 |
+
if batch_idx % self.log_step == 0:
|
96 |
+
progress.set_postfix_str(' {} Loss: {:.6f}'.format(
|
97 |
+
self._progress(batch_idx),
|
98 |
+
loss.item()))
|
99 |
+
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
|
100 |
+
|
101 |
+
if batch_idx == self.len_epoch:
|
102 |
+
break
|
103 |
+
# if hasattr(self.data_loader, 'run'):
|
104 |
+
# self.data_loader.run()
|
105 |
+
|
106 |
+
log = {
|
107 |
+
'loss': total_loss / self.len_epoch,
|
108 |
+
'metrics': (total_metrics / self.len_epoch).tolist(),
|
109 |
+
'learning rate': self.lr_scheduler.get_lr()
|
110 |
+
}
|
111 |
+
|
112 |
+
|
113 |
+
if self.do_validation:
|
114 |
+
val_log = self._valid_epoch(epoch)
|
115 |
+
log.update(val_log)
|
116 |
+
if self.do_test:
|
117 |
+
test_log, test_meta = self._test_epoch(epoch)
|
118 |
+
log.update(test_log)
|
119 |
+
else:
|
120 |
+
test_meta = [0,0]
|
121 |
+
|
122 |
+
|
123 |
+
if self.lr_scheduler is not None:
|
124 |
+
self.lr_scheduler.step()
|
125 |
+
|
126 |
+
return log
|
127 |
+
|
128 |
+
|
129 |
+
def _valid_epoch(self, epoch):
|
130 |
+
"""
|
131 |
+
Validate after training an epoch
|
132 |
+
|
133 |
+
:return: A log that contains information about validation
|
134 |
+
|
135 |
+
Note:
|
136 |
+
The validation metrics in log must have the key 'val_metrics'.
|
137 |
+
"""
|
138 |
+
self.model.eval()
|
139 |
+
|
140 |
+
total_val_loss = 0
|
141 |
+
total_val_metrics = np.zeros(len(self.metrics))
|
142 |
+
with torch.no_grad():
|
143 |
+
with tqdm(self.valid_data_loader) as progress:
|
144 |
+
for batch_idx, (data, label, _, _) in enumerate(progress):
|
145 |
+
progress.set_description_str(f'Valid epoch {epoch}')
|
146 |
+
data, label = data.to(self.device), label.to(self.device)
|
147 |
+
output = self.model(data)
|
148 |
+
loss = self.val_criterion(output, label)
|
149 |
+
|
150 |
+
self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
|
151 |
+
self.writer.add_scalar('loss', loss.item())
|
152 |
+
self.val_loss_list.append(loss.item())
|
153 |
+
total_val_loss += loss.item()
|
154 |
+
total_val_metrics += self._eval_metrics(output, label)
|
155 |
+
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
|
156 |
+
|
157 |
+
# add histogram of model parameters to the tensorboard
|
158 |
+
for name, p in self.model.named_parameters():
|
159 |
+
self.writer.add_histogram(name, p, bins='auto')
|
160 |
+
|
161 |
+
return {
|
162 |
+
'val_loss': total_val_loss / len(self.valid_data_loader),
|
163 |
+
'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
|
164 |
+
}
|
165 |
+
|
166 |
+
def _test_epoch(self, epoch):
|
167 |
+
"""
|
168 |
+
Test after training an epoch
|
169 |
+
|
170 |
+
:return: A log that contains information about test
|
171 |
+
|
172 |
+
Note:
|
173 |
+
The Test metrics in log must have the key 'val_metrics'.
|
174 |
+
"""
|
175 |
+
self.model.eval()
|
176 |
+
total_test_loss = 0
|
177 |
+
total_test_metrics = np.zeros(len(self.metrics))
|
178 |
+
results = np.zeros((len(self.test_data_loader.dataset), self.config['num_classes']), dtype=np.float32)
|
179 |
+
tar_ = np.zeros((len(self.test_data_loader.dataset),), dtype=np.float32)
|
180 |
+
with torch.no_grad():
|
181 |
+
with tqdm(self.test_data_loader) as progress:
|
182 |
+
for batch_idx, (data, label,indexs,_) in enumerate(progress):
|
183 |
+
progress.set_description_str(f'Test epoch {epoch}')
|
184 |
+
data, label = data.to(self.device), label.to(self.device)
|
185 |
+
output = self.model(data)
|
186 |
+
|
187 |
+
loss = self.val_criterion(output, label)
|
188 |
+
|
189 |
+
self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test')
|
190 |
+
self.writer.add_scalar('loss', loss.item())
|
191 |
+
self.test_loss_list.append(loss.item())
|
192 |
+
total_test_loss += loss.item()
|
193 |
+
total_test_metrics += self._eval_metrics(output, label)
|
194 |
+
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
|
195 |
+
|
196 |
+
results[indexs.cpu().detach().numpy().tolist()] = output.cpu().detach().numpy().tolist()
|
197 |
+
tar_[indexs.cpu().detach().numpy().tolist()] = label.cpu().detach().numpy().tolist()
|
198 |
+
|
199 |
+
# add histogram of model parameters to the tensorboard
|
200 |
+
for name, p in self.model.named_parameters():
|
201 |
+
self.writer.add_histogram(name, p, bins='auto')
|
202 |
+
|
203 |
+
return {
|
204 |
+
'test_loss': total_test_loss / len(self.test_data_loader),
|
205 |
+
'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
|
206 |
+
},[results,tar_]
|
207 |
+
|
208 |
+
|
209 |
+
def _warmup_epoch(self, epoch):
|
210 |
+
total_loss = 0
|
211 |
+
total_metrics = np.zeros(len(self.metrics))
|
212 |
+
self.model.train()
|
213 |
+
|
214 |
+
data_loader = self.data_loader#self.loader.run('warmup')
|
215 |
+
|
216 |
+
|
217 |
+
with tqdm(data_loader) as progress:
|
218 |
+
for batch_idx, (data, label, _, indexs , _) in enumerate(progress):
|
219 |
+
progress.set_description_str(f'Warm up epoch {epoch}')
|
220 |
+
|
221 |
+
data, label = data.to(self.device), label.long().to(self.device)
|
222 |
+
|
223 |
+
self.optimizer.zero_grad()
|
224 |
+
output = self.model(data)
|
225 |
+
out_prob = torch.nn.functional.softmax(output).data.detach()
|
226 |
+
|
227 |
+
self.train_criterion.update_hist(indexs.cpu().detach().numpy().tolist(), out_prob)
|
228 |
+
|
229 |
+
loss = torch.nn.functional.cross_entropy(output, label)
|
230 |
+
|
231 |
+
loss.backward()
|
232 |
+
self.optimizer.step()
|
233 |
+
|
234 |
+
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
|
235 |
+
self.writer.add_scalar('loss', loss.item())
|
236 |
+
self.train_loss_list.append(loss.item())
|
237 |
+
total_loss += loss.item()
|
238 |
+
total_metrics += self._eval_metrics(output, label)
|
239 |
+
|
240 |
+
|
241 |
+
if batch_idx % self.log_step == 0:
|
242 |
+
progress.set_postfix_str(' {} Loss: {:.6f}'.format(
|
243 |
+
self._progress(batch_idx),
|
244 |
+
loss.item()))
|
245 |
+
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
|
246 |
+
|
247 |
+
if batch_idx == self.len_epoch:
|
248 |
+
break
|
249 |
+
if hasattr(self.data_loader, 'run'):
|
250 |
+
self.data_loader.run()
|
251 |
+
log = {
|
252 |
+
'loss': total_loss / self.len_epoch,
|
253 |
+
'noise detection rate' : 0.0,
|
254 |
+
'metrics': (total_metrics / self.len_epoch).tolist(),
|
255 |
+
'learning rate': self.lr_scheduler.get_lr()
|
256 |
+
}
|
257 |
+
|
258 |
+
if self.do_validation:
|
259 |
+
val_log = self._valid_epoch(epoch)
|
260 |
+
log.update(val_log)
|
261 |
+
if self.do_test:
|
262 |
+
test_log, test_meta = self._test_epoch(epoch)
|
263 |
+
log.update(test_log)
|
264 |
+
else:
|
265 |
+
test_meta = [0,0]
|
266 |
+
|
267 |
+
return log
|
268 |
+
|
269 |
+
|
270 |
+
def _progress(self, batch_idx):
|
271 |
+
base = '[{}/{} ({:.0f}%)]'
|
272 |
+
if hasattr(self.data_loader, 'n_samples'):
|
273 |
+
current = batch_idx * self.data_loader.batch_size
|
274 |
+
total = self.data_loader.n_samples
|
275 |
+
else:
|
276 |
+
current = batch_idx
|
277 |
+
total = self.len_epoch
|
278 |
+
return base.format(current, total, 100.0 * current / total)
|
ELR/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .util import *
|
ELR/utils/util.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from datetime import datetime
|
4 |
+
from itertools import repeat
|
5 |
+
from collections import OrderedDict
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def ensure_dir(dirname):
|
9 |
+
dirname = Path(dirname)
|
10 |
+
if not dirname.is_dir():
|
11 |
+
dirname.mkdir(parents=True, exist_ok=False)
|
12 |
+
|
13 |
+
|
14 |
+
def read_json(fname):
|
15 |
+
with fname.open('rt') as handle:
|
16 |
+
return json.load(handle, object_hook=OrderedDict)
|
17 |
+
|
18 |
+
|
19 |
+
def write_json(content, fname):
|
20 |
+
with fname.open('wt') as handle:
|
21 |
+
json.dump(content, handle, indent=4, sort_keys=False)
|
22 |
+
|
23 |
+
|
24 |
+
def inf_loop(data_loader):
|
25 |
+
''' wrapper function for endless data loader. '''
|
26 |
+
for loader in repeat(data_loader):
|
27 |
+
yield from loader
|
28 |
+
|
29 |
+
|
30 |
+
class Timer:
|
31 |
+
def __init__(self):
|
32 |
+
self.cache = datetime.now()
|
33 |
+
|
34 |
+
def check(self):
|
35 |
+
now = datetime.now()
|
36 |
+
duration = now - self.cache
|
37 |
+
self.cache = now
|
38 |
+
return duration.total_seconds()
|
39 |
+
|
40 |
+
def reset(self):
|
41 |
+
self.cache = datetime.now()
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def sigmoid_rampup(current, rampup_length):
|
46 |
+
"""Exponential rampup from 2"""
|
47 |
+
if rampup_length == 0:
|
48 |
+
return 1.0
|
49 |
+
else:
|
50 |
+
current = np.clip(current, 0.0, rampup_length)
|
51 |
+
phase = 1.0 - current / rampup_length
|
52 |
+
return float(np.exp(-5.0 * phase * phase))
|
53 |
+
|
54 |
+
|
55 |
+
def linear_rampup(current, rampup_length):
|
56 |
+
"""Linear rampup"""
|
57 |
+
assert current >= 0 and rampup_length >= 0
|
58 |
+
if current >= rampup_length:
|
59 |
+
return 1.0
|
60 |
+
else:
|
61 |
+
return current / rampup_length
|
62 |
+
|
63 |
+
|
64 |
+
def cosine_rampdown(current, rampdown_length):
|
65 |
+
"""Cosine rampdown from https://arxiv.org/abs/1608.03983"""
|
66 |
+
current = np.clip(current, 0.0, rampdown_length)
|
67 |
+
return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
|
68 |
+
|
69 |
+
|
70 |
+
def cosine_rampup(current, rampup_length):
|
71 |
+
"""Cosine rampup"""
|
72 |
+
current = np.clip(current, 0.0, rampup_length)
|
73 |
+
return float(-.5 * (np.cos(np.pi * current / rampup_length) - 1))
|
74 |
+
|
75 |
+
|
ELR_plus/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ELR+
|
2 |
+
This is an official PyTorch implementation of ELR+ method proposed in [Early-Learning Regularization Prevents Memorization of Noisy Labels](https://arxiv.org/abs/2007.00151).
|
3 |
+
|
4 |
+
|
5 |
+
## Usage
|
6 |
+
Train the network on the Symmmetric Noise CIFAR-10 dataset (noise rate = 0.8):
|
7 |
+
|
8 |
+
```
|
9 |
+
python train.py -c config_cifar10.json --percent 0.8
|
10 |
+
```
|
11 |
+
Train the network on the Asymmmetric Noise CIFAR-10 dataset (noise rate = 0.4):
|
12 |
+
|
13 |
+
```
|
14 |
+
python train.py -c config_cifar10_asym.json --percent 0.4
|
15 |
+
```
|
16 |
+
|
17 |
+
Train the network on the Asymmmetric Noise CIFAR-100 dataset (noise rate = 0.4):
|
18 |
+
|
19 |
+
```
|
20 |
+
python train.py -c config_cifar100.json --percent 0.4 --asym 1
|
21 |
+
```
|
22 |
+
|
23 |
+
The config files can be modified to adjust hyperparameters and optimization settings.
|
24 |
+
|
25 |
+
|
26 |
+
## References
|
27 |
+
- S. Liu, J. Niles-Weed, N. Razavian and C. Fernandez-Granda "Early-Learning Regularization Prevents Memorization of Noisy Labels", 2020
|
ELR_plus/base/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .base_data_loader import *
|
2 |
+
from .base_model import *
|
3 |
+
from .base_trainer import *
|
ELR_plus/base/base_data_loader.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union, Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torch.utils.data.dataloader import default_collate
|
6 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
7 |
+
|
8 |
+
|
9 |
+
class BaseDataLoader(DataLoader):
|
10 |
+
"""
|
11 |
+
Base class for all data loaders
|
12 |
+
"""
|
13 |
+
valid_sampler: Optional[SubsetRandomSampler]
|
14 |
+
sampler: Optional[SubsetRandomSampler]
|
15 |
+
|
16 |
+
def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory,
|
17 |
+
collate_fn=default_collate, val_dataset=None):
|
18 |
+
self.collate_fn = collate_fn
|
19 |
+
self.validation_split = validation_split
|
20 |
+
self.shuffle = shuffle
|
21 |
+
self.val_dataset = val_dataset
|
22 |
+
|
23 |
+
self.batch_idx = 0
|
24 |
+
self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset)
|
25 |
+
self.init_kwargs = {
|
26 |
+
'dataset': train_dataset,
|
27 |
+
'batch_size': batch_size,
|
28 |
+
'shuffle': self.shuffle,
|
29 |
+
'collate_fn': collate_fn,
|
30 |
+
'num_workers': num_workers,
|
31 |
+
'pin_memory': pin_memory
|
32 |
+
}
|
33 |
+
if val_dataset is None:
|
34 |
+
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
|
35 |
+
super().__init__(sampler=self.sampler, **self.init_kwargs)
|
36 |
+
else:
|
37 |
+
super().__init__(**self.init_kwargs)
|
38 |
+
|
39 |
+
def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]:
|
40 |
+
if split == 0.0:
|
41 |
+
return None, None
|
42 |
+
|
43 |
+
idx_full = np.arange(self.n_samples)
|
44 |
+
|
45 |
+
np.random.seed(0)
|
46 |
+
np.random.shuffle(idx_full)
|
47 |
+
|
48 |
+
if isinstance(split, int):
|
49 |
+
assert split > 0
|
50 |
+
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
|
51 |
+
len_valid = split
|
52 |
+
else:
|
53 |
+
len_valid = int(self.n_samples * split)
|
54 |
+
|
55 |
+
valid_idx = idx_full[0:len_valid]
|
56 |
+
train_idx = np.delete(idx_full, np.arange(0, len_valid))
|
57 |
+
|
58 |
+
train_sampler = SubsetRandomSampler(train_idx)
|
59 |
+
valid_sampler = SubsetRandomSampler(valid_idx)
|
60 |
+
print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}")
|
61 |
+
|
62 |
+
# turn off shuffle option which is mutually exclusive with sampler
|
63 |
+
self.shuffle = False
|
64 |
+
self.n_samples = len(train_idx)
|
65 |
+
|
66 |
+
return train_sampler, valid_sampler
|
67 |
+
|
68 |
+
def split_validation(self, bs = 1000):
|
69 |
+
if self.val_dataset is not None:
|
70 |
+
kwargs = {
|
71 |
+
'dataset': self.val_dataset,
|
72 |
+
'batch_size': bs,
|
73 |
+
'shuffle': False,
|
74 |
+
'collate_fn': self.collate_fn,
|
75 |
+
'num_workers': self.num_workers
|
76 |
+
}
|
77 |
+
return DataLoader(**kwargs)
|
78 |
+
else:
|
79 |
+
print('Using sampler to split!')
|
80 |
+
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
|
81 |
+
|
82 |
+
|
83 |
+
|
ELR_plus/base/base_model.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import numpy as np
|
3 |
+
from abc import abstractmethod
|
4 |
+
|
5 |
+
|
6 |
+
class BaseModel(nn.Module):
|
7 |
+
"""
|
8 |
+
Base class for all models
|
9 |
+
"""
|
10 |
+
@abstractmethod
|
11 |
+
def forward(self, *inputs):
|
12 |
+
"""
|
13 |
+
Forward pass logic
|
14 |
+
|
15 |
+
:return: Model output
|
16 |
+
"""
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def __str__(self):
|
20 |
+
"""
|
21 |
+
Model prints with number of trainable parameters
|
22 |
+
"""
|
23 |
+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
24 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
25 |
+
return super().__str__() + '\nTrainable parameters: {}'.format(params)
|
ELR_plus/base/base_trainer.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypeVar, List, Tuple
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from abc import abstractmethod
|
5 |
+
from numpy import inf
|
6 |
+
from logger import TensorboardWriter
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
class BaseTrainer:
|
11 |
+
"""
|
12 |
+
Base class for all trainers
|
13 |
+
"""
|
14 |
+
def __init__(self, model1, model2, model_ema1, model_ema2, train_criterion1,
|
15 |
+
train_criterion2, metrics, optimizer1, optimizer2, config, val_criterion,
|
16 |
+
model_ema1_copy, model_ema2_copy):
|
17 |
+
self.config = config.config
|
18 |
+
|
19 |
+
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
|
20 |
+
|
21 |
+
|
22 |
+
# setup GPU device if available, move model into configured device
|
23 |
+
self.device, self.device_ids = self._prepare_device(config['n_gpu'])
|
24 |
+
|
25 |
+
if len(self.device_ids) > 1:
|
26 |
+
print('Using Multi-Processing!')
|
27 |
+
|
28 |
+
self.model1 = model1.to(self.device+str(self.device_ids[0]))
|
29 |
+
self.model2 = model2.to(self.device+str(self.device_ids[-1]))
|
30 |
+
|
31 |
+
if model_ema1 is not None:
|
32 |
+
self.model_ema1 = model_ema1.to(self.device+str(self.device_ids[0]))
|
33 |
+
self.model_ema2_copy = model_ema2_copy.to(self.device+str(self.device_ids[0]))
|
34 |
+
else:
|
35 |
+
self.model_ema1 = None
|
36 |
+
self.model_ema2_copy = None
|
37 |
+
|
38 |
+
if model_ema2 is not None:
|
39 |
+
self.model_ema2 = model_ema2.to(self.device+str(self.device_ids[-1]))
|
40 |
+
self.model_ema1_copy = model_ema1_copy.to(self.device+str(self.device_ids[-1]))
|
41 |
+
else:
|
42 |
+
self.model_ema2 = None
|
43 |
+
self.model_ema1_copy = None
|
44 |
+
|
45 |
+
if self.model_ema1 is not None:
|
46 |
+
for param in self.model_ema1.parameters():
|
47 |
+
param.detach_()
|
48 |
+
|
49 |
+
for param in self.model_ema2_copy.parameters():
|
50 |
+
param.detach_()
|
51 |
+
|
52 |
+
if self.model_ema2 is not None:
|
53 |
+
for param in self.model_ema2.parameters():
|
54 |
+
param.detach_()
|
55 |
+
|
56 |
+
for param in self.model_ema1_copy.parameters():
|
57 |
+
param.detach_()
|
58 |
+
|
59 |
+
|
60 |
+
self.train_criterion1 = train_criterion1.to(self.device+str(self.device_ids[0]))
|
61 |
+
self.train_criterion2 = train_criterion2.to(self.device+str(self.device_ids[-1]))
|
62 |
+
|
63 |
+
self.val_criterion = val_criterion
|
64 |
+
|
65 |
+
self.metrics = metrics
|
66 |
+
|
67 |
+
self.optimizer1 = optimizer1
|
68 |
+
self.optimizer2 = optimizer2
|
69 |
+
|
70 |
+
cfg_trainer = config['trainer']
|
71 |
+
self.epochs = cfg_trainer['epochs']
|
72 |
+
self.save_period = cfg_trainer['save_period']
|
73 |
+
self.monitor = cfg_trainer.get('monitor', 'off')
|
74 |
+
|
75 |
+
# configuration to monitor model performance and save best
|
76 |
+
if self.monitor == 'off':
|
77 |
+
self.mnt_mode = 'off'
|
78 |
+
self.mnt_best = 0
|
79 |
+
else:
|
80 |
+
self.mnt_mode, self.mnt_metric = self.monitor.split()
|
81 |
+
assert self.mnt_mode in ['min', 'max']
|
82 |
+
|
83 |
+
self.mnt_best = inf if self.mnt_mode == 'min' else -inf
|
84 |
+
self.early_stop = cfg_trainer.get('early_stop', inf)
|
85 |
+
|
86 |
+
self.start_epoch = 1
|
87 |
+
|
88 |
+
self.global_step = 0
|
89 |
+
|
90 |
+
self.checkpoint_dir = config.save_dir
|
91 |
+
|
92 |
+
# setup visualization writer instance
|
93 |
+
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
|
94 |
+
|
95 |
+
if config.resume is not None:
|
96 |
+
self._resume_checkpoint(config.resume)
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
@abstractmethod
|
101 |
+
def _train_epoch(self, epoch):
|
102 |
+
"""
|
103 |
+
Training logic for an epoch
|
104 |
+
|
105 |
+
:param epoch: Current epochs number
|
106 |
+
"""
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def train(self):
|
112 |
+
"""
|
113 |
+
Full training logic
|
114 |
+
"""
|
115 |
+
|
116 |
+
if len(self.device_ids) > 1:
|
117 |
+
import torch.multiprocessing as mp
|
118 |
+
mp.set_start_method('spawn', force =True)
|
119 |
+
|
120 |
+
not_improved_count = 0
|
121 |
+
|
122 |
+
for epoch in tqdm(range(self.start_epoch, self.epochs + 1), desc='Total progress: '):
|
123 |
+
if epoch <= self.config['trainer']['warmup']:
|
124 |
+
if len(self.device_ids) > 1:
|
125 |
+
q1 = mp.Queue()
|
126 |
+
q2 = mp.Queue()
|
127 |
+
p1 = mp.Process(target=self._warmup_epoch, args=(epoch, self.model1, self.data_loader1, self.optimizer1, self.train_criterion1, self.lr_scheduler1, self.device+str(self.device_ids[0]), q1 ))
|
128 |
+
p2 = mp.Process(target=self._warmup_epoch, args=(epoch, self.model2, self.data_loader2, self.optimizer2, self.train_criterion2, self.lr_scheduler2, self.device+str(self.device_ids[-1]), q2))
|
129 |
+
p1.start()
|
130 |
+
p2.start()
|
131 |
+
result1 = q1.get()
|
132 |
+
result2 = q2.get()
|
133 |
+
p1.join()
|
134 |
+
p2.join()
|
135 |
+
else:
|
136 |
+
result1 = self._warmup_epoch(epoch, self.model1, self.data_loader1, self.optimizer1, self.train_criterion1, self.lr_scheduler1, self.device+str(self.device_ids[0]))
|
137 |
+
result2 = self._warmup_epoch(epoch, self.model2, self.data_loader2, self.optimizer2, self.train_criterion2, self.lr_scheduler2, self.device+str(self.device_ids[-1]))
|
138 |
+
|
139 |
+
if len(self.device_ids) > 1:
|
140 |
+
self.model_ema1_copy.load_state_dict(self.model_ema1.state_dict())
|
141 |
+
self.model_ema2_copy.load_state_dict(self.model_ema2.state_dict())
|
142 |
+
if self.do_validation:
|
143 |
+
q1 = mp.Queue()
|
144 |
+
p1 = mp.Process(target=self._valid_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q1))
|
145 |
+
|
146 |
+
if self.do_test:
|
147 |
+
q2 = mp.Queue()
|
148 |
+
p2 = mp.Process(target=self._test_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q2))
|
149 |
+
p1.start()
|
150 |
+
p2.start()
|
151 |
+
val_log = q1.get()
|
152 |
+
test_log, test_meta = q2.get()
|
153 |
+
result1.update(val_log)
|
154 |
+
result2.update(val_log)
|
155 |
+
result1.update(test_log)
|
156 |
+
result2.update(test_log)
|
157 |
+
p1.join()
|
158 |
+
p2.join()
|
159 |
+
else:
|
160 |
+
if self.do_validation:
|
161 |
+
val_log = self._valid_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
|
162 |
+
result1.update(val_log)
|
163 |
+
result2.update(val_log)
|
164 |
+
if self.do_test:
|
165 |
+
test_log, test_meta = self._test_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
|
166 |
+
result1.update(test_log)
|
167 |
+
result2.update(test_log)
|
168 |
+
else:
|
169 |
+
test_meta = [0,0]
|
170 |
+
|
171 |
+
else:
|
172 |
+
if len(self.device_ids) > 1:
|
173 |
+
q1 = mp.Queue()
|
174 |
+
q2 = mp.Queue()
|
175 |
+
p1 = mp.Process(target=self._train_epoch, args=(epoch, self.model1, self.model_ema1, self.model_ema2_copy, self.data_loader1, self.train_criterion1, self.optimizer1, self.lr_scheduler1, self.device+str(self.device_ids[0]), q1 ))
|
176 |
+
p2 = mp.Process(target=self._train_epoch, args=(epoch, self.model2, self.model_ema2, self.model_ema1_copy, self.data_loader2, self.train_criterion2, self.optimizer2, self.lr_scheduler2, self.device+str(self.device_ids[-1]), q2 ))
|
177 |
+
p1.start()
|
178 |
+
p2.start()
|
179 |
+
result1 = q1.get()
|
180 |
+
result2 = q2.get()
|
181 |
+
p1.join()
|
182 |
+
p2.join()
|
183 |
+
else:
|
184 |
+
result1 = self._train_epoch(epoch, self.model1, self.model_ema1, self.model_ema2, self.data_loader1, self.train_criterion1, self.optimizer1, self.lr_scheduler1, self.device+str(self.device_ids[0]))
|
185 |
+
result2 = self._train_epoch(epoch, self.model2, self.model_ema2, self.model_ema1, self.data_loader2, self.train_criterion2, self.optimizer2, self.lr_scheduler2, self.device+str(self.device_ids[-1]))
|
186 |
+
|
187 |
+
|
188 |
+
self.global_step += result1['local_step']
|
189 |
+
if len(self.device_ids) > 1:
|
190 |
+
self.model_ema1_copy.load_state_dict(self.model_ema1.state_dict())
|
191 |
+
self.model_ema2_copy.load_state_dict(self.model_ema2.state_dict())
|
192 |
+
if self.do_validation:
|
193 |
+
q1 = mp.Queue()
|
194 |
+
p1 = mp.Process(target=self._valid_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q1))
|
195 |
+
|
196 |
+
if self.do_test:
|
197 |
+
q2 = mp.Queue()
|
198 |
+
p2 = mp.Process(target=self._test_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q2))
|
199 |
+
p1.start()
|
200 |
+
p2.start()
|
201 |
+
val_log = q1.get()
|
202 |
+
test_log = q2.get()
|
203 |
+
result1.update(val_log)
|
204 |
+
result2.update(val_log)
|
205 |
+
result1.update(test_log)
|
206 |
+
result2.update(test_log)
|
207 |
+
p1.join()
|
208 |
+
p2.join()
|
209 |
+
else:
|
210 |
+
if self.do_validation:
|
211 |
+
val_log = self._valid_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
|
212 |
+
result1.update(val_log)
|
213 |
+
result2.update(val_log)
|
214 |
+
if self.do_test:
|
215 |
+
test_log = self._test_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
|
216 |
+
result1.update(test_log)
|
217 |
+
result2.update(test_log)
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
# save logged informations into log dict
|
222 |
+
log = {'epoch': epoch}
|
223 |
+
for key, value in result1.items():
|
224 |
+
if key == 'metrics':
|
225 |
+
log.update({'Net1' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
|
226 |
+
log.update({'Net2' + mtr.__name__: result2[key][i] for i, mtr in enumerate(self.metrics)})
|
227 |
+
elif key == 'val_metrics':
|
228 |
+
log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
|
229 |
+
elif key == 'test_metrics':
|
230 |
+
log.update({'test_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
|
231 |
+
else:
|
232 |
+
log['Net1'+key] = value
|
233 |
+
log['Net2'+key] = result2[key]
|
234 |
+
|
235 |
+
# print logged informations to the screen
|
236 |
+
for key, value in log.items():
|
237 |
+
self.logger.info(' {:15s}: {}'.format(str(key), value))
|
238 |
+
|
239 |
+
# evaluate model performance according to configured metric, save best checkpoint as model_best
|
240 |
+
best = False
|
241 |
+
if self.mnt_mode != 'off':
|
242 |
+
try:
|
243 |
+
# check whether model performance improved or not, according to specified metric(mnt_metric)
|
244 |
+
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
|
245 |
+
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
|
246 |
+
except KeyError:
|
247 |
+
self.logger.warning("Warning: Metric '{}' is not found. "
|
248 |
+
"Model performance monitoring is disabled.".format(self.mnt_metric))
|
249 |
+
self.mnt_mode = 'off'
|
250 |
+
improved = False
|
251 |
+
|
252 |
+
if improved:
|
253 |
+
self.mnt_best = log[self.mnt_metric]
|
254 |
+
not_improved_count = 0
|
255 |
+
best = True
|
256 |
+
else:
|
257 |
+
not_improved_count += 1
|
258 |
+
|
259 |
+
if not_improved_count > self.early_stop:
|
260 |
+
self.logger.info("Validation performance didn\'t improve for {} epochs. "
|
261 |
+
"Training stops.".format(self.early_stop))
|
262 |
+
break
|
263 |
+
|
264 |
+
if epoch % self.save_period == 0:
|
265 |
+
self._save_checkpoint(epoch, save_best=best)
|
266 |
+
|
267 |
+
|
268 |
+
def _prepare_device(self, n_gpu_use):
|
269 |
+
"""
|
270 |
+
setup GPU device if available, move model into configured device
|
271 |
+
"""
|
272 |
+
n_gpu = torch.cuda.device_count()
|
273 |
+
if n_gpu_use > 0 and n_gpu == 0:
|
274 |
+
self.logger.warning("Warning: There\'s no GPU available on this machine,"
|
275 |
+
"training will be performed on CPU.")
|
276 |
+
n_gpu_use = 0
|
277 |
+
if n_gpu_use > n_gpu:
|
278 |
+
self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
|
279 |
+
"on this machine.".format(n_gpu_use, n_gpu))
|
280 |
+
n_gpu_use = n_gpu
|
281 |
+
device = 'cuda:'#torch.device('cuda:' if n_gpu_use > 0 else 'cpu')
|
282 |
+
list_ids = list(range(n_gpu_use))
|
283 |
+
return device, list_ids
|
284 |
+
|
285 |
+
def _save_checkpoint(self, epoch, save_best=False):
|
286 |
+
"""
|
287 |
+
Saving checkpoints
|
288 |
+
|
289 |
+
:param epoch: current epoch number
|
290 |
+
:param log: logging information of the epoch
|
291 |
+
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
|
292 |
+
"""
|
293 |
+
arch = type(self.model1).__name__
|
294 |
+
|
295 |
+
state = {
|
296 |
+
'arch': arch,
|
297 |
+
'epoch': epoch,
|
298 |
+
'state_dict1': self.model1.state_dict(),
|
299 |
+
'state_dict2': self.model2.state_dict(),
|
300 |
+
'optimizer1': self.optimizer1.state_dict(),
|
301 |
+
'optimizer2': self.optimizer2.state_dict(),
|
302 |
+
'monitor_best': self.mnt_best
|
303 |
+
#'config': self.config
|
304 |
+
}
|
305 |
+
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
|
306 |
+
torch.save(state, filename)
|
307 |
+
self.logger.info("Saving checkpoint: {} ...".format(filename))
|
308 |
+
if save_best:
|
309 |
+
best_path = str(self.checkpoint_dir / 'model_best.pth')
|
310 |
+
torch.save(state, best_path)
|
311 |
+
self.logger.info("Saving current best: model_best.pth at: {} ...".format(best_path))
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
def _resume_checkpoint(self, resume_path):
|
316 |
+
"""
|
317 |
+
Resume from saved checkpoints
|
318 |
+
|
319 |
+
:param resume_path: Checkpoint path to be resumed
|
320 |
+
"""
|
321 |
+
resume_path = str(resume_path)
|
322 |
+
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
|
323 |
+
checkpoint = torch.load(resume_path)
|
324 |
+
self.start_epoch = checkpoint['epoch'] + 1
|
325 |
+
self.mnt_best = checkpoint['monitor_best']
|
326 |
+
|
327 |
+
# load architecture params from checkpoint.
|
328 |
+
if checkpoint['config']['arch'] != self.config['arch1']:
|
329 |
+
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
|
330 |
+
"checkpoint. This may yield an exception while state_dict is being loaded.")
|
331 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
332 |
+
|
333 |
+
# load optimizer state from checkpoint only when optimizer type is not changed.
|
334 |
+
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
|
335 |
+
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
|
336 |
+
"Optimizer parameters not being resumed.")
|
337 |
+
else:
|
338 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
339 |
+
|
340 |
+
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
|
341 |
+
|
ELR_plus/config_cifar10.json
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "cifar10_ELR_plus_PreActResNet18",
|
3 |
+
"n_gpu": 1,
|
4 |
+
"seed":123,
|
5 |
+
|
6 |
+
"arch": {
|
7 |
+
"args": {"num_classes":10}
|
8 |
+
},
|
9 |
+
|
10 |
+
"arch1": {
|
11 |
+
"type": "PreActResNet18",
|
12 |
+
"args": {"num_classes":10}
|
13 |
+
},
|
14 |
+
|
15 |
+
"arch2": {
|
16 |
+
"type": "PreActResNet18",
|
17 |
+
"args": {"num_classes":10}
|
18 |
+
},
|
19 |
+
|
20 |
+
"mixup_alpha": 1,
|
21 |
+
"coef_step": 0,
|
22 |
+
"num_classes": 10,
|
23 |
+
"ema_alpha": 0.997,
|
24 |
+
"ema_update": true,
|
25 |
+
"ema_step": 40000,
|
26 |
+
|
27 |
+
|
28 |
+
"data_loader": {
|
29 |
+
"type": "CIFAR10DataLoader",
|
30 |
+
"args":{
|
31 |
+
"data_dir": "/dir/to/data",
|
32 |
+
"batch_size": 128,
|
33 |
+
"batch_size2": 128,
|
34 |
+
"num_batches": 0,
|
35 |
+
"shuffle": true,
|
36 |
+
"validation_split": 0,
|
37 |
+
"num_workers": 8,
|
38 |
+
"pin_memory": true
|
39 |
+
}
|
40 |
+
},
|
41 |
+
|
42 |
+
|
43 |
+
"optimizer1": {
|
44 |
+
"type": "SGD",
|
45 |
+
"args":{
|
46 |
+
"lr": 0.02,
|
47 |
+
"momentum": 0.9,
|
48 |
+
"weight_decay": 5e-4
|
49 |
+
}
|
50 |
+
},
|
51 |
+
|
52 |
+
"optimizer2": {
|
53 |
+
"type": "SGD",
|
54 |
+
"args":{
|
55 |
+
"lr": 0.02,
|
56 |
+
"momentum": 0.9,
|
57 |
+
"weight_decay": 5e-4
|
58 |
+
}
|
59 |
+
},
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
"train_loss": {
|
64 |
+
"type": "elr_plus_loss",
|
65 |
+
"args":{
|
66 |
+
"beta": 0.7,
|
67 |
+
"lambda": 3
|
68 |
+
}
|
69 |
+
},
|
70 |
+
|
71 |
+
"val_loss": "cross_entropy",
|
72 |
+
"metrics": [
|
73 |
+
"my_metric", "my_metric2"
|
74 |
+
],
|
75 |
+
|
76 |
+
"lr_scheduler": {
|
77 |
+
"type": "MultiStepLR",
|
78 |
+
"args": {
|
79 |
+
"milestones": [150],
|
80 |
+
"gamma": 0.1
|
81 |
+
}
|
82 |
+
},
|
83 |
+
|
84 |
+
"trainer": {
|
85 |
+
"epochs": 200,
|
86 |
+
"warmup": 0,
|
87 |
+
"save_dir": "dir/to/model",
|
88 |
+
"save_period": 1,
|
89 |
+
"verbosity": 2,
|
90 |
+
"label_dir": "saved/",
|
91 |
+
|
92 |
+
"monitor": "max val_my_metric",
|
93 |
+
"early_stop": 2000,
|
94 |
+
|
95 |
+
"tensorboard": false,
|
96 |
+
"mlflow": true,
|
97 |
+
|
98 |
+
"_percent": "Percentage of noise",
|
99 |
+
"percent": 0.8,
|
100 |
+
"_begin": "When to begin updating labels",
|
101 |
+
"begin": 0,
|
102 |
+
"_asym": "symmetric noise if false",
|
103 |
+
"asym": false
|
104 |
+
}
|
105 |
+
}
|
ELR_plus/config_cifar100.json
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "cifar100_ELR_plus_PreActResNet18",
|
3 |
+
"n_gpu": 1,
|
4 |
+
"seed":123,
|
5 |
+
|
6 |
+
"arch": {
|
7 |
+
"args": {"num_classes":100}
|
8 |
+
},
|
9 |
+
|
10 |
+
"arch1": {
|
11 |
+
"type": "PreActResNet18",
|
12 |
+
"args": {"num_classes":100}
|
13 |
+
},
|
14 |
+
|
15 |
+
"arch2": {
|
16 |
+
"type": "PreActResNet18",
|
17 |
+
"args": {"num_classes":100}
|
18 |
+
},
|
19 |
+
|
20 |
+
"mixup_alpha": 1,
|
21 |
+
"coef_step": 40000,
|
22 |
+
"num_classes": 100,
|
23 |
+
"ema_alpha": 0.997,
|
24 |
+
"ema_update": true,
|
25 |
+
"ema_step": 40000,
|
26 |
+
|
27 |
+
|
28 |
+
"data_loader": {
|
29 |
+
"type": "CIFAR100DataLoader",
|
30 |
+
"args":{
|
31 |
+
"data_dir": "/gpfs/scratch/sl5924/noisy/data/",
|
32 |
+
"batch_size": 128,
|
33 |
+
"batch_size2": 128,
|
34 |
+
"num_batches": 0,
|
35 |
+
"shuffle": true,
|
36 |
+
"validation_split": 0,
|
37 |
+
"num_workers": 8,
|
38 |
+
"pin_memory": true
|
39 |
+
}
|
40 |
+
},
|
41 |
+
|
42 |
+
"optimizer1": {
|
43 |
+
"type": "SGD",
|
44 |
+
"args":{
|
45 |
+
"lr": 0.02,
|
46 |
+
"momentum": 0.9,
|
47 |
+
"weight_decay": 5e-4
|
48 |
+
}
|
49 |
+
},
|
50 |
+
|
51 |
+
"optimizer2": {
|
52 |
+
"type": "SGD",
|
53 |
+
"args":{
|
54 |
+
"lr": 0.02,
|
55 |
+
"momentum": 0.9,
|
56 |
+
"weight_decay": 5e-4
|
57 |
+
}
|
58 |
+
},
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
"train_loss": {
|
63 |
+
"type": "elr_plus_loss",
|
64 |
+
"args":{
|
65 |
+
"beta": 0.9,
|
66 |
+
"lambda": 7
|
67 |
+
}
|
68 |
+
},
|
69 |
+
|
70 |
+
"val_loss": "cross_entropy",
|
71 |
+
"metrics": [
|
72 |
+
"my_metric", "my_metric2"
|
73 |
+
],
|
74 |
+
|
75 |
+
"lr_scheduler": {
|
76 |
+
"type": "MultiStepLR",
|
77 |
+
"args": {
|
78 |
+
"milestones": [200],
|
79 |
+
"gamma": 0.1
|
80 |
+
}
|
81 |
+
},
|
82 |
+
|
83 |
+
"trainer": {
|
84 |
+
"epochs": 250,
|
85 |
+
"warmup": 0,
|
86 |
+
"save_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/saved/",
|
87 |
+
"save_period": 1,
|
88 |
+
"verbosity": 2,
|
89 |
+
"label_dir": "saved/",
|
90 |
+
|
91 |
+
"monitor": "max val_my_metric",
|
92 |
+
"early_stop": 2000,
|
93 |
+
|
94 |
+
"tensorboard": false,
|
95 |
+
"mlflow": true,
|
96 |
+
|
97 |
+
"_percent": "Percentage of noise",
|
98 |
+
"percent": 0.8,
|
99 |
+
"_begin": "When to begin updating labels",
|
100 |
+
"begin": 0,
|
101 |
+
"_asym": "symmetric noise if false",
|
102 |
+
"asym": false
|
103 |
+
}
|
104 |
+
}
|
ELR_plus/config_cifar10_asym.json
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "cifar10_ELR_plus_PreActResNet18",
|
3 |
+
"n_gpu": 1,
|
4 |
+
"seed":123,
|
5 |
+
|
6 |
+
"arch": {
|
7 |
+
"args": {"num_classes":10}
|
8 |
+
},
|
9 |
+
|
10 |
+
"arch1": {
|
11 |
+
"type": "PreActResNet18",
|
12 |
+
"args": {"num_classes":10}
|
13 |
+
},
|
14 |
+
|
15 |
+
"arch2": {
|
16 |
+
"type": "PreActResNet18",
|
17 |
+
"args": {"num_classes":10}
|
18 |
+
},
|
19 |
+
|
20 |
+
"mixup_alpha": 1,
|
21 |
+
"coef_step": 0,
|
22 |
+
"num_classes": 10,
|
23 |
+
"ema_alpha": 0.997,
|
24 |
+
"ema_update": true,
|
25 |
+
"ema_step": 40000,
|
26 |
+
|
27 |
+
|
28 |
+
"data_loader": {
|
29 |
+
"type": "CIFAR10DataLoader",
|
30 |
+
"args":{
|
31 |
+
"data_dir": "dir/to/data",
|
32 |
+
"batch_size": 128,
|
33 |
+
"batch_size2": 128,
|
34 |
+
"num_batches": 0,
|
35 |
+
"shuffle": true,
|
36 |
+
"validation_split": 0,
|
37 |
+
"num_workers": 8,
|
38 |
+
"pin_memory": true
|
39 |
+
}
|
40 |
+
},
|
41 |
+
|
42 |
+
|
43 |
+
"optimizer1": {
|
44 |
+
"type": "SGD",
|
45 |
+
"args":{
|
46 |
+
"lr": 0.02,
|
47 |
+
"momentum": 0.9,
|
48 |
+
"weight_decay": 5e-4
|
49 |
+
}
|
50 |
+
},
|
51 |
+
|
52 |
+
"optimizer2": {
|
53 |
+
"type": "SGD",
|
54 |
+
"args":{
|
55 |
+
"lr": 0.02,
|
56 |
+
"momentum": 0.9,
|
57 |
+
"weight_decay": 5e-4
|
58 |
+
}
|
59 |
+
},
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
"train_loss": {
|
64 |
+
"type": "elr_plus_loss",
|
65 |
+
"args":{
|
66 |
+
"beta": 0.9,
|
67 |
+
"lambda": 1
|
68 |
+
}
|
69 |
+
},
|
70 |
+
|
71 |
+
"val_loss": "cross_entropy",
|
72 |
+
"metrics": [
|
73 |
+
"my_metric", "my_metric2"
|
74 |
+
],
|
75 |
+
|
76 |
+
"lr_scheduler": {
|
77 |
+
"type": "MultiStepLR",
|
78 |
+
"args": {
|
79 |
+
"milestones": [150],
|
80 |
+
"gamma": 0.1
|
81 |
+
}
|
82 |
+
},
|
83 |
+
|
84 |
+
"trainer": {
|
85 |
+
"epochs": 200,
|
86 |
+
"warmup": 0,
|
87 |
+
"save_dir": "dir/to/model",
|
88 |
+
"save_period": 1,
|
89 |
+
"verbosity": 2,
|
90 |
+
"label_dir": "saved/",
|
91 |
+
|
92 |
+
"monitor": "max val_my_metric",
|
93 |
+
"early_stop": 2000,
|
94 |
+
|
95 |
+
"tensorboard": false,
|
96 |
+
"mlflow": true,
|
97 |
+
|
98 |
+
"_percent": "Percentage of noise",
|
99 |
+
"percent": 0.4,
|
100 |
+
"_begin": "When to begin updating labels",
|
101 |
+
"begin": 0,
|
102 |
+
"_asym": "symmetric noise if false",
|
103 |
+
"asym": true
|
104 |
+
}
|
105 |
+
}
|
ELR_plus/config_clothing1m.json
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "clothing1M_ELR_plus_resnet50",
|
3 |
+
"n_gpu": 1,
|
4 |
+
"seed":123,
|
5 |
+
|
6 |
+
|
7 |
+
"arch1": {
|
8 |
+
"type": "resnet50",
|
9 |
+
"args": {"num_classes":14}
|
10 |
+
},
|
11 |
+
|
12 |
+
"arch2": {
|
13 |
+
"type": "resnet50",
|
14 |
+
"args": {"num_classes":14}
|
15 |
+
},
|
16 |
+
|
17 |
+
"mixup_alpha": 1,
|
18 |
+
"coef_step": 0,
|
19 |
+
"num_classes": 14,
|
20 |
+
"ema_alpha": 0.9999,
|
21 |
+
"ema_update": false,
|
22 |
+
"ema_step": -1,
|
23 |
+
|
24 |
+
|
25 |
+
"data_loader": {
|
26 |
+
"type": "Clothing1MDataLoader",
|
27 |
+
"args":{
|
28 |
+
"data_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/clothing1M/images",
|
29 |
+
"batch_size": 64,
|
30 |
+
"batch_size2": 64,
|
31 |
+
"num_batches": 3000,
|
32 |
+
"shuffle": true,
|
33 |
+
"validation_split": 0,
|
34 |
+
"num_workers": 8,
|
35 |
+
"pin_memory": true
|
36 |
+
}
|
37 |
+
},
|
38 |
+
|
39 |
+
"optimizer1": {
|
40 |
+
"type": "SGD",
|
41 |
+
"args":{
|
42 |
+
"lr": 0.002,
|
43 |
+
"momentum": 0.9,
|
44 |
+
"weight_decay": 1e-3
|
45 |
+
}
|
46 |
+
},
|
47 |
+
|
48 |
+
"optimizer2": {
|
49 |
+
"type": "SGD",
|
50 |
+
"args":{
|
51 |
+
"lr": 0.002,
|
52 |
+
"momentum": 0.9,
|
53 |
+
"weight_decay": 1e-3
|
54 |
+
}
|
55 |
+
},
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
"train_loss": {
|
60 |
+
"type": "elr_plus_loss",
|
61 |
+
"args":{
|
62 |
+
"beta": 0.7,
|
63 |
+
"lambda": 3
|
64 |
+
}
|
65 |
+
},
|
66 |
+
|
67 |
+
"val_loss": "cross_entropy",
|
68 |
+
"metrics": [
|
69 |
+
"my_metric", "my_metric2"
|
70 |
+
],
|
71 |
+
|
72 |
+
"lr_scheduler": {
|
73 |
+
"type": "MultiStepLR",
|
74 |
+
"args": {
|
75 |
+
"milestones": [7],
|
76 |
+
"gamma": 0.1
|
77 |
+
}
|
78 |
+
},
|
79 |
+
|
80 |
+
"trainer": {
|
81 |
+
"epochs": 15,
|
82 |
+
"warmup": 0,
|
83 |
+
"save_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/saved/",
|
84 |
+
"save_period": 1,
|
85 |
+
"verbosity": 2,
|
86 |
+
"label_dir": "saved/",
|
87 |
+
|
88 |
+
"monitor": "max val_my_metric",
|
89 |
+
"early_stop": 2000,
|
90 |
+
|
91 |
+
"tensorboard": true,
|
92 |
+
"mlflow": true,
|
93 |
+
|
94 |
+
"_percent": "Percentage of noise",
|
95 |
+
"percent": 0.8,
|
96 |
+
"_begin": "When to begin updating labels",
|
97 |
+
"begin": 0,
|
98 |
+
"_asym": "symmetric noise if false",
|
99 |
+
"asym": false
|
100 |
+
}
|
101 |
+
}
|
102 |
+
|
ELR_plus/config_webvision.json
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "Webvision_ELR_plus_InceptionResNetV2",
|
3 |
+
"n_gpu": 2,
|
4 |
+
"seed": 123,
|
5 |
+
|
6 |
+
|
7 |
+
"arch": {
|
8 |
+
"args": {"num_classes":50}
|
9 |
+
},
|
10 |
+
|
11 |
+
"arch1": {
|
12 |
+
"type": "InceptionResNetV2",
|
13 |
+
"args": {"num_classes":50}
|
14 |
+
},
|
15 |
+
|
16 |
+
"arch2": {
|
17 |
+
"type": "InceptionResNetV2",
|
18 |
+
"args": {"num_classes":50}
|
19 |
+
},
|
20 |
+
|
21 |
+
|
22 |
+
"mixup_alpha": 1.5,
|
23 |
+
"mixup_ramp": false,
|
24 |
+
"num_classes": 50,
|
25 |
+
"ema_alpha": 0.997,
|
26 |
+
"ema_update": false,
|
27 |
+
"ema_step": 40000,
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
"data_loader": {
|
32 |
+
"type": "WebvisionDataLoader",
|
33 |
+
"args":{
|
34 |
+
"data_dir": "/dir/to/data",
|
35 |
+
"batch_size": 32,
|
36 |
+
"batch_size2": 32,
|
37 |
+
"shuffle": true,
|
38 |
+
"num_batches": 0,
|
39 |
+
"validation_split": 0,
|
40 |
+
"num_workers": 8,
|
41 |
+
"pin_memory": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
|
45 |
+
"optimizer1": {
|
46 |
+
"type": "SGD",
|
47 |
+
"args":{
|
48 |
+
"lr": 0.02,
|
49 |
+
"momentum": 0.9,
|
50 |
+
"weight_decay": 5e-4
|
51 |
+
}
|
52 |
+
},
|
53 |
+
|
54 |
+
"optimizer2": {
|
55 |
+
"type": "SGD",
|
56 |
+
"args":{
|
57 |
+
"lr": 0.02,
|
58 |
+
"momentum": 0.9,
|
59 |
+
"weight_decay": 5e-4
|
60 |
+
}
|
61 |
+
},
|
62 |
+
|
63 |
+
|
64 |
+
"train_loss": {
|
65 |
+
"type": "elr_plus_loss",
|
66 |
+
"args":{
|
67 |
+
"beta": 0.7,
|
68 |
+
"lambda": 3
|
69 |
+
}
|
70 |
+
},
|
71 |
+
"val_loss": "cross_entropy",
|
72 |
+
"metrics": [
|
73 |
+
"my_metric", "my_metric2"
|
74 |
+
],
|
75 |
+
"lr_scheduler": {
|
76 |
+
"type": "MultiStepLR",
|
77 |
+
"args": {
|
78 |
+
"milestones": [50],
|
79 |
+
"gamma": 0.1
|
80 |
+
}
|
81 |
+
},
|
82 |
+
"trainer": {
|
83 |
+
"epochs": 100,
|
84 |
+
"warmup": 0,
|
85 |
+
"save_dir": "/dir/to/data",
|
86 |
+
"save_period": 1,
|
87 |
+
"verbosity": 2,
|
88 |
+
"label_dir": "saved/",
|
89 |
+
|
90 |
+
"monitor": "max val_my_metric",
|
91 |
+
"early_stop": 2000,
|
92 |
+
|
93 |
+
"tensorboard": false,
|
94 |
+
"mlflow": true,
|
95 |
+
|
96 |
+
"_percent": "Percentage of noise",
|
97 |
+
"percent": 0.9,
|
98 |
+
"_begin": "When to begin updating labels",
|
99 |
+
"begin": 0,
|
100 |
+
"_asym": "symmetric noise if false",
|
101 |
+
"asym": false
|
102 |
+
}
|
103 |
+
}
|
ELR_plus/data_loader/cifar10.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision
|
6 |
+
from torch.utils.data.dataset import Subset
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import random
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
|
14 |
+
|
15 |
+
def get_cifar10(root, cfg_trainer, train=True,
|
16 |
+
transform_train=None, transform_val=None,
|
17 |
+
download=False, noise_file = ''):
|
18 |
+
base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download)
|
19 |
+
if train:
|
20 |
+
train_idxs, val_idxs = train_val_split(base_dataset.targets)
|
21 |
+
train_dataset = CIFAR10_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
|
22 |
+
val_dataset = CIFAR10_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
|
23 |
+
if cfg_trainer['asym']:
|
24 |
+
train_dataset.asymmetric_noise()
|
25 |
+
val_dataset.asymmetric_noise()
|
26 |
+
else:
|
27 |
+
train_dataset.symmetric_noise()
|
28 |
+
val_dataset.symmetric_noise()
|
29 |
+
|
30 |
+
print(f"Train: {len(train_idxs)} Val: {len(val_idxs)}") # Train: 45000 Val: 5000
|
31 |
+
else:
|
32 |
+
train_dataset = []
|
33 |
+
val_dataset = CIFAR10_val(root, cfg_trainer, None, train=train, transform=transform_val)
|
34 |
+
print(f"Test: {len(val_dataset)}")
|
35 |
+
|
36 |
+
return train_dataset, val_dataset
|
37 |
+
|
38 |
+
|
39 |
+
def train_val_split(base_dataset: torchvision.datasets.CIFAR10):
|
40 |
+
num_classes = 10
|
41 |
+
base_dataset = np.array(base_dataset)
|
42 |
+
train_n = int(len(base_dataset) * 0.9 / num_classes)
|
43 |
+
train_idxs = []
|
44 |
+
val_idxs = []
|
45 |
+
|
46 |
+
for i in range(num_classes):
|
47 |
+
idxs = np.where(base_dataset == i)[0]
|
48 |
+
np.random.shuffle(idxs)
|
49 |
+
train_idxs.extend(idxs[:train_n])
|
50 |
+
val_idxs.extend(idxs[train_n:])
|
51 |
+
np.random.shuffle(train_idxs)
|
52 |
+
np.random.shuffle(val_idxs)
|
53 |
+
|
54 |
+
return train_idxs, val_idxs
|
55 |
+
|
56 |
+
|
57 |
+
class CIFAR10_train(torchvision.datasets.CIFAR10):
|
58 |
+
def __init__(self, root, cfg_trainer, indexs, train=True,
|
59 |
+
transform=None, target_transform=None,
|
60 |
+
download=False):
|
61 |
+
super(CIFAR10_train, self).__init__(root, train=train,
|
62 |
+
transform=transform, target_transform=target_transform,
|
63 |
+
download=download)
|
64 |
+
self.num_classes = 10
|
65 |
+
self.cfg_trainer = cfg_trainer
|
66 |
+
self.train_data = self.data[indexs]#self.train_data[indexs]
|
67 |
+
self.train_labels = np.array(self.targets)[indexs]#np.array(self.train_labels)[indexs]
|
68 |
+
self.indexs = indexs
|
69 |
+
self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
|
70 |
+
self.noise_indx = []
|
71 |
+
#self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
|
72 |
+
|
73 |
+
def symmetric_noise(self):
|
74 |
+
self.train_labels_gt = self.train_labels.copy()
|
75 |
+
#np.random.seed(seed=888)
|
76 |
+
indices = np.random.permutation(len(self.train_data))
|
77 |
+
for i, idx in enumerate(indices):
|
78 |
+
if i < self.cfg_trainer['percent'] * len(self.train_data):
|
79 |
+
self.noise_indx.append(idx)
|
80 |
+
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
|
81 |
+
|
82 |
+
def asymmetric_noise(self):
|
83 |
+
self.train_labels_gt = self.train_labels.copy()
|
84 |
+
for i in range(self.num_classes):
|
85 |
+
indices = np.where(self.train_labels == i)[0]
|
86 |
+
np.random.shuffle(indices)
|
87 |
+
for j, idx in enumerate(indices):
|
88 |
+
if j < self.cfg_trainer['percent'] * len(indices):
|
89 |
+
self.noise_indx.append(idx)
|
90 |
+
# truck -> automobile
|
91 |
+
if i == 9:
|
92 |
+
self.train_labels[idx] = 1
|
93 |
+
# bird -> airplane
|
94 |
+
elif i == 2:
|
95 |
+
self.train_labels[idx] = 0
|
96 |
+
# cat -> dog
|
97 |
+
elif i == 3:
|
98 |
+
self.train_labels[idx] = 5
|
99 |
+
# dog -> cat
|
100 |
+
elif i == 5:
|
101 |
+
self.train_labels[idx] = 3
|
102 |
+
# deer -> horse
|
103 |
+
elif i == 4:
|
104 |
+
self.train_labels[idx] = 7
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
"""
|
112 |
+
Args:
|
113 |
+
index (int): Index
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
tuple: (image, target) where target is index of the target class.
|
117 |
+
"""
|
118 |
+
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
|
119 |
+
|
120 |
+
|
121 |
+
# doing this so that it is consistent with all other datasets
|
122 |
+
# to return a PIL Image
|
123 |
+
img = Image.fromarray(img)
|
124 |
+
|
125 |
+
|
126 |
+
if self.transform is not None:
|
127 |
+
img = self.transform(img)
|
128 |
+
|
129 |
+
if self.target_transform is not None:
|
130 |
+
target = self.target_transform(target)
|
131 |
+
|
132 |
+
return img,target, index, target_gt
|
133 |
+
|
134 |
+
def __len__(self):
|
135 |
+
return len(self.train_data)
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class CIFAR10_val(torchvision.datasets.CIFAR10):
|
140 |
+
|
141 |
+
def __init__(self, root, cfg_trainer, indexs, train=True,
|
142 |
+
transform=None, target_transform=None,
|
143 |
+
download=False):
|
144 |
+
super(CIFAR10_val, self).__init__(root, train=train,
|
145 |
+
transform=transform, target_transform=target_transform,
|
146 |
+
download=download)
|
147 |
+
|
148 |
+
# self.train_data = self.data[indexs]
|
149 |
+
# self.train_labels = np.array(self.targets)[indexs]
|
150 |
+
self.num_classes = 10
|
151 |
+
self.cfg_trainer = cfg_trainer
|
152 |
+
if train:
|
153 |
+
self.train_data = self.data[indexs]
|
154 |
+
self.train_labels = np.array(self.targets)[indexs]
|
155 |
+
else:
|
156 |
+
self.train_data = self.data
|
157 |
+
self.train_labels = np.array(self.targets)
|
158 |
+
self.train_labels_gt = self.train_labels.copy()
|
159 |
+
def symmetric_noise(self):
|
160 |
+
|
161 |
+
indices = np.random.permutation(len(self.train_data))
|
162 |
+
for i, idx in enumerate(indices):
|
163 |
+
if i < self.cfg_trainer['percent'] * len(self.train_data):
|
164 |
+
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
|
165 |
+
|
166 |
+
def asymmetric_noise(self):
|
167 |
+
for i in range(self.num_classes):
|
168 |
+
indices = np.where(self.train_labels == i)[0]
|
169 |
+
np.random.shuffle(indices)
|
170 |
+
for j, idx in enumerate(indices):
|
171 |
+
if j < self.cfg_trainer['percent'] * len(indices):
|
172 |
+
# truck -> automobile
|
173 |
+
if i == 9:
|
174 |
+
self.train_labels[idx] = 1
|
175 |
+
# bird -> airplane
|
176 |
+
elif i == 2:
|
177 |
+
self.train_labels[idx] = 0
|
178 |
+
# cat -> dog
|
179 |
+
elif i == 3:
|
180 |
+
self.train_labels[idx] = 5
|
181 |
+
# dog -> cat
|
182 |
+
elif i == 5:
|
183 |
+
self.train_labels[idx] = 3
|
184 |
+
# deer -> horse
|
185 |
+
elif i == 4:
|
186 |
+
self.train_labels[idx] = 7
|
187 |
+
def __len__(self):
|
188 |
+
return len(self.train_data)
|
189 |
+
|
190 |
+
|
191 |
+
def __getitem__(self, index):
|
192 |
+
"""
|
193 |
+
Args:
|
194 |
+
index (int): Index
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
tuple: (image, target) where target is index of the target class.
|
198 |
+
"""
|
199 |
+
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
|
200 |
+
|
201 |
+
|
202 |
+
# doing this so that it is consistent with all other datasets
|
203 |
+
# to return a PIL Image
|
204 |
+
img = Image.fromarray(img)
|
205 |
+
|
206 |
+
|
207 |
+
if self.transform is not None:
|
208 |
+
img = self.transform(img)
|
209 |
+
|
210 |
+
if self.target_transform is not None:
|
211 |
+
target = self.target_transform(target)
|
212 |
+
|
213 |
+
return img, target, index, target_gt
|
214 |
+
|
ELR_plus/data_loader/cifar100.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision
|
6 |
+
from torch.utils.data.dataset import Subset
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import random
|
11 |
+
from numpy.testing import assert_array_almost_equal
|
12 |
+
import os
|
13 |
+
import json
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
17 |
+
|
18 |
+
def get_cifar100(root, cfg_trainer, train=True,
|
19 |
+
transform_train=None, transform_val=None,
|
20 |
+
download=False, noise_file = ''):
|
21 |
+
base_dataset = torchvision.datasets.CIFAR100(root, train=train, download=download)
|
22 |
+
if train:
|
23 |
+
train_idxs, val_idxs = train_val_split(base_dataset.targets)
|
24 |
+
train_dataset = CIFAR100_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
|
25 |
+
val_dataset = CIFAR100_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
|
26 |
+
if cfg_trainer['asym']:
|
27 |
+
train_dataset.asymmetric_noise()
|
28 |
+
val_dataset.asymmetric_noise()
|
29 |
+
else:
|
30 |
+
train_dataset.symmetric_noise()
|
31 |
+
val_dataset.symmetric_noise()
|
32 |
+
|
33 |
+
print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}\n") # Train: 45000 Val: 5000
|
34 |
+
else:
|
35 |
+
train_dataset = []
|
36 |
+
val_dataset = CIFAR100_val(root, cfg_trainer, None, train=train, transform=transform_val)
|
37 |
+
print(f"Test: {len(val_dataset)}\n")
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
return train_dataset, val_dataset
|
43 |
+
|
44 |
+
|
45 |
+
def train_val_split(base_dataset: torchvision.datasets.CIFAR100):
|
46 |
+
num_classes = 100
|
47 |
+
base_dataset = np.array(base_dataset)
|
48 |
+
train_n = int(len(base_dataset) * 0.9 / num_classes)
|
49 |
+
train_idxs = []
|
50 |
+
val_idxs = []
|
51 |
+
|
52 |
+
for i in range(num_classes):
|
53 |
+
idxs = np.where(base_dataset == i)[0]
|
54 |
+
np.random.shuffle(idxs)
|
55 |
+
train_idxs.extend(idxs[:train_n])
|
56 |
+
val_idxs.extend(idxs[train_n:])
|
57 |
+
np.random.shuffle(train_idxs)
|
58 |
+
np.random.shuffle(val_idxs)
|
59 |
+
|
60 |
+
return train_idxs, val_idxs
|
61 |
+
|
62 |
+
|
63 |
+
class CIFAR100_train(torchvision.datasets.CIFAR100):
|
64 |
+
def __init__(self, root, cfg_trainer, indexs, train=True,
|
65 |
+
transform=None, target_transform=None,
|
66 |
+
download=False):
|
67 |
+
super(CIFAR100_train, self).__init__(root, train=train,
|
68 |
+
transform=transform, target_transform=target_transform,
|
69 |
+
download=download)
|
70 |
+
self.num_classes = 100
|
71 |
+
self.cfg_trainer = cfg_trainer
|
72 |
+
self.train_data = self.data[indexs]
|
73 |
+
self.train_labels = np.array(self.targets)[indexs]
|
74 |
+
self.indexs = indexs
|
75 |
+
self.soft_labels = np.zeros((len(self.train_data), self.num_classes), dtype=np.float32)
|
76 |
+
self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
|
77 |
+
self.noise_indx = []
|
78 |
+
#self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
|
79 |
+
|
80 |
+
self.count = 0
|
81 |
+
|
82 |
+
def symmetric_noise(self):
|
83 |
+
self.train_labels_gt = self.train_labels.copy()
|
84 |
+
np.random.seed(seed=888)
|
85 |
+
indices = np.random.permutation(len(self.train_data))
|
86 |
+
for i, idx in enumerate(indices):
|
87 |
+
if i < self.cfg_trainer['percent'] * len(self.train_data):
|
88 |
+
self.noise_indx.append(idx)
|
89 |
+
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
|
90 |
+
self.soft_labels[idx][self.train_labels[idx]] = 1.
|
91 |
+
|
92 |
+
def multiclass_noisify(self, y, P, random_state=0):
|
93 |
+
""" Flip classes according to transition probability matrix T.
|
94 |
+
It expects a number between 0 and the number of classes - 1.
|
95 |
+
"""
|
96 |
+
|
97 |
+
assert P.shape[0] == P.shape[1]
|
98 |
+
assert np.max(y) < P.shape[0]
|
99 |
+
|
100 |
+
# row stochastic matrix
|
101 |
+
assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
|
102 |
+
assert (P >= 0.0).all()
|
103 |
+
|
104 |
+
m = y.shape[0]
|
105 |
+
new_y = y.copy()
|
106 |
+
flipper = np.random.RandomState(random_state)
|
107 |
+
|
108 |
+
for idx in np.arange(m):
|
109 |
+
i = y[idx]
|
110 |
+
# draw a vector with only an 1
|
111 |
+
flipped = flipper.multinomial(1, P[i, :], 1)[0]
|
112 |
+
new_y[idx] = np.where(flipped == 1)[0]
|
113 |
+
|
114 |
+
return new_y
|
115 |
+
|
116 |
+
def build_for_cifar100(self, size, noise):
|
117 |
+
""" The noise matrix flips to the "next" class with probability 'noise'.
|
118 |
+
"""
|
119 |
+
|
120 |
+
assert(noise >= 0.) and (noise <= 1.)
|
121 |
+
|
122 |
+
P = (1. - noise) * np.eye(size)
|
123 |
+
for i in np.arange(size - 1):
|
124 |
+
P[i, i + 1] = noise
|
125 |
+
|
126 |
+
# adjust last row
|
127 |
+
P[size - 1, 0] = noise
|
128 |
+
|
129 |
+
assert_array_almost_equal(P.sum(axis=1), 1, 1)
|
130 |
+
return P
|
131 |
+
|
132 |
+
|
133 |
+
def asymmetric_noise(self, asym=False, random_shuffle=False):
|
134 |
+
self.train_labels_gt = self.train_labels.copy()
|
135 |
+
P = np.eye(self.num_classes)
|
136 |
+
n = self.cfg_trainer['percent']
|
137 |
+
nb_superclasses = 20
|
138 |
+
nb_subclasses = 5
|
139 |
+
|
140 |
+
if n > 0.0:
|
141 |
+
for i in np.arange(nb_superclasses):
|
142 |
+
init, end = i * nb_subclasses, (i+1) * nb_subclasses
|
143 |
+
P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
|
144 |
+
|
145 |
+
y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
|
146 |
+
random_state=0)
|
147 |
+
actual_noise = (y_train_noisy != self.train_labels).mean()
|
148 |
+
assert actual_noise > 0.0
|
149 |
+
self.train_labels = y_train_noisy
|
150 |
+
#np.save(P_file, P)
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
def __getitem__(self, index):
|
159 |
+
"""
|
160 |
+
Args:
|
161 |
+
index (int): Index
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
tuple: (image, target) where target is index of the target class.
|
165 |
+
"""
|
166 |
+
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
|
167 |
+
|
168 |
+
|
169 |
+
# doing this so that it is consistent with all other datasets
|
170 |
+
# to return a PIL Image
|
171 |
+
|
172 |
+
img = Image.fromarray(img)
|
173 |
+
if self.transform is not None:
|
174 |
+
img = self.transform(img)
|
175 |
+
if self.target_transform is not None:
|
176 |
+
target = self.target_transform(target)
|
177 |
+
return img, target, index, target_gt
|
178 |
+
|
179 |
+
def __len__(self):
|
180 |
+
return len(self.train_data)
|
181 |
+
|
182 |
+
def rotate_img(self, img, rot):
|
183 |
+
if rot == 0: # 0 degrees rotation
|
184 |
+
return img
|
185 |
+
elif rot == 90: # 90 degrees rotation
|
186 |
+
return np.flipud(np.transpose(img, (1,0,2)))
|
187 |
+
elif rot == 180: # 90 degrees rotation
|
188 |
+
return np.fliplr(np.flipud(img))
|
189 |
+
elif rot == 270: # 270 degrees rotation / or -90
|
190 |
+
return np.transpose(np.flipud(img), (1,0,2))
|
191 |
+
else:
|
192 |
+
raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
|
193 |
+
|
194 |
+
def __len__(self):
|
195 |
+
return len(self.train_data)
|
196 |
+
|
197 |
+
|
198 |
+
class CIFAR100_val(torchvision.datasets.CIFAR100):
|
199 |
+
|
200 |
+
def __init__(self, root, cfg_trainer, indexs, train=True,
|
201 |
+
transform=None, target_transform=None,
|
202 |
+
download=False):
|
203 |
+
super(CIFAR100_val, self).__init__(root, train=train,
|
204 |
+
transform=transform, target_transform=target_transform,
|
205 |
+
download=download)
|
206 |
+
|
207 |
+
# self.train_data = self.data[indexs]
|
208 |
+
# self.train_labels = np.array(self.targets)[indexs]
|
209 |
+
self.num_classes = 100
|
210 |
+
self.cfg_trainer = cfg_trainer
|
211 |
+
if train:
|
212 |
+
self.train_data = self.data[indexs]
|
213 |
+
self.train_labels = np.array(self.targets)[indexs]
|
214 |
+
else:
|
215 |
+
self.train_data = self.data
|
216 |
+
self.train_labels = np.array(self.targets)
|
217 |
+
self.train_labels_gt = self.train_labels.copy()
|
218 |
+
def symmetric_noise(self):
|
219 |
+
indices = np.random.permutation(len(self.train_data))
|
220 |
+
for i, idx in enumerate(indices):
|
221 |
+
if i < self.cfg_trainer['percent'] * len(self.train_data):
|
222 |
+
self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
|
223 |
+
|
224 |
+
def multiclass_noisify(self, y, P, random_state=0):
|
225 |
+
""" Flip classes according to transition probability matrix T.
|
226 |
+
It expects a number between 0 and the number of classes - 1.
|
227 |
+
"""
|
228 |
+
|
229 |
+
assert P.shape[0] == P.shape[1]
|
230 |
+
assert np.max(y) < P.shape[0]
|
231 |
+
|
232 |
+
# row stochastic matrix
|
233 |
+
assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
|
234 |
+
assert (P >= 0.0).all()
|
235 |
+
|
236 |
+
m = y.shape[0]
|
237 |
+
new_y = y.copy()
|
238 |
+
flipper = np.random.RandomState(random_state)
|
239 |
+
|
240 |
+
for idx in np.arange(m):
|
241 |
+
i = y[idx]
|
242 |
+
# draw a vector with only an 1
|
243 |
+
flipped = flipper.multinomial(1, P[i, :], 1)[0]
|
244 |
+
new_y[idx] = np.where(flipped == 1)[0]
|
245 |
+
|
246 |
+
return new_y
|
247 |
+
|
248 |
+
def build_for_cifar100(self, size, noise):
|
249 |
+
""" The noise matrix flips to the "next" class with probability 'noise'.
|
250 |
+
"""
|
251 |
+
|
252 |
+
assert(noise >= 0.) and (noise <= 1.)
|
253 |
+
|
254 |
+
P = (1. - noise) * np.eye(size)
|
255 |
+
for i in np.arange(size - 1):
|
256 |
+
P[i, i + 1] = noise
|
257 |
+
|
258 |
+
# adjust last row
|
259 |
+
P[size - 1, 0] = noise
|
260 |
+
|
261 |
+
assert_array_almost_equal(P.sum(axis=1), 1, 1)
|
262 |
+
return P
|
263 |
+
|
264 |
+
|
265 |
+
def asymmetric_noise(self, asym=False, random_shuffle=False):
|
266 |
+
P = np.eye(self.num_classes)
|
267 |
+
n = self.cfg_trainer['percent']
|
268 |
+
nb_superclasses = 20
|
269 |
+
nb_subclasses = 5
|
270 |
+
|
271 |
+
if n > 0.0:
|
272 |
+
for i in np.arange(nb_superclasses):
|
273 |
+
init, end = i * nb_subclasses, (i+1) * nb_subclasses
|
274 |
+
P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
|
275 |
+
|
276 |
+
y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
|
277 |
+
random_state=0)
|
278 |
+
actual_noise = (y_train_noisy != self.train_labels).mean()
|
279 |
+
assert actual_noise > 0.0
|
280 |
+
self.train_labels = y_train_noisy
|
281 |
+
def __len__(self):
|
282 |
+
return len(self.train_data)
|
283 |
+
|
284 |
+
|
285 |
+
def __getitem__(self, index):
|
286 |
+
"""
|
287 |
+
Args:
|
288 |
+
index (int): Index
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
tuple: (image, target) where target is index of the target class.
|
292 |
+
"""
|
293 |
+
img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
|
294 |
+
|
295 |
+
|
296 |
+
# doing this so that it is consistent with all other datasets
|
297 |
+
# to return a PIL Image
|
298 |
+
img = Image.fromarray(img)
|
299 |
+
|
300 |
+
|
301 |
+
if self.transform is not None:
|
302 |
+
img = self.transform(img)
|
303 |
+
|
304 |
+
if self.target_transform is not None:
|
305 |
+
target = self.target_transform(target)
|
306 |
+
|
307 |
+
return img, target, index, target_gt
|
ELR_plus/data_loader/clothing1m.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision
|
6 |
+
from torch.utils.data.dataset import Subset
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import random
|
11 |
+
|
12 |
+
def get_clothing(root, cfg_trainer, num_samples=0, train=True,
|
13 |
+
transform_train=None, transform_val=None):
|
14 |
+
|
15 |
+
if train:
|
16 |
+
train_dataset = Clothing(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train)
|
17 |
+
val_dataset = Clothing(root, cfg_trainer, val=train, transform=transform_val)
|
18 |
+
print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}")
|
19 |
+
|
20 |
+
else:
|
21 |
+
train_dataset = []
|
22 |
+
val_dataset = Clothing(root, cfg_trainer, test= (not train), transform=transform_val)
|
23 |
+
print(f"Test: {len(val_dataset)}")
|
24 |
+
|
25 |
+
return train_dataset, val_dataset
|
26 |
+
|
27 |
+
class Clothing(torch.utils.data.Dataset):
|
28 |
+
|
29 |
+
def __init__(self, root, cfg_trainer, num_samples=0, train=False, val=False, test=False, transform=None, num_class = 14):
|
30 |
+
self.cfg_trainer = cfg_trainer
|
31 |
+
self.root = root
|
32 |
+
self.transform = transform
|
33 |
+
self.train_labels = {}
|
34 |
+
self.test_labels = {}
|
35 |
+
self.val_labels = {}
|
36 |
+
|
37 |
+
self.train = train
|
38 |
+
self.val = val
|
39 |
+
self.test = test
|
40 |
+
|
41 |
+
with open('%s/noisy_label_kv.txt'%self.root,'r') as f:
|
42 |
+
lines = f.read().splitlines()
|
43 |
+
for l in lines:
|
44 |
+
entry = l.split()
|
45 |
+
img_path = '%s/'%self.root+entry[0][7:]
|
46 |
+
self.train_labels[img_path] = int(entry[1])
|
47 |
+
with open('%s/clean_label_kv.txt'%self.root,'r') as f:
|
48 |
+
lines = f.read().splitlines()
|
49 |
+
for l in lines:
|
50 |
+
entry = l.split()
|
51 |
+
img_path = '%s/'%self.root+entry[0][7:]
|
52 |
+
self.test_labels[img_path] = int(entry[1])
|
53 |
+
|
54 |
+
if train:
|
55 |
+
train_imgs=[]
|
56 |
+
with open('%s/noisy_train_key_list.txt'%self.root,'r') as f:
|
57 |
+
lines = f.read().splitlines()
|
58 |
+
for i , l in enumerate(lines):
|
59 |
+
img_path = '%s/'%self.root+l[7:]
|
60 |
+
train_imgs.append((i,img_path))
|
61 |
+
self.num_raw_example = len(train_imgs)
|
62 |
+
random.shuffle(train_imgs)
|
63 |
+
class_num = torch.zeros(num_class)
|
64 |
+
self.train_imgs = []
|
65 |
+
for id_raw, impath in train_imgs:
|
66 |
+
label = self.train_labels[impath]
|
67 |
+
if class_num[label]<(num_samples/14) and len(self.train_imgs)<num_samples:
|
68 |
+
self.train_imgs.append((id_raw,impath))
|
69 |
+
class_num[label]+=1
|
70 |
+
random.shuffle(self.train_imgs)
|
71 |
+
|
72 |
+
elif test:
|
73 |
+
self.test_imgs = []
|
74 |
+
with open('%s/clean_test_key_list.txt'%self.root,'r') as f:
|
75 |
+
lines = f.read().splitlines()
|
76 |
+
for l in lines:
|
77 |
+
img_path = '%s/'%self.root+l[7:]
|
78 |
+
self.test_imgs.append(img_path)
|
79 |
+
elif val:
|
80 |
+
self.val_imgs = []
|
81 |
+
with open('%s/clean_val_key_list.txt'%self.root,'r') as f:
|
82 |
+
lines = f.read().splitlines()
|
83 |
+
for l in lines:
|
84 |
+
img_path = '%s/'%self.root+l[7:]
|
85 |
+
self.val_imgs.append(img_path)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def __getitem__(self, index):
|
90 |
+
if self.train:
|
91 |
+
id_raw, img_path = self.train_imgs[index]
|
92 |
+
target = self.train_labels[img_path]
|
93 |
+
elif self.val:
|
94 |
+
img_path = self.val_imgs[index]
|
95 |
+
target = self.test_labels[img_path]
|
96 |
+
elif self.test:
|
97 |
+
img_path = self.test_imgs[index]
|
98 |
+
target = self.test_labels[img_path]
|
99 |
+
image = Image.open(img_path).convert('RGB')
|
100 |
+
if self.train:
|
101 |
+
img0 = self.transform(image)
|
102 |
+
|
103 |
+
if self.test or self.val:
|
104 |
+
img = self.transform(image)
|
105 |
+
return img, target, index, target
|
106 |
+
else:
|
107 |
+
return img0, target, id_raw, target
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
if self.test:
|
113 |
+
return len(self.test_imgs)
|
114 |
+
if self.val:
|
115 |
+
return len(self.val_imgs)
|
116 |
+
else:
|
117 |
+
return len(self.train_imgs)
|
118 |
+
|
119 |
+
|
120 |
+
def flist_reader(self, flist):
|
121 |
+
imlist = []
|
122 |
+
with open(flist, 'r') as rf:
|
123 |
+
for line in rf.readlines():
|
124 |
+
row = line.split(" ")
|
125 |
+
impath = self.root + row[0]
|
126 |
+
imlabel = float(row[1].replace('\n',''))
|
127 |
+
imlist.append((impath, int(imlabel)))
|
128 |
+
return imlist
|
ELR_plus/data_loader/data_loaders.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
from torchvision import datasets, transforms
|
4 |
+
from base import BaseDataLoader
|
5 |
+
from data_loader.cifar10 import get_cifar10
|
6 |
+
from data_loader.cifar100 import get_cifar100
|
7 |
+
from data_loader.clothing1m import get_clothing
|
8 |
+
from data_loader.webvision import get_webvision
|
9 |
+
from parse_config import ConfigParser
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class CIFAR10DataLoader(BaseDataLoader):
|
18 |
+
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True):
|
19 |
+
config = ConfigParser.get_instance()
|
20 |
+
cfg_trainer = config['trainer']
|
21 |
+
|
22 |
+
transform_train = transforms.Compose([
|
23 |
+
transforms.RandomCrop(32, padding=4),
|
24 |
+
transforms.RandomHorizontalFlip(),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
27 |
+
])
|
28 |
+
transform_val = transforms.Compose([
|
29 |
+
transforms.ToTensor(),
|
30 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
31 |
+
])
|
32 |
+
self.data_dir = data_dir
|
33 |
+
|
34 |
+
noise_file='%sCIFAR10_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
|
35 |
+
|
36 |
+
self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
|
37 |
+
transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
|
38 |
+
|
39 |
+
super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
|
40 |
+
val_dataset = self.val_dataset)
|
41 |
+
|
42 |
+
|
43 |
+
class CIFAR100DataLoader(BaseDataLoader):
|
44 |
+
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True,num_workers=4, pin_memory=True):
|
45 |
+
config = ConfigParser.get_instance()
|
46 |
+
cfg_trainer = config['trainer']
|
47 |
+
|
48 |
+
transform_train = transforms.Compose([
|
49 |
+
#transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
|
50 |
+
transforms.RandomCrop(32, padding=4),
|
51 |
+
transforms.RandomHorizontalFlip(),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
|
54 |
+
])
|
55 |
+
transform_val = transforms.Compose([
|
56 |
+
transforms.ToTensor(),
|
57 |
+
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
|
58 |
+
])
|
59 |
+
self.data_dir = data_dir
|
60 |
+
|
61 |
+
noise_file='%sCIFAR100_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
|
62 |
+
|
63 |
+
self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
|
64 |
+
transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
|
65 |
+
|
66 |
+
super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
|
67 |
+
val_dataset = self.val_dataset)
|
68 |
+
|
69 |
+
self.batch_size_ = int(batch_size)
|
70 |
+
|
71 |
+
|
72 |
+
class Clothing1MDataLoader(BaseDataLoader):
|
73 |
+
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True):
|
74 |
+
|
75 |
+
self.batch_size = batch_size
|
76 |
+
self.num_workers = num_workers
|
77 |
+
self.num_batches = num_batches
|
78 |
+
self.training = training
|
79 |
+
|
80 |
+
self.transform_train = transforms.Compose([
|
81 |
+
transforms.Resize(256),
|
82 |
+
transforms.RandomCrop(224),
|
83 |
+
transforms.RandomHorizontalFlip(),
|
84 |
+
transforms.ToTensor(),
|
85 |
+
transforms.Normalize((0.6959, 0.6537, 0.6371),(0.3113, 0.3192, 0.3214)),
|
86 |
+
])
|
87 |
+
self.transform_val = transforms.Compose([
|
88 |
+
transforms.Resize(256),
|
89 |
+
transforms.CenterCrop(224),
|
90 |
+
transforms.ToTensor(),
|
91 |
+
transforms.Normalize((0.6959, 0.6537, 0.6371),(0.3113, 0.3192, 0.3214)),
|
92 |
+
])
|
93 |
+
|
94 |
+
self.data_dir = data_dir
|
95 |
+
config = ConfigParser.get_instance()
|
96 |
+
cfg_trainer = config['trainer']
|
97 |
+
self.train_dataset, self.val_dataset = get_clothing(config['data_loader']['args']['data_dir'], cfg_trainer, num_samples=self.num_batches*self.batch_size, train=training,
|
98 |
+
transform_train=self.transform_train, transform_val=self.transform_val)
|
99 |
+
|
100 |
+
super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
|
101 |
+
val_dataset = self.val_dataset)
|
102 |
+
|
103 |
+
class WebvisionDataLoader(BaseDataLoader):
|
104 |
+
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True, num_class = 50):
|
105 |
+
|
106 |
+
self.batch_size = batch_size
|
107 |
+
self.num_workers = num_workers
|
108 |
+
self.num_batches = num_batches
|
109 |
+
self.training = training
|
110 |
+
|
111 |
+
self.transform_train = transforms.Compose([
|
112 |
+
transforms.RandomCrop(227),
|
113 |
+
transforms.RandomHorizontalFlip(),
|
114 |
+
transforms.ToTensor(),
|
115 |
+
transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)),
|
116 |
+
])
|
117 |
+
self.transform_val = transforms.Compose([
|
118 |
+
transforms.CenterCrop(227),
|
119 |
+
transforms.ToTensor(),
|
120 |
+
transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)),
|
121 |
+
])
|
122 |
+
self.transform_imagenet = transforms.Compose([
|
123 |
+
transforms.Resize(256),
|
124 |
+
transforms.CenterCrop(227),
|
125 |
+
transforms.ToTensor(),
|
126 |
+
transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)),
|
127 |
+
])
|
128 |
+
|
129 |
+
self.data_dir = data_dir
|
130 |
+
config = ConfigParser.get_instance()
|
131 |
+
cfg_trainer = config['trainer']
|
132 |
+
self.train_dataset, self.val_dataset = get_webvision(config['data_loader']['args']['data_dir'], cfg_trainer, num_samples=self.num_batches*self.batch_size, train=training,
|
133 |
+
transform_train=self.transform_train, transform_val=self.transform_val, num_class = num_class)
|
134 |
+
|
135 |
+
super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
|
136 |
+
val_dataset = self.val_dataset)
|
137 |
+
|
ELR_plus/data_loader/webvision.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision
|
6 |
+
from torch.utils.data.dataset import Subset
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import random
|
11 |
+
|
12 |
+
def get_webvision(root, cfg_trainer, num_samples=0, train=True,
|
13 |
+
transform_train=None, transform_val=None, num_class = 50):
|
14 |
+
|
15 |
+
if train:
|
16 |
+
train_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train, num_class = num_class)
|
17 |
+
val_dataset = Webvision(root, cfg_trainer, num_samples=num_samples, val=train, transform=transform_val, num_class = num_class)
|
18 |
+
print(f"Train: {len(train_dataset)} WebVision Val: {len(val_dataset)}")
|
19 |
+
|
20 |
+
else:
|
21 |
+
train_dataset = []
|
22 |
+
val_dataset = ImagenetVal(root, transform=transform_val, num_class = num_class)
|
23 |
+
print(f"Imagnet Val: {len(val_dataset)}")
|
24 |
+
|
25 |
+
return train_dataset, val_dataset
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
class ImagenetVal(torch.utils.data.Dataset):
|
30 |
+
def __init__(self, root, transform, num_class):
|
31 |
+
self.root = root+'imagenet/'
|
32 |
+
self.transform = transform
|
33 |
+
|
34 |
+
|
35 |
+
with open(self.root+'imagenet_val.txt') as f:
|
36 |
+
lines=f.readlines()
|
37 |
+
self.val_imgs = []
|
38 |
+
self.val_labels = {}
|
39 |
+
for line in lines:
|
40 |
+
img, target = line.split()
|
41 |
+
target = int(target)
|
42 |
+
if target<num_class:
|
43 |
+
self.val_imgs.append(img)
|
44 |
+
self.val_labels[img]=target
|
45 |
+
|
46 |
+
def __getitem__(self, index):
|
47 |
+
|
48 |
+
img_path = self.val_imgs[index]
|
49 |
+
target = self.val_labels[img_path]
|
50 |
+
image = Image.open(self.root+'val/'+img_path).convert('RGB')
|
51 |
+
img = self.transform(image)
|
52 |
+
|
53 |
+
return img, target, index, target
|
54 |
+
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.val_imgs)
|
58 |
+
|
59 |
+
|
60 |
+
class Webvision(torch.utils.data.Dataset):
|
61 |
+
|
62 |
+
def __init__(self, root, cfg_trainer, num_samples=0, train=False, val=False, test=False, transform=None, num_class = 50):
|
63 |
+
self.cfg_trainer = cfg_trainer
|
64 |
+
self.root = root
|
65 |
+
self.transform = transform
|
66 |
+
self.train_labels = {}
|
67 |
+
self.test_labels = {}
|
68 |
+
self.val_labels = {}
|
69 |
+
|
70 |
+
self.train = train
|
71 |
+
self.val = val
|
72 |
+
self.test = test
|
73 |
+
|
74 |
+
if self.val:
|
75 |
+
with open(self.root+'info/val_filelist.txt') as f:
|
76 |
+
lines=f.readlines()
|
77 |
+
self.val_imgs = []
|
78 |
+
self.val_labels = {}
|
79 |
+
for line in lines:
|
80 |
+
img, target = line.split()
|
81 |
+
target = int(target)
|
82 |
+
if target<num_class:
|
83 |
+
self.val_imgs.append(img)
|
84 |
+
self.val_labels[img]=target
|
85 |
+
elif self.test:
|
86 |
+
with open(self.root+'info/val_filelist.txt') as f:
|
87 |
+
lines=f.readlines()
|
88 |
+
self.test_imgs = []
|
89 |
+
self.test_labels = {}
|
90 |
+
for line in lines:
|
91 |
+
img, target = line.split()
|
92 |
+
target = int(target)
|
93 |
+
if target<num_class:
|
94 |
+
self.test_imgs.append(img)
|
95 |
+
self.test_labels[img]=target
|
96 |
+
else:
|
97 |
+
with open(self.root+'info/train_filelist_google.txt') as f:
|
98 |
+
lines=f.readlines()
|
99 |
+
train_imgs = []
|
100 |
+
self.train_labels = {}
|
101 |
+
for line in lines:
|
102 |
+
img, target = line.split()
|
103 |
+
target = int(target)
|
104 |
+
if target<num_class:
|
105 |
+
train_imgs.append(img)
|
106 |
+
self.train_labels[img]=target
|
107 |
+
|
108 |
+
self.train_imgs = train_imgs
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
|
112 |
+
if self.train:
|
113 |
+
img_path = self.train_imgs[index]
|
114 |
+
target = self.train_labels[img_path]
|
115 |
+
image = Image.open(self.root+img_path)
|
116 |
+
img0 = image.convert('RGB')
|
117 |
+
img0 = self.transform(img0)
|
118 |
+
return img0, target, index, target
|
119 |
+
elif self.val:
|
120 |
+
img_path = self.val_imgs[index]
|
121 |
+
target = self.val_labels[img_path]
|
122 |
+
image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB')
|
123 |
+
img = self.transform(image)
|
124 |
+
return img, target, index, target
|
125 |
+
elif self.test:
|
126 |
+
img_path = self.test_imgs[index]
|
127 |
+
target = self.test_labels[img_path]
|
128 |
+
image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB')
|
129 |
+
img = self.transform(image)
|
130 |
+
return img, target, index, target
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
def __len__(self):
|
135 |
+
if self.test:
|
136 |
+
return len(self.test_imgs)
|
137 |
+
if self.val:
|
138 |
+
return len(self.val_imgs)
|
139 |
+
else:
|
140 |
+
return len(self.train_imgs)
|
ELR_plus/logger/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .logger import *
|
2 |
+
from .visualization import *
|
ELR_plus/logger/logger.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import logging.config
|
3 |
+
from pathlib import Path
|
4 |
+
from utils import read_json
|
5 |
+
|
6 |
+
|
7 |
+
def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
|
8 |
+
"""
|
9 |
+
Setup logging configuration
|
10 |
+
"""
|
11 |
+
log_config = Path(log_config)
|
12 |
+
if log_config.is_file():
|
13 |
+
config = read_json(log_config)
|
14 |
+
# modify logging paths based on run config
|
15 |
+
for _, handler in config['handlers'].items():
|
16 |
+
if 'filename' in handler:
|
17 |
+
handler['filename'] = str(save_dir / handler['filename'])
|
18 |
+
|
19 |
+
logging.config.dictConfig(config)
|
20 |
+
else:
|
21 |
+
print("Warning: logging configuration file is not found in {}.".format(log_config))
|
22 |
+
logging.basicConfig(level=default_level)
|
ELR_plus/logger/logger_config.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
{
|
3 |
+
"version": 1,
|
4 |
+
"disable_existing_loggers": false,
|
5 |
+
"formatters": {
|
6 |
+
"simple": {"format": "%(message)s"},
|
7 |
+
"datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
|
8 |
+
},
|
9 |
+
"handlers": {
|
10 |
+
"console": {
|
11 |
+
"class": "logging.StreamHandler",
|
12 |
+
"level": "DEBUG",
|
13 |
+
"formatter": "simple",
|
14 |
+
"stream": "ext://sys.stdout"
|
15 |
+
},
|
16 |
+
"info_file_handler": {
|
17 |
+
"class": "logging.handlers.RotatingFileHandler",
|
18 |
+
"level": "INFO",
|
19 |
+
"formatter": "datetime",
|
20 |
+
"filename": "info.log",
|
21 |
+
"maxBytes": 10485760,
|
22 |
+
"backupCount": 20, "encoding": "utf8"
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"root": {
|
26 |
+
"level": "INFO",
|
27 |
+
"handlers": [
|
28 |
+
"console",
|
29 |
+
"info_file_handler"
|
30 |
+
]
|
31 |
+
}
|
32 |
+
}
|
ELR_plus/logger/visualization.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from utils import Timer
|
3 |
+
|
4 |
+
|
5 |
+
class MLFlow:
|
6 |
+
def __init__(self, log_dir, logger, enabled):
|
7 |
+
self.mlflow = None
|
8 |
+
|
9 |
+
if enabled:
|
10 |
+
log_dir = str(log_dir)
|
11 |
+
|
12 |
+
# Retrieve visualization writer.
|
13 |
+
try:
|
14 |
+
self.mlflow = importlib.import_module("mlflow")
|
15 |
+
succeeded = True
|
16 |
+
except ImportError:
|
17 |
+
succeeded = False
|
18 |
+
|
19 |
+
if not succeeded:
|
20 |
+
message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \
|
21 |
+
"this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \
|
22 |
+
"the 'config.json' file."
|
23 |
+
logger.warning(message)
|
24 |
+
|
25 |
+
self.step = 0
|
26 |
+
self.mode = ''
|
27 |
+
|
28 |
+
self.mlflow_ftns_with_tag_and_value = {
|
29 |
+
'log_param', 'log_metric'
|
30 |
+
}
|
31 |
+
self.mlflow_ftns = {
|
32 |
+
'start_run'
|
33 |
+
}
|
34 |
+
# self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
|
35 |
+
|
36 |
+
# self.timer = Timer()
|
37 |
+
|
38 |
+
# def set_step(self, step, mode='train'):
|
39 |
+
# self.mode = mode
|
40 |
+
# self.step = step
|
41 |
+
# if step == 0:
|
42 |
+
# self.timer.reset()
|
43 |
+
# else:
|
44 |
+
# duration = self.timer.check()
|
45 |
+
# self.add_scalar('steps_per_sec', 1 / duration)
|
46 |
+
|
47 |
+
def __getattr__(self, name):
|
48 |
+
"""
|
49 |
+
If visualization is configured to use:
|
50 |
+
return add_data() methods of tensorboard with additional information (step, tag) added.
|
51 |
+
Otherwise:
|
52 |
+
return a blank function handle that does nothing
|
53 |
+
"""
|
54 |
+
if name in self.mlflow_ftns_with_tag_and_value:
|
55 |
+
add_data = getattr(self.mlflow, name, None)
|
56 |
+
|
57 |
+
def wrapper(tag, data, *args, **kwargs):
|
58 |
+
if add_data is not None:
|
59 |
+
# add mode(train/valid) tag
|
60 |
+
if name not in self.tag_mode_exceptions:
|
61 |
+
tag = '{}/{}'.format(tag, self.mode)
|
62 |
+
add_data(tag, data, *args, **kwargs)
|
63 |
+
|
64 |
+
return wrapper
|
65 |
+
elif name in self.mlflow_ftns:
|
66 |
+
add_data = getattr(self.mlflow, name, None)
|
67 |
+
|
68 |
+
def wrapper(*args, **kwargs):
|
69 |
+
if add_data is not None:
|
70 |
+
# add mode(train/valid) tag
|
71 |
+
# if name not in self.tag_mode_exceptions:
|
72 |
+
# tag = '{}/{}'.format(tag, self.mode)
|
73 |
+
add_data(*args, **kwargs)
|
74 |
+
|
75 |
+
return wrapper
|
76 |
+
else:
|
77 |
+
# default action for returning methods defined in this class, set_step() for instance.
|
78 |
+
try:
|
79 |
+
attr = object.__getattr__(name)
|
80 |
+
except AttributeError:
|
81 |
+
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
|
82 |
+
return attr
|
83 |
+
|
84 |
+
|
85 |
+
class TensorboardWriter:
|
86 |
+
def __init__(self, log_dir, logger, enabled):
|
87 |
+
self.writer = None
|
88 |
+
self.selected_module = ""
|
89 |
+
|
90 |
+
if enabled:
|
91 |
+
log_dir = str(log_dir)
|
92 |
+
|
93 |
+
# Retrieve vizualization writer.
|
94 |
+
succeeded = False
|
95 |
+
for module in ["torch.utils.tensorboard", "tensorboardX"]:
|
96 |
+
try:
|
97 |
+
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
|
98 |
+
succeeded = True
|
99 |
+
break
|
100 |
+
except ImportError:
|
101 |
+
succeeded = False
|
102 |
+
self.selected_module = module
|
103 |
+
|
104 |
+
if not succeeded:
|
105 |
+
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
|
106 |
+
"this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
|
107 |
+
"PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
|
108 |
+
"the 'config.json' file."
|
109 |
+
logger.warning(message)
|
110 |
+
|
111 |
+
self.step = 0
|
112 |
+
self.mode = ''
|
113 |
+
|
114 |
+
self.tb_writer_ftns = {
|
115 |
+
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
|
116 |
+
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
|
117 |
+
}
|
118 |
+
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
|
119 |
+
|
120 |
+
self.timer = Timer()
|
121 |
+
|
122 |
+
def set_step(self, step, mode='train'):
|
123 |
+
self.mode = mode
|
124 |
+
self.step = step
|
125 |
+
if step == 0:
|
126 |
+
self.timer.reset()
|
127 |
+
else:
|
128 |
+
duration = self.timer.check()
|
129 |
+
self.add_scalar('steps_per_sec', 1 / duration)
|
130 |
+
|
131 |
+
def __getattr__(self, name):
|
132 |
+
"""
|
133 |
+
If visualization is configured to use:
|
134 |
+
return add_data() methods of tensorboard with additional information (step, tag) added.
|
135 |
+
Otherwise:
|
136 |
+
return a blank function handle that does nothing
|
137 |
+
"""
|
138 |
+
if name in self.tb_writer_ftns:
|
139 |
+
add_data = getattr(self.writer, name, None)
|
140 |
+
|
141 |
+
def wrapper(tag, data, *args, **kwargs):
|
142 |
+
if add_data is not None:
|
143 |
+
# add mode(train/valid) tag
|
144 |
+
if name not in self.tag_mode_exceptions:
|
145 |
+
tag = '{}/{}'.format(tag, self.mode)
|
146 |
+
add_data(tag, data, self.step, *args, **kwargs)
|
147 |
+
return wrapper
|
148 |
+
else:
|
149 |
+
# default action for returning methods defined in this class, set_step() for instance.
|
150 |
+
try:
|
151 |
+
attr = object.__getattr__(name)
|
152 |
+
except AttributeError:
|
153 |
+
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
|
154 |
+
return attr
|
ELR_plus/model/InceptionResNetV2.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division, absolute_import
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
|
8 |
+
class BasicConv2d(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
11 |
+
super(BasicConv2d, self).__init__()
|
12 |
+
self.conv = nn.Conv2d(in_planes, out_planes,
|
13 |
+
kernel_size=kernel_size, stride=stride,
|
14 |
+
padding=padding, bias=False) # verify bias false
|
15 |
+
self.bn = nn.BatchNorm2d(out_planes,
|
16 |
+
eps=0.001, # value found in tensorflow
|
17 |
+
momentum=0.1, # default pytorch value
|
18 |
+
affine=True)
|
19 |
+
self.relu = nn.ReLU(inplace=False)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.conv(x)
|
23 |
+
x = self.bn(x)
|
24 |
+
x = self.relu(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
class Mixed_5b(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
super(Mixed_5b, self).__init__()
|
32 |
+
|
33 |
+
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
|
34 |
+
|
35 |
+
self.branch1 = nn.Sequential(
|
36 |
+
BasicConv2d(192, 48, kernel_size=1, stride=1),
|
37 |
+
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.branch2 = nn.Sequential(
|
41 |
+
BasicConv2d(192, 64, kernel_size=1, stride=1),
|
42 |
+
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
43 |
+
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
44 |
+
)
|
45 |
+
|
46 |
+
self.branch3 = nn.Sequential(
|
47 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
48 |
+
BasicConv2d(192, 64, kernel_size=1, stride=1)
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x0 = self.branch0(x)
|
53 |
+
x1 = self.branch1(x)
|
54 |
+
x2 = self.branch2(x)
|
55 |
+
x3 = self.branch3(x)
|
56 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class Block35(nn.Module):
|
61 |
+
|
62 |
+
def __init__(self, scale=1.0):
|
63 |
+
super(Block35, self).__init__()
|
64 |
+
|
65 |
+
self.scale = scale
|
66 |
+
|
67 |
+
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
|
68 |
+
|
69 |
+
self.branch1 = nn.Sequential(
|
70 |
+
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
71 |
+
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
72 |
+
)
|
73 |
+
|
74 |
+
self.branch2 = nn.Sequential(
|
75 |
+
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
76 |
+
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
|
77 |
+
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
|
78 |
+
)
|
79 |
+
|
80 |
+
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
|
81 |
+
self.relu = nn.ReLU(inplace=False)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x0 = self.branch0(x)
|
85 |
+
x1 = self.branch1(x)
|
86 |
+
x2 = self.branch2(x)
|
87 |
+
out = torch.cat((x0, x1, x2), 1)
|
88 |
+
out = self.conv2d(out)
|
89 |
+
out = out * self.scale + x
|
90 |
+
out = self.relu(out)
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class Mixed_6a(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self):
|
97 |
+
super(Mixed_6a, self).__init__()
|
98 |
+
|
99 |
+
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
|
100 |
+
|
101 |
+
self.branch1 = nn.Sequential(
|
102 |
+
BasicConv2d(320, 256, kernel_size=1, stride=1),
|
103 |
+
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
104 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
105 |
+
)
|
106 |
+
|
107 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x0 = self.branch0(x)
|
111 |
+
x1 = self.branch1(x)
|
112 |
+
x2 = self.branch2(x)
|
113 |
+
out = torch.cat((x0, x1, x2), 1)
|
114 |
+
return out
|
115 |
+
|
116 |
+
|
117 |
+
class Block17(nn.Module):
|
118 |
+
|
119 |
+
def __init__(self, scale=1.0):
|
120 |
+
super(Block17, self).__init__()
|
121 |
+
|
122 |
+
self.scale = scale
|
123 |
+
|
124 |
+
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
|
125 |
+
|
126 |
+
self.branch1 = nn.Sequential(
|
127 |
+
BasicConv2d(1088, 128, kernel_size=1, stride=1),
|
128 |
+
BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
|
129 |
+
BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0))
|
130 |
+
)
|
131 |
+
|
132 |
+
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
|
133 |
+
self.relu = nn.ReLU(inplace=False)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
x0 = self.branch0(x)
|
137 |
+
x1 = self.branch1(x)
|
138 |
+
out = torch.cat((x0, x1), 1)
|
139 |
+
out = self.conv2d(out)
|
140 |
+
out = out * self.scale + x
|
141 |
+
out = self.relu(out)
|
142 |
+
return out
|
143 |
+
|
144 |
+
|
145 |
+
class Mixed_7a(nn.Module):
|
146 |
+
|
147 |
+
def __init__(self):
|
148 |
+
super(Mixed_7a, self).__init__()
|
149 |
+
|
150 |
+
self.branch0 = nn.Sequential(
|
151 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
152 |
+
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
153 |
+
)
|
154 |
+
|
155 |
+
self.branch1 = nn.Sequential(
|
156 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
157 |
+
BasicConv2d(256, 288, kernel_size=3, stride=2)
|
158 |
+
)
|
159 |
+
|
160 |
+
self.branch2 = nn.Sequential(
|
161 |
+
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
162 |
+
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
|
163 |
+
BasicConv2d(288, 320, kernel_size=3, stride=2)
|
164 |
+
)
|
165 |
+
|
166 |
+
self.branch3 = nn.MaxPool2d(3, stride=2)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
x0 = self.branch0(x)
|
170 |
+
x1 = self.branch1(x)
|
171 |
+
x2 = self.branch2(x)
|
172 |
+
x3 = self.branch3(x)
|
173 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
174 |
+
return out
|
175 |
+
|
176 |
+
|
177 |
+
class Block8(nn.Module):
|
178 |
+
|
179 |
+
def __init__(self, scale=1.0, noReLU=False):
|
180 |
+
super(Block8, self).__init__()
|
181 |
+
|
182 |
+
self.scale = scale
|
183 |
+
self.noReLU = noReLU
|
184 |
+
|
185 |
+
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
186 |
+
|
187 |
+
self.branch1 = nn.Sequential(
|
188 |
+
BasicConv2d(2080, 192, kernel_size=1, stride=1),
|
189 |
+
BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
|
190 |
+
BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0))
|
191 |
+
)
|
192 |
+
|
193 |
+
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
194 |
+
if not self.noReLU:
|
195 |
+
self.relu = nn.ReLU(inplace=False)
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
x0 = self.branch0(x)
|
199 |
+
x1 = self.branch1(x)
|
200 |
+
out = torch.cat((x0, x1), 1)
|
201 |
+
out = self.conv2d(out)
|
202 |
+
out = out * self.scale + x
|
203 |
+
if not self.noReLU:
|
204 |
+
out = self.relu(out)
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
class InceptionResNetV2(nn.Module):
|
209 |
+
|
210 |
+
def __init__(self, num_classes=50):
|
211 |
+
super(InceptionResNetV2, self).__init__()
|
212 |
+
# Special attributs
|
213 |
+
self.input_space = None
|
214 |
+
self.input_size = (299, 299, 3)
|
215 |
+
self.mean = None
|
216 |
+
self.std = None
|
217 |
+
# Modules
|
218 |
+
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
219 |
+
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
220 |
+
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
221 |
+
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
222 |
+
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
223 |
+
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
224 |
+
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
225 |
+
self.mixed_5b = Mixed_5b()
|
226 |
+
self.repeat = nn.Sequential(
|
227 |
+
Block35(scale=0.17),
|
228 |
+
Block35(scale=0.17),
|
229 |
+
Block35(scale=0.17),
|
230 |
+
Block35(scale=0.17),
|
231 |
+
Block35(scale=0.17),
|
232 |
+
Block35(scale=0.17),
|
233 |
+
Block35(scale=0.17),
|
234 |
+
Block35(scale=0.17),
|
235 |
+
Block35(scale=0.17),
|
236 |
+
Block35(scale=0.17)
|
237 |
+
)
|
238 |
+
self.mixed_6a = Mixed_6a()
|
239 |
+
self.repeat_1 = nn.Sequential(
|
240 |
+
Block17(scale=0.10),
|
241 |
+
Block17(scale=0.10),
|
242 |
+
Block17(scale=0.10),
|
243 |
+
Block17(scale=0.10),
|
244 |
+
Block17(scale=0.10),
|
245 |
+
Block17(scale=0.10),
|
246 |
+
Block17(scale=0.10),
|
247 |
+
Block17(scale=0.10),
|
248 |
+
Block17(scale=0.10),
|
249 |
+
Block17(scale=0.10),
|
250 |
+
Block17(scale=0.10),
|
251 |
+
Block17(scale=0.10),
|
252 |
+
Block17(scale=0.10),
|
253 |
+
Block17(scale=0.10),
|
254 |
+
Block17(scale=0.10),
|
255 |
+
Block17(scale=0.10),
|
256 |
+
Block17(scale=0.10),
|
257 |
+
Block17(scale=0.10),
|
258 |
+
Block17(scale=0.10),
|
259 |
+
Block17(scale=0.10)
|
260 |
+
)
|
261 |
+
self.mixed_7a = Mixed_7a()
|
262 |
+
self.repeat_2 = nn.Sequential(
|
263 |
+
Block8(scale=0.20),
|
264 |
+
Block8(scale=0.20),
|
265 |
+
Block8(scale=0.20),
|
266 |
+
Block8(scale=0.20),
|
267 |
+
Block8(scale=0.20),
|
268 |
+
Block8(scale=0.20),
|
269 |
+
Block8(scale=0.20),
|
270 |
+
Block8(scale=0.20),
|
271 |
+
Block8(scale=0.20)
|
272 |
+
)
|
273 |
+
self.block8 = Block8(noReLU=True)
|
274 |
+
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
|
275 |
+
self.avgpool_1a = nn.AdaptiveAvgPool2d((1, 1))#nn.AvgPool2d(8, count_include_pad=False)
|
276 |
+
self.last_linear = nn.Linear(1536, num_classes)
|
277 |
+
|
278 |
+
|
279 |
+
def features(self, input):
|
280 |
+
x = self.conv2d_1a(input)
|
281 |
+
x = self.conv2d_2a(x)
|
282 |
+
x = self.conv2d_2b(x)
|
283 |
+
x = self.maxpool_3a(x)
|
284 |
+
x = self.conv2d_3b(x)
|
285 |
+
x = self.conv2d_4a(x)
|
286 |
+
x = self.maxpool_5a(x)
|
287 |
+
x = self.mixed_5b(x)
|
288 |
+
x = self.repeat(x)
|
289 |
+
x = self.mixed_6a(x)
|
290 |
+
x = self.repeat_1(x)
|
291 |
+
x = self.mixed_7a(x)
|
292 |
+
x = self.repeat_2(x)
|
293 |
+
x = self.block8(x)
|
294 |
+
x = self.conv2d_7b(x)
|
295 |
+
return x
|
296 |
+
|
297 |
+
def logits(self, features):
|
298 |
+
x = self.avgpool_1a(features)
|
299 |
+
x = x.view(x.size(0), -1)
|
300 |
+
out = self.last_linear(x)
|
301 |
+
return out
|
302 |
+
|
303 |
+
|
304 |
+
def forward(self, input):
|
305 |
+
x = self.features(input)
|
306 |
+
out = self.logits(x)
|
307 |
+
return out
|
308 |
+
|
309 |
+
|
310 |
+
def test():
|
311 |
+
net = InceptionResNetV2().cuda()
|
312 |
+
y = net(torch.randn(1,3,227,227).cuda())
|
313 |
+
print(y.size())
|
314 |
+
#test()
|