krishna-k's picture
Upload folder using huggingface_hub
06555b5 verified
from typing import List, Optional, Tuple
from .rtp import RtpPacket
from .utils import uint16_add
MAX_MISORDER = 100
class JitterFrame:
def __init__(self, data: bytes, timestamp: int) -> None:
self.data = data
self.timestamp = timestamp
class JitterBuffer:
def __init__(
self, capacity: int, prefetch: int = 0, is_video: bool = False
) -> None:
assert capacity & (capacity - 1) == 0, "capacity must be a power of 2"
self._capacity = capacity
self._origin: Optional[int] = None
self._packets: List[Optional[RtpPacket]] = [None for i in range(capacity)]
self._prefetch = prefetch
self._is_video = is_video
@property
def capacity(self) -> int:
return self._capacity
def add(self, packet: RtpPacket) -> Tuple[bool, Optional[JitterFrame]]:
pli_flag = False
if self._origin is None:
self._origin = packet.sequence_number
delta = 0
misorder = 0
else:
delta = uint16_add(packet.sequence_number, -self._origin)
misorder = uint16_add(self._origin, -packet.sequence_number)
if misorder < delta:
if misorder >= MAX_MISORDER:
self.remove(self.capacity)
self._origin = packet.sequence_number
delta = misorder = 0
if self._is_video:
pli_flag = True
else:
return pli_flag, None
if delta >= self.capacity:
# remove just enough frames to fit the received packets
excess = delta - self.capacity + 1
if self.smart_remove(excess):
self._origin = packet.sequence_number
if self._is_video:
pli_flag = True
pos = packet.sequence_number % self._capacity
self._packets[pos] = packet
return pli_flag, self._remove_frame(packet.sequence_number)
def _remove_frame(self, sequence_number: int) -> Optional[JitterFrame]:
frame = None
frames = 0
packets = []
remove = 0
timestamp = None
for count in range(self.capacity):
pos = (self._origin + count) % self._capacity
packet = self._packets[pos]
if packet is None:
break
if timestamp is None:
timestamp = packet.timestamp
elif packet.timestamp != timestamp:
# we now have a complete frame, only store the first one
if frame is None:
frame = JitterFrame(
data=b"".join([x._data for x in packets]), timestamp=timestamp
)
remove = count
# check we have prefetched enough
frames += 1
if frames >= self._prefetch:
self.remove(remove)
return frame
# start a new frame
packets = []
timestamp = packet.timestamp
packets.append(packet)
return None
def remove(self, count: int) -> None:
assert count <= self._capacity
for i in range(count):
pos = self._origin % self._capacity
self._packets[pos] = None
self._origin = uint16_add(self._origin, 1)
def smart_remove(self, count: int) -> bool:
"""
Makes sure that all packages belonging to the same frame are removed
to prevent sending corrupted frames to the decoder.
"""
timestamp = None
for i in range(self._capacity):
pos = self._origin % self._capacity
packet = self._packets[pos]
if packet is not None:
if i >= count and timestamp != packet.timestamp:
break
timestamp = packet.timestamp
self._packets[pos] = None
self._origin = uint16_add(self._origin, 1)
if i == self._capacity - 1:
return True
return False