|
import re |
|
from typing import Union |
|
|
|
import torch |
|
|
|
|
|
def pre_caption(caption, max_words=50): |
|
caption = re.sub( |
|
r"([.!\"()*#:;~])", |
|
" ", |
|
caption.lower(), |
|
) |
|
caption = re.sub( |
|
r"\s{2,}", |
|
" ", |
|
caption, |
|
) |
|
caption = caption.rstrip("\n") |
|
caption = caption.strip(" ") |
|
|
|
|
|
caption_words = caption.split(" ") |
|
if len(caption_words) > max_words: |
|
caption = " ".join(caption_words[:max_words]) |
|
|
|
return caption |
|
|
|
|
|
def id2int(data, sub=""): |
|
if isinstance(data, list): |
|
return [remove_non_digits(d, sub) for d in data] |
|
else: |
|
return remove_non_digits(data, sub) |
|
|
|
|
|
def remove_non_digits(string, sub: str = ""): |
|
return int(re.sub(r"\D", sub, string)) |
|
|
|
|
|
def get_middle_frame(reference_vid_pth): |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
|
|
reference_vid_pth = str(reference_vid_pth) |
|
|
|
if not Path(reference_vid_pth).exists(): |
|
print(f"Video {reference_vid_pth} does not exist") |
|
return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) |
|
|
|
|
|
cap = cv2.VideoCapture(reference_vid_pth) |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
middle_frame_index = total_frames // 2 |
|
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_index) |
|
|
|
|
|
ret, frame = cap.read() |
|
|
|
if not ret or frame is None: |
|
print(f"Video {reference_vid_pth} is corrupted") |
|
return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) |
|
|
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
pil_image = Image.fromarray(frame) |
|
|
|
return pil_image |
|
|
|
|
|
def get_random_frame(reference_vid_pth): |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
|
|
reference_vid_pth = str(reference_vid_pth) |
|
|
|
if not Path(reference_vid_pth).exists(): |
|
print(f"Video {reference_vid_pth} does not exist") |
|
return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) |
|
|
|
|
|
cap = cv2.VideoCapture(reference_vid_pth) |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
random_frame_index = np.random.randint(0, total_frames) |
|
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, random_frame_index) |
|
|
|
|
|
ret, frame = cap.read() |
|
|
|
if not ret or frame is None: |
|
print(f"Video {reference_vid_pth} is corrupted") |
|
return Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) |
|
|
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
pil_image = Image.fromarray(frame) |
|
|
|
return pil_image |
|
|
|
|
|
def sample_frames(frames_videos, vlen): |
|
import numpy as np |
|
|
|
acc_samples = min(frames_videos, vlen) |
|
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) |
|
ranges = [] |
|
for idx, interv in enumerate(intervals[:-1]): |
|
ranges.append((interv, intervals[idx + 1] - 1)) |
|
|
|
frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] |
|
|
|
return frame_idxs |
|
|
|
|
|
class FrameLoader: |
|
def __init__(self, transform, frames_video=1, method="middle"): |
|
self.transform = transform |
|
self.method = method |
|
|
|
if method == "middle": |
|
self.get_frame = get_middle_frame |
|
assert frames_video == 1, "frames_video must be 1 for middle frame method" |
|
elif method == "random": |
|
self.get_frame = get_random_frame |
|
assert frames_video == 1, "frames_video must be 1 for random frame method" |
|
elif method == "sample": |
|
assert frames_video > 1, "frames_video must be > 1 for sample frame method" |
|
self.frames_video = frames_video |
|
else: |
|
raise ValueError(f"Invalid method: {method}") |
|
|
|
def __call__(self, video_pth: str): |
|
if self.method == "sample": |
|
frames = self.get_video_frames(video_pth, 0.0, None) |
|
return torch.stack(frames) |
|
else: |
|
return self.transform(self.get_frame(video_pth)) |
|
|
|
def get_video_frames( |
|
self, |
|
video_pth: str, |
|
start_time: float = 0.0, |
|
end_time: Union[float, None] = None, |
|
) -> list: |
|
import cv2 |
|
from PIL import Image |
|
|
|
cap = cv2.VideoCapture(video_pth) |
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
if end_time is not None: |
|
start_frame = int(fps * start_time) |
|
end_frame = int(fps * end_time) |
|
vlen = end_frame - start_frame |
|
else: |
|
start_frame = 0 |
|
vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
frame_idxs = sample_frames(self.frames_video, vlen) |
|
frame_idxs = [frame_idx + start_frame for frame_idx in frame_idxs] |
|
if self.frames_video != len(frame_idxs): |
|
frame_idxs = (frame_idxs * self.frames_video)[: self.frames_video] |
|
print(f"Video {video_pth} has less than {self.frames_video} frames") |
|
|
|
frames = [] |
|
for index in frame_idxs: |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1) |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames.append(Image.fromarray(frame_rgb).convert("RGB")) |
|
|
|
cap.release() |
|
|
|
if len(frames) > 0: |
|
video_data = [self.transform(frame) for frame in frames] |
|
return video_data |
|
else: |
|
raise ValueError(f"video path: {video_pth} error.") |
|
|