pose_demo_01 / easy_ViTPose /inference.py
Maksym-Lysyi's picture
initial commit
e3641b1
import abc
import os
from typing import Optional
import typing
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from .configs.ViTPose_common import data_cfg
from .sort import Sort
from .vit_models.model import ViTPose
from .vit_utils.inference import draw_bboxes, pad_image
from .vit_utils.top_down_eval import keypoints_from_heatmaps
from .vit_utils.util import dyn_model_import, infer_dataset_by_path
from .vit_utils.visualization import draw_points_and_skeleton, joints_dict
try:
import torch_tensorrt
except ModuleNotFoundError:
pass
try:
import onnxruntime
except ModuleNotFoundError:
pass
__all__ = ['VitInference']
np.bool = np.bool_
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DETC_TO_YOLO_YOLOC = {
'human': [0],
'cat': [15],
'dog': [16],
'horse': [17],
'sheep': [18],
'cow': [19],
'elephant': [20],
'bear': [21],
'zebra': [22],
'giraffe': [23],
'animals': [15, 16, 17, 18, 19, 20, 21, 22, 23]
}
class VitInference:
"""
Class for performing inference using ViTPose models with YOLOv8 human detection and SORT tracking.
Args:
model (str): Path to the ViT model file (.pth, .onnx, .engine).
yolo (str): Path of the YOLOv8 model to load.
model_name (str, optional): Name of the ViT model architecture to use.
Valid values are 's', 'b', 'l', 'h'.
Defaults to None, is necessary when using .pth checkpoints.
det_class (str, optional): the detection class. if None it is inferred by the dataset.
valid values are 'human', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'animals' (which is all previous but human)
dataset (str, optional): Name of the dataset. If None it's extracted from the file name.
Valid values are 'coco', 'coco_25', 'wholebody', 'mpii',
'ap10k', 'apt36k', 'aic'
yolo_size (int, optional): Size of the input image for YOLOv8 model. Defaults to 320.
device (str, optional): Device to use for inference. Defaults to 'cuda' if available, else 'cpu'.
is_video (bool, optional): Flag indicating if the input is video. Defaults to False.
single_pose (bool, optional): Flag indicating if the video (on images this flag has no effect)
will contain a single pose.
In this case the SORT tracker is not used (increasing performance)
but people id tracking
won't be consistent among frames.
yolo_step (int, optional): The tracker can be used to predict the bboxes instead of yolo for performance,
this flag specifies how often yolo is applied (e.g. 1 applies yolo every frame).
This does not have any effect when is_video is False.
"""
def __init__(self, model: str,
yolo: str,
model_name: Optional[str] = None,
det_class: Optional[str] = None,
dataset: Optional[str] = None,
yolo_size: Optional[int] = 320,
device: Optional[str] = None,
is_video: Optional[bool] = False,
single_pose: Optional[bool] = False,
yolo_step: Optional[int] = 1):
assert os.path.isfile(model), f'The model file {model} does not exist'
assert os.path.isfile(yolo), f'The YOLOv8 model {yolo} does not exist'
# Device priority is cuda / mps / cpu
if device is None:
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
self.device = device
self.yolo = YOLO(yolo, task='detect')
self.yolo_size = yolo_size
self.yolo_step = yolo_step
self.is_video = is_video
self.single_pose = single_pose
self.reset()
# State saving during inference
self.save_state = True # Can be disabled manually
self._img = None
self._yolo_res = None
self._tracker_res = None
self._keypoints = None
# Use extension to decide which kind of model has been loaded
use_onnx = model.endswith('.onnx')
use_trt = model.endswith('.engine')
# Extract dataset name
if dataset is None:
dataset = infer_dataset_by_path(model)
assert dataset in ['mpii', 'coco', 'coco_25', 'wholebody', 'aic', 'ap10k', 'apt36k'], \
'The specified dataset is not valid'
# Dataset can now be set for visualization
self.dataset = dataset
# if we picked the dataset switch to correct yolo classes if not set
if det_class is None:
det_class = 'animals' if dataset in ['ap10k', 'apt36k'] else 'human'
self.yolo_classes = DETC_TO_YOLO_YOLOC[det_class]
assert model_name in [None, 's', 'b', 'l', 'h'], \
f'The model name {model_name} is not valid'
# onnx / trt models do not require model_cfg specification
if model_name is None:
assert use_onnx or use_trt, \
'Specify the model_name if not using onnx / trt'
else:
# Dynamically import the model class
model_cfg = dyn_model_import(self.dataset, model_name)
self.target_size = data_cfg['image_size']
if use_onnx:
self._ort_session = onnxruntime.InferenceSession(model,
providers=['CUDAExecutionProvider',
'CPUExecutionProvider'])
inf_fn = self._inference_onnx
else:
self._vit_pose = ViTPose(model_cfg)
self._vit_pose.eval()
if use_trt:
self._vit_pose = torch.jit.load(model)
else:
ckpt = torch.load(model, map_location='cpu')
if 'state_dict' in ckpt:
self._vit_pose.load_state_dict(ckpt['state_dict'])
else:
self._vit_pose.load_state_dict(ckpt)
self._vit_pose.to(torch.device(device))
inf_fn = self._inference_torch
# Override _inference abstract with selected engine
self._inference = inf_fn # type: ignore
def reset(self):
"""
Reset the inference class to be ready for a new video.
This will reset the internal counter of frames, on videos
this is necessary to reset the tracker.
"""
min_hits = 3 if self.yolo_step == 1 else 1
use_tracker = self.is_video and not self.single_pose
self.tracker = Sort(max_age=self.yolo_step,
min_hits=min_hits,
iou_threshold=0.3) if use_tracker else None # TODO: Params
self.frame_counter = 0
@classmethod
def postprocess(cls, heatmaps, org_w, org_h):
"""
Postprocess the heatmaps to obtain keypoints and their probabilities.
Args:
heatmaps (ndarray): Heatmap predictions from the model.
org_w (int): Original width of the image.
org_h (int): Original height of the image.
Returns:
ndarray: Processed keypoints with probabilities.
"""
points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,
center=np.array([[org_w // 2,
org_h // 2]]),
scale=np.array([[org_w, org_h]]),
unbiased=True, use_udp=True)
return np.concatenate([points[:, :, ::-1], prob], axis=2)
@abc.abstractmethod
def _inference(self, img: np.ndarray) -> np.ndarray:
"""
Abstract method for performing inference on an image.
It is overloaded by each inference engine.
Args:
img (ndarray): Input image for inference.
Returns:
ndarray: Inference results.
"""
raise NotImplementedError
def inference(self, img: np.ndarray) -> dict[typing.Any, typing.Any]:
"""
Perform inference on the input image.
Args:
img (ndarray): Input image for inference in RGB format.
Returns:
dict[typing.Any, typing.Any]: Inference results.
"""
# First use YOLOv8 for detection
res_pd = np.empty((0, 5))
results = None
if (self.tracker is None or
(self.frame_counter % self.yolo_step == 0 or self.frame_counter < 3)):
results = self.yolo(img, verbose=False, imgsz=self.yolo_size,
device=self.device if self.device != 'cuda' else 0,
classes=self.yolo_classes)[0]
res_pd = np.array([r[:5].tolist() for r in # TODO: Confidence threshold
results.boxes.data.cpu().numpy() if r[4] > 0.35]).reshape((-1, 5))
self.frame_counter += 1
frame_keypoints = {}
ids = None
if self.tracker is not None:
res_pd = self.tracker.update(res_pd)
ids = res_pd[:, 5].astype(int).tolist()
# Prepare boxes for inference
bboxes = res_pd[:, :4].round().astype(int)
scores = res_pd[:, 4].tolist()
pad_bbox = 10
if ids is None:
ids = range(len(bboxes))
for bbox, id in zip(bboxes, ids):
# TODO: Slightly bigger bbox
bbox[[0, 2]] = np.clip(bbox[[0, 2]] + [-pad_bbox, pad_bbox], 0, img.shape[1])
bbox[[1, 3]] = np.clip(bbox[[1, 3]] + [-pad_bbox, pad_bbox], 0, img.shape[0])
# Crop image and pad to 3/4 aspect ratio
img_inf = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
img_inf, (left_pad, top_pad) = pad_image(img_inf, 3 / 4)
keypoints = self._inference(img_inf)[0]
# Transform keypoints to original image
keypoints[:, :2] += bbox[:2][::-1] - [top_pad, left_pad]
frame_keypoints[id] = keypoints
if self.save_state:
self._img = img
self._yolo_res = results
self._tracker_res = (bboxes, ids, scores)
self._keypoints = frame_keypoints
return frame_keypoints
def draw(self, show_yolo=True, show_raw_yolo=False, confidence_threshold=0.5):
"""
Draw keypoints and bounding boxes on the image.
Args:
show_yolo (bool, optional): Whether to show YOLOv8 bounding boxes. Default is True.
show_raw_yolo (bool, optional): Whether to show raw YOLOv8 bounding boxes. Default is False.
Returns:
ndarray: Image with keypoints and bounding boxes drawn.
"""
img = self._img.copy()
bboxes, ids, scores = self._tracker_res
if self._yolo_res is not None and (show_raw_yolo or (self.tracker is None and show_yolo)):
img = np.array(self._yolo_res.plot())
if show_yolo and self.tracker is not None:
img = draw_bboxes(img, bboxes, ids, scores)
img = np.array(img)[..., ::-1] # RGB to BGR for cv2 modules
for idx, k in self._keypoints.items():
img = draw_points_and_skeleton(img.copy(), k,
joints_dict()[self.dataset]['skeleton'],
person_index=idx,
points_color_palette='gist_rainbow',
skeleton_color_palette='jet',
points_palette_samples=10,
confidence_threshold=confidence_threshold)
return img[..., ::-1] # Return RGB as original
def pre_img(self, img):
org_h, org_w = img.shape[:2]
img_input = cv2.resize(img, self.target_size, interpolation=cv2.INTER_LINEAR) / 255
img_input = ((img_input - MEAN) / STD).transpose(2, 0, 1)[None].astype(np.float32)
return img_input, org_h, org_w
@torch.no_grad()
def _inference_torch(self, img: np.ndarray) -> np.ndarray:
# Prepare input data
img_input, org_h, org_w = self.pre_img(img)
img_input = torch.from_numpy(img_input).to(torch.device(self.device))
# Feed to model
heatmaps = self._vit_pose(img_input).detach().cpu().numpy()
return self.postprocess(heatmaps, org_w, org_h)
def _inference_onnx(self, img: np.ndarray) -> np.ndarray:
# Prepare input data
img_input, org_h, org_w = self.pre_img(img)
# Feed to model
ort_inputs = {self._ort_session.get_inputs()[0].name: img_input}
heatmaps = self._ort_session.run(None, ort_inputs)[0]
return self.postprocess(heatmaps, org_w, org_h)