assi1 / ELR_plus /logger /visualization.py
uthurumella's picture
Upload 69 files
72fc481 verified
import importlib
from utils import Timer
class MLFlow:
def __init__(self, log_dir, logger, enabled):
self.mlflow = None
if enabled:
log_dir = str(log_dir)
# Retrieve visualization writer.
try:
self.mlflow = importlib.import_module("mlflow")
succeeded = True
except ImportError:
succeeded = False
if not succeeded:
message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \
"this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \
"the 'config.json' file."
logger.warning(message)
self.step = 0
self.mode = ''
self.mlflow_ftns_with_tag_and_value = {
'log_param', 'log_metric'
}
self.mlflow_ftns = {
'start_run'
}
# self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
# self.timer = Timer()
# def set_step(self, step, mode='train'):
# self.mode = mode
# self.step = step
# if step == 0:
# self.timer.reset()
# else:
# duration = self.timer.check()
# self.add_scalar('steps_per_sec', 1 / duration)
def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return a blank function handle that does nothing
"""
if name in self.mlflow_ftns_with_tag_and_value:
add_data = getattr(self.mlflow, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = '{}/{}'.format(tag, self.mode)
add_data(tag, data, *args, **kwargs)
return wrapper
elif name in self.mlflow_ftns:
add_data = getattr(self.mlflow, name, None)
def wrapper(*args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
# if name not in self.tag_mode_exceptions:
# tag = '{}/{}'.format(tag, self.mode)
add_data(*args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr
class TensorboardWriter:
def __init__(self, log_dir, logger, enabled):
self.writer = None
self.selected_module = ""
if enabled:
log_dir = str(log_dir)
# Retrieve vizualization writer.
succeeded = False
for module in ["torch.utils.tensorboard", "tensorboardX"]:
try:
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
succeeded = True
break
except ImportError:
succeeded = False
self.selected_module = module
if not succeeded:
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
"this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
"PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
"the 'config.json' file."
logger.warning(message)
self.step = 0
self.mode = ''
self.tb_writer_ftns = {
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
}
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
self.timer = Timer()
def set_step(self, step, mode='train'):
self.mode = mode
self.step = step
if step == 0:
self.timer.reset()
else:
duration = self.timer.check()
self.add_scalar('steps_per_sec', 1 / duration)
def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return a blank function handle that does nothing
"""
if name in self.tb_writer_ftns:
add_data = getattr(self.writer, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = '{}/{}'.format(tag, self.mode)
add_data(tag, data, self.step, *args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr