Mariam-Elz commited on
Commit
bb92043
·
verified ·
1 Parent(s): 37e231b

Upload imagedream/ldm/util.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/util.py +225 -225
imagedream/ldm/util.py CHANGED
@@ -1,226 +1,226 @@
1
- import importlib
2
-
3
- import random
4
- import torch
5
- import numpy as np
6
- from collections import abc
7
-
8
- import multiprocessing as mp
9
- from threading import Thread
10
- from queue import Queue
11
-
12
- from inspect import isfunction
13
- from PIL import Image, ImageDraw, ImageFont
14
-
15
-
16
- def log_txt_as_img(wh, xc, size=10):
17
- # wh a tuple of (width, height)
18
- # xc a list of captions to plot
19
- b = len(xc)
20
- txts = list()
21
- for bi in range(b):
22
- txt = Image.new("RGB", wh, color="white")
23
- draw = ImageDraw.Draw(txt)
24
- font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
25
- nc = int(40 * (wh[0] / 256))
26
- lines = "\n".join(
27
- xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
28
- )
29
-
30
- try:
31
- draw.text((0, 0), lines, fill="black", font=font)
32
- except UnicodeEncodeError:
33
- print("Cant encode string for logging. Skipping.")
34
-
35
- txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
36
- txts.append(txt)
37
- txts = np.stack(txts)
38
- txts = torch.tensor(txts)
39
- return txts
40
-
41
-
42
- def ismap(x):
43
- if not isinstance(x, torch.Tensor):
44
- return False
45
- return (len(x.shape) == 4) and (x.shape[1] > 3)
46
-
47
-
48
- def isimage(x):
49
- if not isinstance(x, torch.Tensor):
50
- return False
51
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
52
-
53
-
54
- def exists(x):
55
- return x is not None
56
-
57
-
58
- def default(val, d):
59
- if exists(val):
60
- return val
61
- return d() if isfunction(d) else d
62
-
63
-
64
- def mean_flat(tensor):
65
- """
66
- https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
67
- Take the mean over all non-batch dimensions.
68
- """
69
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
70
-
71
-
72
- def count_params(model, verbose=False):
73
- total_params = sum(p.numel() for p in model.parameters())
74
- if verbose:
75
- print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
76
- return total_params
77
-
78
-
79
- def instantiate_from_config(config):
80
- if not "target" in config:
81
- if config == "__is_first_stage__":
82
- return None
83
- elif config == "__is_unconditional__":
84
- return None
85
- raise KeyError("Expected key `target` to instantiate.")
86
- # import pdb; pdb.set_trace()
87
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
88
-
89
-
90
- def get_obj_from_str(string, reload=False):
91
- module, cls = string.rsplit(".", 1)
92
- # import pdb; pdb.set_trace()
93
- if reload:
94
- module_imp = importlib.import_module(module)
95
- importlib.reload(module_imp)
96
- return getattr(importlib.import_module(module, package=None), cls)
97
-
98
-
99
- def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
100
- # create dummy dataset instance
101
-
102
- # run prefetching
103
- if idx_to_fn:
104
- res = func(data, worker_id=idx)
105
- else:
106
- res = func(data)
107
- Q.put([idx, res])
108
- Q.put("Done")
109
-
110
-
111
- def parallel_data_prefetch(
112
- func: callable,
113
- data,
114
- n_proc,
115
- target_data_type="ndarray",
116
- cpu_intensive=True,
117
- use_worker_id=False,
118
- ):
119
- # if target_data_type not in ["ndarray", "list"]:
120
- # raise ValueError(
121
- # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
122
- # )
123
- if isinstance(data, np.ndarray) and target_data_type == "list":
124
- raise ValueError("list expected but function got ndarray.")
125
- elif isinstance(data, abc.Iterable):
126
- if isinstance(data, dict):
127
- print(
128
- f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
129
- )
130
- data = list(data.values())
131
- if target_data_type == "ndarray":
132
- data = np.asarray(data)
133
- else:
134
- data = list(data)
135
- else:
136
- raise TypeError(
137
- f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
138
- )
139
-
140
- if cpu_intensive:
141
- Q = mp.Queue(1000)
142
- proc = mp.Process
143
- else:
144
- Q = Queue(1000)
145
- proc = Thread
146
- # spawn processes
147
- if target_data_type == "ndarray":
148
- arguments = [
149
- [func, Q, part, i, use_worker_id]
150
- for i, part in enumerate(np.array_split(data, n_proc))
151
- ]
152
- else:
153
- step = (
154
- int(len(data) / n_proc + 1)
155
- if len(data) % n_proc != 0
156
- else int(len(data) / n_proc)
157
- )
158
- arguments = [
159
- [func, Q, part, i, use_worker_id]
160
- for i, part in enumerate(
161
- [data[i : i + step] for i in range(0, len(data), step)]
162
- )
163
- ]
164
- processes = []
165
- for i in range(n_proc):
166
- p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
167
- processes += [p]
168
-
169
- # start processes
170
- print(f"Start prefetching...")
171
- import time
172
-
173
- start = time.time()
174
- gather_res = [[] for _ in range(n_proc)]
175
- try:
176
- for p in processes:
177
- p.start()
178
-
179
- k = 0
180
- while k < n_proc:
181
- # get result
182
- res = Q.get()
183
- if res == "Done":
184
- k += 1
185
- else:
186
- gather_res[res[0]] = res[1]
187
-
188
- except Exception as e:
189
- print("Exception: ", e)
190
- for p in processes:
191
- p.terminate()
192
-
193
- raise e
194
- finally:
195
- for p in processes:
196
- p.join()
197
- print(f"Prefetching complete. [{time.time() - start} sec.]")
198
-
199
- if target_data_type == "ndarray":
200
- if not isinstance(gather_res[0], np.ndarray):
201
- return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
202
-
203
- # order outputs
204
- return np.concatenate(gather_res, axis=0)
205
- elif target_data_type == "list":
206
- out = []
207
- for r in gather_res:
208
- out.extend(r)
209
- return out
210
- else:
211
- return gather_res
212
-
213
- def set_seed(seed=None):
214
- random.seed(seed)
215
- np.random.seed(seed)
216
- if seed is not None:
217
- torch.manual_seed(seed)
218
- torch.cuda.manual_seed_all(seed)
219
-
220
- def add_random_background(image, bg_color=None):
221
- bg_color = np.random.rand() * 255 if bg_color is None else bg_color
222
- image = np.array(image)
223
- rgb, alpha = image[..., :3], image[..., 3:]
224
- alpha = alpha.astype(np.float32) / 255.0
225
- image_new = rgb * alpha + bg_color * (1 - alpha)
226
  return Image.fromarray(image_new.astype(np.uint8))
 
