Spaces:
Sleeping
Sleeping
File size: 13,420 Bytes
e3641b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
import abc
import os
from typing import Optional
import typing
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from .configs.ViTPose_common import data_cfg
from .sort import Sort
from .vit_models.model import ViTPose
from .vit_utils.inference import draw_bboxes, pad_image
from .vit_utils.top_down_eval import keypoints_from_heatmaps
from .vit_utils.util import dyn_model_import, infer_dataset_by_path
from .vit_utils.visualization import draw_points_and_skeleton, joints_dict
try:
import torch_tensorrt
except ModuleNotFoundError:
pass
try:
import onnxruntime
except ModuleNotFoundError:
pass
__all__ = ['VitInference']
np.bool = np.bool_
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DETC_TO_YOLO_YOLOC = {
'human': [0],
'cat': [15],
'dog': [16],
'horse': [17],
'sheep': [18],
'cow': [19],
'elephant': [20],
'bear': [21],
'zebra': [22],
'giraffe': [23],
'animals': [15, 16, 17, 18, 19, 20, 21, 22, 23]
}
class VitInference:
"""
Class for performing inference using ViTPose models with YOLOv8 human detection and SORT tracking.
Args:
model (str): Path to the ViT model file (.pth, .onnx, .engine).
yolo (str): Path of the YOLOv8 model to load.
model_name (str, optional): Name of the ViT model architecture to use.
Valid values are 's', 'b', 'l', 'h'.
Defaults to None, is necessary when using .pth checkpoints.
det_class (str, optional): the detection class. if None it is inferred by the dataset.
valid values are 'human', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'animals' (which is all previous but human)
dataset (str, optional): Name of the dataset. If None it's extracted from the file name.
Valid values are 'coco', 'coco_25', 'wholebody', 'mpii',
'ap10k', 'apt36k', 'aic'
yolo_size (int, optional): Size of the input image for YOLOv8 model. Defaults to 320.
device (str, optional): Device to use for inference. Defaults to 'cuda' if available, else 'cpu'.
is_video (bool, optional): Flag indicating if the input is video. Defaults to False.
single_pose (bool, optional): Flag indicating if the video (on images this flag has no effect)
will contain a single pose.
In this case the SORT tracker is not used (increasing performance)
but people id tracking
won't be consistent among frames.
yolo_step (int, optional): The tracker can be used to predict the bboxes instead of yolo for performance,
this flag specifies how often yolo is applied (e.g. 1 applies yolo every frame).
This does not have any effect when is_video is False.
"""
def __init__(self, model: str,
yolo: str,
model_name: Optional[str] = None,
det_class: Optional[str] = None,
dataset: Optional[str] = None,
yolo_size: Optional[int] = 320,
device: Optional[str] = None,
is_video: Optional[bool] = False,
single_pose: Optional[bool] = False,
yolo_step: Optional[int] = 1):
assert os.path.isfile(model), f'The model file {model} does not exist'
assert os.path.isfile(yolo), f'The YOLOv8 model {yolo} does not exist'
# Device priority is cuda / mps / cpu
if device is None:
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
self.device = device
self.yolo = YOLO(yolo, task='detect')
self.yolo_size = yolo_size
self.yolo_step = yolo_step
self.is_video = is_video
self.single_pose = single_pose
self.reset()
# State saving during inference
self.save_state = True # Can be disabled manually
self._img = None
self._yolo_res = None
self._tracker_res = None
self._keypoints = None
# Use extension to decide which kind of model has been loaded
use_onnx = model.endswith('.onnx')
use_trt = model.endswith('.engine')
# Extract dataset name
if dataset is None:
dataset = infer_dataset_by_path(model)
assert dataset in ['mpii', 'coco', 'coco_25', 'wholebody', 'aic', 'ap10k', 'apt36k'], \
'The specified dataset is not valid'
# Dataset can now be set for visualization
self.dataset = dataset
# if we picked the dataset switch to correct yolo classes if not set
if det_class is None:
det_class = 'animals' if dataset in ['ap10k', 'apt36k'] else 'human'
self.yolo_classes = DETC_TO_YOLO_YOLOC[det_class]
assert model_name in [None, 's', 'b', 'l', 'h'], \
f'The model name {model_name} is not valid'
# onnx / trt models do not require model_cfg specification
if model_name is None:
assert use_onnx or use_trt, \
'Specify the model_name if not using onnx / trt'
else:
# Dynamically import the model class
model_cfg = dyn_model_import(self.dataset, model_name)
self.target_size = data_cfg['image_size']
if use_onnx:
self._ort_session = onnxruntime.InferenceSession(model,
providers=['CUDAExecutionProvider',
'CPUExecutionProvider'])
inf_fn = self._inference_onnx
else:
self._vit_pose = ViTPose(model_cfg)
self._vit_pose.eval()
if use_trt:
self._vit_pose = torch.jit.load(model)
else:
ckpt = torch.load(model, map_location='cpu')
if 'state_dict' in ckpt:
self._vit_pose.load_state_dict(ckpt['state_dict'])
else:
self._vit_pose.load_state_dict(ckpt)
self._vit_pose.to(torch.device(device))
inf_fn = self._inference_torch
# Override _inference abstract with selected engine
self._inference = inf_fn # type: ignore
def reset(self):
"""
Reset the inference class to be ready for a new video.
This will reset the internal counter of frames, on videos
this is necessary to reset the tracker.
"""
min_hits = 3 if self.yolo_step == 1 else 1
use_tracker = self.is_video and not self.single_pose
self.tracker = Sort(max_age=self.yolo_step,
min_hits=min_hits,
iou_threshold=0.3) if use_tracker else None # TODO: Params
self.frame_counter = 0
@classmethod
def postprocess(cls, heatmaps, org_w, org_h):
"""
Postprocess the heatmaps to obtain keypoints and their probabilities.
Args:
heatmaps (ndarray): Heatmap predictions from the model.
org_w (int): Original width of the image.
org_h (int): Original height of the image.
Returns:
ndarray: Processed keypoints with probabilities.
"""
points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,
center=np.array([[org_w // 2,
org_h // 2]]),
scale=np.array([[org_w, org_h]]),
unbiased=True, use_udp=True)
return np.concatenate([points[:, :, ::-1], prob], axis=2)
@abc.abstractmethod
def _inference(self, img: np.ndarray) -> np.ndarray:
"""
Abstract method for performing inference on an image.
It is overloaded by each inference engine.
Args:
img (ndarray): Input image for inference.
Returns:
ndarray: Inference results.
"""
raise NotImplementedError
def inference(self, img: np.ndarray) -> dict[typing.Any, typing.Any]:
"""
Perform inference on the input image.
Args:
img (ndarray): Input image for inference in RGB format.
Returns:
dict[typing.Any, typing.Any]: Inference results.
"""
# First use YOLOv8 for detection
res_pd = np.empty((0, 5))
results = None
if (self.tracker is None or
(self.frame_counter % self.yolo_step == 0 or self.frame_counter < 3)):
results = self.yolo(img, verbose=False, imgsz=self.yolo_size,
device=self.device if self.device != 'cuda' else 0,
classes=self.yolo_classes)[0]
res_pd = np.array([r[:5].tolist() for r in # TODO: Confidence threshold
results.boxes.data.cpu().numpy() if r[4] > 0.35]).reshape((-1, 5))
self.frame_counter += 1
frame_keypoints = {}
ids = None
if self.tracker is not None:
res_pd = self.tracker.update(res_pd)
ids = res_pd[:, 5].astype(int).tolist()
# Prepare boxes for inference
bboxes = res_pd[:, :4].round().astype(int)
scores = res_pd[:, 4].tolist()
pad_bbox = 10
if ids is None:
ids = range(len(bboxes))
for bbox, id in zip(bboxes, ids):
# TODO: Slightly bigger bbox
bbox[[0, 2]] = np.clip(bbox[[0, 2]] + [-pad_bbox, pad_bbox], 0, img.shape[1])
bbox[[1, 3]] = np.clip(bbox[[1, 3]] + [-pad_bbox, pad_bbox], 0, img.shape[0])
# Crop image and pad to 3/4 aspect ratio
img_inf = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
img_inf, (left_pad, top_pad) = pad_image(img_inf, 3 / 4)
keypoints = self._inference(img_inf)[0]
# Transform keypoints to original image
keypoints[:, :2] += bbox[:2][::-1] - [top_pad, left_pad]
frame_keypoints[id] = keypoints
if self.save_state:
self._img = img
self._yolo_res = results
self._tracker_res = (bboxes, ids, scores)
self._keypoints = frame_keypoints
return frame_keypoints
def draw(self, show_yolo=True, show_raw_yolo=False, confidence_threshold=0.5):
"""
Draw keypoints and bounding boxes on the image.
Args:
show_yolo (bool, optional): Whether to show YOLOv8 bounding boxes. Default is True.
show_raw_yolo (bool, optional): Whether to show raw YOLOv8 bounding boxes. Default is False.
Returns:
ndarray: Image with keypoints and bounding boxes drawn.
"""
img = self._img.copy()
bboxes, ids, scores = self._tracker_res
if self._yolo_res is not None and (show_raw_yolo or (self.tracker is None and show_yolo)):
img = np.array(self._yolo_res.plot())
if show_yolo and self.tracker is not None:
img = draw_bboxes(img, bboxes, ids, scores)
img = np.array(img)[..., ::-1] # RGB to BGR for cv2 modules
for idx, k in self._keypoints.items():
img = draw_points_and_skeleton(img.copy(), k,
joints_dict()[self.dataset]['skeleton'],
person_index=idx,
points_color_palette='gist_rainbow',
skeleton_color_palette='jet',
points_palette_samples=10,
confidence_threshold=confidence_threshold)
return img[..., ::-1] # Return RGB as original
def pre_img(self, img):
org_h, org_w = img.shape[:2]
img_input = cv2.resize(img, self.target_size, interpolation=cv2.INTER_LINEAR) / 255
img_input = ((img_input - MEAN) / STD).transpose(2, 0, 1)[None].astype(np.float32)
return img_input, org_h, org_w
@torch.no_grad()
def _inference_torch(self, img: np.ndarray) -> np.ndarray:
# Prepare input data
img_input, org_h, org_w = self.pre_img(img)
img_input = torch.from_numpy(img_input).to(torch.device(self.device))
# Feed to model
heatmaps = self._vit_pose(img_input).detach().cpu().numpy()
return self.postprocess(heatmaps, org_w, org_h)
def _inference_onnx(self, img: np.ndarray) -> np.ndarray:
# Prepare input data
img_input, org_h, org_w = self.pre_img(img)
# Feed to model
ort_inputs = {self._ort_session.get_inputs()[0].name: img_input}
heatmaps = self._ort_session.run(None, ort_inputs)[0]
return self.postprocess(heatmaps, org_w, org_h) |