Mariam-Elz commited on
Commit
11f10ee
·
verified ·
1 Parent(s): 79cf658

Upload imagedream/ldm/modules/attention.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/modules/attention.py +456 -0
imagedream/ldm/modules/attention.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from .diffusionmodules.util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+
16
+ XFORMERS_IS_AVAILBLE = True
17
+ except:
18
+ XFORMERS_IS_AVAILBLE = False
19
+
20
+ # CrossAttn precision handling
21
+ import os
22
+
23
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def uniq(arr):
31
+ return {el: True for el in arr}.keys()
32
+
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+
40
+ def max_neg_value(t):
41
+ return -torch.finfo(t.dtype).max
42
+
43
+
44
+ def init_(tensor):
45
+ dim = tensor.shape[-1]
46
+ std = 1 / math.sqrt(dim)
47
+ tensor.uniform_(-std, std)
48
+ return tensor
49
+
50
+
51
+ # feedforward
52
+ class GEGLU(nn.Module):
53
+ def __init__(self, dim_in, dim_out):
54
+ super().__init__()
55
+ self.proj = nn.Linear(dim_in, dim_out * 2)
56
+
57
+ def forward(self, x):
58
+ x, gate = self.proj(x).chunk(2, dim=-1)
59
+ return x * F.gelu(gate)
60
+
61
+
62
+ class FeedForward(nn.Module):
63
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
64
+ super().__init__()
65
+ inner_dim = int(dim * mult)
66
+ dim_out = default(dim_out, dim)
67
+ project_in = (
68
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
69
+ if not glu
70
+ else GEGLU(dim, inner_dim)
71
+ )
72
+
73
+ self.net = nn.Sequential(
74
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+
81
+ def zero_module(module):
82
+ """
83
+ Zero out the parameters of a module and return it.
84
+ """
85
+ for p in module.parameters():
86
+ p.detach().zero_()
87
+ return module
88
+
89
+
90
+ def Normalize(in_channels):
91
+ return torch.nn.GroupNorm(
92
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
93
+ )
94
+
95
+
96
+ class SpatialSelfAttention(nn.Module):
97
+ def __init__(self, in_channels):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+
101
+ self.norm = Normalize(in_channels)
102
+ self.q = torch.nn.Conv2d(
103
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
104
+ )
105
+ self.k = torch.nn.Conv2d(
106
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
107
+ )
108
+ self.v = torch.nn.Conv2d(
109
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
110
+ )
111
+ self.proj_out = torch.nn.Conv2d(
112
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
113
+ )
114
+
115
+ def forward(self, x):
116
+ h_ = x
117
+ h_ = self.norm(h_)
118
+ q = self.q(h_)
119
+ k = self.k(h_)
120
+ v = self.v(h_)
121
+
122
+ # compute attention
123
+ b, c, h, w = q.shape
124
+ q = rearrange(q, "b c h w -> b (h w) c")
125
+ k = rearrange(k, "b c h w -> b c (h w)")
126
+ w_ = torch.einsum("bij,bjk->bik", q, k)
127
+
128
+ w_ = w_ * (int(c) ** (-0.5))
129
+ w_ = torch.nn.functional.softmax(w_, dim=2)
130
+
131
+ # attend to values
132
+ v = rearrange(v, "b c h w -> b c (h w)")
133
+ w_ = rearrange(w_, "b i j -> b j i")
134
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
135
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
136
+ h_ = self.proj_out(h_)
137
+
138
+ return x + h_
139
+
140
+
141
+ class MemoryEfficientCrossAttention(nn.Module):
142
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
143
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
144
+ super().__init__()
145
+ print(
146
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
147
+ f"{heads} heads."
148
+ )
149
+ inner_dim = dim_head * heads
150
+ context_dim = default(context_dim, query_dim)
151
+
152
+ self.heads = heads
153
+ self.dim_head = dim_head
154
+
155
+ self.with_ip = kwargs.get("with_ip", False)
156
+ if self.with_ip and (context_dim is not None):
157
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
158
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
159
+ self.ip_dim= kwargs.get("ip_dim", 16)
160
+ self.ip_weight = kwargs.get("ip_weight", 1.0)
161
+
162
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
163
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
164
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
165
+
166
+ self.to_out = nn.Sequential(
167
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
168
+ )
169
+ self.attention_op: Optional[Any] = None
170
+
171
+ def forward(self, x, context=None, mask=None):
172
+ q = self.to_q(x)
173
+
174
+ has_ip = self.with_ip and (context is not None)
175
+ if has_ip:
176
+ # context dim [(b frame_num), (77 + img_token), 1024]
177
+ token_len = context.shape[1]
178
+ context_ip = context[:, -self.ip_dim:, :]
179
+ k_ip = self.to_k_ip(context_ip)
180
+ v_ip = self.to_v_ip(context_ip)
181
+ context = context[:, :(token_len - self.ip_dim), :]
182
+
183
+ context = default(context, x)
184
+ k = self.to_k(context)
185
+ v = self.to_v(context)
186
+
187
+ b, _, _ = q.shape
188
+ q, k, v = map(
189
+ lambda t: t.unsqueeze(3)
190
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
191
+ .permute(0, 2, 1, 3)
192
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
193
+ .contiguous(),
194
+ (q, k, v),
195
+ )
196
+
197
+ # actually compute the attention, what we cannot get enough of
198
+ out = xformers.ops.memory_efficient_attention(
199
+ q, k, v, attn_bias=None, op=self.attention_op
200
+ )
201
+
202
+ if has_ip:
203
+ k_ip, v_ip = map(
204
+ lambda t: t.unsqueeze(3)
205
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
206
+ .permute(0, 2, 1, 3)
207
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
208
+ .contiguous(),
209
+ (k_ip, v_ip),
210
+ )
211
+ # actually compute the attention, what we cannot get enough of
212
+ out_ip = xformers.ops.memory_efficient_attention(
213
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
214
+ )
215
+ out = out + self.ip_weight * out_ip
216
+
217
+ if exists(mask):
218
+ raise NotImplementedError
219
+ out = (
220
+ out.unsqueeze(0)
221
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
222
+ .permute(0, 2, 1, 3)
223
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
224
+ )
225
+ return self.to_out(out)
226
+
227
+
228
+ class BasicTransformerBlock(nn.Module):
229
+ def __init__(
230
+ self,
231
+ dim,
232
+ n_heads,
233
+ d_head,
234
+ dropout=0.0,
235
+ context_dim=None,
236
+ gated_ff=True,
237
+ checkpoint=True,
238
+ disable_self_attn=False,
239
+ **kwargs
240
+ ):
241
+ super().__init__()
242
+ assert XFORMERS_IS_AVAILBLE, "xformers is not available"
243
+ attn_cls = MemoryEfficientCrossAttention
244
+ self.disable_self_attn = disable_self_attn
245
+ self.attn1 = attn_cls(
246
+ query_dim=dim,
247
+ heads=n_heads,
248
+ dim_head=d_head,
249
+ dropout=dropout,
250
+ context_dim=context_dim if self.disable_self_attn else None,
251
+ ) # is a self-attention if not self.disable_self_attn
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = attn_cls(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ **kwargs
260
+ ) # is self-attn if context is none
261
+ self.norm1 = nn.LayerNorm(dim)
262
+ self.norm2 = nn.LayerNorm(dim)
263
+ self.norm3 = nn.LayerNorm(dim)
264
+ self.checkpoint = checkpoint
265
+
266
+ def forward(self, x, context=None):
267
+ return checkpoint(
268
+ self._forward, (x, context), self.parameters(), self.checkpoint
269
+ )
270
+
271
+ def _forward(self, x, context=None):
272
+ x = (
273
+ self.attn1(
274
+ self.norm1(x), context=context if self.disable_self_attn else None
275
+ )
276
+ + x
277
+ )
278
+ x = self.attn2(self.norm2(x), context=context) + x
279
+ x = self.ff(self.norm3(x)) + x
280
+ return x
281
+
282
+
283
+ class SpatialTransformer(nn.Module):
284
+ """
285
+ Transformer block for image-like data.
286
+ First, project the input (aka embedding)
287
+ and reshape to b, t, d.
288
+ Then apply standard transformer action.
289
+ Finally, reshape to image
290
+ NEW: use_linear for more efficiency instead of the 1x1 convs
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ in_channels,
296
+ n_heads,
297
+ d_head,
298
+ depth=1,
299
+ dropout=0.0,
300
+ context_dim=None,
301
+ disable_self_attn=False,
302
+ use_linear=False,
303
+ use_checkpoint=True,
304
+ **kwargs
305
+ ):
306
+ super().__init__()
307
+ if exists(context_dim) and not isinstance(context_dim, list):
308
+ context_dim = [context_dim]
309
+ self.in_channels = in_channels
310
+ inner_dim = n_heads * d_head
311
+ self.norm = Normalize(in_channels)
312
+ if not use_linear:
313
+ self.proj_in = nn.Conv2d(
314
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
315
+ )
316
+ else:
317
+ self.proj_in = nn.Linear(in_channels, inner_dim)
318
+
319
+ self.transformer_blocks = nn.ModuleList(
320
+ [
321
+ BasicTransformerBlock(
322
+ inner_dim,
323
+ n_heads,
324
+ d_head,
325
+ dropout=dropout,
326
+ context_dim=context_dim[d],
327
+ disable_self_attn=disable_self_attn,
328
+ checkpoint=use_checkpoint,
329
+ **kwargs
330
+ )
331
+ for d in range(depth)
332
+ ]
333
+ )
334
+ if not use_linear:
335
+ self.proj_out = zero_module(
336
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
337
+ )
338
+ else:
339
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
340
+ self.use_linear = use_linear
341
+
342
+ def forward(self, x, context=None):
343
+ # note: if no context is given, cross-attention defaults to self-attention
344
+ if not isinstance(context, list):
345
+ context = [context]
346
+ b, c, h, w = x.shape
347
+ x_in = x
348
+ x = self.norm(x)
349
+ if not self.use_linear:
350
+ x = self.proj_in(x)
351
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
352
+ if self.use_linear:
353
+ x = self.proj_in(x)
354
+ for i, block in enumerate(self.transformer_blocks):
355
+ x = block(x, context=context[i])
356
+ if self.use_linear:
357
+ x = self.proj_out(x)
358
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
359
+ if not self.use_linear:
360
+ x = self.proj_out(x)
361
+ return x + x_in
362
+
363
+
364
+ class BasicTransformerBlock3D(BasicTransformerBlock):
365
+ def forward(self, x, context=None, num_frames=1):
366
+ return checkpoint(
367
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
368
+ )
369
+
370
+ def _forward(self, x, context=None, num_frames=1):
371
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
372
+ x = (
373
+ self.attn1(
374
+ self.norm1(x),
375
+ context=context if self.disable_self_attn else None
376
+ )
377
+ + x
378
+ )
379
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
380
+ x = self.attn2(self.norm2(x), context=context) + x
381
+ x = self.ff(self.norm3(x)) + x
382
+ return x
383
+
384
+
385
+ class SpatialTransformer3D(nn.Module):
386
+ """3D self-attention"""
387
+
388
+ def __init__(
389
+ self,
390
+ in_channels,
391
+ n_heads,
392
+ d_head,
393
+ depth=1,
394
+ dropout=0.0,
395
+ context_dim=None,
396
+ disable_self_attn=False,
397
+ use_linear=False,
398
+ use_checkpoint=True,
399
+ **kwargs
400
+ ):
401
+ super().__init__()
402
+ if exists(context_dim) and not isinstance(context_dim, list):
403
+ context_dim = [context_dim]
404
+ self.in_channels = in_channels
405
+ inner_dim = n_heads * d_head
406
+ self.norm = Normalize(in_channels)
407
+ if not use_linear:
408
+ self.proj_in = nn.Conv2d(
409
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
410
+ )
411
+ else:
412
+ self.proj_in = nn.Linear(in_channels, inner_dim)
413
+
414
+ self.transformer_blocks = nn.ModuleList(
415
+ [
416
+ BasicTransformerBlock3D(
417
+ inner_dim,
418
+ n_heads,
419
+ d_head,
420
+ dropout=dropout,
421
+ context_dim=context_dim[d],
422
+ disable_self_attn=disable_self_attn,
423
+ checkpoint=use_checkpoint,
424
+ **kwargs
425
+ )
426
+ for d in range(depth)
427
+ ]
428
+ )
429
+ if not use_linear:
430
+ self.proj_out = zero_module(
431
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
432
+ )
433
+ else:
434
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
435
+ self.use_linear = use_linear
436
+
437
+ def forward(self, x, context=None, num_frames=1):
438
+ # note: if no context is given, cross-attention defaults to self-attention
439
+ if not isinstance(context, list):
440
+ context = [context]
441
+ b, c, h, w = x.shape
442
+ x_in = x
443
+ x = self.norm(x)
444
+ if not self.use_linear:
445
+ x = self.proj_in(x)
446
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
447
+ if self.use_linear:
448
+ x = self.proj_in(x)
449
+ for i, block in enumerate(self.transformer_blocks):
450
+ x = block(x, context=context[i], num_frames=num_frames)
451
+ if self.use_linear:
452
+ x = self.proj_out(x)
453
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
454
+ if not self.use_linear:
455
+ x = self.proj_out(x)
456
+ return x + x_in