|
|
|
|
|
import numpy as np |
|
import argparse |
|
import requests |
|
import concurrent.futures |
|
from mpi4py import MPI |
|
import warnings |
|
from pathlib import Path |
|
from tqdm.auto import tqdm |
|
|
|
COMM = MPI.COMM_WORLD |
|
RANK = COMM.Get_rank() |
|
SIZE = COMM.Get_size() |
|
|
|
def request_save(url, save_fp): |
|
img_data = requests.get(url, timeout=5).content |
|
with open(save_fp, 'wb') as handler: |
|
handler.write(img_data) |
|
|
|
def get_urls(split): |
|
if split == 'train': |
|
url = "http://imagine.enpc.fr/~ventural/covr/dataset/webvid2m-covr_paths-train.json" |
|
elif split == 'val': |
|
url = "http://imagine.enpc.fr/~ventural/covr/dataset/webvid8m-covr_paths-val.json" |
|
elif split == 'test': |
|
url = "http://imagine.enpc.fr/~ventural/covr/dataset/webvid8m-covr_paths-test.json" |
|
else: |
|
raise ValueError("Split must be one of train, val, or test") |
|
|
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
|
|
path2url = response.json() |
|
return path2url |
|
|
|
except requests.exceptions.RequestException as e: |
|
print(f"Error fetching JSON: {e}") |
|
|
|
def main(args): |
|
if args.split == 'train': |
|
video_dir = Path(f"{args.data_dir}/WebVid/2M/train") |
|
elif args.split in ['val', 'test']: |
|
video_dir = Path(f"{args.data_dir}/WebVid/8M/train") |
|
if RANK == 0: |
|
video_dir.mkdir(parents=True, exist_ok=True) |
|
COMM.barrier() |
|
|
|
path2url = get_urls(args.split) |
|
paths = set(path2url.keys()) |
|
|
|
|
|
found_paths = list(video_dir.glob('*/*.mp4')) |
|
found_paths = {str(p.relative_to(video_dir)) for p in found_paths} |
|
paths = list(paths - found_paths) |
|
paths.sort() |
|
|
|
|
|
paths = np.array_split(paths, args.partitions)[args.part] |
|
|
|
for path in paths: |
|
vid_path = video_dir / path |
|
vid_dir = vid_path.parent |
|
vid_dir.mkdir(exist_ok=True) |
|
|
|
path2url = {path: path2url[path] for path in paths} |
|
|
|
|
|
for i in tqdm(range(0, len(path2url), 1000)): |
|
path2url_batch = {path: path2url[path] for path in list(path2url.keys())[i:i+1000]} |
|
with concurrent.futures.ThreadPoolExecutor(max_workers=args.processes) as executor: |
|
{executor.submit(request_save, url, video_dir / path) for path, url in path2url_batch.items()} |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Shutter Image/Video Downloader') |
|
parser.add_argument('--partitions', type=int, default=1, |
|
help='Number of partitions to split the dataset into, to run multiple jobs in parallel') |
|
parser.add_argument('--part', type=int, default=0, |
|
help='Partition number to download where 0 <= part < partitions') |
|
parser.add_argument('--data_dir', type=str, default='./datasets', |
|
help='Directory where webvid data is stored.') |
|
parser.add_argument('--split', type=str, default='test', choices=['train', 'val', 'test'], |
|
help='Which split to download') |
|
parser.add_argument('--processes', type=int, default=8) |
|
args = parser.parse_args() |
|
|
|
if SIZE > 1: |
|
warnings.warn("Overriding --part with MPI rank number") |
|
args.part = RANK |
|
|
|
if args.part >= args.partitions: |
|
raise ValueError("Part idx must be less than number of partitions") |
|
main(args) |