|
|
|
import copy |
|
import warnings |
|
from pathlib import Path |
|
from typing import Optional, Sequence, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from mmcv.ops import RoIPool |
|
from mmcv.transforms import Compose |
|
from mmengine.config import Config |
|
from mmengine.model.utils import revert_sync_batchnorm |
|
from mmengine.registry import init_default_scope |
|
from mmengine.runner import load_checkpoint |
|
|
|
from mmdet.registry import DATASETS |
|
from ..evaluation import get_classes |
|
from ..registry import MODELS |
|
from ..structures import DetDataSample, SampleList |
|
from ..utils import get_test_pipeline_cfg |
|
|
|
|
|
def init_detector( |
|
config: Union[str, Path, Config], |
|
checkpoint: Optional[str] = None, |
|
palette: str = 'none', |
|
device: str = 'cuda:0', |
|
cfg_options: Optional[dict] = None, |
|
) -> nn.Module: |
|
"""Initialize a detector from config file. |
|
|
|
Args: |
|
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, |
|
:obj:`Path`, or the config object. |
|
checkpoint (str, optional): Checkpoint path. If left as None, the model |
|
will not load any weights. |
|
palette (str): Color palette used for visualization. If palette |
|
is stored in checkpoint, use checkpoint's palette first, otherwise |
|
use externally passed palette. Currently, supports 'coco', 'voc', |
|
'citys' and 'random'. Defaults to none. |
|
device (str): The device where the anchors will be put on. |
|
Defaults to cuda:0. |
|
cfg_options (dict, optional): Options to override some settings in |
|
the used config. |
|
|
|
Returns: |
|
nn.Module: The constructed detector. |
|
""" |
|
if isinstance(config, (str, Path)): |
|
config = Config.fromfile(config) |
|
elif not isinstance(config, Config): |
|
raise TypeError('config must be a filename or Config object, ' |
|
f'but got {type(config)}') |
|
if cfg_options is not None: |
|
config.merge_from_dict(cfg_options) |
|
elif 'init_cfg' in config.model.backbone: |
|
config.model.backbone.init_cfg = None |
|
init_default_scope(config.get('default_scope', 'mmdet')) |
|
|
|
model = MODELS.build(config.model) |
|
model = revert_sync_batchnorm(model) |
|
if checkpoint is None: |
|
warnings.simplefilter('once') |
|
warnings.warn('checkpoint is None, use COCO classes by default.') |
|
model.dataset_meta = {'classes': get_classes('coco')} |
|
else: |
|
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') |
|
|
|
checkpoint_meta = checkpoint.get('meta', {}) |
|
|
|
|
|
if 'dataset_meta' in checkpoint_meta: |
|
|
|
model.dataset_meta = { |
|
k.lower(): v |
|
for k, v in checkpoint_meta['dataset_meta'].items() |
|
} |
|
elif 'CLASSES' in checkpoint_meta: |
|
|
|
classes = checkpoint_meta['CLASSES'] |
|
model.dataset_meta = {'classes': classes} |
|
else: |
|
warnings.simplefilter('once') |
|
warnings.warn( |
|
'dataset_meta or class names are not saved in the ' |
|
'checkpoint\'s meta data, use COCO classes by default.') |
|
model.dataset_meta = {'classes': get_classes('coco')} |
|
|
|
|
|
if palette != 'none': |
|
model.dataset_meta['palette'] = palette |
|
else: |
|
test_dataset_cfg = copy.deepcopy(config.test_dataloader.dataset) |
|
|
|
test_dataset_cfg['lazy_init'] = True |
|
metainfo = DATASETS.build(test_dataset_cfg).metainfo |
|
cfg_palette = metainfo.get('palette', None) |
|
if cfg_palette is not None: |
|
model.dataset_meta['palette'] = cfg_palette |
|
else: |
|
if 'palette' not in model.dataset_meta: |
|
warnings.warn( |
|
'palette does not exist, random is used by default. ' |
|
'You can also set the palette to customize.') |
|
model.dataset_meta['palette'] = 'random' |
|
|
|
model.cfg = config |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] |
|
|
|
|
|
def inference_detector( |
|
model: nn.Module, |
|
imgs: ImagesType, |
|
test_pipeline: Optional[Compose] = None |
|
) -> Union[DetDataSample, SampleList]: |
|
"""Inference image(s) with the detector. |
|
|
|
Args: |
|
model (nn.Module): The loaded detector. |
|
imgs (str, ndarray, Sequence[str/ndarray]): |
|
Either image files or loaded images. |
|
test_pipeline (:obj:`Compose`): Test pipeline. |
|
|
|
Returns: |
|
:obj:`DetDataSample` or list[:obj:`DetDataSample`]: |
|
If imgs is a list or tuple, the same length list type results |
|
will be returned, otherwise return the detection results directly. |
|
""" |
|
|
|
if isinstance(imgs, (list, tuple)): |
|
is_batch = True |
|
else: |
|
imgs = [imgs] |
|
is_batch = False |
|
|
|
cfg = model.cfg |
|
|
|
if test_pipeline is None: |
|
cfg = cfg.copy() |
|
test_pipeline = get_test_pipeline_cfg(cfg) |
|
if isinstance(imgs[0], np.ndarray): |
|
|
|
|
|
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' |
|
|
|
test_pipeline = Compose(test_pipeline) |
|
|
|
if model.data_preprocessor.device.type == 'cpu': |
|
for m in model.modules(): |
|
assert not isinstance( |
|
m, RoIPool |
|
), 'CPU inference with RoIPool is not supported currently.' |
|
|
|
result_list = [] |
|
for img in imgs: |
|
|
|
if isinstance(img, np.ndarray): |
|
|
|
data_ = dict(img=img, img_id=0) |
|
else: |
|
|
|
data_ = dict(img_path=img, img_id=0) |
|
|
|
data_ = test_pipeline(data_) |
|
|
|
data_['inputs'] = [data_['inputs']] |
|
data_['data_samples'] = [data_['data_samples']] |
|
|
|
|
|
with torch.no_grad(): |
|
results = model.test_step(data_)[0] |
|
|
|
result_list.append(results) |
|
|
|
if not is_batch: |
|
return result_list[0] |
|
else: |
|
return result_list |
|
|
|
|
|
|
|
async def async_inference_detector(model, imgs): |
|
"""Async inference image(s) with the detector. |
|
|
|
Args: |
|
model (nn.Module): The loaded detector. |
|
img (str | ndarray): Either image files or loaded images. |
|
|
|
Returns: |
|
Awaitable detection results. |
|
""" |
|
if not isinstance(imgs, (list, tuple)): |
|
imgs = [imgs] |
|
|
|
cfg = model.cfg |
|
|
|
if isinstance(imgs[0], np.ndarray): |
|
cfg = cfg.copy() |
|
|
|
cfg.data.test.pipeline[0].type = 'LoadImageFromNDArray' |
|
|
|
|
|
test_pipeline = Compose(cfg.data.test.pipeline) |
|
|
|
datas = [] |
|
for img in imgs: |
|
|
|
if isinstance(img, np.ndarray): |
|
|
|
data = dict(img=img) |
|
else: |
|
|
|
data = dict(img_info=dict(filename=img), img_prefix=None) |
|
|
|
data = test_pipeline(data) |
|
datas.append(data) |
|
|
|
for m in model.modules(): |
|
assert not isinstance( |
|
m, |
|
RoIPool), 'CPU inference with RoIPool is not supported currently.' |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
results = await model.aforward_test(data, rescale=True) |
|
return results |
|
|