File size: 9,125 Bytes
131da64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
from pathlib import Path
from typing import Optional
import torch
import typer
from tensordict import TensorDict
from typing_extensions import Annotated
import time
import shutil
from decoupled_utils import rprint
app = typer.Typer(pretty_exceptions_show_locals=False)
typer.main.get_command_name = lambda name: name
def split_dataset(dataset, n: int, m: int):
# Ensure m is valid
if m < 0 or m >= n:
raise ValueError(f"m must be between 0 and {n-1}, but got {m}.")
# Calculate the size of each subset
total_len = len(dataset)
subset_size = total_len // n
remainder = total_len % n
# Calculate the start and end index of the m-th subset
start_idx = m * subset_size + min(m, remainder)
end_idx = start_idx + subset_size + (1 if m < remainder else 0)
# Return the m-th subset
return dataset[slice(start_idx, end_idx)]
@app.command()
def main(
data_dir: Path,
splits: Optional[list[str]] = ["train", "val"],
add_vggface2_text_tokens: bool = False,
use_tmp: bool = False,
use_all: bool = False,
allow_zero_idx: bool = False,
use_timestamp: bool = False,
delete_after_combining: bool = False,
allow_existing: bool = False,
force_overwrite: bool = False,
move_files: bool = False,
allow_tmp: bool = False,
mem_efficient: bool = False,
output_dir: Optional[Path] = None,
require_image_tokens: bool = False,
min_idx: Optional[int] = None,
max_idx: Optional[int] = None,
split_num: Optional[int] = None,
split_idx: Optional[int] = None,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
for split in splits:
if allow_tmp:
all_folders = sorted([folder for folder in data_dir.iterdir() if folder.is_dir() and split in folder.name and "_" in folder.name and (allow_existing or "existing" not in folder.name)])
print(f"All folders: len({len(all_folders)})")
from collections import defaultdict
unique_ids = defaultdict(list)
for folder in all_folders:
folder_id = int(folder.name.split("_")[-1])
unique_ids[folder_id].append(folder)
folders = []
for folder_id, _folders in unique_ids.items():
if len(_folders) == 1:
folders.append(_folders[0])
else:
for folder in _folders:
if "tmp" not in folder.name:
folders.append(folder)
folders = sorted(folders)
print(f"Using {len(folders)} folders for {split}")
else:
folders = sorted([folder for folder in data_dir.iterdir() if folder.is_dir() and split in folder.name and "_" in folder.name and (use_all or (not use_tmp or "tmp" in folder.name)) and (allow_existing or "existing" not in folder.name)])
if min_idx is not None and max_idx is not None:
print(f"Filtering with min_idx: {min_idx} and max_idx: {max_idx}")
_tmp_folders = []
for folder in folders:
_name = int(folder.name.split("_")[-1])
if min_idx <= _name <= max_idx:
_tmp_folders.append(folder)
folders = _tmp_folders
print(f"Filtered folders and got: {len(folders)}")
if split_num is not None and split_idx is not None:
folders = split_dataset(folders, split_num, split_idx)
print(f"Filtered folders and got: {len(folders)}")
initial_folder_count = len(folders)
folders = [folder for folder in folders if any(folder.iterdir())]
removed_folders_count = initial_folder_count - len(folders)
print(f"Removed {removed_folders_count} empty folders")
if len(folders) == 0:
print(f"No folders found for {split}")
continue
print(f"{split} folders: {folders}")
_tensors = [TensorDict.load_memmap(folder) for folder in folders if (folder / "meta.json").exists()]
_tensors = [tensor for tensor in _tensors if tensor.shape[0] > 0]
for _tensor in _tensors:
if "write_flag" not in _tensor:
_tensor["write_flag"] = torch.ones((len(_tensor), 1), dtype=torch.bool)
loaded_tensors = torch.cat(_tensors, dim=0)
del _tensors
if add_vggface2_text_tokens:
loaded_tensors.set("txt_input_ids", loaded_tensors["img_input_ids"].new_zeros(loaded_tensors["img_input_ids"].shape[0], 47), inplace=True)
loaded_tensors.set("txt_attention_mask", loaded_tensors["img_input_ids"].new_zeros(loaded_tensors["img_input_ids"].shape[0], 1), inplace=True)
print(f"Added VGGFace2 text tokens to {split}")
index_keys = ("img_label", "img_input_ids", "txt_input_ids", "input_ids")
if not mem_efficient:
for key in index_keys:
if key in loaded_tensors:
loaded_tensors[key] = loaded_tensors[key].to(torch.int32)
if "img_input_ids" in loaded_tensors:
written_indices = ((loaded_tensors["write_flag"] > 0).squeeze(-1) & (loaded_tensors["img_input_ids"] > 0).all(dim=-1))
else:
if mem_efficient:
written_indices = (loaded_tensors["write_flag"] > 0).squeeze(-1)
else:
written_indices = ((loaded_tensors["write_flag"] > 0).squeeze(-1) & (loaded_tensors["input_ids"] > 0).any(dim=-1))
print(f"Valid elements for {split}: {written_indices.shape[0]}")
loaded_tensors = loaded_tensors[written_indices]
invalid_indices = loaded_tensors["idx"].squeeze(-1) == -1
if require_image_tokens:
invalid_modality = ~(loaded_tensors["modality"] > 0).any(dim=-1)
invalid_indices |= invalid_modality
print(f"Found {invalid_modality.sum()} invalid indices for {split} due to missing image tokens")
print(f"Invalid indices for {split}: {invalid_indices.sum()}")
loaded_tensors = loaded_tensors[~invalid_indices]
if allow_zero_idx is False:
_, idx = torch.unique(loaded_tensors["idx"].to(device), dim=0, sorted=True, return_inverse=True)
loaded_tensors = loaded_tensors[torch.unique(idx, return_inverse=False).to(loaded_tensors.device)]
print(f"After filtering: {loaded_tensors.shape[0]}")
if loaded_tensors.shape[0] == 0:
rprint(f"WARNING!!! No valid elements for {split}")
return
for _key in ["img_input_ids", "input_ids"]:
if _key in loaded_tensors:
assert 0 <= loaded_tensors[_key].min() and loaded_tensors[_key].max() < torch.iinfo(torch.int16).max
loaded_tensors[_key] = loaded_tensors[_key].to(torch.int16)
index_keys = ("img_label", "txt_attention_mask", "attention_mask")
for key in index_keys:
if key in loaded_tensors:
loaded_tensors[key] = loaded_tensors[key].squeeze(-1)
if "write_flag" in loaded_tensors:
del loaded_tensors["write_flag"]
if split_idx is not None:
split = f"split_{split_idx}_{split}"
if use_timestamp:
loaded_tensors.memmap(data_dir / f"{split}_existing_{int(time.time())}")
else:
if (data_dir / f"{split}").exists():
print("Already exists!")
if force_overwrite:
shutil.rmtree(data_dir / f"{split}")
else:
breakpoint()
if output_dir is not None:
loaded_tensors.memmap(output_dir / f"{split}")
else:
loaded_tensors.memmap(data_dir / f"{split}")
if delete_after_combining:
for folder in folders:
try:
rprint(f"Removing folder: {folder}")
shutil.rmtree(folder)
except Exception as e:
rprint(f"Error removing folder: {e}")
if force_overwrite:
from pathlib import Path
for train_folder in Path(data_dir).glob('train_*'):
rprint(f"Removing folder: {train_folder}")
if train_folder.is_file():
train_folder.unlink()
else:
shutil.rmtree(train_folder)
train_dir = data_dir / 'train'
if train_dir.exists() and train_dir.is_dir():
for item in train_dir.iterdir():
shutil.move(str(item), str(train_dir.parent))
shutil.rmtree(train_dir)
elif move_files:
train_dir = data_dir / 'train'
if train_dir.exists() and train_dir.is_dir():
for item in train_dir.iterdir():
shutil.move(str(item), str(train_dir.parent))
# Check if train_dir is empty after moving files
if train_dir.exists() and train_dir.is_dir():
if not any(train_dir.iterdir()):
shutil.rmtree(train_dir)
rprint(f"Removed empty train directory: {train_dir}")
if __name__ == "__main__":
app() |