Spaces:
Paused
Paused
Upload uno.py
Browse files- uno/dataset/uno.py +132 -0
uno/dataset/uno.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
3 |
+
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torchvision.transforms.functional as TVF
|
22 |
+
from torch.utils.data import DataLoader, Dataset
|
23 |
+
from torchvision.transforms import Compose, Normalize, ToTensor
|
24 |
+
|
25 |
+
def bucket_images(images: list[torch.Tensor], resolution: int = 512):
|
26 |
+
bucket_override=[
|
27 |
+
# h w
|
28 |
+
(256, 768),
|
29 |
+
(320, 768),
|
30 |
+
(320, 704),
|
31 |
+
(384, 640),
|
32 |
+
(448, 576),
|
33 |
+
(512, 512),
|
34 |
+
(576, 448),
|
35 |
+
(640, 384),
|
36 |
+
(704, 320),
|
37 |
+
(768, 320),
|
38 |
+
(768, 256)
|
39 |
+
]
|
40 |
+
bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
|
41 |
+
bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
|
42 |
+
|
43 |
+
aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
|
44 |
+
mean_aspect_ratio = np.mean(aspect_ratios)
|
45 |
+
|
46 |
+
new_h, new_w = bucket_override[0]
|
47 |
+
min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
|
48 |
+
for h, w in bucket_override:
|
49 |
+
aspect_diff = np.abs(h / w - mean_aspect_ratio)
|
50 |
+
if aspect_diff < min_aspect_diff:
|
51 |
+
min_aspect_diff = aspect_diff
|
52 |
+
new_h, new_w = h, w
|
53 |
+
|
54 |
+
images = [TVF.resize(image, (new_h, new_w)) for image in images]
|
55 |
+
images = torch.stack(images, dim=0)
|
56 |
+
return images
|
57 |
+
|
58 |
+
class FluxPairedDatasetV2(Dataset):
|
59 |
+
def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
|
60 |
+
super().__init__()
|
61 |
+
self.json_file = json_file
|
62 |
+
self.resolution = resolution
|
63 |
+
self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
|
64 |
+
self.image_root = os.path.dirname(json_file)
|
65 |
+
|
66 |
+
with open(self.json_file, "rt") as f:
|
67 |
+
self.data_dicts = json.load(f)
|
68 |
+
|
69 |
+
self.transform = Compose([
|
70 |
+
ToTensor(),
|
71 |
+
Normalize([0.5], [0.5]),
|
72 |
+
])
|
73 |
+
|
74 |
+
def __getitem__(self, idx):
|
75 |
+
data_dict = self.data_dicts[idx]
|
76 |
+
image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
|
77 |
+
txt = data_dict["prompt"]
|
78 |
+
image_tgt_path = data_dict.get("image_tgt_path", None)
|
79 |
+
ref_imgs = [
|
80 |
+
Image.open(os.path.join(self.image_root, path)).convert("RGB")
|
81 |
+
for path in image_paths
|
82 |
+
]
|
83 |
+
ref_imgs = [self.transform(img) for img in ref_imgs]
|
84 |
+
img = None
|
85 |
+
if image_tgt_path is not None:
|
86 |
+
img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
|
87 |
+
img = self.transform(img)
|
88 |
+
|
89 |
+
return {
|
90 |
+
"img": img,
|
91 |
+
"txt": txt,
|
92 |
+
"ref_imgs": ref_imgs,
|
93 |
+
}
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return len(self.data_dicts)
|
97 |
+
|
98 |
+
def collate_fn(self, batch):
|
99 |
+
img = [data["img"] for data in batch]
|
100 |
+
txt = [data["txt"] for data in batch]
|
101 |
+
ref_imgs = [data["ref_imgs"] for data in batch]
|
102 |
+
assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
|
103 |
+
|
104 |
+
n_ref = len(ref_imgs[0])
|
105 |
+
|
106 |
+
img = bucket_images(img, self.resolution)
|
107 |
+
ref_imgs_new = []
|
108 |
+
for i in range(n_ref):
|
109 |
+
ref_imgs_i = [refs[i] for refs in ref_imgs]
|
110 |
+
ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
|
111 |
+
ref_imgs_new.append(ref_imgs_i)
|
112 |
+
|
113 |
+
return {
|
114 |
+
"txt": txt,
|
115 |
+
"img": img,
|
116 |
+
"ref_imgs": ref_imgs_new,
|
117 |
+
}
|
118 |
+
|
119 |
+
if __name__ == '__main__':
|
120 |
+
import argparse
|
121 |
+
from pprint import pprint
|
122 |
+
parser = argparse.ArgumentParser()
|
123 |
+
# parser.add_argument("--json_file", type=str, required=True)
|
124 |
+
parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
|
125 |
+
args = parser.parse_args()
|
126 |
+
dataset = FluxPairedDatasetV2(args.json_file, 512)
|
127 |
+
dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
|
128 |
+
|
129 |
+
for i, data_dict in enumerate(dataloder):
|
130 |
+
pprint(i)
|
131 |
+
pprint(data_dict)
|
132 |
+
breakpoint()
|