Mariam-Elz commited on
Commit
584bda0
·
verified ·
1 Parent(s): f34eca8

Upload imagedream/ldm/modules/encoders/modules.py with huggingface_hub

Browse files
imagedream/ldm/modules/encoders/modules.py CHANGED
@@ -1,329 +1,329 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.checkpoint import checkpoint
4
-
5
- from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6
-
7
- import numpy as np
8
- import open_clip
9
- from PIL import Image
10
- from ...util import default, count_params
11
-
12
-
13
- class AbstractEncoder(nn.Module):
14
- def __init__(self):
15
- super().__init__()
16
-
17
- def encode(self, *args, **kwargs):
18
- raise NotImplementedError
19
-
20
-
21
- class IdentityEncoder(AbstractEncoder):
22
- def encode(self, x):
23
- return x
24
-
25
-
26
- class ClassEmbedder(nn.Module):
27
- def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
28
- super().__init__()
29
- self.key = key
30
- self.embedding = nn.Embedding(n_classes, embed_dim)
31
- self.n_classes = n_classes
32
- self.ucg_rate = ucg_rate
33
-
34
- def forward(self, batch, key=None, disable_dropout=False):
35
- if key is None:
36
- key = self.key
37
- # this is for use in crossattn
38
- c = batch[key][:, None]
39
- if self.ucg_rate > 0.0 and not disable_dropout:
40
- mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
41
- c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
42
- c = c.long()
43
- c = self.embedding(c)
44
- return c
45
-
46
- def get_unconditional_conditioning(self, bs, device="cuda"):
47
- uc_class = (
48
- self.n_classes - 1
49
- ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
50
- uc = torch.ones((bs,), device=device) * uc_class
51
- uc = {self.key: uc}
52
- return uc
53
-
54
-
55
- def disabled_train(self, mode=True):
56
- """Overwrite model.train with this function to make sure train/eval mode
57
- does not change anymore."""
58
- return self
59
-
60
-
61
- class FrozenT5Embedder(AbstractEncoder):
62
- """Uses the T5 transformer encoder for text"""
63
-
64
- def __init__(
65
- self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
66
- ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
67
- super().__init__()
68
- self.tokenizer = T5Tokenizer.from_pretrained(version)
69
- self.transformer = T5EncoderModel.from_pretrained(version)
70
- self.device = device
71
- self.max_length = max_length # TODO: typical value?
72
- if freeze:
73
- self.freeze()
74
-
75
- def freeze(self):
76
- self.transformer = self.transformer.eval()
77
- # self.train = disabled_train
78
- for param in self.parameters():
79
- param.requires_grad = False
80
-
81
- def forward(self, text):
82
- batch_encoding = self.tokenizer(
83
- text,
84
- truncation=True,
85
- max_length=self.max_length,
86
- return_length=True,
87
- return_overflowing_tokens=False,
88
- padding="max_length",
89
- return_tensors="pt",
90
- )
91
- tokens = batch_encoding["input_ids"].to(self.device)
92
- outputs = self.transformer(input_ids=tokens)
93
-
94
- z = outputs.last_hidden_state
95
- return z
96
-
97
- def encode(self, text):
98
- return self(text)
99
-
100
-
101
- class FrozenCLIPEmbedder(AbstractEncoder):
102
- """Uses the CLIP transformer encoder for text (from huggingface)"""
103
-
104
- LAYERS = ["last", "pooled", "hidden"]
105
-
106
- def __init__(
107
- self,
108
- version="openai/clip-vit-large-patch14",
109
- device="cuda",
110
- max_length=77,
111
- freeze=True,
112
- layer="last",
113
- layer_idx=None,
114
- ): # clip-vit-base-patch32
115
- super().__init__()
116
- assert layer in self.LAYERS
117
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
118
- self.transformer = CLIPTextModel.from_pretrained(version)
119
- self.device = device
120
- self.max_length = max_length
121
- if freeze:
122
- self.freeze()
123
- self.layer = layer
124
- self.layer_idx = layer_idx
125
- if layer == "hidden":
126
- assert layer_idx is not None
127
- assert 0 <= abs(layer_idx) <= 12
128
-
129
- def freeze(self):
130
- self.transformer = self.transformer.eval()
131
- # self.train = disabled_train
132
- for param in self.parameters():
133
- param.requires_grad = False
134
-
135
- def forward(self, text):
136
- batch_encoding = self.tokenizer(
137
- text,
138
- truncation=True,
139
- max_length=self.max_length,
140
- return_length=True,
141
- return_overflowing_tokens=False,
142
- padding="max_length",
143
- return_tensors="pt",
144
- )
145
- tokens = batch_encoding["input_ids"].to(self.device)
146
- outputs = self.transformer(
147
- input_ids=tokens, output_hidden_states=self.layer == "hidden"
148
- )
149
- if self.layer == "last":
150
- z = outputs.last_hidden_state
151
- elif self.layer == "pooled":
152
- z = outputs.pooler_output[:, None, :]
153
- else:
154
- z = outputs.hidden_states[self.layer_idx]
155
- return z
156
-
157
- def encode(self, text):
158
- return self(text)
159
-
160
-
161
- class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
162
- """
163
- Uses the OpenCLIP transformer encoder for text
164
- """
165
-
166
- LAYERS = [
167
- # "pooled",
168
- "last",
169
- "penultimate",
170
- ]
171
-
172
- def __init__(
173
- self,
174
- arch="ViT-H-14",
175
- version="laion2b_s32b_b79k",
176
- device="cuda",
177
- max_length=77,
178
- freeze=True,
179
- layer="last",
180
- ip_mode=None
181
- ):
182
- """_summary_
183
-
184
- Args:
185
- ip_mode (str, optional): what is the image promcessing mode. Defaults to None.
186
-
187
- """
188
- super().__init__()
189
- assert layer in self.LAYERS
190
- model, _, preprocess = open_clip.create_model_and_transforms(
191
- arch, device=torch.device("cpu"), pretrained=version
192
- )
193
- if ip_mode is None:
194
- del model.visual
195
-
196
- self.model = model
197
- self.preprocess = preprocess
198
- self.device = device
199
- self.max_length = max_length
200
- self.ip_mode = ip_mode
201
- if freeze:
202
- self.freeze()
203
- self.layer = layer
204
- if self.layer == "last":
205
- self.layer_idx = 0
206
- elif self.layer == "penultimate":
207
- self.layer_idx = 1
208
- else:
209
- raise NotImplementedError()
210
-
211
- def freeze(self):
212
- self.model = self.model.eval()
213
- for param in self.parameters():
214
- param.requires_grad = False
215
-
216
- def forward(self, text):
217
- tokens = open_clip.tokenize(text)
218
- z = self.encode_with_transformer(tokens.to(self.device))
219
- return z
220
-
221
- def forward_image(self, pil_image):
222
- if isinstance(pil_image, Image.Image):
223
- pil_image = [pil_image]
224
- if isinstance(pil_image, torch.Tensor):
225
- pil_image = pil_image.cpu().numpy()
226
- if isinstance(pil_image, np.ndarray):
227
- if pil_image.ndim == 3:
228
- pil_image = pil_image[None, :, :, :]
229
- pil_image = [Image.fromarray(x) for x in pil_image]
230
-
231
- images = []
232
- for image in pil_image:
233
- images.append(self.preprocess(image).to(self.device))
234
-
235
- image = torch.stack(images, 0) # to [b, 3, h, w]
236
- if self.ip_mode == "global":
237
- image_features = self.model.encode_image(image)
238
- image_features /= image_features.norm(dim=-1, keepdim=True)
239
- elif "local" in self.ip_mode:
240
- image_features = self.encode_image_with_transformer(image)
241
-
242
- return image_features # b, l
243
-
244
- def encode_image_with_transformer(self, x):
245
- visual = self.model.visual
246
- x = visual.conv1(x) # shape = [*, width, grid, grid]
247
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
248
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
249
-
250
- # class embeddings and positional embeddings
251
- x = torch.cat(
252
- [visual.class_embedding.to(x.dtype) + \
253
- torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
254
- x], dim=1) # shape = [*, grid ** 2 + 1, width]
255
- x = x + visual.positional_embedding.to(x.dtype)
256
-
257
- # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
258
- # x = visual.patch_dropout(x)
259
- x = visual.ln_pre(x)
260
-
261
- x = x.permute(1, 0, 2) # NLD -> LND
262
- hidden = self.image_transformer_forward(x)
263
- x = hidden[-2].permute(1, 0, 2) # LND -> NLD
264
- return x
265
-
266
- def image_transformer_forward(self, x):
267
- encoder_states = ()
268
- trans = self.model.visual.transformer
269
- for r in trans.resblocks:
270
- if trans.grad_checkpointing and not torch.jit.is_scripting():
271
- # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
272
- x = checkpoint(r, x, None, None, None)
273
- else:
274
- x = r(x, attn_mask=None)
275
- encoder_states = encoder_states + (x, )
276
- return encoder_states
277
-
278
- def encode_with_transformer(self, text):
279
- x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
280
- x = x + self.model.positional_embedding
281
- x = x.permute(1, 0, 2) # NLD -> LND
282
- x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
283
- x = x.permute(1, 0, 2) # LND -> NLD
284
- x = self.model.ln_final(x)
285
- return x
286
-
287
- def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
288
- for i, r in enumerate(self.model.transformer.resblocks):
289
- if i == len(self.model.transformer.resblocks) - self.layer_idx:
290
- break
291
- if (
292
- self.model.transformer.grad_checkpointing
293
- and not torch.jit.is_scripting()
294
- ):
295
- x = checkpoint(r, x, attn_mask)
296
- else:
297
- x = r(x, attn_mask=attn_mask)
298
- return x
299
-
300
- def encode(self, text):
301
- return self(text)
302
-
303
-
304
- class FrozenCLIPT5Encoder(AbstractEncoder):
305
- def __init__(
306
- self,
307
- clip_version="openai/clip-vit-large-patch14",
308
- t5_version="google/t5-v1_1-xl",
309
- device="cuda",
310
- clip_max_length=77,
311
- t5_max_length=77,
312
- ):
313
- super().__init__()
314
- self.clip_encoder = FrozenCLIPEmbedder(
315
- clip_version, device, max_length=clip_max_length
316
- )
317
- self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
318
- print(
319
- f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
320
- f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
321
- )
322
-
323
- def encode(self, text):
324
- return self(text)
325
-
326
- def forward(self, text):
327
- clip_z = self.clip_encoder.encode(text)
328
- t5_z = self.t5_encoder.encode(text)
329
- return [clip_z, t5_z]
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6
+
7
+ import numpy as np
8
+ import open_clip
9
+ from PIL import Image
10
+ from ...util import default, count_params
11
+
12
+
13
+ class AbstractEncoder(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def encode(self, *args, **kwargs):
18
+ raise NotImplementedError
19
+
20
+
21
+ class IdentityEncoder(AbstractEncoder):
22
+ def encode(self, x):
23
+ return x
24
+
25
+
26
+ class ClassEmbedder(nn.Module):
27
+ def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
28
+ super().__init__()
29
+ self.key = key
30
+ self.embedding = nn.Embedding(n_classes, embed_dim)
31
+ self.n_classes = n_classes
32
+ self.ucg_rate = ucg_rate
33
+
34
+ def forward(self, batch, key=None, disable_dropout=False):
35
+ if key is None:
36
+ key = self.key
37
+ # this is for use in crossattn
38
+ c = batch[key][:, None]
39
+ if self.ucg_rate > 0.0 and not disable_dropout:
40
+ mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
41
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
42
+ c = c.long()
43
+ c = self.embedding(c)
44
+ return c
45
+
46
+ def get_unconditional_conditioning(self, bs, device="cuda"):
47
+ uc_class = (
48
+ self.n_classes - 1
49
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
50
+ uc = torch.ones((bs,), device=device) * uc_class
51
+ uc = {self.key: uc}
52
+ return uc
53
+
54
+
55
+ def disabled_train(self, mode=True):
56
+ """Overwrite model.train with this function to make sure train/eval mode
57
+ does not change anymore."""
58
+ return self
59
+
60
+
61
+ class FrozenT5Embedder(AbstractEncoder):
62
+ """Uses the T5 transformer encoder for text"""
63
+
64
+ def __init__(
65
+ self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
66
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
67
+ super().__init__()
68
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
69
+ self.transformer = T5EncoderModel.from_pretrained(version)
70
+ self.device = device
71
+ self.max_length = max_length # TODO: typical value?
72
+ if freeze:
73
+ self.freeze()
74
+
75
+ def freeze(self):
76
+ self.transformer = self.transformer.eval()
77
+ # self.train = disabled_train
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, text):
82
+ batch_encoding = self.tokenizer(
83
+ text,
84
+ truncation=True,
85
+ max_length=self.max_length,
86
+ return_length=True,
87
+ return_overflowing_tokens=False,
88
+ padding="max_length",
89
+ return_tensors="pt",
90
+ )
91
+ tokens = batch_encoding["input_ids"].to(self.device)
92
+ outputs = self.transformer(input_ids=tokens)
93
+
94
+ z = outputs.last_hidden_state
95
+ return z
96
+
97
+ def encode(self, text):
98
+ return self(text)
99
+
100
+
101
+ class FrozenCLIPEmbedder(AbstractEncoder):
102
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
103
+
104
+ LAYERS = ["last", "pooled", "hidden"]
105
+
106
+ def __init__(
107
+ self,
108
+ version="openai/clip-vit-large-patch14",
109
+ device="cuda",
110
+ max_length=77,
111
+ freeze=True,
112
+ layer="last",
113
+ layer_idx=None,
114
+ ): # clip-vit-base-patch32
115
+ super().__init__()
116
+ assert layer in self.LAYERS
117
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
118
+ self.transformer = CLIPTextModel.from_pretrained(version)
119
+ self.device = device
120
+ self.max_length = max_length
121
+ if freeze:
122
+ self.freeze()
123
+ self.layer = layer
124
+ self.layer_idx = layer_idx
125
+ if layer == "hidden":
126
+ assert layer_idx is not None
127
+ assert 0 <= abs(layer_idx) <= 12
128
+
129
+ def freeze(self):
130
+ self.transformer = self.transformer.eval()
131
+ # self.train = disabled_train
132
+ for param in self.parameters():
133
+ param.requires_grad = False
134
+
135
+ def forward(self, text):
136
+ batch_encoding = self.tokenizer(
137
+ text,
138
+ truncation=True,
139
+ max_length=self.max_length,
140
+ return_length=True,
141
+ return_overflowing_tokens=False,
142
+ padding="max_length",
143
+ return_tensors="pt",
144
+ )
145
+ tokens = batch_encoding["input_ids"].to(self.device)
146
+ outputs = self.transformer(
147
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
148
+ )
149
+ if self.layer == "last":
150
+ z = outputs.last_hidden_state
151
+ elif self.layer == "pooled":
152
+ z = outputs.pooler_output[:, None, :]
153
+ else:
154
+ z = outputs.hidden_states[self.layer_idx]
155
+ return z
156
+
157
+ def encode(self, text):
158
+ return self(text)
159
+
160
+
161
+ class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
162
+ """
163
+ Uses the OpenCLIP transformer encoder for text
164
+ """
165
+
166
+ LAYERS = [
167
+ # "pooled",
168
+ "last",
169
+ "penultimate",
170
+ ]
171
+
172
+ def __init__(
173
+ self,
174
+ arch="ViT-H-14",
175
+ version="laion2b_s32b_b79k",
176
+ device="cuda",
177
+ max_length=77,
178
+ freeze=True,
179
+ layer="last",
180
+ ip_mode=None
181
+ ):
182
+ """_summary_
183
+
184
+ Args:
185
+ ip_mode (str, optional): what is the image promcessing mode. Defaults to None.
186
+
187
+ """
188
+ super().__init__()
189
+ assert layer in self.LAYERS
190
+ model, _, preprocess = open_clip.create_model_and_transforms(
191
+ arch, device=torch.device("cpu"), pretrained=version
192
+ )
193
+ if ip_mode is None:
194
+ del model.visual
195
+
196
+ self.model = model
197
+ self.preprocess = preprocess
198
+ self.device = device
199
+ self.max_length = max_length
200
+ self.ip_mode = ip_mode
201
+ if freeze:
202
+ self.freeze()
203
+ self.layer = layer
204
+ if self.layer == "last":
205
+ self.layer_idx = 0
206
+ elif self.layer == "penultimate":
207
+ self.layer_idx = 1
208
+ else:
209
+ raise NotImplementedError()
210
+
211
+ def freeze(self):
212
+ self.model = self.model.eval()
213
+ for param in self.parameters():
214
+ param.requires_grad = False
215
+
216
+ def forward(self, text):
217
+ tokens = open_clip.tokenize(text)
218
+ z = self.encode_with_transformer(tokens.to(self.device))
219
+ return z
220
+
221
+ def forward_image(self, pil_image):
222
+ if isinstance(pil_image, Image.Image):
223
+ pil_image = [pil_image]
224
+ if isinstance(pil_image, torch.Tensor):
225
+ pil_image = pil_image.cpu().numpy()
226
+ if isinstance(pil_image, np.ndarray):
227
+ if pil_image.ndim == 3:
228
+ pil_image = pil_image[None, :, :, :]
229
+ pil_image = [Image.fromarray(x) for x in pil_image]
230
+
231
+ images = []
232
+ for image in pil_image:
233
+ images.append(self.preprocess(image).to(self.device))
234
+
235
+ image = torch.stack(images, 0) # to [b, 3, h, w]
236
+ if self.ip_mode == "global":
237
+ image_features = self.model.encode_image(image)
238
+ image_features /= image_features.norm(dim=-1, keepdim=True)
239
+ elif "local" in self.ip_mode:
240
+ image_features = self.encode_image_with_transformer(image)
241
+
242
+ return image_features # b, l
243
+
244
+ def encode_image_with_transformer(self, x):
245
+ visual = self.model.visual
246
+ x = visual.conv1(x) # shape = [*, width, grid, grid]
247
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
248
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
249
+
250
+ # class embeddings and positional embeddings
251
+ x = torch.cat(
252
+ [visual.class_embedding.to(x.dtype) + \
253
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
254
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
255
+ x = x + visual.positional_embedding.to(x.dtype)
256
+
257
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
258
+ # x = visual.patch_dropout(x)
259
+ x = visual.ln_pre(x)
260
+
261
+ x = x.permute(1, 0, 2) # NLD -> LND
262
+ hidden = self.image_transformer_forward(x)
263
+ x = hidden[-2].permute(1, 0, 2) # LND -> NLD
264
+ return x
265
+
266
+ def image_transformer_forward(self, x):
267
+ encoder_states = ()
268
+ trans = self.model.visual.transformer
269
+ for r in trans.resblocks:
270
+ if trans.grad_checkpointing and not torch.jit.is_scripting():
271
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
272
+ x = checkpoint(r, x, None, None, None)
273
+ else:
274
+ x = r(x, attn_mask=None)
275
+ encoder_states = encoder_states + (x, )
276
+ return encoder_states
277
+
278
+ def encode_with_transformer(self, text):
279
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
280
+ x = x + self.model.positional_embedding
281
+ x = x.permute(1, 0, 2) # NLD -> LND
282
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
283
+ x = x.permute(1, 0, 2) # LND -> NLD
284
+ x = self.model.ln_final(x)
285
+ return x
286
+
287
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
288
+ for i, r in enumerate(self.model.transformer.resblocks):
289
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
290
+ break
291
+ if (
292
+ self.model.transformer.grad_checkpointing
293
+ and not torch.jit.is_scripting()
294
+ ):
295
+ x = checkpoint(r, x, attn_mask)
296
+ else:
297
+ x = r(x, attn_mask=attn_mask)
298
+ return x
299
+
300
+ def encode(self, text):
301
+ return self(text)
302
+
303
+
304
+ class FrozenCLIPT5Encoder(AbstractEncoder):
305
+ def __init__(
306
+ self,
307
+ clip_version="openai/clip-vit-large-patch14",
308
+ t5_version="google/t5-v1_1-xl",
309
+ device="cuda",
310
+ clip_max_length=77,
311
+ t5_max_length=77,
312
+ ):
313
+ super().__init__()
314
+ self.clip_encoder = FrozenCLIPEmbedder(
315
+ clip_version, device, max_length=clip_max_length
316
+ )
317
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
318
+ print(
319
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
320
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
321
+ )
322
+
323
+ def encode(self, text):
324
+ return self(text)
325
+
326
+ def forward(self, text):
327
+ clip_z = self.clip_encoder.encode(text)
328
+ t5_z = self.t5_encoder.encode(text)
329
+ return [clip_z, t5_z]