File size: 4,146 Bytes
091117d
 
 
 
 
 
 
5093bc2
 
 
 
 
091117d
 
 
 
 
 
 
d2451ea
091117d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c903b19
091117d
 
 
 
5093bc2
 
091117d
 
 
5093bc2
091117d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0936f88
091117d
 
 
 
0936f88
 
 
 
091117d
 
 
 
 
 
 
 
0936f88
091117d
 
 
0936f88
fba6e1e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from rt_pose import PoseEstimationPipeline
import cv2
import supervision as sv
import numpy as np
from rt_pose import PoseEstimationPipeline, PoseEstimationOutput

import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class VitPose:
    def __init__(self):
        self.pipeline = PoseEstimationPipeline(
            object_detection_checkpoint="PekingU/rtdetr_r50vd_coco_o365",
            pose_estimation_checkpoint="usyd-community/vitpose-plus-small",
            device="cuda" if torch.cuda.is_available() else "cpu",
            dtype=torch.bfloat16,
            compile=True,  # or True to get more speedup
        )
        self.output_video_path = None
        self.video_metadata = {}
        
        
    def video_to_frames(self,video):
        frames = []
        cap = cv2.VideoCapture(video)
        self.video_metadata = {
            "fps": cap.get(cv2.CAP_PROP_FPS),
            "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
        }
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)
        return frames
    
    def run(self,video):
        frames = self.video_to_frames(video)
        annotated_frames = []
        for i, frame in enumerate(frames):
            logger.info(f"Processing frame {i} of {len(frames)}")
            output = self.pipeline(frame)
            annotated_frame = self.visualize_output(frame,output)
            annotated_frames.append(annotated_frame)
        logger.info(f"Processed {len(annotated_frames)} frames")
        return annotated_frames
    
    
 
    def visualize_output(self,image: np.ndarray, output: PoseEstimationOutput, confidence: float = 0.3) -> np.ndarray:
        """
        Visualize pose estimation output.
        """
        keypoints_xy = output.keypoints_xy.float().cpu().numpy()
        scores = output.scores.float().cpu().numpy()

        # Supervision will not draw vertices with `0` score
        # and coordinates with `(0, 0)` value
        invisible_keypoints = scores < confidence
        scores[invisible_keypoints] = 0
        keypoints_xy[invisible_keypoints] = 0

        keypoints = sv.KeyPoints(xy=keypoints_xy, confidence=scores)

        _, y_min, _, y_max = output.person_boxes_xyxy.T
        height = int((y_max - y_min).mean().item())
        radius = max(height // 100, 4)
        thickness = max(height // 200, 2)
        edge_annotator = sv.EdgeAnnotator(color=sv.Color.YELLOW, thickness=thickness)
        vertex_annotator = sv.VertexAnnotator(color=sv.Color.ROBOFLOW, radius=radius)

        annotated_frame = image.copy()
        annotated_frame = edge_annotator.annotate(annotated_frame, keypoints)
        annotated_frame = vertex_annotator.annotate(annotated_frame, keypoints)

        return annotated_frame
    
    def frames_to_video(self, frames):
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        height = self.video_metadata["height"]
        width = self.video_metadata["width"]
        
        # Always ensure vertical orientation
        rotate = width > height  # Rotate only if the video is in landscape mode
        
        # For the VideoWriter, we need to specify the dimensions of the output frames
        if rotate:
            print(f"Original dimensions: {width}x{height}, Rotated dimensions: {height}x{width}")
            out = cv2.VideoWriter(self.output_video_path, fourcc, self.video_metadata["fps"], (height, width))
        else:
            print(f"Dimensions: {width}x{height}")
            out = cv2.VideoWriter(self.output_video_path, fourcc, self.video_metadata["fps"], (width, height))
        for frame in frames:
            if rotate:
                # Rotate landscape videos 90 degrees to make them vertical
                rotated_frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
                out.write(rotated_frame)
            else:
                # Already vertical, no rotation needed
                out.write(frame)    
        out.release()