|
|
|
import copy |
|
import os.path as osp |
|
import warnings |
|
from typing import Dict, Iterable, List, Optional, Sequence, Union |
|
|
|
import mmcv |
|
import mmengine |
|
import numpy as np |
|
import torch.nn as nn |
|
from mmengine.dataset import Compose |
|
from mmengine.fileio import (get_file_backend, isdir, join_path, |
|
list_dir_or_file) |
|
from mmengine.infer.infer import BaseInferencer, ModelType |
|
from mmengine.model.utils import revert_sync_batchnorm |
|
from mmengine.registry import init_default_scope |
|
from mmengine.runner.checkpoint import _load_checkpoint_to_model |
|
from mmengine.visualization import Visualizer |
|
from rich.progress import track |
|
|
|
from mmdet.evaluation import INSTANCE_OFFSET |
|
from mmdet.registry import DATASETS |
|
from mmdet.structures import DetDataSample |
|
from mmdet.structures.mask import encode_mask_results, mask2bbox |
|
from mmdet.utils import ConfigType |
|
from ..evaluation import get_classes |
|
|
|
try: |
|
from panopticapi.evaluation import VOID |
|
from panopticapi.utils import id2rgb |
|
except ImportError: |
|
id2rgb = None |
|
VOID = None |
|
|
|
InputType = Union[str, np.ndarray] |
|
InputsType = Union[InputType, Sequence[InputType]] |
|
PredType = List[DetDataSample] |
|
ImgType = Union[np.ndarray, Sequence[np.ndarray]] |
|
|
|
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', |
|
'.tiff', '.webp') |
|
|
|
|
|
class DetInferencer(BaseInferencer): |
|
"""Object Detection Inferencer. |
|
|
|
Args: |
|
model (str, optional): Path to the config file or the model name |
|
defined in metafile. For example, it could be |
|
"rtmdet-s" or 'rtmdet_s_8xb32-300e_coco' or |
|
"configs/rtmdet/rtmdet_s_8xb32-300e_coco.py". |
|
If model is not specified, user must provide the |
|
`weights` saved by MMEngine which contains the config string. |
|
Defaults to None. |
|
weights (str, optional): Path to the checkpoint. If it is not specified |
|
and model is a model name of metafile, the weights will be loaded |
|
from metafile. Defaults to None. |
|
device (str, optional): Device to run inference. If None, the available |
|
device will be automatically used. Defaults to None. |
|
scope (str, optional): The scope of the model. Defaults to mmdet. |
|
palette (str): Color palette used for visualization. The order of |
|
priority is palette -> config -> checkpoint. Defaults to 'none'. |
|
""" |
|
|
|
preprocess_kwargs: set = set() |
|
forward_kwargs: set = set() |
|
visualize_kwargs: set = { |
|
'return_vis', |
|
'show', |
|
'wait_time', |
|
'draw_pred', |
|
'pred_score_thr', |
|
'img_out_dir', |
|
'no_save_vis', |
|
} |
|
postprocess_kwargs: set = { |
|
'print_result', |
|
'pred_out_dir', |
|
'return_datasample', |
|
'no_save_pred', |
|
} |
|
|
|
def __init__(self, |
|
model: Optional[Union[ModelType, str]] = None, |
|
weights: Optional[str] = None, |
|
device: Optional[str] = None, |
|
scope: Optional[str] = 'mmdet', |
|
palette: str = 'none') -> None: |
|
|
|
|
|
self.num_visualized_imgs = 0 |
|
self.num_predicted_imgs = 0 |
|
self.palette = palette |
|
init_default_scope(scope) |
|
super().__init__( |
|
model=model, weights=weights, device=device, scope=scope) |
|
self.model = revert_sync_batchnorm(self.model) |
|
|
|
def _load_weights_to_model(self, model: nn.Module, |
|
checkpoint: Optional[dict], |
|
cfg: Optional[ConfigType]) -> None: |
|
"""Loading model weights and meta information from cfg and checkpoint. |
|
|
|
Args: |
|
model (nn.Module): Model to load weights and meta information. |
|
checkpoint (dict, optional): The loaded checkpoint. |
|
cfg (Config or ConfigDict, optional): The loaded config. |
|
""" |
|
|
|
if checkpoint is not None: |
|
_load_checkpoint_to_model(model, checkpoint) |
|
checkpoint_meta = checkpoint.get('meta', {}) |
|
|
|
if 'dataset_meta' in checkpoint_meta: |
|
|
|
model.dataset_meta = { |
|
k.lower(): v |
|
for k, v in checkpoint_meta['dataset_meta'].items() |
|
} |
|
elif 'CLASSES' in checkpoint_meta: |
|
|
|
classes = checkpoint_meta['CLASSES'] |
|
model.dataset_meta = {'classes': classes} |
|
else: |
|
warnings.warn( |
|
'dataset_meta or class names are not saved in the ' |
|
'checkpoint\'s meta data, use COCO classes by default.') |
|
model.dataset_meta = {'classes': get_classes('coco')} |
|
else: |
|
warnings.warn('Checkpoint is not loaded, and the inference ' |
|
'result is calculated by the randomly initialized ' |
|
'model!') |
|
warnings.warn('weights is None, use COCO classes by default.') |
|
model.dataset_meta = {'classes': get_classes('coco')} |
|
|
|
|
|
if self.palette != 'none': |
|
model.dataset_meta['palette'] = self.palette |
|
else: |
|
test_dataset_cfg = copy.deepcopy(cfg.test_dataloader.dataset) |
|
|
|
test_dataset_cfg['lazy_init'] = True |
|
metainfo = DATASETS.build(test_dataset_cfg).metainfo |
|
cfg_palette = metainfo.get('palette', None) |
|
if cfg_palette is not None: |
|
model.dataset_meta['palette'] = cfg_palette |
|
else: |
|
if 'palette' not in model.dataset_meta: |
|
warnings.warn( |
|
'palette does not exist, random is used by default. ' |
|
'You can also set the palette to customize.') |
|
model.dataset_meta['palette'] = 'random' |
|
|
|
def _init_pipeline(self, cfg: ConfigType) -> Compose: |
|
"""Initialize the test pipeline.""" |
|
pipeline_cfg = cfg.test_dataloader.dataset.pipeline |
|
|
|
|
|
if 'meta_keys' in pipeline_cfg[-1]: |
|
pipeline_cfg[-1]['meta_keys'] = tuple( |
|
meta_key for meta_key in pipeline_cfg[-1]['meta_keys'] |
|
if meta_key != 'img_id') |
|
|
|
load_img_idx = self._get_transform_idx(pipeline_cfg, |
|
'LoadImageFromFile') |
|
if load_img_idx == -1: |
|
raise ValueError( |
|
'LoadImageFromFile is not found in the test pipeline') |
|
pipeline_cfg[load_img_idx]['type'] = 'mmdet.InferencerLoader' |
|
return Compose(pipeline_cfg) |
|
|
|
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: |
|
"""Returns the index of the transform in a pipeline. |
|
|
|
If the transform is not found, returns -1. |
|
""" |
|
for i, transform in enumerate(pipeline_cfg): |
|
if transform['type'] == name: |
|
return i |
|
return -1 |
|
|
|
def _init_visualizer(self, cfg: ConfigType) -> Optional[Visualizer]: |
|
"""Initialize visualizers. |
|
|
|
Args: |
|
cfg (ConfigType): Config containing the visualizer information. |
|
|
|
Returns: |
|
Visualizer or None: Visualizer initialized with config. |
|
""" |
|
visualizer = super()._init_visualizer(cfg) |
|
visualizer.dataset_meta = self.model.dataset_meta |
|
return visualizer |
|
|
|
def _inputs_to_list(self, inputs: InputsType) -> list: |
|
"""Preprocess the inputs to a list. |
|
|
|
Preprocess inputs to a list according to its type: |
|
|
|
- list or tuple: return inputs |
|
- str: |
|
- Directory path: return all files in the directory |
|
- other cases: return a list containing the string. The string |
|
could be a path to file, a url or other types of string according |
|
to the task. |
|
|
|
Args: |
|
inputs (InputsType): Inputs for the inferencer. |
|
|
|
Returns: |
|
list: List of input for the :meth:`preprocess`. |
|
""" |
|
if isinstance(inputs, str): |
|
backend = get_file_backend(inputs) |
|
if hasattr(backend, 'isdir') and isdir(inputs): |
|
|
|
|
|
|
|
filename_list = list_dir_or_file( |
|
inputs, list_dir=False, suffix=IMG_EXTENSIONS) |
|
inputs = [ |
|
join_path(inputs, filename) for filename in filename_list |
|
] |
|
|
|
if not isinstance(inputs, (list, tuple)): |
|
inputs = [inputs] |
|
|
|
return list(inputs) |
|
|
|
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): |
|
"""Process the inputs into a model-feedable format. |
|
|
|
Customize your preprocess by overriding this method. Preprocess should |
|
return an iterable object, of which each item will be used as the |
|
input of ``model.test_step``. |
|
|
|
``BaseInferencer.preprocess`` will return an iterable chunked data, |
|
which will be used in __call__ like this: |
|
|
|
.. code-block:: python |
|
|
|
def __call__(self, inputs, batch_size=1, **kwargs): |
|
chunked_data = self.preprocess(inputs, batch_size, **kwargs) |
|
for batch in chunked_data: |
|
preds = self.forward(batch, **kwargs) |
|
|
|
Args: |
|
inputs (InputsType): Inputs given by user. |
|
batch_size (int): batch size. Defaults to 1. |
|
|
|
Yields: |
|
Any: Data processed by the ``pipeline`` and ``collate_fn``. |
|
""" |
|
chunked_data = self._get_chunk_data(inputs, batch_size) |
|
yield from map(self.collate_fn, chunked_data) |
|
|
|
def _get_chunk_data(self, inputs: Iterable, chunk_size: int): |
|
"""Get batch data from inputs. |
|
|
|
Args: |
|
inputs (Iterable): An iterable dataset. |
|
chunk_size (int): Equivalent to batch size. |
|
|
|
Yields: |
|
list: batch data. |
|
""" |
|
inputs_iter = iter(inputs) |
|
while True: |
|
try: |
|
chunk_data = [] |
|
for _ in range(chunk_size): |
|
inputs_ = next(inputs_iter) |
|
chunk_data.append((inputs_, self.pipeline(inputs_))) |
|
yield chunk_data |
|
except StopIteration: |
|
if chunk_data: |
|
yield chunk_data |
|
break |
|
|
|
|
|
|
|
|
|
def __call__(self, |
|
inputs: InputsType, |
|
batch_size: int = 1, |
|
return_vis: bool = False, |
|
show: bool = False, |
|
wait_time: int = 0, |
|
no_save_vis: bool = False, |
|
draw_pred: bool = True, |
|
pred_score_thr: float = 0.3, |
|
return_datasample: bool = False, |
|
print_result: bool = False, |
|
no_save_pred: bool = True, |
|
out_dir: str = '', |
|
**kwargs) -> dict: |
|
"""Call the inferencer. |
|
|
|
Args: |
|
inputs (InputsType): Inputs for the inferencer. |
|
batch_size (int): Inference batch size. Defaults to 1. |
|
show (bool): Whether to display the visualization results in a |
|
popup window. Defaults to False. |
|
wait_time (float): The interval of show (s). Defaults to 0. |
|
no_save_vis (bool): Whether to force not to save prediction |
|
vis results. Defaults to False. |
|
draw_pred (bool): Whether to draw predicted bounding boxes. |
|
Defaults to True. |
|
pred_score_thr (float): Minimum score of bboxes to draw. |
|
Defaults to 0.3. |
|
return_datasample (bool): Whether to return results as |
|
:obj:`DetDataSample`. Defaults to False. |
|
print_result (bool): Whether to print the inference result w/o |
|
visualization to the console. Defaults to False. |
|
no_save_pred (bool): Whether to force not to save prediction |
|
results. Defaults to True. |
|
out_file: Dir to save the inference results or |
|
visualization. If left as empty, no file will be saved. |
|
Defaults to ''. |
|
|
|
**kwargs: Other keyword arguments passed to :meth:`preprocess`, |
|
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`. |
|
Each key in kwargs should be in the corresponding set of |
|
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` |
|
and ``postprocess_kwargs``. |
|
|
|
Returns: |
|
dict: Inference and visualization results. |
|
""" |
|
( |
|
preprocess_kwargs, |
|
forward_kwargs, |
|
visualize_kwargs, |
|
postprocess_kwargs, |
|
) = self._dispatch_kwargs(**kwargs) |
|
|
|
ori_inputs = self._inputs_to_list(inputs) |
|
inputs = self.preprocess( |
|
ori_inputs, batch_size=batch_size, **preprocess_kwargs) |
|
|
|
results_dict = {'predictions': [], 'visualization': []} |
|
for ori_inputs, data in track(inputs, description='Inference'): |
|
preds = self.forward(data, **forward_kwargs) |
|
visualization = self.visualize( |
|
ori_inputs, |
|
preds, |
|
return_vis=return_vis, |
|
show=show, |
|
wait_time=wait_time, |
|
draw_pred=draw_pred, |
|
pred_score_thr=pred_score_thr, |
|
no_save_vis=no_save_vis, |
|
img_out_dir=out_dir, |
|
**visualize_kwargs) |
|
results = self.postprocess( |
|
preds, |
|
visualization, |
|
return_datasample=return_datasample, |
|
print_result=print_result, |
|
no_save_pred=no_save_pred, |
|
pred_out_dir=out_dir, |
|
**postprocess_kwargs) |
|
results_dict['predictions'].extend(results['predictions']) |
|
if results['visualization'] is not None: |
|
results_dict['visualization'].extend(results['visualization']) |
|
return results_dict |
|
|
|
def visualize(self, |
|
inputs: InputsType, |
|
preds: PredType, |
|
return_vis: bool = False, |
|
show: bool = False, |
|
wait_time: int = 0, |
|
draw_pred: bool = True, |
|
pred_score_thr: float = 0.3, |
|
no_save_vis: bool = False, |
|
img_out_dir: str = '', |
|
**kwargs) -> Union[List[np.ndarray], None]: |
|
"""Visualize predictions. |
|
|
|
Args: |
|
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. |
|
preds (List[:obj:`DetDataSample`]): Predictions of the model. |
|
return_vis (bool): Whether to return the visualization result. |
|
Defaults to False. |
|
show (bool): Whether to display the image in a popup window. |
|
Defaults to False. |
|
wait_time (float): The interval of show (s). Defaults to 0. |
|
draw_pred (bool): Whether to draw predicted bounding boxes. |
|
Defaults to True. |
|
pred_score_thr (float): Minimum score of bboxes to draw. |
|
Defaults to 0.3. |
|
no_save_vis (bool): Whether to force not to save prediction |
|
vis results. Defaults to False. |
|
img_out_dir (str): Output directory of visualization results. |
|
If left as empty, no file will be saved. Defaults to ''. |
|
|
|
Returns: |
|
List[np.ndarray] or None: Returns visualization results only if |
|
applicable. |
|
""" |
|
if no_save_vis is True: |
|
img_out_dir = '' |
|
|
|
if not show and img_out_dir == '' and not return_vis: |
|
return None |
|
|
|
if self.visualizer is None: |
|
raise ValueError('Visualization needs the "visualizer" term' |
|
'defined in the config, but got None.') |
|
|
|
results = [] |
|
|
|
for single_input, pred in zip(inputs, preds): |
|
if isinstance(single_input, str): |
|
img_bytes = mmengine.fileio.get(single_input) |
|
img = mmcv.imfrombytes(img_bytes) |
|
img = img[:, :, ::-1] |
|
img_name = osp.basename(single_input) |
|
elif isinstance(single_input, np.ndarray): |
|
img = single_input.copy() |
|
img_num = str(self.num_visualized_imgs).zfill(8) |
|
img_name = f'{img_num}.jpg' |
|
else: |
|
raise ValueError('Unsupported input type: ' |
|
f'{type(single_input)}') |
|
|
|
out_file = osp.join(img_out_dir, 'vis', |
|
img_name) if img_out_dir != '' else None |
|
|
|
self.visualizer.add_datasample( |
|
img_name, |
|
img, |
|
pred, |
|
show=show, |
|
wait_time=wait_time, |
|
draw_gt=False, |
|
draw_pred=draw_pred, |
|
pred_score_thr=pred_score_thr, |
|
out_file=out_file, |
|
) |
|
results.append(self.visualizer.get_image()) |
|
self.num_visualized_imgs += 1 |
|
|
|
return results |
|
|
|
def postprocess( |
|
self, |
|
preds: PredType, |
|
visualization: Optional[List[np.ndarray]] = None, |
|
return_datasample: bool = False, |
|
print_result: bool = False, |
|
no_save_pred: bool = False, |
|
pred_out_dir: str = '', |
|
**kwargs, |
|
) -> Dict: |
|
"""Process the predictions and visualization results from ``forward`` |
|
and ``visualize``. |
|
|
|
This method should be responsible for the following tasks: |
|
|
|
1. Convert datasamples into a json-serializable dict if needed. |
|
2. Pack the predictions and visualization results and return them. |
|
3. Dump or log the predictions. |
|
|
|
Args: |
|
preds (List[:obj:`DetDataSample`]): Predictions of the model. |
|
visualization (Optional[np.ndarray]): Visualized predictions. |
|
return_datasample (bool): Whether to use Datasample to store |
|
inference results. If False, dict will be used. |
|
print_result (bool): Whether to print the inference result w/o |
|
visualization to the console. Defaults to False. |
|
no_save_pred (bool): Whether to force not to save prediction |
|
results. Defaults to False. |
|
pred_out_dir: Dir to save the inference results w/o |
|
visualization. If left as empty, no file will be saved. |
|
Defaults to ''. |
|
|
|
Returns: |
|
dict: Inference and visualization results with key ``predictions`` |
|
and ``visualization``. |
|
|
|
- ``visualization`` (Any): Returned by :meth:`visualize`. |
|
- ``predictions`` (dict or DataSample): Returned by |
|
:meth:`forward` and processed in :meth:`postprocess`. |
|
If ``return_datasample=False``, it usually should be a |
|
json-serializable dict containing only basic data elements such |
|
as strings and numbers. |
|
""" |
|
if no_save_pred is True: |
|
pred_out_dir = '' |
|
|
|
result_dict = {} |
|
results = preds |
|
if not return_datasample: |
|
results = [] |
|
for pred in preds: |
|
result = self.pred2dict(pred, pred_out_dir) |
|
results.append(result) |
|
elif pred_out_dir != '': |
|
warnings.warn('Currently does not support saving datasample ' |
|
'when return_datasample is set to True. ' |
|
'Prediction results are not saved!') |
|
|
|
result_dict['predictions'] = results |
|
if print_result: |
|
print(result_dict) |
|
result_dict['visualization'] = visualization |
|
return result_dict |
|
|
|
|
|
|
|
def pred2dict(self, |
|
data_sample: DetDataSample, |
|
pred_out_dir: str = '') -> Dict: |
|
"""Extract elements necessary to represent a prediction into a |
|
dictionary. |
|
|
|
It's better to contain only basic data elements such as strings and |
|
numbers in order to guarantee it's json-serializable. |
|
|
|
Args: |
|
data_sample (:obj:`DetDataSample`): Predictions of the model. |
|
pred_out_dir: Dir to save the inference results w/o |
|
visualization. If left as empty, no file will be saved. |
|
Defaults to ''. |
|
|
|
Returns: |
|
dict: Prediction results. |
|
""" |
|
is_save_pred = True |
|
if pred_out_dir == '': |
|
is_save_pred = False |
|
|
|
if is_save_pred and 'img_path' in data_sample: |
|
img_path = osp.basename(data_sample.img_path) |
|
img_path = osp.splitext(img_path)[0] |
|
out_img_path = osp.join(pred_out_dir, 'preds', |
|
img_path + '_panoptic_seg.png') |
|
out_json_path = osp.join(pred_out_dir, 'preds', img_path + '.json') |
|
elif is_save_pred: |
|
out_img_path = osp.join( |
|
pred_out_dir, 'preds', |
|
f'{self.num_predicted_imgs}_panoptic_seg.png') |
|
out_json_path = osp.join(pred_out_dir, 'preds', |
|
f'{self.num_predicted_imgs}.json') |
|
self.num_predicted_imgs += 1 |
|
|
|
result = {} |
|
if 'pred_instances' in data_sample: |
|
masks = data_sample.pred_instances.get('masks') |
|
pred_instances = data_sample.pred_instances.numpy() |
|
result = { |
|
'bboxes': pred_instances.bboxes.tolist(), |
|
'labels': pred_instances.labels.tolist(), |
|
'scores': pred_instances.scores.tolist() |
|
} |
|
if masks is not None: |
|
if pred_instances.bboxes.sum() == 0: |
|
|
|
bboxes = mask2bbox(masks.cpu()).numpy().tolist() |
|
result['bboxes'] = bboxes |
|
encode_masks = encode_mask_results(pred_instances.masks) |
|
for encode_mask in encode_masks: |
|
if isinstance(encode_mask['counts'], bytes): |
|
encode_mask['counts'] = encode_mask['counts'].decode() |
|
result['masks'] = encode_masks |
|
|
|
if 'pred_panoptic_seg' in data_sample: |
|
if VOID is None: |
|
raise RuntimeError( |
|
'panopticapi is not installed, please install it by: ' |
|
'pip install git+https://github.com/cocodataset/' |
|
'panopticapi.git.') |
|
|
|
pan = data_sample.pred_panoptic_seg.sem_seg.cpu().numpy()[0] |
|
pan[pan % INSTANCE_OFFSET == len( |
|
self.model.dataset_meta['classes'])] = VOID |
|
pan = id2rgb(pan).astype(np.uint8) |
|
|
|
if is_save_pred: |
|
mmcv.imwrite(pan[:, :, ::-1], out_img_path) |
|
result['panoptic_seg_path'] = out_img_path |
|
else: |
|
result['panoptic_seg'] = pan |
|
|
|
if is_save_pred: |
|
mmengine.dump(result, out_json_path) |
|
|
|
return result |
|
|