File size: 2,606 Bytes
717b269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
from dataclasses import dataclass, field

from omegaconf import OmegaConf

from tgs.utils.typing import *

# ============ Register OmegaConf Recolvers ============= #
OmegaConf.register_new_resolver(
    "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n)
)
OmegaConf.register_new_resolver("add", lambda a, b: a + b)
OmegaConf.register_new_resolver("sub", lambda a, b: a - b)
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
OmegaConf.register_new_resolver("div", lambda a, b: a / b)
OmegaConf.register_new_resolver("idiv", lambda a, b: a // b)
OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p))
OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub))
OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)])
OmegaConf.register_new_resolver("gt0", lambda s: s > 0)
OmegaConf.register_new_resolver("not", lambda s: not s)
OmegaConf.register_new_resolver("shsdim", lambda sh_degree: (sh_degree + 1) ** 2 * 3)
# ======================================================= #

# ============== Automatic Name Resolvers =============== #
def get_naming_convention(cfg):
    # TODO
    name = f"tgs_{cfg.system.backbone.num_layers}"
    return name

# ======================================================= #

@dataclass
class ExperimentConfig:
    n_gpus: int = 1
    data: dict = field(default_factory=dict)
    system: dict = field(default_factory=dict)

def load_config(
    *yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs
) -> Any:
    if from_string:
        parse_func = OmegaConf.create
    else:
        parse_func = OmegaConf.load
    yaml_confs = []
    for y in yamls:
        conf = parse_func(y)
        extends = conf.pop("extends", None)
        if extends:
            assert os.path.exists(extends), f"File {extends} does not exist."
            yaml_confs.append(OmegaConf.load(extends))
        yaml_confs.append(conf)
    cli_conf = OmegaConf.from_cli(cli_args)
    cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
    OmegaConf.resolve(cfg)
    assert isinstance(cfg, DictConfig)
    scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg)

    return scfg


def config_to_primitive(config, resolve: bool = True) -> Any:
    return OmegaConf.to_container(config, resolve=resolve)


def dump_config(path: str, config) -> None:
    with open(path, "w") as fp:
        OmegaConf.save(config=config, f=fp)


def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
    scfg = OmegaConf.structured(fields(**cfg))
    return scfg