|
from pathlib import Path |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from src.data.utils import pre_caption |
|
from src.tools.files import read_txt |
|
|
|
normalize = transforms.Normalize( |
|
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) |
|
) |
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((384, 384), interpolation=InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
normalize, |
|
] |
|
) |
|
|
|
|
|
class ImageDataset(Dataset): |
|
def __init__( |
|
self, |
|
image_dir, |
|
img_ext: str = "png", |
|
save_dir=None, |
|
): |
|
self.image_dir = Path(image_dir) |
|
self.img_pths = self.image_dir.glob(f"*.{img_ext}") |
|
self.id2pth = {img_pth.stem: img_pth for img_pth in self.img_pths} |
|
self.video_ids = list(self.id2pth.keys()) |
|
|
|
if save_dir is not None: |
|
save_dir = Path(save_dir) |
|
done_paths = list(save_dir.glob("*.pth")) |
|
done_paths = {p.stem for p in done_paths} |
|
print(f"video_ids: {len(self.video_ids)} - {len(done_paths)} = ", end="") |
|
self.video_ids = list(set(self.video_ids) - done_paths) |
|
print(len(self.video_ids)) |
|
self.video_ids.sort() |
|
if len(self.video_ids) == 0: |
|
print("All videos are done") |
|
exit() |
|
|
|
def __len__(self): |
|
return len(self.video_ids) |
|
|
|
def __getitem__(self, index): |
|
video_id = self.video_ids[index] |
|
|
|
img_pth = self.id2pth[video_id] |
|
img = Image.open(img_pth).convert("RGB") |
|
img = transform(img) |
|
|
|
return img, video_id |
|
|
|
|
|
class VideoDataset(Dataset): |
|
def __init__( |
|
self, |
|
video_dir, |
|
todo_ids=None, |
|
shard_id=0, |
|
num_shards=1, |
|
frames_video=15, |
|
extention="mp4", |
|
save_dir=None, |
|
): |
|
self.video_dir = Path(video_dir) |
|
|
|
if isinstance(todo_ids, (str, Path)): |
|
todo_ids = read_txt(todo_ids) |
|
|
|
found_paths = list(video_dir.glob(f"*/*.{extention}")) |
|
if todo_ids is not None: |
|
video_paths = [video_dir / f"{v}.{extention}" for v in todo_ids] |
|
video_paths = list(set(video_paths) & set(found_paths)) |
|
else: |
|
video_paths = found_paths |
|
video_paths.sort() |
|
self.id2path = {pth.parent.name + "/" + pth.stem: pth for pth in video_paths} |
|
self.video_ids = list(self.id2path.keys()) |
|
self.video_ids.sort() |
|
|
|
if save_dir is not None: |
|
save_dir = Path(save_dir) |
|
done_paths = list(save_dir.glob("*/*.pth")) |
|
done_paths = {p.parent.name + "/" + p.stem for p in done_paths} |
|
print(f"video_ids: {len(self.video_ids)} - {len(done_paths)} = ", end="") |
|
self.video_ids = list(set(self.video_ids) - done_paths) |
|
print(len(self.video_ids)) |
|
self.video_ids.sort() |
|
if len(self.video_ids) == 0: |
|
print("All videos are done") |
|
exit() |
|
|
|
assert len(self.video_ids) > 0, f"video_ids is empty" |
|
|
|
|
|
n_videos = len(self.video_ids) |
|
self.video_ids = self.video_ids[ |
|
shard_id * n_videos // num_shards : (shard_id + 1) * n_videos // num_shards |
|
] |
|
|
|
self.frames_video = frames_video |
|
|
|
def __len__(self): |
|
return len(self.video_ids) |
|
|
|
def __getitem__(self, index): |
|
video_id = self.video_ids[index] |
|
video_path = self.id2path[video_id] |
|
frames, f_idxs = get_video_frames(video_path, self.frames_video) |
|
frames = [transform(frame) for frame in frames] |
|
frames = torch.stack(frames, dim=0) |
|
f_idxs = torch.tensor(f_idxs) |
|
|
|
return video_id, f_idxs, frames |
|
|
|
|
|
class TextDataset(Dataset): |
|
def __init__( |
|
self, |
|
csv_path, |
|
max_words=30, |
|
): |
|
self.df = pd.read_csv(csv_path) |
|
self.texts = list(set(self.df["edit"].unique().tolist())) |
|
self.texts.sort() |
|
self.max_words = max_words |
|
|
|
def __len__(self): |
|
return len(self.texts) |
|
|
|
def __getitem__(self, index): |
|
txt = self.texts[index] |
|
txt = pre_caption(txt, self.max_words) |
|
|
|
return txt |
|
|
|
|
|
def get_video_frames(video_pth, frames_video=15): |
|
import cv2 |
|
|
|
video_pth = str(video_pth) |
|
|
|
|
|
cap = cv2.VideoCapture(video_pth) |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
frame_idxs = sample_frames(total_frames, frames_video) |
|
|
|
frames = [] |
|
f_idxs = [] |
|
for frame_idx in frame_idxs: |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
|
ret, frame = cap.read() |
|
|
|
if not ret or frame is None: |
|
print(f"Video {video_pth} is corrupted") |
|
frames = [ |
|
Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8)) |
|
] * frames_video |
|
f_idxs = [-1] * frames_video |
|
return frames, f_idxs |
|
|
|
frames.append(Image.fromarray(frame)) |
|
f_idxs.append(frame_idx) |
|
|
|
|
|
n_frames = len(frames) |
|
if n_frames < frames_video: |
|
frames += [Image.fromarray(np.zeros((384, 384, 3)).astype(np.uint8))] * ( |
|
frames_video - n_frames |
|
) |
|
|
|
|
|
f_idxs += [-1] * (frames_video - len(f_idxs)) |
|
|
|
return frames, f_idxs |
|
|
|
|
|
def sample_frames(vlen, frames_per_video=15): |
|
acc_samples = min(vlen, frames_per_video) |
|
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) |
|
ranges = [] |
|
for idx, interv in enumerate(intervals[:-1]): |
|
ranges.append((interv, intervals[idx + 1] - 1)) |
|
|
|
frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] |
|
|
|
return frame_idxs |
|
|