Spaces:
Configuration error
Configuration error
File size: 5,851 Bytes
72fc481 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import importlib
from utils import Timer
class MLFlow:
def __init__(self, log_dir, logger, enabled):
self.mlflow = None
if enabled:
log_dir = str(log_dir)
# Retrieve visualization writer.
try:
self.mlflow = importlib.import_module("mlflow")
succeeded = True
except ImportError:
succeeded = False
if not succeeded:
message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \
"this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \
"the 'config.json' file."
logger.warning(message)
self.step = 0
self.mode = ''
self.mlflow_ftns_with_tag_and_value = {
'log_param', 'log_metric'
}
self.mlflow_ftns = {
'start_run'
}
# self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
# self.timer = Timer()
# def set_step(self, step, mode='train'):
# self.mode = mode
# self.step = step
# if step == 0:
# self.timer.reset()
# else:
# duration = self.timer.check()
# self.add_scalar('steps_per_sec', 1 / duration)
def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return a blank function handle that does nothing
"""
if name in self.mlflow_ftns_with_tag_and_value:
add_data = getattr(self.mlflow, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = '{}/{}'.format(tag, self.mode)
add_data(tag, data, *args, **kwargs)
return wrapper
elif name in self.mlflow_ftns:
add_data = getattr(self.mlflow, name, None)
def wrapper(*args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
# if name not in self.tag_mode_exceptions:
# tag = '{}/{}'.format(tag, self.mode)
add_data(*args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr
class TensorboardWriter:
def __init__(self, log_dir, logger, enabled):
self.writer = None
self.selected_module = ""
if enabled:
log_dir = str(log_dir)
# Retrieve vizualization writer.
succeeded = False
for module in ["torch.utils.tensorboard", "tensorboardX"]:
try:
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
succeeded = True
break
except ImportError:
succeeded = False
self.selected_module = module
if not succeeded:
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
"this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
"PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
"the 'config.json' file."
logger.warning(message)
self.step = 0
self.mode = ''
self.tb_writer_ftns = {
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
}
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
self.timer = Timer()
def set_step(self, step, mode='train'):
self.mode = mode
self.step = step
if step == 0:
self.timer.reset()
else:
duration = self.timer.check()
self.add_scalar('steps_per_sec', 1 / duration)
def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return a blank function handle that does nothing
"""
if name in self.tb_writer_ftns:
add_data = getattr(self.writer, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = '{}/{}'.format(tag, self.mode)
add_data(tag, data, self.step, *args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr
|