Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import torch | |
from packaging import version | |
from tgs.utils.typing import * | |
def parse_version(ver: str): | |
return version.parse(ver) | |
def get_rank(): | |
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, | |
# therefore LOCAL_RANK needs to be checked first | |
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") | |
for key in rank_keys: | |
rank = os.environ.get(key) | |
if rank is not None: | |
return int(rank) | |
return 0 | |
def get_device(): | |
return torch.device(f"cuda:{get_rank()}") | |
def load_module_weights( | |
path, module_name=None, ignore_modules=None, map_location=None | |
) -> Tuple[dict, int, int]: | |
if module_name is not None and ignore_modules is not None: | |
raise ValueError("module_name and ignore_modules cannot be both set") | |
if map_location is None: | |
map_location = get_device() | |
ckpt = torch.load(path, map_location=map_location) | |
state_dict = ckpt["state_dict"] | |
state_dict_to_load = state_dict | |
if ignore_modules is not None: | |
state_dict_to_load = {} | |
for k, v in state_dict.items(): | |
ignore = any( | |
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules] | |
) | |
if ignore: | |
continue | |
state_dict_to_load[k] = v | |
if module_name is not None: | |
state_dict_to_load = {} | |
for k, v in state_dict.items(): | |
m = re.match(rf"^{module_name}\.(.*)$", k) | |
if m is None: | |
continue | |
state_dict_to_load[m.group(1)] = v | |
return state_dict_to_load | |
# convert a function into recursive style to handle nested dict/list/tuple variables | |
def make_recursive_func(func): | |
def wrapper(vars, *args, **kwargs): | |
if isinstance(vars, list): | |
return [wrapper(x, *args, **kwargs) for x in vars] | |
elif isinstance(vars, tuple): | |
return tuple([wrapper(x, *args, **kwargs) for x in vars]) | |
elif isinstance(vars, dict): | |
return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()} | |
else: | |
return func(vars, *args, **kwargs) | |
return wrapper | |
def todevice(vars, device="cuda"): | |
if isinstance(vars, torch.Tensor): | |
return vars.to(device) | |
elif isinstance(vars, str): | |
return vars | |
elif isinstance(vars, bool): | |
return vars | |
elif isinstance(vars, float): | |
return vars | |
elif isinstance(vars, int): | |
return vars | |
else: | |
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) | |