|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import glob |
|
import os |
|
import re |
|
import subprocess |
|
import traceback |
|
|
|
import torch |
|
from torch.nn.parallel import DistributedDataParallel |
|
import torch.distributed as dist |
|
|
|
|
|
@contextlib.contextmanager |
|
def dist_load(path): |
|
if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'): |
|
yield path |
|
else: |
|
from tts.utils.commons.hparams import hparams |
|
from tts.utils.commons.trainer import LOCAL_RANK |
|
tmpdir = '/dev/shm' |
|
assert len(os.path.basename(path)) > 0 |
|
shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}' |
|
if LOCAL_RANK == 0: |
|
subprocess.check_call( |
|
f'mkdir -p {os.path.dirname(shm_ckpt_path)}; ' |
|
f'cp -Lr {path} {shm_ckpt_path}', shell=True) |
|
dist.barrier() |
|
yield shm_ckpt_path |
|
dist.barrier() |
|
if LOCAL_RANK == 0: |
|
subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True) |
|
|
|
|
|
def torch_load_dist(path, map_location='cpu'): |
|
with dist_load(path) as tmp_path: |
|
checkpoint = torch.load(tmp_path, map_location=map_location) |
|
return checkpoint |
|
|
|
|
|
def get_last_checkpoint(work_dir, steps=None): |
|
checkpoint = None |
|
last_ckpt_path = None |
|
ckpt_paths = get_all_ckpts(work_dir, steps) |
|
if len(ckpt_paths) > 0: |
|
last_ckpt_path = ckpt_paths[0] |
|
checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu') |
|
return checkpoint, last_ckpt_path |
|
|
|
|
|
def get_all_ckpts(work_dir, steps=None): |
|
if steps is None or steps == 0: |
|
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' |
|
else: |
|
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' |
|
return sorted(glob.glob(ckpt_path_pattern), |
|
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) |
|
|
|
|
|
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True, |
|
silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True): |
|
if checkpoint is None: |
|
if os.path.isfile(ckpt_base_dir): |
|
base_dir = os.path.dirname(ckpt_base_dir) |
|
ckpt_path = ckpt_base_dir |
|
checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu') |
|
else: |
|
base_dir = ckpt_base_dir |
|
if load_opt: |
|
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps) |
|
else: |
|
ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt' |
|
if os.path.exists(ckpt_path): |
|
checkpoint = torch_load_dist(ckpt_path, map_location='cpu') |
|
else: |
|
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps) |
|
if checkpoint is not None: |
|
state_dict_all = { |
|
k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()} |
|
if not isinstance(cur_model, list): |
|
cur_models = [cur_model] |
|
model_names = [model_name] |
|
else: |
|
cur_models = cur_model |
|
model_names = model_name |
|
for model_name, cur_model in zip(model_names, cur_models): |
|
if isinstance(cur_model, DistributedDataParallel): |
|
cur_model = cur_model.module |
|
device = next(cur_model.parameters()).device |
|
if '.' not in model_name: |
|
state_dict = state_dict_all[model_name] |
|
else: |
|
base_model_name = model_name.split('.')[0] |
|
rest_model_name = model_name[len(base_model_name) + 1:] |
|
state_dict = { |
|
k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items() |
|
if k.startswith(f'{rest_model_name}.')} |
|
state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()} |
|
if not strict and delete_unmatch: |
|
try: |
|
cur_model.load_state_dict(state_dict, strict=True) |
|
if not silent: |
|
print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.") |
|
except: |
|
cur_model_state_dict = cur_model.state_dict() |
|
cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in |
|
cur_model_state_dict.items()} |
|
unmatched_keys = [] |
|
for key, param in state_dict.items(): |
|
if key in cur_model_state_dict: |
|
new_param = cur_model_state_dict[key] |
|
if new_param.shape != param.shape: |
|
unmatched_keys.append(key) |
|
print("| Unmatched keys: ", key, "cur model: ", new_param.shape, |
|
"ckpt model: ", param.shape) |
|
for key in unmatched_keys: |
|
del state_dict[key] |
|
load_results = cur_model.load_state_dict(state_dict, strict=strict) |
|
cur_model.to(device) |
|
if not silent: |
|
print(f"| loaded '{model_name}' from '{ckpt_path}'.") |
|
missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys |
|
print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}") |
|
if load_opt: |
|
optimizer_states = checkpoint['optimizer_states'] |
|
assert len(opts) == len(optimizer_states) |
|
for optimizer, opt_state in zip(opts, optimizer_states): |
|
opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()} |
|
if optimizer is None: |
|
return |
|
try: |
|
optimizer.load_state_dict(opt_state) |
|
for i, state in enumerate(optimizer.state.values()): |
|
for k, v in state.items(): |
|
if isinstance(v, torch.Tensor): |
|
state[k] = v.to(device) |
|
except ValueError: |
|
print(f"| WARMING: optimizer {optimizer} parameters not match !!!") |
|
return checkpoint.get('global_step', 0) |
|
else: |
|
e_msg = f"| ckpt not found in {base_dir}." |
|
if force: |
|
assert False, e_msg |
|
else: |
|
print(e_msg) |
|
|
|
|
|
def load_with_size_mismatch(model, state_dict, prefix=""): |
|
current_model_dict = model.state_dict() |
|
cm_keys = current_model_dict.keys() |
|
mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()} |
|
new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()} |
|
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) |
|
print(f"| mismatch keys: ", mismatch_keys) |
|
if len(missing_keys) > 0: |
|
print(f"| missing_keys in dit: {missing_keys}") |
|
if len(unexpected_keys) > 0: |
|
print(f"| unexpected_keys in dit: {unexpected_keys}") |
|
|