|
from torch.utils.data import Dataset |
|
from torchvision import datasets |
|
import torchvision.transforms as transforms |
|
from scipy.signal import convolve2d |
|
import numpy as np |
|
import torch |
|
import math |
|
import random |
|
from PIL import Image |
|
import os |
|
import glob |
|
import einops |
|
import torchvision.transforms.functional as F |
|
import time |
|
from tqdm import tqdm |
|
import json |
|
import pickle |
|
import io |
|
import cv2 |
|
|
|
import libs.clip |
|
import bisect |
|
|
|
|
|
class UnlabeledDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, item): |
|
data = tuple(self.dataset[item][:-1]) |
|
if len(data) == 1: |
|
data = data[0] |
|
return data |
|
|
|
|
|
class LabeledDataset(Dataset): |
|
def __init__(self, dataset, labels): |
|
self.dataset = dataset |
|
self.labels = labels |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, item): |
|
return self.dataset[item], self.labels[item] |
|
|
|
|
|
class DatasetFactory(object): |
|
|
|
def __init__(self): |
|
self.train = None |
|
self.test = None |
|
|
|
def get_split(self, split, labeled=False): |
|
if split == "train": |
|
dataset = self.train |
|
elif split == "test": |
|
dataset = self.test |
|
else: |
|
raise ValueError |
|
|
|
if self.has_label: |
|
return dataset if labeled else UnlabeledDataset(dataset) |
|
else: |
|
assert not labeled |
|
return dataset |
|
|
|
def unpreprocess(self, v): |
|
v = 0.5 * (v + 1.) |
|
v.clamp_(0., 1.) |
|
return v |
|
|
|
@property |
|
def has_label(self): |
|
return True |
|
|
|
@property |
|
def data_shape(self): |
|
raise NotImplementedError |
|
|
|
@property |
|
def data_dim(self): |
|
return int(np.prod(self.data_shape)) |
|
|
|
@property |
|
def fid_stat(self): |
|
return None |
|
|
|
def sample_label(self, n_samples, device): |
|
raise NotImplementedError |
|
|
|
def label_prob(self, k): |
|
raise NotImplementedError |
|
|
|
|
|
def center_crop_arr(pil_image, image_size): |
|
|
|
|
|
|
|
while min(*pil_image.size) >= 2 * image_size: |
|
pil_image = pil_image.resize( |
|
tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
|
) |
|
|
|
scale = image_size / min(*pil_image.size) |
|
pil_image = pil_image.resize( |
|
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
|
) |
|
|
|
arr = np.array(pil_image) |
|
crop_y = (arr.shape[0] - image_size) // 2 |
|
crop_x = (arr.shape[1] - image_size) // 2 |
|
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] |
|
|
|
|
|
|
|
|
|
|
|
def center_crop(width, height, img): |
|
resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos'] |
|
crop = np.min(img.shape[:2]) |
|
img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2, |
|
(img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2] |
|
try: |
|
img = Image.fromarray(img, 'RGB') |
|
except: |
|
img = Image.fromarray(img) |
|
img = img.resize((width, height), resample) |
|
|
|
return np.array(img).astype(np.uint8) |
|
|
|
|
|
class MSCOCODatabase(Dataset): |
|
def __init__(self, root, annFile, size=None): |
|
from pycocotools.coco import COCO |
|
self.root = root |
|
self.height = self.width = size |
|
|
|
self.coco = COCO(annFile) |
|
self.keys = list(sorted(self.coco.imgs.keys())) |
|
|
|
def _load_image(self, key: int): |
|
path = self.coco.loadImgs(key)[0]["file_name"] |
|
return Image.open(os.path.join(self.root, path)).convert("RGB") |
|
|
|
def _load_target(self, key: int): |
|
return self.coco.loadAnns(self.coco.getAnnIds(key)) |
|
|
|
def __len__(self): |
|
return len(self.keys) |
|
|
|
def __getitem__(self, index): |
|
key = self.keys[index] |
|
image = self._load_image(key) |
|
image = np.array(image).astype(np.uint8) |
|
image = center_crop(self.width, self.height, image).astype(np.float32) |
|
image = (image / 127.5 - 1.0).astype(np.float32) |
|
image = einops.rearrange(image, 'h w c -> c h w') |
|
|
|
anns = self._load_target(key) |
|
target = [] |
|
for ann in anns: |
|
target.append(ann['caption']) |
|
|
|
return image, target |
|
|
|
|
|
def get_feature_dir_info(root): |
|
files = glob.glob(os.path.join(root, '*.npy')) |
|
files_caption = glob.glob(os.path.join(root, '*_*.npy')) |
|
num_data = len(files) - len(files_caption) |
|
n_captions = {k: 0 for k in range(num_data)} |
|
for f in files_caption: |
|
name = os.path.split(f)[-1] |
|
k1, k2 = os.path.splitext(name)[0].split('_') |
|
n_captions[int(k1)] += 1 |
|
return num_data, n_captions |
|
|
|
|
|
class MSCOCOFeatureDataset(Dataset): |
|
|
|
def __init__(self, root, need_squeeze=False, full_feature=False, fix_test_order=False): |
|
self.root = root |
|
self.num_data, self.n_captions = get_feature_dir_info(root) |
|
self.need_squeeze = need_squeeze |
|
self.full_feature = full_feature |
|
self.fix_test_order = fix_test_order |
|
|
|
def __len__(self): |
|
return self.num_data |
|
|
|
def __getitem__(self, index): |
|
if self.full_feature: |
|
z = np.load(os.path.join(self.root, f'{index}.npy')) |
|
|
|
if self.fix_test_order: |
|
k = self.n_captions[index] - 1 |
|
else: |
|
k = random.randint(0, self.n_captions[index] - 1) |
|
|
|
test_item = np.load(os.path.join(self.root, f'{index}_{k}.npy'), allow_pickle=True).item() |
|
token_embedding = test_item['token_embedding'] |
|
token_mask = test_item['token_mask'] |
|
token = test_item['token'] |
|
caption = test_item['promt'] |
|
return z, token_embedding, token_mask, token, caption |
|
else: |
|
z = np.load(os.path.join(self.root, f'{index}.npy')) |
|
k = random.randint(0, self.n_captions[index] - 1) |
|
c = np.load(os.path.join(self.root, f'{index}_{k}.npy')) |
|
if self.need_squeeze: |
|
return z, c.squeeze() |
|
else: |
|
return z, c |
|
|
|
|
|
class JDBFeatureDataset(Dataset): |
|
def __init__(self, root, resolution, llm): |
|
super().__init__() |
|
json_path = os.path.join(root,'img_text_pair.jsonl') |
|
self.img_root = os.path.join(root,'imgs') |
|
self.feature_root = os.path.join(root,'features') |
|
self.resolution = resolution |
|
self.llm = llm |
|
self.file_list = [] |
|
with open(json_path, 'r', encoding='utf-8') as file: |
|
for line in file: |
|
self.file_list.append(json.loads(line)['img_path']) |
|
|
|
def __len__(self): |
|
return len(self.file_list) |
|
|
|
def __getitem__(self, idx): |
|
data_item = self.file_list[idx] |
|
feature_path = os.path.join(self.feature_root, data_item.split('/')[-1].replace('.jpg','.npy')) |
|
img_path = os.path.join(self.img_root, data_item) |
|
|
|
train_item = np.load(feature_path, allow_pickle=True).item() |
|
pil_image = Image.open(img_path) |
|
pil_image.load() |
|
pil_image = pil_image.convert("RGB") |
|
|
|
|
|
z = train_item[f'image_latent_{self.resolution}'] |
|
token_embedding = train_item[f'token_embedding_{self.llm}'] |
|
token_mask = train_item[f'token_mask_{self.llm}'] |
|
token = train_item[f'token_{self.llm}'] |
|
caption = train_item['batch_caption'] |
|
|
|
img = center_crop_arr(pil_image, image_size=self.resolution) |
|
img = (img / 127.5 - 1.0).astype(np.float32) |
|
img = einops.rearrange(img, 'h w c -> c h w') |
|
|
|
|
|
return z, token_embedding, token_mask, token, caption, img |
|
|
|
|
|
class JDBFullFeatures(DatasetFactory): |
|
def __init__(self, train_path, val_path, resolution, llm, cfg=False, p_uncond=None, fix_test_order=False): |
|
super().__init__() |
|
print('Prepare dataset...') |
|
self.resolution = resolution |
|
|
|
self.train = JDBFeatureDataset(train_path, resolution=resolution, llm=llm) |
|
self.test = MSCOCOFeatureDataset(os.path.join(val_path, 'val'), full_feature=True, fix_test_order=fix_test_order) |
|
assert len(self.test) == 40504 |
|
|
|
print('Prepare dataset ok') |
|
|
|
self.empty_context = np.load(os.path.join(val_path, 'empty_context.npy'), allow_pickle=True).item() |
|
|
|
assert not cfg |
|
|
|
|
|
self.prompts, self.token_embedding, self.token_mask, self.token = [], [], [], [] |
|
for f in sorted(os.listdir(os.path.join(val_path, 'run_vis')), key=lambda x: int(x.split('.')[0])): |
|
vis_item = np.load(os.path.join(val_path, 'run_vis', f), allow_pickle=True).item() |
|
self.prompts.append(vis_item['promt']) |
|
self.token_embedding.append(vis_item['token_embedding']) |
|
self.token_mask.append(vis_item['token_mask']) |
|
self.token.append(vis_item['token']) |
|
self.token_embedding = np.array(self.token_embedding) |
|
self.token_mask = np.array(self.token_mask) |
|
self.token = np.array(self.token) |
|
|
|
@property |
|
def data_shape(self): |
|
if self.resolution==512: |
|
return 4, 64, 64 |
|
else: |
|
return 4, 32, 32 |
|
|
|
@property |
|
def fid_stat(self): |
|
return f'assets/fid_stats/fid_stats_mscoco256_val.npz' |
|
|
|
|
|
def get_dataset(name, **kwargs): |
|
if name == 'JDB_demo_features': |
|
return JDBFullFeatures(**kwargs) |
|
else: |
|
raise NotImplementedError(name) |
|
|