# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np from scipy.spatial import ConvexHull from skimage.draw import polygon from scipy import ndimage from .utils import convert_to_numpy class MaskDrawAnnotator: def __init__(self, cfg, device=None): self.mode = cfg.get('MODE', 'maskpoint') self.return_dict = cfg.get('RETURN_DICT', True) assert self.mode in ['maskpoint', 'maskbbox', 'mask', 'bbox'] def forward(self, mask=None, image=None, bbox=None, mode=None, return_dict=None): mode = mode if mode is not None else self.mode return_dict = return_dict if return_dict is not None else self.return_dict mask = convert_to_numpy(mask) if mask is not None else None image = convert_to_numpy(image) if image is not None else None mask_shape = mask.shape if mode == 'maskpoint': scribble = mask.transpose(1, 0) labeled_array, num_features = ndimage.label(scribble >= 255) centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features + 1)) centers = np.array(centers) out_mask = np.zeros(mask_shape, dtype=np.uint8) hull = ConvexHull(centers) hull_vertices = centers[hull.vertices] rr, cc = polygon(hull_vertices[:, 1], hull_vertices[:, 0], mask_shape) out_mask[rr, cc] = 255 elif mode == 'maskbbox': scribble = mask.transpose(1, 0) labeled_array, num_features = ndimage.label(scribble >= 255) centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features + 1)) centers = np.array(centers) # (x1, y1, x2, y2) x_min = centers[:, 0].min() x_max = centers[:, 0].max() y_min = centers[:, 1].min() y_max = centers[:, 1].max() out_mask = np.zeros(mask_shape, dtype=np.uint8) out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255 if image is not None: out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] elif mode == 'bbox': if isinstance(bbox, list): bbox = np.array(bbox) x_min, y_min, x_max, y_max = bbox out_mask = np.zeros(mask_shape, dtype=np.uint8) out_mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 255 if image is not None: out_image = image[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] elif mode == 'mask': out_mask = mask else: raise NotImplementedError if return_dict: if image is not None: return {"image": out_image, "mask": out_mask} else: return {"mask": out_mask} else: if image is not None: return out_image, out_mask else: return out_mask