Spaces:
Runtime error
Runtime error
File size: 3,844 Bytes
be9b1db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import gradio as gr
import numpy as np
import rerun as rr
from gradio_rerun.events import SelectionChange
from typing_extensions import TypedDict
def get_recording(recording_id) -> rr.RecordingStream:
return rr.RecordingStream(application_id="multiview_sam_annotate", recording_id=recording_id)
class RerunLogPaths(TypedDict):
timeline_name: str
parent_log_path: Path
cam_log_path_list: list[Path]
@dataclass
class KeypointsContainer:
"""Container for include and exclude keypoints"""
include_points: np.ndarray # shape (n,2)
exclude_points: np.ndarray # shape (m,2)
@classmethod
def empty(cls) -> "KeypointsContainer":
"""Create an empty keypoints container"""
return cls(include_points=np.zeros((0, 2), dtype=float), exclude_points=np.zeros((0, 2), dtype=float))
def add_point(self, point: tuple[float, float], label: Literal["include", "exclude"]) -> None:
"""Add a point with the specified label"""
point_array = np.array([point], dtype=float)
if label == "include":
self.include_points = (
np.vstack([self.include_points, point_array]) if self.include_points.shape[0] > 0 else point_array
)
else:
self.exclude_points = (
np.vstack([self.exclude_points, point_array]) if self.exclude_points.shape[0] > 0 else point_array
)
def clear(self) -> None:
"""Clear all points"""
self.include_points = np.zeros((0, 2), dtype=float)
self.exclude_points = np.zeros((0, 2), dtype=float)
# In this function, the `request` and `evt` parameters will be automatically injected by Gradio when this event listener is fired.
#
# `SelectionChange` is a subclass of `EventData`: https://www.gradio.app/docs/gradio/eventdata
# `gr.Request`: https://www.gradio.app/main/docs/gradio/request
def update_keypoints(
active_recording_id: uuid.UUID,
point_type: Literal["include", "exclude"],
mv_keypoint_dict: dict[str, KeypointsContainer],
log_paths: RerunLogPaths,
request: gr.Request,
change: SelectionChange,
):
if active_recording_id == "":
return
evt = change.payload
# We can only log a keypoint if the user selected only a single item.
if len(evt.items) != 1:
return
item = evt.items[0]
# If the selected item isn't an entity, or we don't have its position, then bail out.
if item.type != "entity" or item.position is None:
return
# Now we can produce a valid keypoint.
rec: rr.RecordingStream = get_recording(active_recording_id)
stream: rr.BinaryStream = rec.binary_stream()
current_keypoint: tuple[int, int] = item.position[0:2]
for cam_name in mv_keypoint_dict:
if cam_name in item.entity_path:
# Update the keypoints for the specific camera
mv_keypoint_dict[cam_name].add_point(current_keypoint, point_type)
current_keypoint_container: KeypointsContainer = mv_keypoint_dict[cam_name]
rec.set_time_nanos(log_paths["timeline_name"], nanos=0)
# Log include points if any exist
if current_keypoint_container.include_points.shape[0] > 0:
rec.log(
f"{item.entity_path}/include",
rr.Points2D(current_keypoint_container.include_points, colors=(0, 255, 0), radii=5),
)
# Log exclude points if any exist
if current_keypoint_container.exclude_points.shape[0] > 0:
rec.log(
f"{item.entity_path}/exclude",
rr.Points2D(current_keypoint_container.exclude_points, colors=(255, 0, 0), radii=5),
)
# # Ensure we consume everything from the recording.
stream.flush()
yield stream.read(), mv_keypoint_dict
|