File size: 3,730 Bytes
690f890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from scipy import ndimage

from .utils import convert_to_numpy


class SAMImageAnnotator:
    def __init__(self, cfg, device=None):
        try:
            from segment_anything import sam_model_registry, SamPredictor
            from segment_anything.utils.transforms import ResizeLongestSide
        except:
            import warnings
            warnings.warn("please pip install sam package, or you can refer to models/VACE-Annotators/sam/segment_anything-1.0-py3-none-any.whl")
        self.task_type = cfg.get('TASK_TYPE', 'input_box')
        self.return_mask = cfg.get('RETURN_MASK', False)
        self.transform = ResizeLongestSide(1024)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        seg_model = sam_model_registry[cfg.get('MODEL_NAME', 'vit_b')](checkpoint=cfg['PRETRAINED_MODEL']).eval().to(self.device)
        self.predictor = SamPredictor(seg_model)

    def forward(self,
                image,
                input_box=None,
                mask=None,
                task_type=None,
                return_mask=None):
        task_type = task_type if task_type is not None else self.task_type
        return_mask = return_mask if return_mask is not None else self.return_mask
        mask = convert_to_numpy(mask) if mask is not None else None

        if task_type == 'mask_point':
            if len(mask.shape) == 3:
                scribble = mask.transpose(2, 1, 0)[0]
            else:
                scribble = mask.transpose(1, 0)   # (H, W) -> (W, H)
            labeled_array, num_features = ndimage.label(scribble >= 255)
            centers = ndimage.center_of_mass(scribble, labeled_array,
                                             range(1, num_features + 1))
            point_coords = np.array(centers)
            point_labels = np.array([1] * len(centers))
            sample = {
                'point_coords': point_coords,
                'point_labels': point_labels
            }
        elif task_type == 'mask_box':
            if len(mask.shape) == 3:
                scribble = mask.transpose(2, 1, 0)[0]
            else:
                scribble = mask.transpose(1, 0)  # (H, W) -> (W, H)
            labeled_array, num_features = ndimage.label(scribble >= 255)
            centers = ndimage.center_of_mass(scribble, labeled_array,
                                             range(1, num_features + 1))
            centers = np.array(centers)
            # (x1, y1, x2, y2)
            x_min = centers[:, 0].min()
            x_max = centers[:, 0].max()
            y_min = centers[:, 1].min()
            y_max = centers[:, 1].max()
            bbox = np.array([x_min, y_min, x_max, y_max])
            sample = {'box': bbox}
        elif task_type == 'input_box':
            if isinstance(input_box, list):
                input_box = np.array(input_box)
            sample = {'box': input_box}
        elif task_type == 'mask':
            sample = {'mask_input': mask[None, :, :]}
        else:
            raise NotImplementedError

        self.predictor.set_image(image)
        masks, scores, logits = self.predictor.predict(
            multimask_output=False,
            **sample
        )
        sorted_ind = np.argsort(scores)[::-1]
        masks = masks[sorted_ind]
        scores = scores[sorted_ind]
        logits = logits[sorted_ind]
        
        if return_mask:
            return masks[0]
        else:
            ret_data = {
                "masks": masks,
                "scores": scores,
                "logits": logits
            }
            return ret_data