|
from typing import List, Optional, Tuple
|
|
|
|
from .rtp import RtpPacket
|
|
from .utils import uint16_add
|
|
|
|
MAX_MISORDER = 100
|
|
|
|
|
|
class JitterFrame:
|
|
def __init__(self, data: bytes, timestamp: int) -> None:
|
|
self.data = data
|
|
self.timestamp = timestamp
|
|
|
|
|
|
class JitterBuffer:
|
|
def __init__(
|
|
self, capacity: int, prefetch: int = 0, is_video: bool = False
|
|
) -> None:
|
|
assert capacity & (capacity - 1) == 0, "capacity must be a power of 2"
|
|
self._capacity = capacity
|
|
self._origin: Optional[int] = None
|
|
self._packets: List[Optional[RtpPacket]] = [None for i in range(capacity)]
|
|
self._prefetch = prefetch
|
|
self._is_video = is_video
|
|
|
|
@property
|
|
def capacity(self) -> int:
|
|
return self._capacity
|
|
|
|
def add(self, packet: RtpPacket) -> Tuple[bool, Optional[JitterFrame]]:
|
|
pli_flag = False
|
|
if self._origin is None:
|
|
self._origin = packet.sequence_number
|
|
delta = 0
|
|
misorder = 0
|
|
else:
|
|
delta = uint16_add(packet.sequence_number, -self._origin)
|
|
misorder = uint16_add(self._origin, -packet.sequence_number)
|
|
|
|
if misorder < delta:
|
|
if misorder >= MAX_MISORDER:
|
|
self.remove(self.capacity)
|
|
self._origin = packet.sequence_number
|
|
delta = misorder = 0
|
|
if self._is_video:
|
|
pli_flag = True
|
|
else:
|
|
return pli_flag, None
|
|
|
|
if delta >= self.capacity:
|
|
|
|
excess = delta - self.capacity + 1
|
|
if self.smart_remove(excess):
|
|
self._origin = packet.sequence_number
|
|
if self._is_video:
|
|
pli_flag = True
|
|
|
|
pos = packet.sequence_number % self._capacity
|
|
self._packets[pos] = packet
|
|
|
|
return pli_flag, self._remove_frame(packet.sequence_number)
|
|
|
|
def _remove_frame(self, sequence_number: int) -> Optional[JitterFrame]:
|
|
frame = None
|
|
frames = 0
|
|
packets = []
|
|
remove = 0
|
|
timestamp = None
|
|
|
|
for count in range(self.capacity):
|
|
pos = (self._origin + count) % self._capacity
|
|
packet = self._packets[pos]
|
|
if packet is None:
|
|
break
|
|
if timestamp is None:
|
|
timestamp = packet.timestamp
|
|
elif packet.timestamp != timestamp:
|
|
|
|
if frame is None:
|
|
frame = JitterFrame(
|
|
data=b"".join([x._data for x in packets]), timestamp=timestamp
|
|
)
|
|
remove = count
|
|
|
|
|
|
frames += 1
|
|
if frames >= self._prefetch:
|
|
self.remove(remove)
|
|
return frame
|
|
|
|
|
|
packets = []
|
|
timestamp = packet.timestamp
|
|
|
|
packets.append(packet)
|
|
|
|
return None
|
|
|
|
def remove(self, count: int) -> None:
|
|
assert count <= self._capacity
|
|
for i in range(count):
|
|
pos = self._origin % self._capacity
|
|
self._packets[pos] = None
|
|
self._origin = uint16_add(self._origin, 1)
|
|
|
|
def smart_remove(self, count: int) -> bool:
|
|
"""
|
|
Makes sure that all packages belonging to the same frame are removed
|
|
to prevent sending corrupted frames to the decoder.
|
|
"""
|
|
timestamp = None
|
|
for i in range(self._capacity):
|
|
pos = self._origin % self._capacity
|
|
packet = self._packets[pos]
|
|
if packet is not None:
|
|
if i >= count and timestamp != packet.timestamp:
|
|
break
|
|
timestamp = packet.timestamp
|
|
self._packets[pos] = None
|
|
self._origin = uint16_add(self._origin, 1)
|
|
if i == self._capacity - 1:
|
|
return True
|
|
return False
|
|
|