|
import os |
|
import time |
|
from copy import deepcopy |
|
from typing import Optional |
|
|
|
import torch.backends.cudnn |
|
import torch.distributed |
|
import torch.nn as nn |
|
|
|
from ..apps.utils import ( |
|
dist_init, |
|
dump_config, |
|
get_dist_local_rank, |
|
get_dist_rank, |
|
get_dist_size, |
|
init_modules, |
|
is_master, |
|
load_config, |
|
partial_update_config, |
|
zero_last_gamma, |
|
) |
|
from ..models.utils import build_kwargs_from_config, load_state_dict_from_file |
|
|
|
__all__ = [ |
|
"save_exp_config", |
|
"setup_dist_env", |
|
"setup_seed", |
|
"setup_exp_config", |
|
"init_model", |
|
] |
|
|
|
|
|
def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None: |
|
if not is_master(): |
|
return |
|
dump_config(exp_config, os.path.join(path, name)) |
|
|
|
|
|
def setup_dist_env(gpu: Optional[str] = None) -> None: |
|
if gpu is not None: |
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
|
if not torch.distributed.is_initialized(): |
|
dist_init() |
|
torch.backends.cudnn.benchmark = True |
|
torch.cuda.set_device(get_dist_local_rank()) |
|
|
|
|
|
def setup_seed(manual_seed: int, resume: bool) -> None: |
|
if resume: |
|
manual_seed = int(time.time()) |
|
manual_seed = get_dist_rank() + manual_seed |
|
torch.manual_seed(manual_seed) |
|
torch.cuda.manual_seed_all(manual_seed) |
|
|
|
|
|
def setup_exp_config(config_path: str, recursive=True, opt_args: Optional[dict] = None) -> dict: |
|
|
|
if not os.path.isfile(config_path): |
|
raise ValueError(config_path) |
|
|
|
fpaths = [config_path] |
|
if recursive: |
|
extension = os.path.splitext(config_path)[1] |
|
while os.path.dirname(config_path) != config_path: |
|
config_path = os.path.dirname(config_path) |
|
fpath = os.path.join(config_path, "default" + extension) |
|
if os.path.isfile(fpath): |
|
fpaths.append(fpath) |
|
fpaths = fpaths[::-1] |
|
|
|
default_config = load_config(fpaths[0]) |
|
exp_config = deepcopy(default_config) |
|
for fpath in fpaths[1:]: |
|
partial_update_config(exp_config, load_config(fpath)) |
|
|
|
if opt_args is not None: |
|
partial_update_config(exp_config, opt_args) |
|
|
|
return exp_config |
|
|
|
|
|
def init_model( |
|
network: nn.Module, |
|
init_from: Optional[str] = None, |
|
backbone_init_from: Optional[str] = None, |
|
rand_init="trunc_normal", |
|
last_gamma=None, |
|
) -> None: |
|
|
|
init_modules(network, init_type=rand_init) |
|
|
|
if last_gamma is not None: |
|
zero_last_gamma(network, last_gamma) |
|
|
|
|
|
if init_from is not None and os.path.isfile(init_from): |
|
network.load_state_dict(load_state_dict_from_file(init_from)) |
|
print(f"Loaded init from {init_from}") |
|
elif backbone_init_from is not None and os.path.isfile(backbone_init_from): |
|
network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from)) |
|
print(f"Loaded backbone init from {backbone_init_from}") |
|
else: |
|
print(f"Random init ({rand_init}) with last gamma {last_gamma}") |
|
|