File size: 3,844 Bytes
be9b1db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import gradio as gr
import numpy as np
import rerun as rr
from gradio_rerun.events import SelectionChange
from typing_extensions import TypedDict


def get_recording(recording_id) -> rr.RecordingStream:
    return rr.RecordingStream(application_id="multiview_sam_annotate", recording_id=recording_id)


class RerunLogPaths(TypedDict):
    timeline_name: str
    parent_log_path: Path
    cam_log_path_list: list[Path]


@dataclass
class KeypointsContainer:
    """Container for include and exclude keypoints"""

    include_points: np.ndarray  # shape (n,2)
    exclude_points: np.ndarray  # shape (m,2)

    @classmethod
    def empty(cls) -> "KeypointsContainer":
        """Create an empty keypoints container"""
        return cls(include_points=np.zeros((0, 2), dtype=float), exclude_points=np.zeros((0, 2), dtype=float))

    def add_point(self, point: tuple[float, float], label: Literal["include", "exclude"]) -> None:
        """Add a point with the specified label"""
        point_array = np.array([point], dtype=float)
        if label == "include":
            self.include_points = (
                np.vstack([self.include_points, point_array]) if self.include_points.shape[0] > 0 else point_array
            )
        else:
            self.exclude_points = (
                np.vstack([self.exclude_points, point_array]) if self.exclude_points.shape[0] > 0 else point_array
            )

    def clear(self) -> None:
        """Clear all points"""
        self.include_points = np.zeros((0, 2), dtype=float)
        self.exclude_points = np.zeros((0, 2), dtype=float)


# In this function, the `request` and `evt` parameters will be automatically injected by Gradio when this event listener is fired.
#
# `SelectionChange` is a subclass of `EventData`: https://www.gradio.app/docs/gradio/eventdata
# `gr.Request`: https://www.gradio.app/main/docs/gradio/request
def update_keypoints(
    active_recording_id: uuid.UUID,
    point_type: Literal["include", "exclude"],
    mv_keypoint_dict: dict[str, KeypointsContainer],
    log_paths: RerunLogPaths,
    request: gr.Request,
    change: SelectionChange,
):
    if active_recording_id == "":
        return

    evt = change.payload

    # We can only log a keypoint if the user selected only a single item.
    if len(evt.items) != 1:
        return
    item = evt.items[0]

    # If the selected item isn't an entity, or we don't have its position, then bail out.
    if item.type != "entity" or item.position is None:
        return

    # Now we can produce a valid keypoint.
    rec: rr.RecordingStream = get_recording(active_recording_id)
    stream: rr.BinaryStream = rec.binary_stream()
    current_keypoint: tuple[int, int] = item.position[0:2]

    for cam_name in mv_keypoint_dict:
        if cam_name in item.entity_path:
            # Update the keypoints for the specific camera
            mv_keypoint_dict[cam_name].add_point(current_keypoint, point_type)
            current_keypoint_container: KeypointsContainer = mv_keypoint_dict[cam_name]

    rec.set_time_nanos(log_paths["timeline_name"], nanos=0)
    # Log include points if any exist
    if current_keypoint_container.include_points.shape[0] > 0:
        rec.log(
            f"{item.entity_path}/include",
            rr.Points2D(current_keypoint_container.include_points, colors=(0, 255, 0), radii=5),
        )

    # Log exclude points if any exist
    if current_keypoint_container.exclude_points.shape[0] > 0:
        rec.log(
            f"{item.entity_path}/exclude",
            rr.Points2D(current_keypoint_container.exclude_points, colors=(255, 0, 0), radii=5),
        )

    # # Ensure we consume everything from the recording.
    stream.flush()
    yield stream.read(), mv_keypoint_dict