Mariam-Elz commited on
Commit
5b9ca85
·
verified ·
1 Parent(s): c6597fb

Upload imagedream/ldm/modules/diffusionmodules/model.py with huggingface_hub

Browse files
imagedream/ldm/modules/diffusionmodules/model.py CHANGED
@@ -1,1018 +1,1018 @@
1
- # pytorch_diffusion + derived encoder decoder
2
- import math
3
- import torch
4
- import torch.nn as nn
5
- import numpy as np
6
- from einops import rearrange
7
- from typing import Optional, Any
8
-
9
- from ..attention import MemoryEfficientCrossAttention
10
-
11
- try:
12
- import xformers
13
- import xformers.ops
14
-
15
- XFORMERS_IS_AVAILBLE = True
16
- except:
17
- XFORMERS_IS_AVAILBLE = False
18
- print("No module 'xformers'. Proceeding without it.")
19
-
20
-
21
- def get_timestep_embedding(timesteps, embedding_dim):
22
- """
23
- This matches the implementation in Denoising Diffusion Probabilistic Models:
24
- From Fairseq.
25
- Build sinusoidal embeddings.
26
- This matches the implementation in tensor2tensor, but differs slightly
27
- from the description in Section 3.5 of "Attention Is All You Need".
28
- """
29
- assert len(timesteps.shape) == 1
30
-
31
- half_dim = embedding_dim // 2
32
- emb = math.log(10000) / (half_dim - 1)
33
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
34
- emb = emb.to(device=timesteps.device)
35
- emb = timesteps.float()[:, None] * emb[None, :]
36
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
37
- if embedding_dim % 2 == 1: # zero pad
38
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
39
- return emb
40
-
41
-
42
- def nonlinearity(x):
43
- # swish
44
- return x * torch.sigmoid(x)
45
-
46
-
47
- def Normalize(in_channels, num_groups=32):
48
- return torch.nn.GroupNorm(
49
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
50
- )
51
-
52
-
53
- class Upsample(nn.Module):
54
- def __init__(self, in_channels, with_conv):
55
- super().__init__()
56
- self.with_conv = with_conv
57
- if self.with_conv:
58
- self.conv = torch.nn.Conv2d(
59
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
60
- )
61
-
62
- def forward(self, x):
63
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
64
- if self.with_conv:
65
- x = self.conv(x)
66
- return x
67
-
68
-
69
- class Downsample(nn.Module):
70
- def __init__(self, in_channels, with_conv):
71
- super().__init__()
72
- self.with_conv = with_conv
73
- if self.with_conv:
74
- # no asymmetric padding in torch conv, must do it ourselves
75
- self.conv = torch.nn.Conv2d(
76
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
77
- )
78
-
79
- def forward(self, x):
80
- if self.with_conv:
81
- pad = (0, 1, 0, 1)
82
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
83
- x = self.conv(x)
84
- else:
85
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
86
- return x
87
-
88
-
89
- class ResnetBlock(nn.Module):
90
- def __init__(
91
- self,
92
- *,
93
- in_channels,
94
- out_channels=None,
95
- conv_shortcut=False,
96
- dropout,
97
- temb_channels=512,
98
- ):
99
- super().__init__()
100
- self.in_channels = in_channels
101
- out_channels = in_channels if out_channels is None else out_channels
102
- self.out_channels = out_channels
103
- self.use_conv_shortcut = conv_shortcut
104
-
105
- self.norm1 = Normalize(in_channels)
106
- self.conv1 = torch.nn.Conv2d(
107
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
108
- )
109
- if temb_channels > 0:
110
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
111
- self.norm2 = Normalize(out_channels)
112
- self.dropout = torch.nn.Dropout(dropout)
113
- self.conv2 = torch.nn.Conv2d(
114
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
115
- )
116
- if self.in_channels != self.out_channels:
117
- if self.use_conv_shortcut:
118
- self.conv_shortcut = torch.nn.Conv2d(
119
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
120
- )
121
- else:
122
- self.nin_shortcut = torch.nn.Conv2d(
123
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
124
- )
125
-
126
- def forward(self, x, temb):
127
- h = x
128
- h = self.norm1(h)
129
- h = nonlinearity(h)
130
- h = self.conv1(h)
131
-
132
- if temb is not None:
133
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
134
-
135
- h = self.norm2(h)
136
- h = nonlinearity(h)
137
- h = self.dropout(h)
138
- h = self.conv2(h)
139
-
140
- if self.in_channels != self.out_channels:
141
- if self.use_conv_shortcut:
142
- x = self.conv_shortcut(x)
143
- else:
144
- x = self.nin_shortcut(x)
145
-
146
- return x + h
147
-
148
-
149
- class AttnBlock(nn.Module):
150
- def __init__(self, in_channels):
151
- super().__init__()
152
- self.in_channels = in_channels
153
-
154
- self.norm = Normalize(in_channels)
155
- self.q = torch.nn.Conv2d(
156
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
- )
158
- self.k = torch.nn.Conv2d(
159
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
- )
161
- self.v = torch.nn.Conv2d(
162
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
- )
164
- self.proj_out = torch.nn.Conv2d(
165
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
- )
167
-
168
- def forward(self, x):
169
- h_ = x
170
- h_ = self.norm(h_)
171
- q = self.q(h_)
172
- k = self.k(h_)
173
- v = self.v(h_)
174
-
175
- # compute attention
176
- b, c, h, w = q.shape
177
- q = q.reshape(b, c, h * w)
178
- q = q.permute(0, 2, 1) # b,hw,c
179
- k = k.reshape(b, c, h * w) # b,c,hw
180
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
- w_ = w_ * (int(c) ** (-0.5))
182
- w_ = torch.nn.functional.softmax(w_, dim=2)
183
-
184
- # attend to values
185
- v = v.reshape(b, c, h * w)
186
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
187
- h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
- h_ = h_.reshape(b, c, h, w)
189
-
190
- h_ = self.proj_out(h_)
191
-
192
- return x + h_
193
-
194
-
195
- class MemoryEfficientAttnBlock(nn.Module):
196
- """
197
- Uses xformers efficient implementation,
198
- see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199
- Note: this is a single-head self-attention operation
200
- """
201
-
202
- #
203
- def __init__(self, in_channels):
204
- super().__init__()
205
- self.in_channels = in_channels
206
-
207
- self.norm = Normalize(in_channels)
208
- self.q = torch.nn.Conv2d(
209
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
210
- )
211
- self.k = torch.nn.Conv2d(
212
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
213
- )
214
- self.v = torch.nn.Conv2d(
215
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
216
- )
217
- self.proj_out = torch.nn.Conv2d(
218
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
- )
220
- self.attention_op: Optional[Any] = None
221
-
222
- def forward(self, x):
223
- h_ = x
224
- h_ = self.norm(h_)
225
- q = self.q(h_)
226
- k = self.k(h_)
227
- v = self.v(h_)
228
-
229
- # compute attention
230
- B, C, H, W = q.shape
231
- q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
232
-
233
- q, k, v = map(
234
- lambda t: t.unsqueeze(3)
235
- .reshape(B, t.shape[1], 1, C)
236
- .permute(0, 2, 1, 3)
237
- .reshape(B * 1, t.shape[1], C)
238
- .contiguous(),
239
- (q, k, v),
240
- )
241
- out = xformers.ops.memory_efficient_attention(
242
- q, k, v, attn_bias=None, op=self.attention_op
243
- )
244
-
245
- out = (
246
- out.unsqueeze(0)
247
- .reshape(B, 1, out.shape[1], C)
248
- .permute(0, 2, 1, 3)
249
- .reshape(B, out.shape[1], C)
250
- )
251
- out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
252
- out = self.proj_out(out)
253
- return x + out
254
-
255
-
256
- class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
257
- def forward(self, x, context=None, mask=None):
258
- b, c, h, w = x.shape
259
- x = rearrange(x, "b c h w -> b (h w) c")
260
- out = super().forward(x, context=context, mask=mask)
261
- out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
262
- return x + out
263
-
264
-
265
- def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
266
- assert attn_type in [
267
- "vanilla",
268
- "vanilla-xformers",
269
- "memory-efficient-cross-attn",
270
- "linear",
271
- "none",
272
- ], f"attn_type {attn_type} unknown"
273
- if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
274
- attn_type = "vanilla-xformers"
275
- print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
276
- if attn_type == "vanilla":
277
- assert attn_kwargs is None
278
- return AttnBlock(in_channels)
279
- elif attn_type == "vanilla-xformers":
280
- print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
281
- return MemoryEfficientAttnBlock(in_channels)
282
- elif type == "memory-efficient-cross-attn":
283
- attn_kwargs["query_dim"] = in_channels
284
- return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
285
- elif attn_type == "none":
286
- return nn.Identity(in_channels)
287
- else:
288
- raise NotImplementedError()
289
-
290
-
291
- class Model(nn.Module):
292
- def __init__(
293
- self,
294
- *,
295
- ch,
296
- out_ch,
297
- ch_mult=(1, 2, 4, 8),
298
- num_res_blocks,
299
- attn_resolutions,
300
- dropout=0.0,
301
- resamp_with_conv=True,
302
- in_channels,
303
- resolution,
304
- use_timestep=True,
305
- use_linear_attn=False,
306
- attn_type="vanilla",
307
- ):
308
- super().__init__()
309
- if use_linear_attn:
310
- attn_type = "linear"
311
- self.ch = ch
312
- self.temb_ch = self.ch * 4
313
- self.num_resolutions = len(ch_mult)
314
- self.num_res_blocks = num_res_blocks
315
- self.resolution = resolution
316
- self.in_channels = in_channels
317
-
318
- self.use_timestep = use_timestep
319
- if self.use_timestep:
320
- # timestep embedding
321
- self.temb = nn.Module()
322
- self.temb.dense = nn.ModuleList(
323
- [
324
- torch.nn.Linear(self.ch, self.temb_ch),
325
- torch.nn.Linear(self.temb_ch, self.temb_ch),
326
- ]
327
- )
328
-
329
- # downsampling
330
- self.conv_in = torch.nn.Conv2d(
331
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
332
- )
333
-
334
- curr_res = resolution
335
- in_ch_mult = (1,) + tuple(ch_mult)
336
- self.down = nn.ModuleList()
337
- for i_level in range(self.num_resolutions):
338
- block = nn.ModuleList()
339
- attn = nn.ModuleList()
340
- block_in = ch * in_ch_mult[i_level]
341
- block_out = ch * ch_mult[i_level]
342
- for i_block in range(self.num_res_blocks):
343
- block.append(
344
- ResnetBlock(
345
- in_channels=block_in,
346
- out_channels=block_out,
347
- temb_channels=self.temb_ch,
348
- dropout=dropout,
349
- )
350
- )
351
- block_in = block_out
352
- if curr_res in attn_resolutions:
353
- attn.append(make_attn(block_in, attn_type=attn_type))
354
- down = nn.Module()
355
- down.block = block
356
- down.attn = attn
357
- if i_level != self.num_resolutions - 1:
358
- down.downsample = Downsample(block_in, resamp_with_conv)
359
- curr_res = curr_res // 2
360
- self.down.append(down)
361
-
362
- # middle
363
- self.mid = nn.Module()
364
- self.mid.block_1 = ResnetBlock(
365
- in_channels=block_in,
366
- out_channels=block_in,
367
- temb_channels=self.temb_ch,
368
- dropout=dropout,
369
- )
370
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
371
- self.mid.block_2 = ResnetBlock(
372
- in_channels=block_in,
373
- out_channels=block_in,
374
- temb_channels=self.temb_ch,
375
- dropout=dropout,
376
- )
377
-
378
- # upsampling
379
- self.up = nn.ModuleList()
380
- for i_level in reversed(range(self.num_resolutions)):
381
- block = nn.ModuleList()
382
- attn = nn.ModuleList()
383
- block_out = ch * ch_mult[i_level]
384
- skip_in = ch * ch_mult[i_level]
385
- for i_block in range(self.num_res_blocks + 1):
386
- if i_block == self.num_res_blocks:
387
- skip_in = ch * in_ch_mult[i_level]
388
- block.append(
389
- ResnetBlock(
390
- in_channels=block_in + skip_in,
391
- out_channels=block_out,
392
- temb_channels=self.temb_ch,
393
- dropout=dropout,
394
- )
395
- )
396
- block_in = block_out
397
- if curr_res in attn_resolutions:
398
- attn.append(make_attn(block_in, attn_type=attn_type))
399
- up = nn.Module()
400
- up.block = block
401
- up.attn = attn
402
- if i_level != 0:
403
- up.upsample = Upsample(block_in, resamp_with_conv)
404
- curr_res = curr_res * 2
405
- self.up.insert(0, up) # prepend to get consistent order
406
-
407
- # end
408
- self.norm_out = Normalize(block_in)
409
- self.conv_out = torch.nn.Conv2d(
410
- block_in, out_ch, kernel_size=3, stride=1, padding=1
411
- )
412
-
413
- def forward(self, x, t=None, context=None):
414
- # assert x.shape[2] == x.shape[3] == self.resolution
415
- if context is not None:
416
- # assume aligned context, cat along channel axis
417
- x = torch.cat((x, context), dim=1)
418
- if self.use_timestep:
419
- # timestep embedding
420
- assert t is not None
421
- temb = get_timestep_embedding(t, self.ch)
422
- temb = self.temb.dense[0](temb)
423
- temb = nonlinearity(temb)
424
- temb = self.temb.dense[1](temb)
425
- else:
426
- temb = None
427
-
428
- # downsampling
429
- hs = [self.conv_in(x)]
430
- for i_level in range(self.num_resolutions):
431
- for i_block in range(self.num_res_blocks):
432
- h = self.down[i_level].block[i_block](hs[-1], temb)
433
- if len(self.down[i_level].attn) > 0:
434
- h = self.down[i_level].attn[i_block](h)
435
- hs.append(h)
436
- if i_level != self.num_resolutions - 1:
437
- hs.append(self.down[i_level].downsample(hs[-1]))
438
-
439
- # middle
440
- h = hs[-1]
441
- h = self.mid.block_1(h, temb)
442
- h = self.mid.attn_1(h)
443
- h = self.mid.block_2(h, temb)
444
-
445
- # upsampling
446
- for i_level in reversed(range(self.num_resolutions)):
447
- for i_block in range(self.num_res_blocks + 1):
448
- h = self.up[i_level].block[i_block](
449
- torch.cat([h, hs.pop()], dim=1), temb
450
- )
451
- if len(self.up[i_level].attn) > 0:
452
- h = self.up[i_level].attn[i_block](h)
453
- if i_level != 0:
454
- h = self.up[i_level].upsample(h)
455
-
456
- # end
457
- h = self.norm_out(h)
458
- h = nonlinearity(h)
459
- h = self.conv_out(h)
460
- return h
461
-
462
- def get_last_layer(self):
463
- return self.conv_out.weight
464
-
465
-
466
- class Encoder(nn.Module):
467
- def __init__(
468
- self,
469
- *,
470
- ch,
471
- out_ch,
472
- ch_mult=(1, 2, 4, 8),
473
- num_res_blocks,
474
- attn_resolutions,
475
- dropout=0.0,
476
- resamp_with_conv=True,
477
- in_channels,
478
- resolution,
479
- z_channels,
480
- double_z=True,
481
- use_linear_attn=False,
482
- attn_type="vanilla",
483
- **ignore_kwargs,
484
- ):
485
- super().__init__()
486
- if use_linear_attn:
487
- attn_type = "linear"
488
- self.ch = ch
489
- self.temb_ch = 0
490
- self.num_resolutions = len(ch_mult)
491
- self.num_res_blocks = num_res_blocks
492
- self.resolution = resolution
493
- self.in_channels = in_channels
494
-
495
- # downsampling
496
- self.conv_in = torch.nn.Conv2d(
497
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
498
- )
499
-
500
- curr_res = resolution
501
- in_ch_mult = (1,) + tuple(ch_mult)
502
- self.in_ch_mult = in_ch_mult
503
- self.down = nn.ModuleList()
504
- for i_level in range(self.num_resolutions):
505
- block = nn.ModuleList()
506
- attn = nn.ModuleList()
507
- block_in = ch * in_ch_mult[i_level]
508
- block_out = ch * ch_mult[i_level]
509
- for i_block in range(self.num_res_blocks):
510
- block.append(
511
- ResnetBlock(
512
- in_channels=block_in,
513
- out_channels=block_out,
514
- temb_channels=self.temb_ch,
515
- dropout=dropout,
516
- )
517
- )
518
- block_in = block_out
519
- if curr_res in attn_resolutions:
520
- attn.append(make_attn(block_in, attn_type=attn_type))
521
- down = nn.Module()
522
- down.block = block
523
- down.attn = attn
524
- if i_level != self.num_resolutions - 1:
525
- down.downsample = Downsample(block_in, resamp_with_conv)
526
- curr_res = curr_res // 2
527
- self.down.append(down)
528
-
529
- # middle
530
- self.mid = nn.Module()
531
- self.mid.block_1 = ResnetBlock(
532
- in_channels=block_in,
533
- out_channels=block_in,
534
- temb_channels=self.temb_ch,
535
- dropout=dropout,
536
- )
537
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
538
- self.mid.block_2 = ResnetBlock(
539
- in_channels=block_in,
540
- out_channels=block_in,
541
- temb_channels=self.temb_ch,
542
- dropout=dropout,
543
- )
544
-
545
- # end
546
- self.norm_out = Normalize(block_in)
547
- self.conv_out = torch.nn.Conv2d(
548
- block_in,
549
- 2 * z_channels if double_z else z_channels,
550
- kernel_size=3,
551
- stride=1,
552
- padding=1,
553
- )
554
-
555
- def forward(self, x):
556
- # timestep embedding
557
- temb = None
558
-
559
- # downsampling
560
- hs = [self.conv_in(x)]
561
- for i_level in range(self.num_resolutions):
562
- for i_block in range(self.num_res_blocks):
563
- h = self.down[i_level].block[i_block](hs[-1], temb)
564
- if len(self.down[i_level].attn) > 0:
565
- h = self.down[i_level].attn[i_block](h)
566
- hs.append(h)
567
- if i_level != self.num_resolutions - 1:
568
- hs.append(self.down[i_level].downsample(hs[-1]))
569
-
570
- # middle
571
- h = hs[-1]
572
- h = self.mid.block_1(h, temb)
573
- h = self.mid.attn_1(h)
574
- h = self.mid.block_2(h, temb)
575
-
576
- # end
577
- h = self.norm_out(h)
578
- h = nonlinearity(h)
579
- h = self.conv_out(h)
580
- return h
581
-
582
-
583
- class Decoder(nn.Module):
584
- def __init__(
585
- self,
586
- *,
587
- ch,
588
- out_ch,
589
- ch_mult=(1, 2, 4, 8),
590
- num_res_blocks,
591
- attn_resolutions,
592
- dropout=0.0,
593
- resamp_with_conv=True,
594
- in_channels,
595
- resolution,
596
- z_channels,
597
- give_pre_end=False,
598
- tanh_out=False,
599
- use_linear_attn=False,
600
- attn_type="vanilla",
601
- **ignorekwargs,
602
- ):
603
- super().__init__()
604
- if use_linear_attn:
605
- attn_type = "linear"
606
- self.ch = ch
607
- self.temb_ch = 0
608
- self.num_resolutions = len(ch_mult)
609
- self.num_res_blocks = num_res_blocks
610
- self.resolution = resolution
611
- self.in_channels = in_channels
612
- self.give_pre_end = give_pre_end
613
- self.tanh_out = tanh_out
614
-
615
- # compute in_ch_mult, block_in and curr_res at lowest res
616
- in_ch_mult = (1,) + tuple(ch_mult)
617
- block_in = ch * ch_mult[self.num_resolutions - 1]
618
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
619
- self.z_shape = (1, z_channels, curr_res, curr_res)
620
- print(
621
- "Working with z of shape {} = {} dimensions.".format(
622
- self.z_shape, np.prod(self.z_shape)
623
- )
624
- )
625
-
626
- # z to block_in
627
- self.conv_in = torch.nn.Conv2d(
628
- z_channels, block_in, kernel_size=3, stride=1, padding=1
629
- )
630
-
631
- # middle
632
- self.mid = nn.Module()
633
- self.mid.block_1 = ResnetBlock(
634
- in_channels=block_in,
635
- out_channels=block_in,
636
- temb_channels=self.temb_ch,
637
- dropout=dropout,
638
- )
639
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
640
- self.mid.block_2 = ResnetBlock(
641
- in_channels=block_in,
642
- out_channels=block_in,
643
- temb_channels=self.temb_ch,
644
- dropout=dropout,
645
- )
646
-
647
- # upsampling
648
- self.up = nn.ModuleList()
649
- for i_level in reversed(range(self.num_resolutions)):
650
- block = nn.ModuleList()
651
- attn = nn.ModuleList()
652
- block_out = ch * ch_mult[i_level]
653
- for i_block in range(self.num_res_blocks + 1):
654
- block.append(
655
- ResnetBlock(
656
- in_channels=block_in,
657
- out_channels=block_out,
658
- temb_channels=self.temb_ch,
659
- dropout=dropout,
660
- )
661
- )
662
- block_in = block_out
663
- if curr_res in attn_resolutions:
664
- attn.append(make_attn(block_in, attn_type=attn_type))
665
- up = nn.Module()
666
- up.block = block
667
- up.attn = attn
668
- if i_level != 0:
669
- up.upsample = Upsample(block_in, resamp_with_conv)
670
- curr_res = curr_res * 2
671
- self.up.insert(0, up) # prepend to get consistent order
672
-
673
- # end
674
- self.norm_out = Normalize(block_in)
675
- self.conv_out = torch.nn.Conv2d(
676
- block_in, out_ch, kernel_size=3, stride=1, padding=1
677
- )
678
-
679
- def forward(self, z):
680
- # assert z.shape[1:] == self.z_shape[1:]
681
- self.last_z_shape = z.shape
682
-
683
- # timestep embedding
684
- temb = None
685
-
686
- # z to block_in
687
- h = self.conv_in(z)
688
-
689
- # middle
690
- h = self.mid.block_1(h, temb)
691
- h = self.mid.attn_1(h)
692
- h = self.mid.block_2(h, temb)
693
-
694
- # upsampling
695
- for i_level in reversed(range(self.num_resolutions)):
696
- for i_block in range(self.num_res_blocks + 1):
697
- h = self.up[i_level].block[i_block](h, temb)
698
- if len(self.up[i_level].attn) > 0:
699
- h = self.up[i_level].attn[i_block](h)
700
- if i_level != 0:
701
- h = self.up[i_level].upsample(h)
702
-
703
- # end
704
- if self.give_pre_end:
705
- return h
706
-
707
- h = self.norm_out(h)
708
- h = nonlinearity(h)
709
- h = self.conv_out(h)
710
- if self.tanh_out:
711
- h = torch.tanh(h)
712
- return h
713
-
714
-
715
- class SimpleDecoder(nn.Module):
716
- def __init__(self, in_channels, out_channels, *args, **kwargs):
717
- super().__init__()
718
- self.model = nn.ModuleList(
719
- [
720
- nn.Conv2d(in_channels, in_channels, 1),
721
- ResnetBlock(
722
- in_channels=in_channels,
723
- out_channels=2 * in_channels,
724
- temb_channels=0,
725
- dropout=0.0,
726
- ),
727
- ResnetBlock(
728
- in_channels=2 * in_channels,
729
- out_channels=4 * in_channels,
730
- temb_channels=0,
731
- dropout=0.0,
732
- ),
733
- ResnetBlock(
734
- in_channels=4 * in_channels,
735
- out_channels=2 * in_channels,
736
- temb_channels=0,
737
- dropout=0.0,
738
- ),
739
- nn.Conv2d(2 * in_channels, in_channels, 1),
740
- Upsample(in_channels, with_conv=True),
741
- ]
742
- )
743
- # end
744
- self.norm_out = Normalize(in_channels)
745
- self.conv_out = torch.nn.Conv2d(
746
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
747
- )
748
-
749
- def forward(self, x):
750
- for i, layer in enumerate(self.model):
751
- if i in [1, 2, 3]:
752
- x = layer(x, None)
753
- else:
754
- x = layer(x)
755
-
756
- h = self.norm_out(x)
757
- h = nonlinearity(h)
758
- x = self.conv_out(h)
759
- return x
760
-
761
-
762
- class UpsampleDecoder(nn.Module):
763
- def __init__(
764
- self,
765
- in_channels,
766
- out_channels,
767
- ch,
768
- num_res_blocks,
769
- resolution,
770
- ch_mult=(2, 2),
771
- dropout=0.0,
772
- ):
773
- super().__init__()
774
- # upsampling
775
- self.temb_ch = 0
776
- self.num_resolutions = len(ch_mult)
777
- self.num_res_blocks = num_res_blocks
778
- block_in = in_channels
779
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
780
- self.res_blocks = nn.ModuleList()
781
- self.upsample_blocks = nn.ModuleList()
782
- for i_level in range(self.num_resolutions):
783
- res_block = []
784
- block_out = ch * ch_mult[i_level]
785
- for i_block in range(self.num_res_blocks + 1):
786
- res_block.append(
787
- ResnetBlock(
788
- in_channels=block_in,
789
- out_channels=block_out,
790
- temb_channels=self.temb_ch,
791
- dropout=dropout,
792
- )
793
- )
794
- block_in = block_out
795
- self.res_blocks.append(nn.ModuleList(res_block))
796
- if i_level != self.num_resolutions - 1:
797
- self.upsample_blocks.append(Upsample(block_in, True))
798
- curr_res = curr_res * 2
799
-
800
- # end
801
- self.norm_out = Normalize(block_in)
802
- self.conv_out = torch.nn.Conv2d(
803
- block_in, out_channels, kernel_size=3, stride=1, padding=1
804
- )
805
-
806
- def forward(self, x):
807
- # upsampling
808
- h = x
809
- for k, i_level in enumerate(range(self.num_resolutions)):
810
- for i_block in range(self.num_res_blocks + 1):
811
- h = self.res_blocks[i_level][i_block](h, None)
812
- if i_level != self.num_resolutions - 1:
813
- h = self.upsample_blocks[k](h)
814
- h = self.norm_out(h)
815
- h = nonlinearity(h)
816
- h = self.conv_out(h)
817
- return h
818
-
819
-
820
- class LatentRescaler(nn.Module):
821
- def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
822
- super().__init__()
823
- # residual block, interpolate, residual block
824
- self.factor = factor
825
- self.conv_in = nn.Conv2d(
826
- in_channels, mid_channels, kernel_size=3, stride=1, padding=1
827
- )
828
- self.res_block1 = nn.ModuleList(
829
- [
830
- ResnetBlock(
831
- in_channels=mid_channels,
832
- out_channels=mid_channels,
833
- temb_channels=0,
834
- dropout=0.0,
835
- )
836
- for _ in range(depth)
837
- ]
838
- )
839
- self.attn = AttnBlock(mid_channels)
840
- self.res_block2 = nn.ModuleList(
841
- [
842
- ResnetBlock(
843
- in_channels=mid_channels,
844
- out_channels=mid_channels,
845
- temb_channels=0,
846
- dropout=0.0,
847
- )
848
- for _ in range(depth)
849
- ]
850
- )
851
-
852
- self.conv_out = nn.Conv2d(
853
- mid_channels,
854
- out_channels,
855
- kernel_size=1,
856
- )
857
-
858
- def forward(self, x):
859
- x = self.conv_in(x)
860
- for block in self.res_block1:
861
- x = block(x, None)
862
- x = torch.nn.functional.interpolate(
863
- x,
864
- size=(
865
- int(round(x.shape[2] * self.factor)),
866
- int(round(x.shape[3] * self.factor)),
867
- ),
868
- )
869
- x = self.attn(x)
870
- for block in self.res_block2:
871
- x = block(x, None)
872
- x = self.conv_out(x)
873
- return x
874
-
875
-
876
- class MergedRescaleEncoder(nn.Module):
877
- def __init__(
878
- self,
879
- in_channels,
880
- ch,
881
- resolution,
882
- out_ch,
883
- num_res_blocks,
884
- attn_resolutions,
885
- dropout=0.0,
886
- resamp_with_conv=True,
887
- ch_mult=(1, 2, 4, 8),
888
- rescale_factor=1.0,
889
- rescale_module_depth=1,
890
- ):
891
- super().__init__()
892
- intermediate_chn = ch * ch_mult[-1]
893
- self.encoder = Encoder(
894
- in_channels=in_channels,
895
- num_res_blocks=num_res_blocks,
896
- ch=ch,
897
- ch_mult=ch_mult,
898
- z_channels=intermediate_chn,
899
- double_z=False,
900
- resolution=resolution,
901
- attn_resolutions=attn_resolutions,
902
- dropout=dropout,
903
- resamp_with_conv=resamp_with_conv,
904
- out_ch=None,
905
- )
906
- self.rescaler = LatentRescaler(
907
- factor=rescale_factor,
908
- in_channels=intermediate_chn,
909
- mid_channels=intermediate_chn,
910
- out_channels=out_ch,
911
- depth=rescale_module_depth,
912
- )
913
-
914
- def forward(self, x):
915
- x = self.encoder(x)
916
- x = self.rescaler(x)
917
- return x
918
-
919
-
920
- class MergedRescaleDecoder(nn.Module):
921
- def __init__(
922
- self,
923
- z_channels,
924
- out_ch,
925
- resolution,
926
- num_res_blocks,
927
- attn_resolutions,
928
- ch,
929
- ch_mult=(1, 2, 4, 8),
930
- dropout=0.0,
931
- resamp_with_conv=True,
932
- rescale_factor=1.0,
933
- rescale_module_depth=1,
934
- ):
935
- super().__init__()
936
- tmp_chn = z_channels * ch_mult[-1]
937
- self.decoder = Decoder(
938
- out_ch=out_ch,
939
- z_channels=tmp_chn,
940
- attn_resolutions=attn_resolutions,
941
- dropout=dropout,
942
- resamp_with_conv=resamp_with_conv,
943
- in_channels=None,
944
- num_res_blocks=num_res_blocks,
945
- ch_mult=ch_mult,
946
- resolution=resolution,
947
- ch=ch,
948
- )
949
- self.rescaler = LatentRescaler(
950
- factor=rescale_factor,
951
- in_channels=z_channels,
952
- mid_channels=tmp_chn,
953
- out_channels=tmp_chn,
954
- depth=rescale_module_depth,
955
- )
956
-
957
- def forward(self, x):
958
- x = self.rescaler(x)
959
- x = self.decoder(x)
960
- return x
961
-
962
-
963
- class Upsampler(nn.Module):
964
- def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
965
- super().__init__()
966
- assert out_size >= in_size
967
- num_blocks = int(np.log2(out_size // in_size)) + 1
968
- factor_up = 1.0 + (out_size % in_size)
969
- print(
970
- f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
971
- )
972
- self.rescaler = LatentRescaler(
973
- factor=factor_up,
974
- in_channels=in_channels,
975
- mid_channels=2 * in_channels,
976
- out_channels=in_channels,
977
- )
978
- self.decoder = Decoder(
979
- out_ch=out_channels,
980
- resolution=out_size,
981
- z_channels=in_channels,
982
- num_res_blocks=2,
983
- attn_resolutions=[],
984
- in_channels=None,
985
- ch=in_channels,
986
- ch_mult=[ch_mult for _ in range(num_blocks)],
987
- )
988
-
989
- def forward(self, x):
990
- x = self.rescaler(x)
991
- x = self.decoder(x)
992
- return x
993
-
994
-
995
- class Resize(nn.Module):
996
- def __init__(self, in_channels=None, learned=False, mode="bilinear"):
997
- super().__init__()
998
- self.with_conv = learned
999
- self.mode = mode
1000
- if self.with_conv:
1001
- print(
1002
- f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
1003
- )
1004
- raise NotImplementedError()
1005
- assert in_channels is not None
1006
- # no asymmetric padding in torch conv, must do it ourselves
1007
- self.conv = torch.nn.Conv2d(
1008
- in_channels, in_channels, kernel_size=4, stride=2, padding=1
1009
- )
1010
-
1011
- def forward(self, x, scale_factor=1.0):
1012
- if scale_factor == 1.0:
1013
- return x
1014
- else:
1015
- x = torch.nn.functional.interpolate(
1016
- x, mode=self.mode, align_corners=False, scale_factor=scale_factor
1017
- )
1018
- return x
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from ..attention import MemoryEfficientCrossAttention
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILBLE = False
18
+ print("No module 'xformers'. Proceeding without it.")
19
+
20
+
21
+ def get_timestep_embedding(timesteps, embedding_dim):
22
+ """
23
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
24
+ From Fairseq.
25
+ Build sinusoidal embeddings.
26
+ This matches the implementation in tensor2tensor, but differs slightly
27
+ from the description in Section 3.5 of "Attention Is All You Need".
28
+ """
29
+ assert len(timesteps.shape) == 1
30
+
31
+ half_dim = embedding_dim // 2
32
+ emb = math.log(10000) / (half_dim - 1)
33
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
34
+ emb = emb.to(device=timesteps.device)
35
+ emb = timesteps.float()[:, None] * emb[None, :]
36
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
37
+ if embedding_dim % 2 == 1: # zero pad
38
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
39
+ return emb
40
+
41
+
42
+ def nonlinearity(x):
43
+ # swish
44
+ return x * torch.sigmoid(x)
45
+
46
+
47
+ def Normalize(in_channels, num_groups=32):
48
+ return torch.nn.GroupNorm(
49
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
50
+ )
51
+
52
+
53
+ class Upsample(nn.Module):
54
+ def __init__(self, in_channels, with_conv):
55
+ super().__init__()
56
+ self.with_conv = with_conv
57
+ if self.with_conv:
58
+ self.conv = torch.nn.Conv2d(
59
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
60
+ )
61
+
62
+ def forward(self, x):
63
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
64
+ if self.with_conv:
65
+ x = self.conv(x)
66
+ return x
67
+
68
+
69
+ class Downsample(nn.Module):
70
+ def __init__(self, in_channels, with_conv):
71
+ super().__init__()
72
+ self.with_conv = with_conv
73
+ if self.with_conv:
74
+ # no asymmetric padding in torch conv, must do it ourselves
75
+ self.conv = torch.nn.Conv2d(
76
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
77
+ )
78
+
79
+ def forward(self, x):
80
+ if self.with_conv:
81
+ pad = (0, 1, 0, 1)
82
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
83
+ x = self.conv(x)
84
+ else:
85
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
86
+ return x
87
+
88
+
89
+ class ResnetBlock(nn.Module):
90
+ def __init__(
91
+ self,
92
+ *,
93
+ in_channels,
94
+ out_channels=None,
95
+ conv_shortcut=False,
96
+ dropout,
97
+ temb_channels=512,
98
+ ):
99
+ super().__init__()
100
+ self.in_channels = in_channels
101
+ out_channels = in_channels if out_channels is None else out_channels
102
+ self.out_channels = out_channels
103
+ self.use_conv_shortcut = conv_shortcut
104
+
105
+ self.norm1 = Normalize(in_channels)
106
+ self.conv1 = torch.nn.Conv2d(
107
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
108
+ )
109
+ if temb_channels > 0:
110
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
111
+ self.norm2 = Normalize(out_channels)
112
+ self.dropout = torch.nn.Dropout(dropout)
113
+ self.conv2 = torch.nn.Conv2d(
114
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
115
+ )
116
+ if self.in_channels != self.out_channels:
117
+ if self.use_conv_shortcut:
118
+ self.conv_shortcut = torch.nn.Conv2d(
119
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
120
+ )
121
+ else:
122
+ self.nin_shortcut = torch.nn.Conv2d(
123
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
124
+ )
125
+
126
+ def forward(self, x, temb):
127
+ h = x
128
+ h = self.norm1(h)
129
+ h = nonlinearity(h)
130
+ h = self.conv1(h)
131
+
132
+ if temb is not None:
133
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
134
+
135
+ h = self.norm2(h)
136
+ h = nonlinearity(h)
137
+ h = self.dropout(h)
138
+ h = self.conv2(h)
139
+
140
+ if self.in_channels != self.out_channels:
141
+ if self.use_conv_shortcut:
142
+ x = self.conv_shortcut(x)
143
+ else:
144
+ x = self.nin_shortcut(x)
145
+
146
+ return x + h
147
+
148
+
149
+ class AttnBlock(nn.Module):
150
+ def __init__(self, in_channels):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels)
155
+ self.q = torch.nn.Conv2d(
156
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.k = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.v = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.proj_out = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b, c, h, w = q.shape
177
+ q = q.reshape(b, c, h * w)
178
+ q = q.permute(0, 2, 1) # b,hw,c
179
+ k = k.reshape(b, c, h * w) # b,c,hw
180
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
+ w_ = w_ * (int(c) ** (-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = v.reshape(b, c, h * w)
186
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
187
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
+ h_ = h_.reshape(b, c, h, w)
189
+
190
+ h_ = self.proj_out(h_)
191
+
192
+ return x + h_
193
+
194
+
195
+ class MemoryEfficientAttnBlock(nn.Module):
196
+ """
197
+ Uses xformers efficient implementation,
198
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199
+ Note: this is a single-head self-attention operation
200
+ """
201
+
202
+ #
203
+ def __init__(self, in_channels):
204
+ super().__init__()
205
+ self.in_channels = in_channels
206
+
207
+ self.norm = Normalize(in_channels)
208
+ self.q = torch.nn.Conv2d(
209
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
210
+ )
211
+ self.k = torch.nn.Conv2d(
212
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
213
+ )
214
+ self.v = torch.nn.Conv2d(
215
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
216
+ )
217
+ self.proj_out = torch.nn.Conv2d(
218
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
+ )
220
+ self.attention_op: Optional[Any] = None
221
+
222
+ def forward(self, x):
223
+ h_ = x
224
+ h_ = self.norm(h_)
225
+ q = self.q(h_)
226
+ k = self.k(h_)
227
+ v = self.v(h_)
228
+
229
+ # compute attention
230
+ B, C, H, W = q.shape
231
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
232
+
233
+ q, k, v = map(
234
+ lambda t: t.unsqueeze(3)
235
+ .reshape(B, t.shape[1], 1, C)
236
+ .permute(0, 2, 1, 3)
237
+ .reshape(B * 1, t.shape[1], C)
238
+ .contiguous(),
239
+ (q, k, v),
240
+ )
241
+ out = xformers.ops.memory_efficient_attention(
242
+ q, k, v, attn_bias=None, op=self.attention_op
243
+ )
244
+
245
+ out = (
246
+ out.unsqueeze(0)
247
+ .reshape(B, 1, out.shape[1], C)
248
+ .permute(0, 2, 1, 3)
249
+ .reshape(B, out.shape[1], C)
250
+ )
251
+ out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
252
+ out = self.proj_out(out)
253
+ return x + out
254
+
255
+
256
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
257
+ def forward(self, x, context=None, mask=None):
258
+ b, c, h, w = x.shape
259
+ x = rearrange(x, "b c h w -> b (h w) c")
260
+ out = super().forward(x, context=context, mask=mask)
261
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
262
+ return x + out
263
+
264
+
265
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
266
+ assert attn_type in [
267
+ "vanilla",
268
+ "vanilla-xformers",
269
+ "memory-efficient-cross-attn",
270
+ "linear",
271
+ "none",
272
+ ], f"attn_type {attn_type} unknown"
273
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
274
+ attn_type = "vanilla-xformers"
275
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
276
+ if attn_type == "vanilla":
277
+ assert attn_kwargs is None
278
+ return AttnBlock(in_channels)
279
+ elif attn_type == "vanilla-xformers":
280
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
281
+ return MemoryEfficientAttnBlock(in_channels)
282
+ elif type == "memory-efficient-cross-attn":
283
+ attn_kwargs["query_dim"] = in_channels
284
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
285
+ elif attn_type == "none":
286
+ return nn.Identity(in_channels)
287
+ else:
288
+ raise NotImplementedError()
289
+
290
+
291
+ class Model(nn.Module):
292
+ def __init__(
293
+ self,
294
+ *,
295
+ ch,
296
+ out_ch,
297
+ ch_mult=(1, 2, 4, 8),
298
+ num_res_blocks,
299
+ attn_resolutions,
300
+ dropout=0.0,
301
+ resamp_with_conv=True,
302
+ in_channels,
303
+ resolution,
304
+ use_timestep=True,
305
+ use_linear_attn=False,
306
+ attn_type="vanilla",
307
+ ):
308
+ super().__init__()
309
+ if use_linear_attn:
310
+ attn_type = "linear"
311
+ self.ch = ch
312
+ self.temb_ch = self.ch * 4
313
+ self.num_resolutions = len(ch_mult)
314
+ self.num_res_blocks = num_res_blocks
315
+ self.resolution = resolution
316
+ self.in_channels = in_channels
317
+
318
+ self.use_timestep = use_timestep
319
+ if self.use_timestep:
320
+ # timestep embedding
321
+ self.temb = nn.Module()
322
+ self.temb.dense = nn.ModuleList(
323
+ [
324
+ torch.nn.Linear(self.ch, self.temb_ch),
325
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
326
+ ]
327
+ )
328
+
329
+ # downsampling
330
+ self.conv_in = torch.nn.Conv2d(
331
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
332
+ )
333
+
334
+ curr_res = resolution
335
+ in_ch_mult = (1,) + tuple(ch_mult)
336
+ self.down = nn.ModuleList()
337
+ for i_level in range(self.num_resolutions):
338
+ block = nn.ModuleList()
339
+ attn = nn.ModuleList()
340
+ block_in = ch * in_ch_mult[i_level]
341
+ block_out = ch * ch_mult[i_level]
342
+ for i_block in range(self.num_res_blocks):
343
+ block.append(
344
+ ResnetBlock(
345
+ in_channels=block_in,
346
+ out_channels=block_out,
347
+ temb_channels=self.temb_ch,
348
+ dropout=dropout,
349
+ )
350
+ )
351
+ block_in = block_out
352
+ if curr_res in attn_resolutions:
353
+ attn.append(make_attn(block_in, attn_type=attn_type))
354
+ down = nn.Module()
355
+ down.block = block
356
+ down.attn = attn
357
+ if i_level != self.num_resolutions - 1:
358
+ down.downsample = Downsample(block_in, resamp_with_conv)
359
+ curr_res = curr_res // 2
360
+ self.down.append(down)
361
+
362
+ # middle
363
+ self.mid = nn.Module()
364
+ self.mid.block_1 = ResnetBlock(
365
+ in_channels=block_in,
366
+ out_channels=block_in,
367
+ temb_channels=self.temb_ch,
368
+ dropout=dropout,
369
+ )
370
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
371
+ self.mid.block_2 = ResnetBlock(
372
+ in_channels=block_in,
373
+ out_channels=block_in,
374
+ temb_channels=self.temb_ch,
375
+ dropout=dropout,
376
+ )
377
+
378
+ # upsampling
379
+ self.up = nn.ModuleList()
380
+ for i_level in reversed(range(self.num_resolutions)):
381
+ block = nn.ModuleList()
382
+ attn = nn.ModuleList()
383
+ block_out = ch * ch_mult[i_level]
384
+ skip_in = ch * ch_mult[i_level]
385
+ for i_block in range(self.num_res_blocks + 1):
386
+ if i_block == self.num_res_blocks:
387
+ skip_in = ch * in_ch_mult[i_level]
388
+ block.append(
389
+ ResnetBlock(
390
+ in_channels=block_in + skip_in,
391
+ out_channels=block_out,
392
+ temb_channels=self.temb_ch,
393
+ dropout=dropout,
394
+ )
395
+ )
396
+ block_in = block_out
397
+ if curr_res in attn_resolutions:
398
+ attn.append(make_attn(block_in, attn_type=attn_type))
399
+ up = nn.Module()
400
+ up.block = block
401
+ up.attn = attn
402
+ if i_level != 0:
403
+ up.upsample = Upsample(block_in, resamp_with_conv)
404
+ curr_res = curr_res * 2
405
+ self.up.insert(0, up) # prepend to get consistent order
406
+
407
+ # end
408
+ self.norm_out = Normalize(block_in)
409
+ self.conv_out = torch.nn.Conv2d(
410
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
411
+ )
412
+
413
+ def forward(self, x, t=None, context=None):
414
+ # assert x.shape[2] == x.shape[3] == self.resolution
415
+ if context is not None:
416
+ # assume aligned context, cat along channel axis
417
+ x = torch.cat((x, context), dim=1)
418
+ if self.use_timestep:
419
+ # timestep embedding
420
+ assert t is not None
421
+ temb = get_timestep_embedding(t, self.ch)
422
+ temb = self.temb.dense[0](temb)
423
+ temb = nonlinearity(temb)
424
+ temb = self.temb.dense[1](temb)
425
+ else:
426
+ temb = None
427
+
428
+ # downsampling
429
+ hs = [self.conv_in(x)]
430
+ for i_level in range(self.num_resolutions):
431
+ for i_block in range(self.num_res_blocks):
432
+ h = self.down[i_level].block[i_block](hs[-1], temb)
433
+ if len(self.down[i_level].attn) > 0:
434
+ h = self.down[i_level].attn[i_block](h)
435
+ hs.append(h)
436
+ if i_level != self.num_resolutions - 1:
437
+ hs.append(self.down[i_level].downsample(hs[-1]))
438
+
439
+ # middle
440
+ h = hs[-1]
441
+ h = self.mid.block_1(h, temb)
442
+ h = self.mid.attn_1(h)
443
+ h = self.mid.block_2(h, temb)
444
+
445
+ # upsampling
446
+ for i_level in reversed(range(self.num_resolutions)):
447
+ for i_block in range(self.num_res_blocks + 1):
448
+ h = self.up[i_level].block[i_block](
449
+ torch.cat([h, hs.pop()], dim=1), temb
450
+ )
451
+ if len(self.up[i_level].attn) > 0:
452
+ h = self.up[i_level].attn[i_block](h)
453
+ if i_level != 0:
454
+ h = self.up[i_level].upsample(h)
455
+
456
+ # end
457
+ h = self.norm_out(h)
458
+ h = nonlinearity(h)
459
+ h = self.conv_out(h)
460
+ return h
461
+
462
+ def get_last_layer(self):
463
+ return self.conv_out.weight
464
+
465
+
466
+ class Encoder(nn.Module):
467
+ def __init__(
468
+ self,
469
+ *,
470
+ ch,
471
+ out_ch,
472
+ ch_mult=(1, 2, 4, 8),
473
+ num_res_blocks,
474
+ attn_resolutions,
475
+ dropout=0.0,
476
+ resamp_with_conv=True,
477
+ in_channels,
478
+ resolution,
479
+ z_channels,
480
+ double_z=True,
481
+ use_linear_attn=False,
482
+ attn_type="vanilla",
483
+ **ignore_kwargs,
484
+ ):
485
+ super().__init__()
486
+ if use_linear_attn:
487
+ attn_type = "linear"
488
+ self.ch = ch
489
+ self.temb_ch = 0
490
+ self.num_resolutions = len(ch_mult)
491
+ self.num_res_blocks = num_res_blocks
492
+ self.resolution = resolution
493
+ self.in_channels = in_channels
494
+
495
+ # downsampling
496
+ self.conv_in = torch.nn.Conv2d(
497
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
498
+ )
499
+
500
+ curr_res = resolution
501
+ in_ch_mult = (1,) + tuple(ch_mult)
502
+ self.in_ch_mult = in_ch_mult
503
+ self.down = nn.ModuleList()
504
+ for i_level in range(self.num_resolutions):
505
+ block = nn.ModuleList()
506
+ attn = nn.ModuleList()
507
+ block_in = ch * in_ch_mult[i_level]
508
+ block_out = ch * ch_mult[i_level]
509
+ for i_block in range(self.num_res_blocks):
510
+ block.append(
511
+ ResnetBlock(
512
+ in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout,
516
+ )
517
+ )
518
+ block_in = block_out
519
+ if curr_res in attn_resolutions:
520
+ attn.append(make_attn(block_in, attn_type=attn_type))
521
+ down = nn.Module()
522
+ down.block = block
523
+ down.attn = attn
524
+ if i_level != self.num_resolutions - 1:
525
+ down.downsample = Downsample(block_in, resamp_with_conv)
526
+ curr_res = curr_res // 2
527
+ self.down.append(down)
528
+
529
+ # middle
530
+ self.mid = nn.Module()
531
+ self.mid.block_1 = ResnetBlock(
532
+ in_channels=block_in,
533
+ out_channels=block_in,
534
+ temb_channels=self.temb_ch,
535
+ dropout=dropout,
536
+ )
537
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
538
+ self.mid.block_2 = ResnetBlock(
539
+ in_channels=block_in,
540
+ out_channels=block_in,
541
+ temb_channels=self.temb_ch,
542
+ dropout=dropout,
543
+ )
544
+
545
+ # end
546
+ self.norm_out = Normalize(block_in)
547
+ self.conv_out = torch.nn.Conv2d(
548
+ block_in,
549
+ 2 * z_channels if double_z else z_channels,
550
+ kernel_size=3,
551
+ stride=1,
552
+ padding=1,
553
+ )
554
+
555
+ def forward(self, x):
556
+ # timestep embedding
557
+ temb = None
558
+
559
+ # downsampling
560
+ hs = [self.conv_in(x)]
561
+ for i_level in range(self.num_resolutions):
562
+ for i_block in range(self.num_res_blocks):
563
+ h = self.down[i_level].block[i_block](hs[-1], temb)
564
+ if len(self.down[i_level].attn) > 0:
565
+ h = self.down[i_level].attn[i_block](h)
566
+ hs.append(h)
567
+ if i_level != self.num_resolutions - 1:
568
+ hs.append(self.down[i_level].downsample(hs[-1]))
569
+
570
+ # middle
571
+ h = hs[-1]
572
+ h = self.mid.block_1(h, temb)
573
+ h = self.mid.attn_1(h)
574
+ h = self.mid.block_2(h, temb)
575
+
576
+ # end
577
+ h = self.norm_out(h)
578
+ h = nonlinearity(h)
579
+ h = self.conv_out(h)
580
+ return h
581
+
582
+
583
+ class Decoder(nn.Module):
584
+ def __init__(
585
+ self,
586
+ *,
587
+ ch,
588
+ out_ch,
589
+ ch_mult=(1, 2, 4, 8),
590
+ num_res_blocks,
591
+ attn_resolutions,
592
+ dropout=0.0,
593
+ resamp_with_conv=True,
594
+ in_channels,
595
+ resolution,
596
+ z_channels,
597
+ give_pre_end=False,
598
+ tanh_out=False,
599
+ use_linear_attn=False,
600
+ attn_type="vanilla",
601
+ **ignorekwargs,
602
+ ):
603
+ super().__init__()
604
+ if use_linear_attn:
605
+ attn_type = "linear"
606
+ self.ch = ch
607
+ self.temb_ch = 0
608
+ self.num_resolutions = len(ch_mult)
609
+ self.num_res_blocks = num_res_blocks
610
+ self.resolution = resolution
611
+ self.in_channels = in_channels
612
+ self.give_pre_end = give_pre_end
613
+ self.tanh_out = tanh_out
614
+
615
+ # compute in_ch_mult, block_in and curr_res at lowest res
616
+ in_ch_mult = (1,) + tuple(ch_mult)
617
+ block_in = ch * ch_mult[self.num_resolutions - 1]
618
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
619
+ self.z_shape = (1, z_channels, curr_res, curr_res)
620
+ print(
621
+ "Working with z of shape {} = {} dimensions.".format(
622
+ self.z_shape, np.prod(self.z_shape)
623
+ )
624
+ )
625
+
626
+ # z to block_in
627
+ self.conv_in = torch.nn.Conv2d(
628
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
629
+ )
630
+
631
+ # middle
632
+ self.mid = nn.Module()
633
+ self.mid.block_1 = ResnetBlock(
634
+ in_channels=block_in,
635
+ out_channels=block_in,
636
+ temb_channels=self.temb_ch,
637
+ dropout=dropout,
638
+ )
639
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
640
+ self.mid.block_2 = ResnetBlock(
641
+ in_channels=block_in,
642
+ out_channels=block_in,
643
+ temb_channels=self.temb_ch,
644
+ dropout=dropout,
645
+ )
646
+
647
+ # upsampling
648
+ self.up = nn.ModuleList()
649
+ for i_level in reversed(range(self.num_resolutions)):
650
+ block = nn.ModuleList()
651
+ attn = nn.ModuleList()
652
+ block_out = ch * ch_mult[i_level]
653
+ for i_block in range(self.num_res_blocks + 1):
654
+ block.append(
655
+ ResnetBlock(
656
+ in_channels=block_in,
657
+ out_channels=block_out,
658
+ temb_channels=self.temb_ch,
659
+ dropout=dropout,
660
+ )
661
+ )
662
+ block_in = block_out
663
+ if curr_res in attn_resolutions:
664
+ attn.append(make_attn(block_in, attn_type=attn_type))
665
+ up = nn.Module()
666
+ up.block = block
667
+ up.attn = attn
668
+ if i_level != 0:
669
+ up.upsample = Upsample(block_in, resamp_with_conv)
670
+ curr_res = curr_res * 2
671
+ self.up.insert(0, up) # prepend to get consistent order
672
+
673
+ # end
674
+ self.norm_out = Normalize(block_in)
675
+ self.conv_out = torch.nn.Conv2d(
676
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
677
+ )
678
+
679
+ def forward(self, z):
680
+ # assert z.shape[1:] == self.z_shape[1:]
681
+ self.last_z_shape = z.shape
682
+
683
+ # timestep embedding
684
+ temb = None
685
+
686
+ # z to block_in
687
+ h = self.conv_in(z)
688
+
689
+ # middle
690
+ h = self.mid.block_1(h, temb)
691
+ h = self.mid.attn_1(h)
692
+ h = self.mid.block_2(h, temb)
693
+
694
+ # upsampling
695
+ for i_level in reversed(range(self.num_resolutions)):
696
+ for i_block in range(self.num_res_blocks + 1):
697
+ h = self.up[i_level].block[i_block](h, temb)
698
+ if len(self.up[i_level].attn) > 0:
699
+ h = self.up[i_level].attn[i_block](h)
700
+ if i_level != 0:
701
+ h = self.up[i_level].upsample(h)
702
+
703
+ # end
704
+ if self.give_pre_end:
705
+ return h
706
+
707
+ h = self.norm_out(h)
708
+ h = nonlinearity(h)
709
+ h = self.conv_out(h)
710
+ if self.tanh_out:
711
+ h = torch.tanh(h)
712
+ return h
713
+
714
+
715
+ class SimpleDecoder(nn.Module):
716
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
717
+ super().__init__()
718
+ self.model = nn.ModuleList(
719
+ [
720
+ nn.Conv2d(in_channels, in_channels, 1),
721
+ ResnetBlock(
722
+ in_channels=in_channels,
723
+ out_channels=2 * in_channels,
724
+ temb_channels=0,
725
+ dropout=0.0,
726
+ ),
727
+ ResnetBlock(
728
+ in_channels=2 * in_channels,
729
+ out_channels=4 * in_channels,
730
+ temb_channels=0,
731
+ dropout=0.0,
732
+ ),
733
+ ResnetBlock(
734
+ in_channels=4 * in_channels,
735
+ out_channels=2 * in_channels,
736
+ temb_channels=0,
737
+ dropout=0.0,
738
+ ),
739
+ nn.Conv2d(2 * in_channels, in_channels, 1),
740
+ Upsample(in_channels, with_conv=True),
741
+ ]
742
+ )
743
+ # end
744
+ self.norm_out = Normalize(in_channels)
745
+ self.conv_out = torch.nn.Conv2d(
746
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
747
+ )
748
+
749
+ def forward(self, x):
750
+ for i, layer in enumerate(self.model):
751
+ if i in [1, 2, 3]:
752
+ x = layer(x, None)
753
+ else:
754
+ x = layer(x)
755
+
756
+ h = self.norm_out(x)
757
+ h = nonlinearity(h)
758
+ x = self.conv_out(h)
759
+ return x
760
+
761
+
762
+ class UpsampleDecoder(nn.Module):
763
+ def __init__(
764
+ self,
765
+ in_channels,
766
+ out_channels,
767
+ ch,
768
+ num_res_blocks,
769
+ resolution,
770
+ ch_mult=(2, 2),
771
+ dropout=0.0,
772
+ ):
773
+ super().__init__()
774
+ # upsampling
775
+ self.temb_ch = 0
776
+ self.num_resolutions = len(ch_mult)
777
+ self.num_res_blocks = num_res_blocks
778
+ block_in = in_channels
779
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
780
+ self.res_blocks = nn.ModuleList()
781
+ self.upsample_blocks = nn.ModuleList()
782
+ for i_level in range(self.num_resolutions):
783
+ res_block = []
784
+ block_out = ch * ch_mult[i_level]
785
+ for i_block in range(self.num_res_blocks + 1):
786
+ res_block.append(
787
+ ResnetBlock(
788
+ in_channels=block_in,
789
+ out_channels=block_out,
790
+ temb_channels=self.temb_ch,
791
+ dropout=dropout,
792
+ )
793
+ )
794
+ block_in = block_out
795
+ self.res_blocks.append(nn.ModuleList(res_block))
796
+ if i_level != self.num_resolutions - 1:
797
+ self.upsample_blocks.append(Upsample(block_in, True))
798
+ curr_res = curr_res * 2
799
+
800
+ # end
801
+ self.norm_out = Normalize(block_in)
802
+ self.conv_out = torch.nn.Conv2d(
803
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
804
+ )
805
+
806
+ def forward(self, x):
807
+ # upsampling
808
+ h = x
809
+ for k, i_level in enumerate(range(self.num_resolutions)):
810
+ for i_block in range(self.num_res_blocks + 1):
811
+ h = self.res_blocks[i_level][i_block](h, None)
812
+ if i_level != self.num_resolutions - 1:
813
+ h = self.upsample_blocks[k](h)
814
+ h = self.norm_out(h)
815
+ h = nonlinearity(h)
816
+ h = self.conv_out(h)
817
+ return h
818
+
819
+
820
+ class LatentRescaler(nn.Module):
821
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
822
+ super().__init__()
823
+ # residual block, interpolate, residual block
824
+ self.factor = factor
825
+ self.conv_in = nn.Conv2d(
826
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
827
+ )
828
+ self.res_block1 = nn.ModuleList(
829
+ [
830
+ ResnetBlock(
831
+ in_channels=mid_channels,
832
+ out_channels=mid_channels,
833
+ temb_channels=0,
834
+ dropout=0.0,
835
+ )
836
+ for _ in range(depth)
837
+ ]
838
+ )
839
+ self.attn = AttnBlock(mid_channels)
840
+ self.res_block2 = nn.ModuleList(
841
+ [
842
+ ResnetBlock(
843
+ in_channels=mid_channels,
844
+ out_channels=mid_channels,
845
+ temb_channels=0,
846
+ dropout=0.0,
847
+ )
848
+ for _ in range(depth)
849
+ ]
850
+ )
851
+
852
+ self.conv_out = nn.Conv2d(
853
+ mid_channels,
854
+ out_channels,
855
+ kernel_size=1,
856
+ )
857
+
858
+ def forward(self, x):
859
+ x = self.conv_in(x)
860
+ for block in self.res_block1:
861
+ x = block(x, None)
862
+ x = torch.nn.functional.interpolate(
863
+ x,
864
+ size=(
865
+ int(round(x.shape[2] * self.factor)),
866
+ int(round(x.shape[3] * self.factor)),
867
+ ),
868
+ )
869
+ x = self.attn(x)
870
+ for block in self.res_block2:
871
+ x = block(x, None)
872
+ x = self.conv_out(x)
873
+ return x
874
+
875
+
876
+ class MergedRescaleEncoder(nn.Module):
877
+ def __init__(
878
+ self,
879
+ in_channels,
880
+ ch,
881
+ resolution,
882
+ out_ch,
883
+ num_res_blocks,
884
+ attn_resolutions,
885
+ dropout=0.0,
886
+ resamp_with_conv=True,
887
+ ch_mult=(1, 2, 4, 8),
888
+ rescale_factor=1.0,
889
+ rescale_module_depth=1,
890
+ ):
891
+ super().__init__()
892
+ intermediate_chn = ch * ch_mult[-1]
893
+ self.encoder = Encoder(
894
+ in_channels=in_channels,
895
+ num_res_blocks=num_res_blocks,
896
+ ch=ch,
897
+ ch_mult=ch_mult,
898
+ z_channels=intermediate_chn,
899
+ double_z=False,
900
+ resolution=resolution,
901
+ attn_resolutions=attn_resolutions,
902
+ dropout=dropout,
903
+ resamp_with_conv=resamp_with_conv,
904
+ out_ch=None,
905
+ )
906
+ self.rescaler = LatentRescaler(
907
+ factor=rescale_factor,
908
+ in_channels=intermediate_chn,
909
+ mid_channels=intermediate_chn,
910
+ out_channels=out_ch,
911
+ depth=rescale_module_depth,
912
+ )
913
+
914
+ def forward(self, x):
915
+ x = self.encoder(x)
916
+ x = self.rescaler(x)
917
+ return x
918
+
919
+
920
+ class MergedRescaleDecoder(nn.Module):
921
+ def __init__(
922
+ self,
923
+ z_channels,
924
+ out_ch,
925
+ resolution,
926
+ num_res_blocks,
927
+ attn_resolutions,
928
+ ch,
929
+ ch_mult=(1, 2, 4, 8),
930
+ dropout=0.0,
931
+ resamp_with_conv=True,
932
+ rescale_factor=1.0,
933
+ rescale_module_depth=1,
934
+ ):
935
+ super().__init__()
936
+ tmp_chn = z_channels * ch_mult[-1]
937
+ self.decoder = Decoder(
938
+ out_ch=out_ch,
939
+ z_channels=tmp_chn,
940
+ attn_resolutions=attn_resolutions,
941
+ dropout=dropout,
942
+ resamp_with_conv=resamp_with_conv,
943
+ in_channels=None,
944
+ num_res_blocks=num_res_blocks,
945
+ ch_mult=ch_mult,
946
+ resolution=resolution,
947
+ ch=ch,
948
+ )
949
+ self.rescaler = LatentRescaler(
950
+ factor=rescale_factor,
951
+ in_channels=z_channels,
952
+ mid_channels=tmp_chn,
953
+ out_channels=tmp_chn,
954
+ depth=rescale_module_depth,
955
+ )
956
+
957
+ def forward(self, x):
958
+ x = self.rescaler(x)
959
+ x = self.decoder(x)
960
+ return x
961
+
962
+
963
+ class Upsampler(nn.Module):
964
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
965
+ super().__init__()
966
+ assert out_size >= in_size
967
+ num_blocks = int(np.log2(out_size // in_size)) + 1
968
+ factor_up = 1.0 + (out_size % in_size)
969
+ print(
970
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
971
+ )
972
+ self.rescaler = LatentRescaler(
973
+ factor=factor_up,
974
+ in_channels=in_channels,
975
+ mid_channels=2 * in_channels,
976
+ out_channels=in_channels,
977
+ )
978
+ self.decoder = Decoder(
979
+ out_ch=out_channels,
980
+ resolution=out_size,
981
+ z_channels=in_channels,
982
+ num_res_blocks=2,
983
+ attn_resolutions=[],
984
+ in_channels=None,
985
+ ch=in_channels,
986
+ ch_mult=[ch_mult for _ in range(num_blocks)],
987
+ )
988
+
989
+ def forward(self, x):
990
+ x = self.rescaler(x)
991
+ x = self.decoder(x)
992
+ return x
993
+
994
+
995
+ class Resize(nn.Module):
996
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
997
+ super().__init__()
998
+ self.with_conv = learned
999
+ self.mode = mode
1000
+ if self.with_conv:
1001
+ print(
1002
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
1003
+ )
1004
+ raise NotImplementedError()
1005
+ assert in_channels is not None
1006
+ # no asymmetric padding in torch conv, must do it ourselves
1007
+ self.conv = torch.nn.Conv2d(
1008
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
1009
+ )
1010
+
1011
+ def forward(self, x, scale_factor=1.0):
1012
+ if scale_factor == 1.0:
1013
+ return x
1014
+ else:
1015
+ x = torch.nn.functional.interpolate(
1016
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
1017
+ )
1018
+ return x