import importlib from utils import Timer class MLFlow: def __init__(self, log_dir, logger, enabled): self.mlflow = None if enabled: log_dir = str(log_dir) # Retrieve visualization writer. try: self.mlflow = importlib.import_module("mlflow") succeeded = True except ImportError: succeeded = False if not succeeded: message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \ "this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \ "the 'config.json' file." logger.warning(message) self.step = 0 self.mode = '' self.mlflow_ftns_with_tag_and_value = { 'log_param', 'log_metric' } self.mlflow_ftns = { 'start_run' } # self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} # self.timer = Timer() # def set_step(self, step, mode='train'): # self.mode = mode # self.step = step # if step == 0: # self.timer.reset() # else: # duration = self.timer.check() # self.add_scalar('steps_per_sec', 1 / duration) def __getattr__(self, name): """ If visualization is configured to use: return add_data() methods of tensorboard with additional information (step, tag) added. Otherwise: return a blank function handle that does nothing """ if name in self.mlflow_ftns_with_tag_and_value: add_data = getattr(self.mlflow, name, None) def wrapper(tag, data, *args, **kwargs): if add_data is not None: # add mode(train/valid) tag if name not in self.tag_mode_exceptions: tag = '{}/{}'.format(tag, self.mode) add_data(tag, data, *args, **kwargs) return wrapper elif name in self.mlflow_ftns: add_data = getattr(self.mlflow, name, None) def wrapper(*args, **kwargs): if add_data is not None: # add mode(train/valid) tag # if name not in self.tag_mode_exceptions: # tag = '{}/{}'.format(tag, self.mode) add_data(*args, **kwargs) return wrapper else: # default action for returning methods defined in this class, set_step() for instance. try: attr = object.__getattr__(name) except AttributeError: raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) return attr class TensorboardWriter: def __init__(self, log_dir, logger, enabled): self.writer = None self.selected_module = "" if enabled: log_dir = str(log_dir) # Retrieve vizualization writer. succeeded = False for module in ["torch.utils.tensorboard", "tensorboardX"]: try: self.writer = importlib.import_module(module).SummaryWriter(log_dir) succeeded = True break except ImportError: succeeded = False self.selected_module = module if not succeeded: message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \ "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ "the 'config.json' file." logger.warning(message) self.step = 0 self.mode = '' self.tb_writer_ftns = { 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' } self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} self.timer = Timer() def set_step(self, step, mode='train'): self.mode = mode self.step = step if step == 0: self.timer.reset() else: duration = self.timer.check() self.add_scalar('steps_per_sec', 1 / duration) def __getattr__(self, name): """ If visualization is configured to use: return add_data() methods of tensorboard with additional information (step, tag) added. Otherwise: return a blank function handle that does nothing """ if name in self.tb_writer_ftns: add_data = getattr(self.writer, name, None) def wrapper(tag, data, *args, **kwargs): if add_data is not None: # add mode(train/valid) tag if name not in self.tag_mode_exceptions: tag = '{}/{}'.format(tag, self.mode) add_data(tag, data, self.step, *args, **kwargs) return wrapper else: # default action for returning methods defined in this class, set_step() for instance. try: attr = object.__getattr__(name) except AttributeError: raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) return attr