Spaces:
Configuration error
Configuration error
import os | |
import logging | |
from pathlib import Path | |
from functools import reduce | |
from operator import getitem | |
from datetime import datetime | |
from logger import setup_logging | |
from utils import read_json, write_json | |
class ConfigParser: | |
__instance = None | |
def __new__(cls, args, options='', timestamp=True): | |
raise NotImplementedError('Cannot initialize via Constructor') | |
def __internal_new__(cls): | |
return super().__new__(cls) | |
def get_instance(cls, args=None, options='', timestamp=True): | |
if not cls.__instance: | |
if args is None: | |
NotImplementedError('Cannot initialize without args') | |
cls.__instance = cls.__internal_new__() | |
cls.__instance.__init__(args, options) | |
return cls.__instance | |
def __init__(self, args, options='', timestamp=True): | |
# parse default and custom cli options | |
for opt in options: | |
args.add_argument(*opt.flags, default=None, type=opt.type) | |
args = args.parse_args() | |
if args.device: | |
os.environ["CUDA_VISIBLE_DEVICES"] = args.device | |
if args.resume is None: | |
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." | |
assert args.config is not None, msg_no_cfg | |
self.cfg_fname = Path(args.config) | |
config = read_json(self.cfg_fname) | |
self.resume = None | |
else: | |
self.resume = Path(args.resume) | |
resume_cfg_fname = self.resume.parent / 'config.json' | |
config = read_json(resume_cfg_fname) | |
if args.config is not None: | |
config.update(read_json(Path(args.config))) | |
# load config file and apply custom cli options | |
self._config = _update_config(config, options, args) | |
# set save_dir where trained model and log will be saved. | |
save_dir = Path(self.config['trainer']['save_dir']) | |
timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' | |
if self.config['trainer']['asym']: | |
exper_name = self.config['name'] + '_asym_' + str(int(self.config['trainer']['percent']*100)) | |
else: | |
exper_name = self.config['name'] + '_sym_' + str(int(self.config['trainer']['percent']*100)) | |
self._save_dir = save_dir / 'models' / exper_name / timestamp | |
self._log_dir = save_dir / 'log' / exper_name / timestamp | |
self.save_dir.mkdir(parents=True, exist_ok=True) | |
self.log_dir.mkdir(parents=True, exist_ok=True) | |
# save updated config file to the checkpoint dir | |
write_json(self.config, self.save_dir / 'config.json') | |
# configure logging module | |
setup_logging(self.log_dir) | |
self.log_levels = { | |
0: logging.WARNING, | |
1: logging.INFO, | |
2: logging.DEBUG | |
} | |
def initialize(self, name, module, *args, **kwargs): | |
""" | |
finds a function handle with the name given as 'type' in config, and returns the | |
instance initialized with corresponding keyword args given as 'args'. | |
""" | |
module_name = self[name]['type'] | |
module_args = dict(self[name]['args']) | |
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' | |
module_args.update(kwargs) | |
return getattr(module, module_name)(*args, **module_args) | |
def __getitem__(self, name): | |
return self.config[name] | |
def get_logger(self, name, verbosity=2): | |
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, | |
self.log_levels.keys()) | |
assert verbosity in self.log_levels, msg_verbosity | |
logger = logging.getLogger(name) | |
logger.setLevel(self.log_levels[verbosity]) | |
return logger | |
# setting read-only attributes | |
def config(self): | |
return self._config | |
def save_dir(self): | |
return self._save_dir | |
def log_dir(self): | |
return self._log_dir | |
# helper functions used to update config dict with custom cli options | |
def _update_config(config, options, args): | |
for opt in options: | |
value = getattr(args, _get_opt_name(opt.flags)) | |
if value is not None: | |
_set_by_path(config, opt.target, value) | |
if 'target2' in opt._fields: | |
_set_by_path(config, opt.target2, value) | |
if 'target3' in opt._fields: | |
_set_by_path(config, opt.target3, value) | |
return config | |
def _get_opt_name(flags): | |
for flg in flags: | |
if flg.startswith('--'): | |
return flg.replace('--', '') | |
return flags[0].replace('--', '') | |
def _set_by_path(tree, keys, value): | |
"""Set a value in a nested object in tree by sequence of keys.""" | |
_get_by_path(tree, keys[:-1])[keys[-1]] = value | |
def _get_by_path(tree, keys): | |
"""Access a nested object in tree by sequence of keys.""" | |
return reduce(getitem, keys, tree) | |