Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,652 Bytes
717b269 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
import re
import torch
from packaging import version
from tgs.utils.typing import *
def parse_version(ver: str):
return version.parse(ver)
def get_rank():
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
def get_device():
return torch.device(f"cuda:{get_rank()}")
def load_module_weights(
path, module_name=None, ignore_modules=None, map_location=None
) -> Tuple[dict, int, int]:
if module_name is not None and ignore_modules is not None:
raise ValueError("module_name and ignore_modules cannot be both set")
if map_location is None:
map_location = get_device()
ckpt = torch.load(path, map_location=map_location)
state_dict = ckpt["state_dict"]
state_dict_to_load = state_dict
if ignore_modules is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
ignore = any(
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
)
if ignore:
continue
state_dict_to_load[k] = v
if module_name is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
m = re.match(rf"^{module_name}\.(.*)$", k)
if m is None:
continue
state_dict_to_load[m.group(1)] = v
return state_dict_to_load
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars, *args, **kwargs):
if isinstance(vars, list):
return [wrapper(x, *args, **kwargs) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x, *args, **kwargs) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()}
else:
return func(vars, *args, **kwargs)
return wrapper
@make_recursive_func
def todevice(vars, device="cuda"):
if isinstance(vars, torch.Tensor):
return vars.to(device)
elif isinstance(vars, str):
return vars
elif isinstance(vars, bool):
return vars
elif isinstance(vars, float):
return vars
elif isinstance(vars, int):
return vars
else:
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
|