File size: 2,280 Bytes
690f890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import numpy as np
import argparse

from .utils import convert_to_numpy

class FlowAnnotator:
    def __init__(self, cfg, device=None):
        try:
            from raft import RAFT
            from raft.utils.utils import InputPadder
            from raft.utils import flow_viz
        except:
            import warnings
            warnings.warn(
                "ignore raft import, please pip install raft package. you can refer to models/VACE-Annotators/flow/raft-1.0.0-py3-none-any.whl")

        params = {
            "small": False,
            "mixed_precision": False,
            "alternate_corr": False
        }
        params = argparse.Namespace(**params)
        pretrained_model = cfg['PRETRAINED_MODEL']
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        self.model = RAFT(params)
        self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()})
        self.model = self.model.to(self.device).eval()
        self.InputPadder = InputPadder
        self.flow_viz = flow_viz

    def forward(self, frames):
        # frames / RGB
        frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames]
        flow_up_list, flow_up_vis_list = [], []
        with torch.no_grad():
            for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])):
                padder = self.InputPadder(image1.shape)
                image1, image2 = padder.pad(image1, image2)
                flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True)
                flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy()
                flow_up_vis = self.flow_viz.flow_to_image(flow_up)
                flow_up_list.append(flow_up)
                flow_up_vis_list.append(flow_up_vis)
        return flow_up_list, flow_up_vis_list  # RGB


class FlowVisAnnotator(FlowAnnotator):
    def forward(self, frames):
        flow_up_list, flow_up_vis_list = super().forward(frames)
        return flow_up_vis_list[:1] + flow_up_vis_list