File size: 3,528 Bytes
1b7e88c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
import re
from pathlib import Path
from typing import List

import json_repair
from omagent_core.models.llms.base import BaseLLMBackend
from omagent_core.models.llms.prompt import PromptTemplate
from omagent_core.tool_system.base import ArgSchema, BaseTool
from omagent_core.utils.logger import logging
from omagent_core.utils.registry import registry
from pydantic import Field
from scenedetect import FrameTimecode

from ...misc.scene import VideoScenes

CURRENT_PATH = Path(__file__).parents[0]

ARGSCHEMA = {
    "start_time": {
        "type": "number",
        "description": "Start time (in seconds) of the video to extract frames from.",
        "required": True,
    },
    "end_time": {
        "type": "number",
        "description": "End time (in seconds) of the video to extract frames from.",
        "required": True,
    },
    "number": {
        "type": "number",
        "description": "Number of frames of extraction. More frames means more details but more cost. Do not exceed 10.",
        "required": True,
    },
}


@registry.register_tool()
class Rewinder(BaseTool, BaseLLMBackend):
    args_schema: ArgSchema = ArgSchema(**ARGSCHEMA)
    description: str = (
        "Rollback and extract frames from video which is already loaded to get more specific details for further analysis."
    )
    prompts: List[PromptTemplate] = Field(
        default=[
            PromptTemplate.from_file(
                CURRENT_PATH.joinpath("rewinder_sys_prompt.prompt"),
                role="system",
            ),
            PromptTemplate.from_file(
                CURRENT_PATH.joinpath("rewinder_user_prompt.prompt"),
                role="user",
            ),
        ]
    )

    def _run(
        self, start_time: float = 0.0, end_time: float = None, number: int = 1
    ) -> str:
        if self.stm(self.workflow_instance_id).get("video", None) is None:
            raise ValueError("No video is loaded.")
        else:
            video: VideoScenes = VideoScenes.from_serializable(
                self.stm(self.workflow_instance_id)["video"]
            )
        if number > 10:
            logging.warning("Number of frames exceeds 10. Will extract 10 frames.")
            number = 10

        start = FrameTimecode(timecode=start_time, fps=video.stream.frame_rate)
        if end_time is None:
            end = video.stream.duration
        else:
            end = FrameTimecode(timecode=end_time, fps=video.stream.frame_rate)

        if start_time == end_time:
            frames, time_stamps = video.get_video_frames(
                (start, end + 1), video.stream.frame_rate
            )
        else:
            interval = int((end.get_frames() - start.get_frames()) / number)
            frames, time_stamps = video.get_video_frames((start, end), interval)

        # self.stm.image_cache.clear()
        payload = []
        for i, (frame, time_stamp) in enumerate(zip(frames, time_stamps)):
            payload.append(f"timestamp_{time_stamp}")
            payload.append(frame)
        res = self.infer(input_list=[{"timestamp_with_images": payload}])[0]["choices"][
            0
        ]["message"]["content"]
        image_contents = json_repair.loads(res)
        self.stm(self.workflow_instance_id)["image_cache"] = {}
        return f"extracted_frames described as: {image_contents}."

    async def _arun(
        self, start_time: float = 0.0, end_time: float = None, number: int = 1
    ) -> str:
        return self._run(start_time, end_time, number=number)