Manireddy1508 commited on
Commit
2e547b7
·
verified ·
1 Parent(s): 26548e6

Upload uno.py

Browse files
Files changed (1) hide show
  1. uno/dataset/uno.py +132 -0
uno/dataset/uno.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import os
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torchvision.transforms.functional as TVF
22
+ from torch.utils.data import DataLoader, Dataset
23
+ from torchvision.transforms import Compose, Normalize, ToTensor
24
+
25
+ def bucket_images(images: list[torch.Tensor], resolution: int = 512):
26
+ bucket_override=[
27
+ # h w
28
+ (256, 768),
29
+ (320, 768),
30
+ (320, 704),
31
+ (384, 640),
32
+ (448, 576),
33
+ (512, 512),
34
+ (576, 448),
35
+ (640, 384),
36
+ (704, 320),
37
+ (768, 320),
38
+ (768, 256)
39
+ ]
40
+ bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
41
+ bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
42
+
43
+ aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
44
+ mean_aspect_ratio = np.mean(aspect_ratios)
45
+
46
+ new_h, new_w = bucket_override[0]
47
+ min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
48
+ for h, w in bucket_override:
49
+ aspect_diff = np.abs(h / w - mean_aspect_ratio)
50
+ if aspect_diff < min_aspect_diff:
51
+ min_aspect_diff = aspect_diff
52
+ new_h, new_w = h, w
53
+
54
+ images = [TVF.resize(image, (new_h, new_w)) for image in images]
55
+ images = torch.stack(images, dim=0)
56
+ return images
57
+
58
+ class FluxPairedDatasetV2(Dataset):
59
+ def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
60
+ super().__init__()
61
+ self.json_file = json_file
62
+ self.resolution = resolution
63
+ self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
64
+ self.image_root = os.path.dirname(json_file)
65
+
66
+ with open(self.json_file, "rt") as f:
67
+ self.data_dicts = json.load(f)
68
+
69
+ self.transform = Compose([
70
+ ToTensor(),
71
+ Normalize([0.5], [0.5]),
72
+ ])
73
+
74
+ def __getitem__(self, idx):
75
+ data_dict = self.data_dicts[idx]
76
+ image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
77
+ txt = data_dict["prompt"]
78
+ image_tgt_path = data_dict.get("image_tgt_path", None)
79
+ ref_imgs = [
80
+ Image.open(os.path.join(self.image_root, path)).convert("RGB")
81
+ for path in image_paths
82
+ ]
83
+ ref_imgs = [self.transform(img) for img in ref_imgs]
84
+ img = None
85
+ if image_tgt_path is not None:
86
+ img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
87
+ img = self.transform(img)
88
+
89
+ return {
90
+ "img": img,
91
+ "txt": txt,
92
+ "ref_imgs": ref_imgs,
93
+ }
94
+
95
+ def __len__(self):
96
+ return len(self.data_dicts)
97
+
98
+ def collate_fn(self, batch):
99
+ img = [data["img"] for data in batch]
100
+ txt = [data["txt"] for data in batch]
101
+ ref_imgs = [data["ref_imgs"] for data in batch]
102
+ assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
103
+
104
+ n_ref = len(ref_imgs[0])
105
+
106
+ img = bucket_images(img, self.resolution)
107
+ ref_imgs_new = []
108
+ for i in range(n_ref):
109
+ ref_imgs_i = [refs[i] for refs in ref_imgs]
110
+ ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
111
+ ref_imgs_new.append(ref_imgs_i)
112
+
113
+ return {
114
+ "txt": txt,
115
+ "img": img,
116
+ "ref_imgs": ref_imgs_new,
117
+ }
118
+
119
+ if __name__ == '__main__':
120
+ import argparse
121
+ from pprint import pprint
122
+ parser = argparse.ArgumentParser()
123
+ # parser.add_argument("--json_file", type=str, required=True)
124
+ parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
125
+ args = parser.parse_args()
126
+ dataset = FluxPairedDatasetV2(args.json_file, 512)
127
+ dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
128
+
129
+ for i, data_dict in enumerate(dataloder):
130
+ pprint(i)
131
+ pprint(data_dict)
132
+ breakpoint()