zerchen's picture
init test without models
717b269
import os
import re
import torch
from packaging import version
from tgs.utils.typing import *
def parse_version(ver: str):
return version.parse(ver)
def get_rank():
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
def get_device():
return torch.device(f"cuda:{get_rank()}")
def load_module_weights(
path, module_name=None, ignore_modules=None, map_location=None
) -> Tuple[dict, int, int]:
if module_name is not None and ignore_modules is not None:
raise ValueError("module_name and ignore_modules cannot be both set")
if map_location is None:
map_location = get_device()
ckpt = torch.load(path, map_location=map_location)
state_dict = ckpt["state_dict"]
state_dict_to_load = state_dict
if ignore_modules is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
ignore = any(
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
)
if ignore:
continue
state_dict_to_load[k] = v
if module_name is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
m = re.match(rf"^{module_name}\.(.*)$", k)
if m is None:
continue
state_dict_to_load[m.group(1)] = v
return state_dict_to_load
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars, *args, **kwargs):
if isinstance(vars, list):
return [wrapper(x, *args, **kwargs) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x, *args, **kwargs) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()}
else:
return func(vars, *args, **kwargs)
return wrapper
@make_recursive_func
def todevice(vars, device="cuda"):
if isinstance(vars, torch.Tensor):
return vars.to(device)
elif isinstance(vars, str):
return vars
elif isinstance(vars, bool):
return vars
elif isinstance(vars, float):
return vars
elif isinstance(vars, int):
return vars
else:
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))