1
+ import importlib
2
+
3
+ import random
4
+ import torch
5
+ import numpy as np
6
+ from collections import abc
7
+
8
+ import multiprocessing as mp
9
+ from threading import Thread
10
+ from queue import Queue
11
+
12
+ from inspect import isfunction
13
+ from PIL import Image, ImageDraw, ImageFont
14
+
15
+
16
+ def log_txt_as_img(wh, xc, size=10):
17
+ # wh a tuple of (width, height)
18
+ # xc a list of captions to plot
19
+ b = len(xc)
20
+ txts = list()
21
+ for bi in range(b):
22
+ txt = Image.new("RGB", wh, color="white")
23
+ draw = ImageDraw.Draw(txt)
24
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
25
+ nc = int(40 * (wh[0] / 256))
26
+ lines = "\n".join(
27
+ xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
28
+ )
29
+
30
+ try:
31
+ draw.text((0, 0), lines, fill="black", font=font)
32
+ except UnicodeEncodeError:
33
+ print("Cant encode string for logging. Skipping.")
34
+
35
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
36
+ txts.append(txt)
37
+ txts = np.stack(txts)
38
+ txts = torch.tensor(txts)
39
+ return txts
40
+
41
+
42
+ def ismap(x):
43
+ if not isinstance(x, torch.Tensor):
44
+ return False
45
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
46
+
47
+
48
+ def isimage(x):
49
+ if not isinstance(x, torch.Tensor):
50
+ return False
51
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
52
+
53
+
54
+ def exists(x):
55
+ return x is not None
56
+
57
+
58
+ def default(val, d):
59
+ if exists(val):
60
+ return val
61
+ return d() if isfunction(d) else d
62
+
63
+
64
+ def mean_flat(tensor):
65
+ """
66
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
67
+ Take the mean over all non-batch dimensions.
68
+ """
69
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
70
+
71
+
72
+ def count_params(model, verbose=False):
73
+ total_params = sum(p.numel() for p in model.parameters())
74
+ if verbose:
75
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
76
+ return total_params
77
+
78
+
79
+ def instantiate_from_config(config):
80
+ if not "target" in config:
81
+ if config == "__is_first_stage__":
82
+ return None
83
+ elif config == "__is_unconditional__":
84
+ return None
85
+ raise KeyError("Expected key `target` to instantiate.")
86
+ # import pdb; pdb.set_trace()
87
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
88
+
89
+
90
+ def get_obj_from_str(string, reload=False):
91
+ module, cls = string.rsplit(".", 1)
92
+ # import pdb; pdb.set_trace()
93
+ if reload:
94
+ module_imp = importlib.import_module(module)
95
+ importlib.reload(module_imp)
96
+ return getattr(importlib.import_module(module, package=None), cls)
97
+
98
+
99
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
100
+ # create dummy dataset instance
101
+
102
+ # run prefetching
103
+ if idx_to_fn:
104
+ res = func(data, worker_id=idx)
105
+ else:
106
+ res = func(data)
107
+ Q.put([idx, res])
108
+ Q.put("Done")
109
+
110
+
111
+ def parallel_data_prefetch(
112
+ func: callable,
113
+ data,
114
+ n_proc,
115
+ target_data_type="ndarray",
116
+ cpu_intensive=True,
117
+ use_worker_id=False,
118
+ ):
119
+ # if target_data_type not in ["ndarray", "list"]:
120
+ # raise ValueError(
121
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
122
+ # )
123
+ if isinstance(data, np.ndarray) and target_data_type == "list":
124
+ raise ValueError("list expected but function got ndarray.")
125
+ elif isinstance(data, abc.Iterable):
126
+ if isinstance(data, dict):
127
+ print(
128
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
129
+ )
130
+ data = list(data.values())
131
+ if target_data_type == "ndarray":
132
+ data = np.asarray(data)
133
+ else:
134
+ data = list(data)
135
+ else:
136
+ raise TypeError(
137
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
138
+ )
139
+
140
+ if cpu_intensive:
141
+ Q = mp.Queue(1000)
142
+ proc = mp.Process
143
+ else:
144
+ Q = Queue(1000)
145
+ proc = Thread
146
+ # spawn processes
147
+ if target_data_type == "ndarray":
148
+ arguments = [
149
+ [func, Q, part, i, use_worker_id]
150
+ for i, part in enumerate(np.array_split(data, n_proc))
151
+ ]
152
+ else:
153
+ step = (
154
+ int(len(data) / n_proc + 1)
155
+ if len(data) % n_proc != 0
156
+ else int(len(data) / n_proc)
157
+ )
158
+ arguments = [
159
+ [func, Q, part, i, use_worker_id]
160
+ for i, part in enumerate(
161
+ [data[i : i + step] for i in range(0, len(data), step)]
162
+ )
163
+ ]
164
+ processes = []
165
+ for i in range(n_proc):
166
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
167
+ processes += [p]
168
+
169
+ # start processes
170
+ print(f"Start prefetching...")
171
+ import time
172
+
173
+ start = time.time()
174
+ gather_res = [[] for _ in range(n_proc)]
175
+ try:
176
+ for p in processes:
177
+ p.start()
178
+
179
+ k = 0
180
+ while k < n_proc:
181
+ # get result
182
+ res = Q.get()
183
+ if res == "Done":
184
+ k += 1
185
+ else:
186
+ gather_res[res[0]] = res[1]
187
+
188
+ except Exception as e:
189
+ print("Exception: ", e)
190
+ for p in processes:
191
+ p.terminate()
192
+
193
+ raise e
194
+ finally:
195
+ for p in processes:
196
+ p.join()
197
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
198
+
199
+ if target_data_type == "ndarray":
200
+ if not isinstance(gather_res[0], np.ndarray):
201
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
202
+
203
+ # order outputs
204
+ return np.concatenate(gather_res, axis=0)
205
+ elif target_data_type == "list":
206
+ out = []
207
+ for r in gather_res:
208
+ out.extend(r)
209
+ return out
210
+ else:
211
+ return gather_res
212
+
213
+ def set_seed(seed=None):
214
+ random.seed(seed)
215
+ np.random.seed(seed)
216
+ if seed is not None:
217
+ torch.manual_seed(seed)
218
+ torch.cuda.manual_seed_all(seed)
219
+
220
+ def add_random_background(image, bg_color=None):
221
+ bg_color = np.random.rand() * 255 if bg_color is None else bg_color
222
+ image = np.array(image)
223
+ rgb, alpha = image[..., :3], image[..., 3:]
224
+ alpha = alpha.astype(np.float32) / 255.0
225
+ image_new = rgb * alpha + bg_color * (1 - alpha)
226
  return Image.fromarray(image_new.astype(np.uint8))