|
|
|
|
|
|
|
import torch |
|
from PIL.Image import Image |
|
from collections import OrderedDict |
|
from scepter.modules.utils.distribute import we |
|
from scepter.modules.utils.config import Config |
|
from scepter.modules.utils.logger import get_logger |
|
from scepter.studio.utils.env import get_available_memory |
|
from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS |
|
from scepter.modules.utils.registry import Registry, build_from_config |
|
def get_model(model_tuple): |
|
assert 'model' in model_tuple |
|
return model_tuple['model'] |
|
|
|
class BaseInference(): |
|
''' |
|
support to load the components dynamicly. |
|
create and load model when run this model at the first time. |
|
''' |
|
def __init__(self, cfg, logger=None): |
|
if logger is None: |
|
logger = get_logger(name='scepter') |
|
self.logger = logger |
|
self.name = cfg.NAME |
|
|
|
def init_from_modules(self, modules): |
|
for k, v in modules.items(): |
|
self.__setattr__(k, v) |
|
|
|
def infer_model(self, cfg, module_paras=None): |
|
module = { |
|
'model': None, |
|
'cfg': cfg, |
|
'device': 'offline', |
|
'name': cfg.NAME, |
|
'function_info': {}, |
|
'paras': {} |
|
} |
|
if module_paras is None: |
|
return module |
|
function_info = {} |
|
paras = { |
|
k.lower(): v |
|
for k, v in module_paras.get('PARAS', {}).items() |
|
} |
|
for function in module_paras.get('FUNCTION', []): |
|
input_dict = {} |
|
for inp in function.get('INPUT', []): |
|
if inp.lower() in self.input: |
|
input_dict[inp.lower()] = self.input[inp.lower()] |
|
function_info[function.NAME] = { |
|
'dtype': function.get('DTYPE', 'float32'), |
|
'input': input_dict |
|
} |
|
module['paras'] = paras |
|
module['function_info'] = function_info |
|
return module |
|
|
|
def init_from_ckpt(self, path, model, ignore_keys=list()): |
|
if path.endswith('safetensors'): |
|
from safetensors.torch import load_file as load_safetensors |
|
sd = load_safetensors(path) |
|
else: |
|
sd = torch.load(path, map_location='cpu', weights_only=True) |
|
|
|
new_sd = OrderedDict() |
|
for k, v in sd.items(): |
|
ignored = False |
|
for ik in ignore_keys: |
|
if ik in k: |
|
if we.rank == 0: |
|
self.logger.info( |
|
'Ignore key {} from state_dict.'.format(k)) |
|
ignored = True |
|
break |
|
if not ignored: |
|
new_sd[k] = v |
|
|
|
missing, unexpected = model.load_state_dict(new_sd, strict=False) |
|
if we.rank == 0: |
|
self.logger.info( |
|
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' |
|
) |
|
if len(missing) > 0: |
|
self.logger.info(f'Missing Keys:\n {missing}') |
|
if len(unexpected) > 0: |
|
self.logger.info(f'\nUnexpected Keys:\n {unexpected}') |
|
|
|
def load(self, module): |
|
if module['device'] == 'offline': |
|
from scepter.modules.utils.import_utils import LazyImportModule |
|
if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or |
|
module['cfg'].NAME in MODELS.class_map): |
|
model = MODELS.build(module['cfg'], logger=self.logger).eval() |
|
elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or |
|
module['cfg'].NAME in BACKBONES.class_map): |
|
model = BACKBONES.build(module['cfg'], |
|
logger=self.logger).eval() |
|
elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or |
|
module['cfg'].NAME in EMBEDDERS.class_map): |
|
model = EMBEDDERS.build(module['cfg'], |
|
logger=self.logger).eval() |
|
else: |
|
raise NotImplementedError |
|
if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None: |
|
model = model.to(getattr(torch, module['cfg'].DTYPE)) |
|
if module['cfg'].get('RELOAD_MODEL', None): |
|
self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model) |
|
module['model'] = model |
|
module['device'] = 'cpu' |
|
if module['device'] == 'cpu': |
|
module['device'] = we.device_id |
|
module['model'] = module['model'].to(we.device_id) |
|
return module |
|
|
|
def unload(self, module): |
|
if module is None: |
|
return module |
|
mem = get_available_memory() |
|
free_mem = int(mem['available'] / (1024**2)) |
|
total_mem = int(mem['total'] / (1024**2)) |
|
if free_mem < 0.5 * total_mem: |
|
if module['model'] is not None: |
|
module['model'] = module['model'].to('cpu') |
|
del module['model'] |
|
module['model'] = None |
|
module['device'] = 'offline' |
|
print('delete module') |
|
else: |
|
if module['model'] is not None: |
|
module['model'] = module['model'].to('cpu') |
|
module['device'] = 'cpu' |
|
else: |
|
module['device'] = 'offline' |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
return module |
|
|
|
def dynamic_load(self, module=None, name=''): |
|
self.logger.info('Loading {} model'.format(name)) |
|
if name == 'all': |
|
for subname in self.loaded_model_name: |
|
self.loaded_model[subname] = self.dynamic_load( |
|
getattr(self, subname), subname) |
|
elif name in self.loaded_model_name: |
|
if name in self.loaded_model: |
|
if module['cfg'] != self.loaded_model[name]['cfg']: |
|
self.unload(self.loaded_model[name]) |
|
module = self.load(module) |
|
self.loaded_model[name] = module |
|
return module |
|
elif module['device'] == 'cpu' or module['device'] == 'offline': |
|
module = self.load(module) |
|
return module |
|
else: |
|
return module |
|
else: |
|
module = self.load(module) |
|
self.loaded_model[name] = module |
|
return module |
|
else: |
|
return self.load(module) |
|
|
|
def dynamic_unload(self, module=None, name='', skip_loaded=False): |
|
self.logger.info('Unloading {} model'.format(name)) |
|
if name == 'all': |
|
for name, module in self.loaded_model.items(): |
|
module = self.unload(self.loaded_model[name]) |
|
self.loaded_model[name] = module |
|
elif name in self.loaded_model_name: |
|
if name in self.loaded_model: |
|
if not skip_loaded: |
|
module = self.unload(self.loaded_model[name]) |
|
self.loaded_model[name] = module |
|
else: |
|
self.unload(module) |
|
else: |
|
self.unload(module) |
|
|
|
def load_default(self, cfg): |
|
module_paras = {} |
|
if cfg is not None: |
|
self.paras = cfg.PARAS |
|
self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()} |
|
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()} |
|
self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()} |
|
module_paras = cfg.MODULES_PARAS |
|
return module_paras |
|
|
|
def load_image(self, image, num_samples=1): |
|
if isinstance(image, torch.Tensor): |
|
pass |
|
elif isinstance(image, Image): |
|
pass |
|
elif isinstance(image, Image): |
|
pass |
|
|
|
def get_function_info(self, module, function_name=None): |
|
all_function = module['function_info'] |
|
if function_name in all_function: |
|
return function_name, all_function[function_name]['dtype'] |
|
if function_name is None and len(all_function) == 1: |
|
for k, v in all_function.items(): |
|
return k, v['dtype'] |
|
|
|
@torch.no_grad() |
|
def __call__(self, |
|
input, |
|
**kwargs): |
|
return |
|
|
|
def build_inference(cfg, registry, logger=None, *args, **kwargs): |
|
""" After build model, load pretrained model if exists key `pretrain`. |
|
|
|
pretrain (str, dict): Describes how to load pretrained model. |
|
str, treat pretrain as model path; |
|
dict: should contains key `path`, and other parameters token by function load_pretrained(); |
|
""" |
|
if not isinstance(cfg, Config): |
|
raise TypeError(f'Config must be type dict, got {type(cfg)}') |
|
model = build_from_config(cfg, registry, logger=logger, *args, **kwargs) |
|
return model |
|
|
|
|
|
INFERENCES = Registry('INFERENCE', build_func=build_inference) |
|
|