maffia's picture
Upload 94 files
690f890 verified
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import torch
import numpy as np
import torchvision
from .utils import convert_to_numpy
class GDINOAnnotator:
def __init__(self, cfg, device=None):
try:
from groundingdino.util.inference import Model, load_model, load_image, predict
except:
import warnings
warnings.warn("please pip install groundingdino package, or you can refer to models/VACE-Annotators/gdino/groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl")
grounding_dino_config_path = cfg['CONFIG_PATH']
grounding_dino_checkpoint_path = cfg['PRETRAINED_MODEL']
grounding_dino_tokenizer_path = cfg['TOKENIZER_PATH'] # TODO
self.box_threshold = cfg.get('BOX_THRESHOLD', 0.25)
self.text_threshold = cfg.get('TEXT_THRESHOLD', 0.2)
self.iou_threshold = cfg.get('IOU_THRESHOLD', 0.5)
self.use_nms = cfg.get('USE_NMS', True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = Model(model_config_path=grounding_dino_config_path,
model_checkpoint_path=grounding_dino_checkpoint_path,
device=self.device)
def forward(self, image, classes=None, caption=None):
image_bgr = convert_to_numpy(image)[..., ::-1] # bgr
if classes is not None:
classes = [classes] if isinstance(classes, str) else classes
detections = self.model.predict_with_classes(
image=image_bgr,
classes=classes,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold
)
elif caption is not None:
detections, phrases = self.model.predict_with_caption(
image=image_bgr,
caption=caption,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold
)
else:
raise NotImplementedError()
if self.use_nms:
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
self.iou_threshold
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx] if detections.class_id is not None else None
boxes = detections.xyxy
confidences = detections.confidence
class_ids = detections.class_id
class_names = [classes[_id] for _id in class_ids] if classes is not None else phrases
ret_data = {
"boxes": boxes.tolist() if boxes is not None else None,
"confidences": confidences.tolist() if confidences is not None else None,
"class_ids": class_ids.tolist() if class_ids is not None else None,
"class_names": class_names if class_names is not None else None,
}
return ret_data
class GDINORAMAnnotator:
def __init__(self, cfg, device=None):
from .ram import RAMAnnotator
from .gdino import GDINOAnnotator
self.ram_model = RAMAnnotator(cfg['RAM'], device=device)
self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)
def forward(self, image):
ram_res = self.ram_model.forward(image)
classes = ram_res['tag_e'] if isinstance(ram_res, dict) else ram_res
gdino_res = self.gdino_model.forward(image, classes=classes)
return gdino_res