Spaces:
Sleeping
Sleeping
import abc | |
import os | |
from typing import Optional | |
import typing | |
import cv2 | |
import numpy as np | |
import torch | |
from ultralytics import YOLO | |
from .configs.ViTPose_common import data_cfg | |
from .sort import Sort | |
from .vit_models.model import ViTPose | |
from .vit_utils.inference import draw_bboxes, pad_image | |
from .vit_utils.top_down_eval import keypoints_from_heatmaps | |
from .vit_utils.util import dyn_model_import, infer_dataset_by_path | |
from .vit_utils.visualization import draw_points_and_skeleton, joints_dict | |
try: | |
import torch_tensorrt | |
except ModuleNotFoundError: | |
pass | |
try: | |
import onnxruntime | |
except ModuleNotFoundError: | |
pass | |
__all__ = ['VitInference'] | |
np.bool = np.bool_ | |
MEAN = [0.485, 0.456, 0.406] | |
STD = [0.229, 0.224, 0.225] | |
DETC_TO_YOLO_YOLOC = { | |
'human': [0], | |
'cat': [15], | |
'dog': [16], | |
'horse': [17], | |
'sheep': [18], | |
'cow': [19], | |
'elephant': [20], | |
'bear': [21], | |
'zebra': [22], | |
'giraffe': [23], | |
'animals': [15, 16, 17, 18, 19, 20, 21, 22, 23] | |
} | |
class VitInference: | |
""" | |
Class for performing inference using ViTPose models with YOLOv8 human detection and SORT tracking. | |
Args: | |
model (str): Path to the ViT model file (.pth, .onnx, .engine). | |
yolo (str): Path of the YOLOv8 model to load. | |
model_name (str, optional): Name of the ViT model architecture to use. | |
Valid values are 's', 'b', 'l', 'h'. | |
Defaults to None, is necessary when using .pth checkpoints. | |
det_class (str, optional): the detection class. if None it is inferred by the dataset. | |
valid values are 'human', 'cat', 'dog', 'horse', 'sheep', | |
'cow', 'elephant', 'bear', 'zebra', 'giraffe', | |
'animals' (which is all previous but human) | |
dataset (str, optional): Name of the dataset. If None it's extracted from the file name. | |
Valid values are 'coco', 'coco_25', 'wholebody', 'mpii', | |
'ap10k', 'apt36k', 'aic' | |
yolo_size (int, optional): Size of the input image for YOLOv8 model. Defaults to 320. | |
device (str, optional): Device to use for inference. Defaults to 'cuda' if available, else 'cpu'. | |
is_video (bool, optional): Flag indicating if the input is video. Defaults to False. | |
single_pose (bool, optional): Flag indicating if the video (on images this flag has no effect) | |
will contain a single pose. | |
In this case the SORT tracker is not used (increasing performance) | |
but people id tracking | |
won't be consistent among frames. | |
yolo_step (int, optional): The tracker can be used to predict the bboxes instead of yolo for performance, | |
this flag specifies how often yolo is applied (e.g. 1 applies yolo every frame). | |
This does not have any effect when is_video is False. | |
""" | |
def __init__(self, model: str, | |
yolo: str, | |
model_name: Optional[str] = None, | |
det_class: Optional[str] = None, | |
dataset: Optional[str] = None, | |
yolo_size: Optional[int] = 320, | |
device: Optional[str] = None, | |
is_video: Optional[bool] = False, | |
single_pose: Optional[bool] = False, | |
yolo_step: Optional[int] = 1): | |
assert os.path.isfile(model), f'The model file {model} does not exist' | |
assert os.path.isfile(yolo), f'The YOLOv8 model {yolo} does not exist' | |
# Device priority is cuda / mps / cpu | |
if device is None: | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
device = 'mps' | |
else: | |
device = 'cpu' | |
self.device = device | |
self.yolo = YOLO(yolo, task='detect') | |
self.yolo_size = yolo_size | |
self.yolo_step = yolo_step | |
self.is_video = is_video | |
self.single_pose = single_pose | |
self.reset() | |
# State saving during inference | |
self.save_state = True # Can be disabled manually | |
self._img = None | |
self._yolo_res = None | |
self._tracker_res = None | |
self._keypoints = None | |
# Use extension to decide which kind of model has been loaded | |
use_onnx = model.endswith('.onnx') | |
use_trt = model.endswith('.engine') | |
# Extract dataset name | |
if dataset is None: | |
dataset = infer_dataset_by_path(model) | |
assert dataset in ['mpii', 'coco', 'coco_25', 'wholebody', 'aic', 'ap10k', 'apt36k'], \ | |
'The specified dataset is not valid' | |
# Dataset can now be set for visualization | |
self.dataset = dataset | |
# if we picked the dataset switch to correct yolo classes if not set | |
if det_class is None: | |
det_class = 'animals' if dataset in ['ap10k', 'apt36k'] else 'human' | |
self.yolo_classes = DETC_TO_YOLO_YOLOC[det_class] | |
assert model_name in [None, 's', 'b', 'l', 'h'], \ | |
f'The model name {model_name} is not valid' | |
# onnx / trt models do not require model_cfg specification | |
if model_name is None: | |
assert use_onnx or use_trt, \ | |
'Specify the model_name if not using onnx / trt' | |
else: | |
# Dynamically import the model class | |
model_cfg = dyn_model_import(self.dataset, model_name) | |
self.target_size = data_cfg['image_size'] | |
if use_onnx: | |
self._ort_session = onnxruntime.InferenceSession(model, | |
providers=['CUDAExecutionProvider', | |
'CPUExecutionProvider']) | |
inf_fn = self._inference_onnx | |
else: | |
self._vit_pose = ViTPose(model_cfg) | |
self._vit_pose.eval() | |
if use_trt: | |
self._vit_pose = torch.jit.load(model) | |
else: | |
ckpt = torch.load(model, map_location='cpu') | |
if 'state_dict' in ckpt: | |
self._vit_pose.load_state_dict(ckpt['state_dict']) | |
else: | |
self._vit_pose.load_state_dict(ckpt) | |
self._vit_pose.to(torch.device(device)) | |
inf_fn = self._inference_torch | |
# Override _inference abstract with selected engine | |
self._inference = inf_fn # type: ignore | |
def reset(self): | |
""" | |
Reset the inference class to be ready for a new video. | |
This will reset the internal counter of frames, on videos | |
this is necessary to reset the tracker. | |
""" | |
min_hits = 3 if self.yolo_step == 1 else 1 | |
use_tracker = self.is_video and not self.single_pose | |
self.tracker = Sort(max_age=self.yolo_step, | |
min_hits=min_hits, | |
iou_threshold=0.3) if use_tracker else None # TODO: Params | |
self.frame_counter = 0 | |
def postprocess(cls, heatmaps, org_w, org_h): | |
""" | |
Postprocess the heatmaps to obtain keypoints and their probabilities. | |
Args: | |
heatmaps (ndarray): Heatmap predictions from the model. | |
org_w (int): Original width of the image. | |
org_h (int): Original height of the image. | |
Returns: | |
ndarray: Processed keypoints with probabilities. | |
""" | |
points, prob = keypoints_from_heatmaps(heatmaps=heatmaps, | |
center=np.array([[org_w // 2, | |
org_h // 2]]), | |
scale=np.array([[org_w, org_h]]), | |
unbiased=True, use_udp=True) | |
return np.concatenate([points[:, :, ::-1], prob], axis=2) | |
def _inference(self, img: np.ndarray) -> np.ndarray: | |
""" | |
Abstract method for performing inference on an image. | |
It is overloaded by each inference engine. | |
Args: | |
img (ndarray): Input image for inference. | |
Returns: | |
ndarray: Inference results. | |
""" | |
raise NotImplementedError | |
def inference(self, img: np.ndarray) -> dict[typing.Any, typing.Any]: | |
""" | |
Perform inference on the input image. | |
Args: | |
img (ndarray): Input image for inference in RGB format. | |
Returns: | |
dict[typing.Any, typing.Any]: Inference results. | |
""" | |
# First use YOLOv8 for detection | |
res_pd = np.empty((0, 5)) | |
results = None | |
if (self.tracker is None or | |
(self.frame_counter % self.yolo_step == 0 or self.frame_counter < 3)): | |
results = self.yolo(img, verbose=False, imgsz=self.yolo_size, | |
device=self.device if self.device != 'cuda' else 0, | |
classes=self.yolo_classes)[0] | |
res_pd = np.array([r[:5].tolist() for r in # TODO: Confidence threshold | |
results.boxes.data.cpu().numpy() if r[4] > 0.35]).reshape((-1, 5)) | |
self.frame_counter += 1 | |
frame_keypoints = {} | |
ids = None | |
if self.tracker is not None: | |
res_pd = self.tracker.update(res_pd) | |
ids = res_pd[:, 5].astype(int).tolist() | |
# Prepare boxes for inference | |
bboxes = res_pd[:, :4].round().astype(int) | |
scores = res_pd[:, 4].tolist() | |
pad_bbox = 10 | |
if ids is None: | |
ids = range(len(bboxes)) | |
for bbox, id in zip(bboxes, ids): | |
# TODO: Slightly bigger bbox | |
bbox[[0, 2]] = np.clip(bbox[[0, 2]] + [-pad_bbox, pad_bbox], 0, img.shape[1]) | |
bbox[[1, 3]] = np.clip(bbox[[1, 3]] + [-pad_bbox, pad_bbox], 0, img.shape[0]) | |
# Crop image and pad to 3/4 aspect ratio | |
img_inf = img[bbox[1]:bbox[3], bbox[0]:bbox[2]] | |
img_inf, (left_pad, top_pad) = pad_image(img_inf, 3 / 4) | |
keypoints = self._inference(img_inf)[0] | |
# Transform keypoints to original image | |
keypoints[:, :2] += bbox[:2][::-1] - [top_pad, left_pad] | |
frame_keypoints[id] = keypoints | |
if self.save_state: | |
self._img = img | |
self._yolo_res = results | |
self._tracker_res = (bboxes, ids, scores) | |
self._keypoints = frame_keypoints | |
return frame_keypoints | |
def draw(self, show_yolo=True, show_raw_yolo=False, confidence_threshold=0.5): | |
""" | |
Draw keypoints and bounding boxes on the image. | |
Args: | |
show_yolo (bool, optional): Whether to show YOLOv8 bounding boxes. Default is True. | |
show_raw_yolo (bool, optional): Whether to show raw YOLOv8 bounding boxes. Default is False. | |
Returns: | |
ndarray: Image with keypoints and bounding boxes drawn. | |
""" | |
img = self._img.copy() | |
bboxes, ids, scores = self._tracker_res | |
if self._yolo_res is not None and (show_raw_yolo or (self.tracker is None and show_yolo)): | |
img = np.array(self._yolo_res.plot()) | |
if show_yolo and self.tracker is not None: | |
img = draw_bboxes(img, bboxes, ids, scores) | |
img = np.array(img)[..., ::-1] # RGB to BGR for cv2 modules | |
for idx, k in self._keypoints.items(): | |
img = draw_points_and_skeleton(img.copy(), k, | |
joints_dict()[self.dataset]['skeleton'], | |
person_index=idx, | |
points_color_palette='gist_rainbow', | |
skeleton_color_palette='jet', | |
points_palette_samples=10, | |
confidence_threshold=confidence_threshold) | |
return img[..., ::-1] # Return RGB as original | |
def pre_img(self, img): | |
org_h, org_w = img.shape[:2] | |
img_input = cv2.resize(img, self.target_size, interpolation=cv2.INTER_LINEAR) / 255 | |
img_input = ((img_input - MEAN) / STD).transpose(2, 0, 1)[None].astype(np.float32) | |
return img_input, org_h, org_w | |
def _inference_torch(self, img: np.ndarray) -> np.ndarray: | |
# Prepare input data | |
img_input, org_h, org_w = self.pre_img(img) | |
img_input = torch.from_numpy(img_input).to(torch.device(self.device)) | |
# Feed to model | |
heatmaps = self._vit_pose(img_input).detach().cpu().numpy() | |
return self.postprocess(heatmaps, org_w, org_h) | |
def _inference_onnx(self, img: np.ndarray) -> np.ndarray: | |
# Prepare input data | |
img_input, org_h, org_w = self.pre_img(img) | |
# Feed to model | |
ort_inputs = {self._ort_session.get_inputs()[0].name: img_input} | |
heatmaps = self._ort_session.run(None, ort_inputs)[0] | |
return self.postprocess(heatmaps, org_w, org_h) |