Spaces:
Configuration error
Configuration error
File size: 5,288 Bytes
72fc481 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import logging
from pathlib import Path
from functools import reduce
from operator import getitem
from datetime import datetime
from logger import setup_logging
from utils import read_json, write_json
class ConfigParser:
__instance = None
def __new__(cls, args, options='', timestamp=True):
raise NotImplementedError('Cannot initialize via Constructor')
@classmethod
def __internal_new__(cls):
return super().__new__(cls)
@classmethod
def get_instance(cls, args=None, options='', timestamp=True):
if not cls.__instance:
if args is None:
NotImplementedError('Cannot initialize without args')
cls.__instance = cls.__internal_new__()
cls.__instance.__init__(args, options)
return cls.__instance
def __init__(self, args, options='', timestamp=True):
# parse default and custom cli options
for opt in options:
args.add_argument(*opt.flags, default=None, type=opt.type)
args = args.parse_args()
if args.device:
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
if args.resume is None:
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
assert args.config is not None, msg_no_cfg
self.cfg_fname = Path(args.config)
config = read_json(self.cfg_fname)
self.resume = None
else:
self.resume = Path(args.resume)
resume_cfg_fname = self.resume.parent / 'config.json'
config = read_json(resume_cfg_fname)
if args.config is not None:
config.update(read_json(Path(args.config)))
# load config file and apply custom cli options
self._config = _update_config(config, options, args)
# set save_dir where trained model and log will be saved.
save_dir = Path(self.config['trainer']['save_dir'])
timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
if self.config['trainer']['asym']:
exper_name = self.config['name'] + '_asym_' + str(int(self.config['trainer']['percent']*100))
else:
exper_name = self.config['name'] + '_sym_' + str(int(self.config['trainer']['percent']*100))
self._save_dir = save_dir / 'models' / exper_name / timestamp
self._log_dir = save_dir / 'log' / exper_name / timestamp
self.save_dir.mkdir(parents=True, exist_ok=True)
self.log_dir.mkdir(parents=True, exist_ok=True)
# save updated config file to the checkpoint dir
write_json(self.config, self.save_dir / 'config.json')
# configure logging module
setup_logging(self.log_dir)
self.log_levels = {
0: logging.WARNING,
1: logging.INFO,
2: logging.DEBUG
}
def initialize(self, name, module, *args, **kwargs):
"""
finds a function handle with the name given as 'type' in config, and returns the
instance initialized with corresponding keyword args given as 'args'.
"""
module_name = self[name]['type']
module_args = dict(self[name]['args'])
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return getattr(module, module_name)(*args, **module_args)
def __getitem__(self, name):
return self.config[name]
def get_logger(self, name, verbosity=2):
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
self.log_levels.keys())
assert verbosity in self.log_levels, msg_verbosity
logger = logging.getLogger(name)
logger.setLevel(self.log_levels[verbosity])
return logger
# setting read-only attributes
@property
def config(self):
return self._config
@property
def save_dir(self):
return self._save_dir
@property
def log_dir(self):
return self._log_dir
# helper functions used to update config dict with custom cli options
def _update_config(config, options, args):
for opt in options:
value = getattr(args, _get_opt_name(opt.flags))
if value is not None:
_set_by_path(config, opt.target, value)
if 'target2' in opt._fields:
_set_by_path(config, opt.target2, value)
if 'target3' in opt._fields:
_set_by_path(config, opt.target3, value)
return config
def _get_opt_name(flags):
for flg in flags:
if flg.startswith('--'):
return flg.replace('--', '')
return flags[0].replace('--', '')
def _set_by_path(tree, keys, value):
"""Set a value in a nested object in tree by sequence of keys."""
_get_by_path(tree, keys[:-1])[keys[-1]] = value
def _get_by_path(tree, keys):
"""Access a nested object in tree by sequence of keys."""
return reduce(getitem, keys, tree)
|