Mariam-Elz commited on
Commit
f34eca8
·
verified ·
1 Parent(s): 09dbe55

Upload imagedream/ldm/modules/diffusionmodules/openaimodel.py with huggingface_hub

Browse files
imagedream/ldm/modules/diffusionmodules/openaimodel.py CHANGED
@@ -1,1135 +1,1135 @@
1
- from abc import abstractmethod
2
- import math
3
-
4
- import numpy as np
5
- import torch
6
- import torch as th
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from einops import rearrange, repeat
10
-
11
- from imagedream.ldm.modules.diffusionmodules.util import (
12
- checkpoint,
13
- conv_nd,
14
- linear,
15
- avg_pool_nd,
16
- zero_module,
17
- normalization,
18
- timestep_embedding,
19
- convert_module_to_f16,
20
- convert_module_to_f32
21
- )
22
- from imagedream.ldm.modules.attention import (
23
- SpatialTransformer,
24
- SpatialTransformer3D,
25
- exists
26
- )
27
- from imagedream.ldm.modules.diffusionmodules.adaptors import (
28
- Resampler,
29
- ImageProjModel
30
- )
31
-
32
- ## go
33
- class AttentionPool2d(nn.Module):
34
- """
35
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
36
- """
37
-
38
- def __init__(
39
- self,
40
- spacial_dim: int,
41
- embed_dim: int,
42
- num_heads_channels: int,
43
- output_dim: int = None,
44
- ):
45
- super().__init__()
46
- self.positional_embedding = nn.Parameter(
47
- th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
48
- )
49
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
50
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
51
- self.num_heads = embed_dim // num_heads_channels
52
- self.attention = QKVAttention(self.num_heads)
53
-
54
- def forward(self, x):
55
- b, c, *_spatial = x.shape
56
- x = x.reshape(b, c, -1) # NC(HW)
57
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
58
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
59
- x = self.qkv_proj(x)
60
- x = self.attention(x)
61
- x = self.c_proj(x)
62
- return x[:, :, 0]
63
-
64
-
65
- class TimestepBlock(nn.Module):
66
- """
67
- Any module where forward() takes timestep embeddings as a second argument.
68
- """
69
-
70
- @abstractmethod
71
- def forward(self, x, emb):
72
- """
73
- Apply the module to `x` given `emb` timestep embeddings.
74
- """
75
-
76
-
77
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
78
- """
79
- A sequential module that passes timestep embeddings to the children that
80
- support it as an extra input.
81
- """
82
-
83
- def forward(self, x, emb, context=None, num_frames=1):
84
- for layer in self:
85
- if isinstance(layer, TimestepBlock):
86
- x = layer(x, emb)
87
- elif isinstance(layer, SpatialTransformer3D):
88
- x = layer(x, context, num_frames=num_frames)
89
- elif isinstance(layer, SpatialTransformer):
90
- x = layer(x, context)
91
- else:
92
- x = layer(x)
93
- return x
94
-
95
-
96
- class Upsample(nn.Module):
97
- """
98
- An upsampling layer with an optional convolution.
99
- :param channels: channels in the inputs and outputs.
100
- :param use_conv: a bool determining if a convolution is applied.
101
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
102
- upsampling occurs in the inner-two dimensions.
103
- """
104
-
105
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
106
- super().__init__()
107
- self.channels = channels
108
- self.out_channels = out_channels or channels
109
- self.use_conv = use_conv
110
- self.dims = dims
111
- if use_conv:
112
- self.conv = conv_nd(
113
- dims, self.channels, self.out_channels, 3, padding=padding
114
- )
115
-
116
- def forward(self, x):
117
- assert x.shape[1] == self.channels
118
- if self.dims == 3:
119
- x = F.interpolate(
120
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
121
- )
122
- else:
123
- x = F.interpolate(x, scale_factor=2, mode="nearest")
124
- if self.use_conv:
125
- x = self.conv(x)
126
- return x
127
-
128
-
129
- class TransposedUpsample(nn.Module):
130
- "Learned 2x upsampling without padding"
131
-
132
- def __init__(self, channels, out_channels=None, ks=5):
133
- super().__init__()
134
- self.channels = channels
135
- self.out_channels = out_channels or channels
136
-
137
- self.up = nn.ConvTranspose2d(
138
- self.channels, self.out_channels, kernel_size=ks, stride=2
139
- )
140
-
141
- def forward(self, x):
142
- return self.up(x)
143
-
144
-
145
- class Downsample(nn.Module):
146
- """
147
- A downsampling layer with an optional convolution.
148
- :param channels: channels in the inputs and outputs.
149
- :param use_conv: a bool determining if a convolution is applied.
150
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
151
- downsampling occurs in the inner-two dimensions.
152
- """
153
-
154
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
155
- super().__init__()
156
- self.channels = channels
157
- self.out_channels = out_channels or channels
158
- self.use_conv = use_conv
159
- self.dims = dims
160
- stride = 2 if dims != 3 else (1, 2, 2)
161
- if use_conv:
162
- self.op = conv_nd(
163
- dims,
164
- self.channels,
165
- self.out_channels,
166
- 3,
167
- stride=stride,
168
- padding=padding,
169
- )
170
- else:
171
- assert self.channels == self.out_channels
172
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
173
-
174
- def forward(self, x):
175
- assert x.shape[1] == self.channels
176
- return self.op(x)
177
-
178
-
179
- class ResBlock(TimestepBlock):
180
- """
181
- A residual block that can optionally change the number of channels.
182
- :param channels: the number of input channels.
183
- :param emb_channels: the number of timestep embedding channels.
184
- :param dropout: the rate of dropout.
185
- :param out_channels: if specified, the number of out channels.
186
- :param use_conv: if True and out_channels is specified, use a spatial
187
- convolution instead of a smaller 1x1 convolution to change the
188
- channels in the skip connection.
189
- :param dims: determines if the signal is 1D, 2D, or 3D.
190
- :param use_checkpoint: if True, use gradient checkpointing on this module.
191
- :param up: if True, use this block for upsampling.
192
- :param down: if True, use this block for downsampling.
193
- """
194
-
195
- def __init__(
196
- self,
197
- channels,
198
- emb_channels,
199
- dropout,
200
- out_channels=None,
201
- use_conv=False,
202
- use_scale_shift_norm=False,
203
- dims=2,
204
- use_checkpoint=False,
205
- up=False,
206
- down=False,
207
- ):
208
- super().__init__()
209
- self.channels = channels
210
- self.emb_channels = emb_channels
211
- self.dropout = dropout
212
- self.out_channels = out_channels or channels
213
- self.use_conv = use_conv
214
- self.use_checkpoint = use_checkpoint
215
- self.use_scale_shift_norm = use_scale_shift_norm
216
-
217
- self.in_layers = nn.Sequential(
218
- normalization(channels),
219
- nn.SiLU(),
220
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
221
- )
222
-
223
- self.updown = up or down
224
-
225
- if up:
226
- self.h_upd = Upsample(channels, False, dims)
227
- self.x_upd = Upsample(channels, False, dims)
228
- elif down:
229
- self.h_upd = Downsample(channels, False, dims)
230
- self.x_upd = Downsample(channels, False, dims)
231
- else:
232
- self.h_upd = self.x_upd = nn.Identity()
233
-
234
- self.emb_layers = nn.Sequential(
235
- nn.SiLU(),
236
- linear(
237
- emb_channels,
238
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
239
- ),
240
- )
241
- self.out_layers = nn.Sequential(
242
- normalization(self.out_channels),
243
- nn.SiLU(),
244
- nn.Dropout(p=dropout),
245
- zero_module(
246
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
247
- ),
248
- )
249
-
250
- if self.out_channels == channels:
251
- self.skip_connection = nn.Identity()
252
- elif use_conv:
253
- self.skip_connection = conv_nd(
254
- dims, channels, self.out_channels, 3, padding=1
255
- )
256
- else:
257
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
258
-
259
- def forward(self, x, emb):
260
- """
261
- Apply the block to a Tensor, conditioned on a timestep embedding.
262
- :param x: an [N x C x ...] Tensor of features.
263
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
264
- :return: an [N x C x ...] Tensor of outputs.
265
- """
266
- return checkpoint(
267
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
268
- )
269
-
270
- def _forward(self, x, emb):
271
- if self.updown:
272
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
273
- h = in_rest(x)
274
- h = self.h_upd(h)
275
- x = self.x_upd(x)
276
- h = in_conv(h)
277
- else:
278
- h = self.in_layers(x)
279
- emb_out = self.emb_layers(emb).type(h.dtype)
280
- while len(emb_out.shape) < len(h.shape):
281
- emb_out = emb_out[..., None]
282
- if self.use_scale_shift_norm:
283
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
284
- scale, shift = th.chunk(emb_out, 2, dim=1)
285
- h = out_norm(h) * (1 + scale) + shift
286
- h = out_rest(h)
287
- else:
288
- h = h + emb_out
289
- h = self.out_layers(h)
290
- return self.skip_connection(x) + h
291
-
292
-
293
- class AttentionBlock(nn.Module):
294
- """
295
- An attention block that allows spatial positions to attend to each other.
296
- Originally ported from here, but adapted to the N-d case.
297
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
298
- """
299
-
300
- def __init__(
301
- self,
302
- channels,
303
- num_heads=1,
304
- num_head_channels=-1,
305
- use_checkpoint=False,
306
- use_new_attention_order=False,
307
- ):
308
- super().__init__()
309
- self.channels = channels
310
- if num_head_channels == -1:
311
- self.num_heads = num_heads
312
- else:
313
- assert (
314
- channels % num_head_channels == 0
315
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
316
- self.num_heads = channels // num_head_channels
317
- self.use_checkpoint = use_checkpoint
318
- self.norm = normalization(channels)
319
- self.qkv = conv_nd(1, channels, channels * 3, 1)
320
- if use_new_attention_order:
321
- # split qkv before split heads
322
- self.attention = QKVAttention(self.num_heads)
323
- else:
324
- # split heads before split qkv
325
- self.attention = QKVAttentionLegacy(self.num_heads)
326
-
327
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
328
-
329
- def forward(self, x):
330
- return checkpoint(
331
- self._forward, (x,), self.parameters(), True
332
- ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
333
- # return pt_checkpoint(self._forward, x) # pytorch
334
-
335
- def _forward(self, x):
336
- b, c, *spatial = x.shape
337
- x = x.reshape(b, c, -1)
338
- qkv = self.qkv(self.norm(x))
339
- h = self.attention(qkv)
340
- h = self.proj_out(h)
341
- return (x + h).reshape(b, c, *spatial)
342
-
343
-
344
- def count_flops_attn(model, _x, y):
345
- """
346
- A counter for the `thop` package to count the operations in an
347
- attention operation.
348
- Meant to be used like:
349
- macs, params = thop.profile(
350
- model,
351
- inputs=(inputs, timestamps),
352
- custom_ops={QKVAttention: QKVAttention.count_flops},
353
- )
354
- """
355
- b, c, *spatial = y[0].shape
356
- num_spatial = int(np.prod(spatial))
357
- # We perform two matmuls with the same number of ops.
358
- # The first computes the weight matrix, the second computes
359
- # the combination of the value vectors.
360
- matmul_ops = 2 * b * (num_spatial**2) * c
361
- model.total_ops += th.DoubleTensor([matmul_ops])
362
-
363
-
364
- class QKVAttentionLegacy(nn.Module):
365
- """
366
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
367
- """
368
-
369
- def __init__(self, n_heads):
370
- super().__init__()
371
- self.n_heads = n_heads
372
-
373
- def forward(self, qkv):
374
- """
375
- Apply QKV attention.
376
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
377
- :return: an [N x (H * C) x T] tensor after attention.
378
- """
379
- bs, width, length = qkv.shape
380
- assert width % (3 * self.n_heads) == 0
381
- ch = width // (3 * self.n_heads)
382
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
383
- scale = 1 / math.sqrt(math.sqrt(ch))
384
- weight = th.einsum(
385
- "bct,bcs->bts", q * scale, k * scale
386
- ) # More stable with f16 than dividing afterwards
387
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
388
- a = th.einsum("bts,bcs->bct", weight, v)
389
- return a.reshape(bs, -1, length)
390
-
391
- @staticmethod
392
- def count_flops(model, _x, y):
393
- return count_flops_attn(model, _x, y)
394
-
395
-
396
- class QKVAttention(nn.Module):
397
- """
398
- A module which performs QKV attention and splits in a different order.
399
- """
400
-
401
- def __init__(self, n_heads):
402
- super().__init__()
403
- self.n_heads = n_heads
404
-
405
- def forward(self, qkv):
406
- """
407
- Apply QKV attention.
408
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
409
- :return: an [N x (H * C) x T] tensor after attention.
410
- """
411
- bs, width, length = qkv.shape
412
- assert width % (3 * self.n_heads) == 0
413
- ch = width // (3 * self.n_heads)
414
- q, k, v = qkv.chunk(3, dim=1)
415
- scale = 1 / math.sqrt(math.sqrt(ch))
416
- weight = th.einsum(
417
- "bct,bcs->bts",
418
- (q * scale).view(bs * self.n_heads, ch, length),
419
- (k * scale).view(bs * self.n_heads, ch, length),
420
- ) # More stable with f16 than dividing afterwards
421
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
422
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
423
- return a.reshape(bs, -1, length)
424
-
425
- @staticmethod
426
- def count_flops(model, _x, y):
427
- return count_flops_attn(model, _x, y)
428
-
429
-
430
- class Timestep(nn.Module):
431
- def __init__(self, dim):
432
- super().__init__()
433
- self.dim = dim
434
-
435
- def forward(self, t):
436
- return timestep_embedding(t, self.dim)
437
-
438
-
439
- class MultiViewUNetModel(nn.Module):
440
- """
441
- The full multi-view UNet model with attention, timestep embedding and camera embedding.
442
- :param in_channels: channels in the input Tensor.
443
- :param model_channels: base channel count for the model.
444
- :param out_channels: channels in the output Tensor.
445
- :param num_res_blocks: number of residual blocks per downsample.
446
- :param attention_resolutions: a collection of downsample rates at which
447
- attention will take place. May be a set, list, or tuple.
448
- For example, if this contains 4, then at 4x downsampling, attention
449
- will be used.
450
- :param dropout: the dropout probability.
451
- :param channel_mult: channel multiplier for each level of the UNet.
452
- :param conv_resample: if True, use learned convolutions for upsampling and
453
- downsampling.
454
- :param dims: determines if the signal is 1D, 2D, or 3D.
455
- :param num_classes: if specified (as an int), then this model will be
456
- class-conditional with `num_classes` classes.
457
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
458
- :param num_heads: the number of attention heads in each attention layer.
459
- :param num_heads_channels: if specified, ignore num_heads and instead use
460
- a fixed channel width per attention head.
461
- :param num_heads_upsample: works with num_heads to set a different number
462
- of heads for upsampling. Deprecated.
463
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
464
- :param resblock_updown: use residual blocks for up/downsampling.
465
- :param use_new_attention_order: use a different attention pattern for potentially
466
- increased efficiency.
467
- :param camera_dim: dimensionality of camera input.
468
- """
469
-
470
- def __init__(
471
- self,
472
- image_size,
473
- in_channels,
474
- model_channels,
475
- out_channels,
476
- num_res_blocks,
477
- attention_resolutions,
478
- dropout=0,
479
- channel_mult=(1, 2, 4, 8),
480
- conv_resample=True,
481
- dims=2,
482
- num_classes=None,
483
- use_checkpoint=False,
484
- use_fp16=False,
485
- use_bf16=False,
486
- num_heads=-1,
487
- num_head_channels=-1,
488
- num_heads_upsample=-1,
489
- use_scale_shift_norm=False,
490
- resblock_updown=False,
491
- use_new_attention_order=False,
492
- use_spatial_transformer=False, # custom transformer support
493
- transformer_depth=1, # custom transformer support
494
- context_dim=None, # custom transformer support
495
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
496
- legacy=True,
497
- disable_self_attentions=None,
498
- num_attention_blocks=None,
499
- disable_middle_self_attn=False,
500
- use_linear_in_transformer=False,
501
- adm_in_channels=None,
502
- camera_dim=None,
503
- with_ip=False, # wether add image prompt images
504
- ip_dim=0, # number of extra token, 4 for global 16 for local
505
- ip_weight=1.0, # weight for image prompt context
506
- ip_mode="local_resample", # which mode of adaptor, global or local
507
- ):
508
- super().__init__()
509
- if use_spatial_transformer:
510
- assert (
511
- context_dim is not None
512
- ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
513
-
514
- if context_dim is not None:
515
- assert (
516
- use_spatial_transformer
517
- ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
518
- from omegaconf.listconfig import ListConfig
519
-
520
- if type(context_dim) == ListConfig:
521
- context_dim = list(context_dim)
522
-
523
- if num_heads_upsample == -1:
524
- num_heads_upsample = num_heads
525
-
526
- if num_heads == -1:
527
- assert (
528
- num_head_channels != -1
529
- ), "Either num_heads or num_head_channels has to be set"
530
-
531
- if num_head_channels == -1:
532
- assert (
533
- num_heads != -1
534
- ), "Either num_heads or num_head_channels has to be set"
535
-
536
- self.image_size = image_size
537
- self.in_channels = in_channels
538
- self.model_channels = model_channels
539
- self.out_channels = out_channels
540
- if isinstance(num_res_blocks, int):
541
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
542
- else:
543
- if len(num_res_blocks) != len(channel_mult):
544
- raise ValueError(
545
- "provide num_res_blocks either as an int (globally constant) or "
546
- "as a list/tuple (per-level) with the same length as channel_mult"
547
- )
548
- self.num_res_blocks = num_res_blocks
549
- if disable_self_attentions is not None:
550
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
551
- assert len(disable_self_attentions) == len(channel_mult)
552
- if num_attention_blocks is not None:
553
- assert len(num_attention_blocks) == len(self.num_res_blocks)
554
- assert all(
555
- map(
556
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
557
- range(len(num_attention_blocks)),
558
- )
559
- )
560
- print(
561
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
562
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
563
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
564
- f"attention will still not be set."
565
- )
566
-
567
- self.attention_resolutions = attention_resolutions
568
- self.dropout = dropout
569
- self.channel_mult = channel_mult
570
- self.conv_resample = conv_resample
571
- self.num_classes = num_classes
572
- self.use_checkpoint = use_checkpoint
573
- self.dtype = th.float16 if use_fp16 else th.float32
574
- self.dtype = th.bfloat16 if use_bf16 else self.dtype
575
- self.num_heads = num_heads
576
- self.num_head_channels = num_head_channels
577
- self.num_heads_upsample = num_heads_upsample
578
- self.predict_codebook_ids = n_embed is not None
579
-
580
- self.with_ip = with_ip # wether there is image prompt
581
- self.ip_dim = ip_dim # num of extra token, 4 for global 16 for local
582
- self.ip_weight = ip_weight
583
- assert ip_mode in ["global", "local_resample"]
584
- self.ip_mode = ip_mode # which mode of adaptor
585
-
586
- time_embed_dim = model_channels * 4
587
- self.time_embed = nn.Sequential(
588
- linear(model_channels, time_embed_dim),
589
- nn.SiLU(),
590
- linear(time_embed_dim, time_embed_dim),
591
- )
592
-
593
- if camera_dim is not None:
594
- time_embed_dim = model_channels * 4
595
- self.camera_embed = nn.Sequential(
596
- linear(camera_dim, time_embed_dim),
597
- nn.SiLU(),
598
- linear(time_embed_dim, time_embed_dim),
599
- )
600
-
601
- if self.num_classes is not None:
602
- if isinstance(self.num_classes, int):
603
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
604
- elif self.num_classes == "continuous":
605
- print("setting up linear c_adm embedding layer")
606
- self.label_emb = nn.Linear(1, time_embed_dim)
607
- elif self.num_classes == "sequential":
608
- assert adm_in_channels is not None
609
- self.label_emb = nn.Sequential(
610
- nn.Sequential(
611
- linear(adm_in_channels, time_embed_dim),
612
- nn.SiLU(),
613
- linear(time_embed_dim, time_embed_dim),
614
- )
615
- )
616
- else:
617
- raise ValueError()
618
-
619
- if self.with_ip and (context_dim is not None) and ip_dim > 0:
620
- if self.ip_mode == "local_resample":
621
- # ip-adapter-plus
622
- hidden_dim = 1280
623
- self.image_embed = Resampler(
624
- dim=context_dim,
625
- depth=4,
626
- dim_head=64,
627
- heads=12,
628
- num_queries=ip_dim, # num token
629
- embedding_dim=hidden_dim,
630
- output_dim=context_dim,
631
- ff_mult=4,
632
- )
633
- elif self.ip_mode == "global":
634
- self.image_embed = ImageProjModel(
635
- cross_attention_dim=context_dim,
636
- clip_extra_context_tokens=ip_dim)
637
- else:
638
- raise ValueError(f"{self.ip_mode} is not supported")
639
-
640
- self.input_blocks = nn.ModuleList(
641
- [
642
- TimestepEmbedSequential(
643
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
644
- )
645
- ]
646
- )
647
- self._feature_size = model_channels
648
- input_block_chans = [model_channels]
649
- ch = model_channels
650
- ds = 1
651
- for level, mult in enumerate(channel_mult):
652
- for nr in range(self.num_res_blocks[level]):
653
- layers = [
654
- ResBlock(
655
- ch,
656
- time_embed_dim,
657
- dropout,
658
- out_channels=mult * model_channels,
659
- dims=dims,
660
- use_checkpoint=use_checkpoint,
661
- use_scale_shift_norm=use_scale_shift_norm,
662
- )
663
- ]
664
- ch = mult * model_channels
665
- if ds in attention_resolutions:
666
- if num_head_channels == -1:
667
- dim_head = ch // num_heads
668
- else:
669
- num_heads = ch // num_head_channels
670
- dim_head = num_head_channels
671
- if legacy:
672
- # num_heads = 1
673
- dim_head = (
674
- ch // num_heads
675
- if use_spatial_transformer
676
- else num_head_channels
677
- )
678
- if exists(disable_self_attentions):
679
- disabled_sa = disable_self_attentions[level]
680
- else:
681
- disabled_sa = False
682
-
683
- if (
684
- not exists(num_attention_blocks)
685
- or nr < num_attention_blocks[level]
686
- ):
687
- layers.append(
688
- AttentionBlock(
689
- ch,
690
- use_checkpoint=use_checkpoint,
691
- num_heads=num_heads,
692
- num_head_channels=dim_head,
693
- use_new_attention_order=use_new_attention_order,
694
- )
695
- if not use_spatial_transformer
696
- else SpatialTransformer3D(
697
- ch,
698
- num_heads,
699
- dim_head,
700
- depth=transformer_depth,
701
- context_dim=context_dim,
702
- disable_self_attn=disabled_sa,
703
- use_linear=use_linear_in_transformer,
704
- use_checkpoint=use_checkpoint,
705
- with_ip=self.with_ip,
706
- ip_dim=self.ip_dim,
707
- ip_weight=self.ip_weight
708
- )
709
- )
710
- self.input_blocks.append(TimestepEmbedSequential(*layers))
711
- self._feature_size += ch
712
- input_block_chans.append(ch)
713
-
714
- if level != len(channel_mult) - 1:
715
- out_ch = ch
716
- self.input_blocks.append(
717
- TimestepEmbedSequential(
718
- ResBlock(
719
- ch,
720
- time_embed_dim,
721
- dropout,
722
- out_channels=out_ch,
723
- dims=dims,
724
- use_checkpoint=use_checkpoint,
725
- use_scale_shift_norm=use_scale_shift_norm,
726
- down=True,
727
- )
728
- if resblock_updown
729
- else Downsample(
730
- ch, conv_resample, dims=dims, out_channels=out_ch
731
- )
732
- )
733
- )
734
- ch = out_ch
735
- input_block_chans.append(ch)
736
- ds *= 2
737
- self._feature_size += ch
738
-
739
- if num_head_channels == -1:
740
- dim_head = ch // num_heads
741
- else:
742
- num_heads = ch // num_head_channels
743
- dim_head = num_head_channels
744
- if legacy:
745
- # num_heads = 1
746
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
747
- self.middle_block = TimestepEmbedSequential(
748
- ResBlock(
749
- ch,
750
- time_embed_dim,
751
- dropout,
752
- dims=dims,
753
- use_checkpoint=use_checkpoint,
754
- use_scale_shift_norm=use_scale_shift_norm,
755
- ),
756
- AttentionBlock(
757
- ch,
758
- use_checkpoint=use_checkpoint,
759
- num_heads=num_heads,
760
- num_head_channels=dim_head,
761
- use_new_attention_order=use_new_attention_order,
762
- )
763
- if not use_spatial_transformer
764
- else SpatialTransformer3D( # always uses a self-attn
765
- ch,
766
- num_heads,
767
- dim_head,
768
- depth=transformer_depth,
769
- context_dim=context_dim,
770
- disable_self_attn=disable_middle_self_attn,
771
- use_linear=use_linear_in_transformer,
772
- use_checkpoint=use_checkpoint,
773
- with_ip=self.with_ip,
774
- ip_dim=self.ip_dim,
775
- ip_weight=self.ip_weight
776
- ),
777
- ResBlock(
778
- ch,
779
- time_embed_dim,
780
- dropout,
781
- dims=dims,
782
- use_checkpoint=use_checkpoint,
783
- use_scale_shift_norm=use_scale_shift_norm,
784
- ),
785
- )
786
- self._feature_size += ch
787
-
788
- self.output_blocks = nn.ModuleList([])
789
- for level, mult in list(enumerate(channel_mult))[::-1]:
790
- for i in range(self.num_res_blocks[level] + 1):
791
- ich = input_block_chans.pop()
792
- layers = [
793
- ResBlock(
794
- ch + ich,
795
- time_embed_dim,
796
- dropout,
797
- out_channels=model_channels * mult,
798
- dims=dims,
799
- use_checkpoint=use_checkpoint,
800
- use_scale_shift_norm=use_scale_shift_norm,
801
- )
802
- ]
803
- ch = model_channels * mult
804
- if ds in attention_resolutions:
805
- if num_head_channels == -1:
806
- dim_head = ch // num_heads
807
- else:
808
- num_heads = ch // num_head_channels
809
- dim_head = num_head_channels
810
- if legacy:
811
- # num_heads = 1
812
- dim_head = (
813
- ch // num_heads
814
- if use_spatial_transformer
815
- else num_head_channels
816
- )
817
- if exists(disable_self_attentions):
818
- disabled_sa = disable_self_attentions[level]
819
- else:
820
- disabled_sa = False
821
-
822
- if (
823
- not exists(num_attention_blocks)
824
- or i < num_attention_blocks[level]
825
- ):
826
- layers.append(
827
- AttentionBlock(
828
- ch,
829
- use_checkpoint=use_checkpoint,
830
- num_heads=num_heads_upsample,
831
- num_head_channels=dim_head,
832
- use_new_attention_order=use_new_attention_order,
833
- )
834
- if not use_spatial_transformer
835
- else SpatialTransformer3D(
836
- ch,
837
- num_heads,
838
- dim_head,
839
- depth=transformer_depth,
840
- context_dim=context_dim,
841
- disable_self_attn=disabled_sa,
842
- use_linear=use_linear_in_transformer,
843
- use_checkpoint=use_checkpoint,
844
- with_ip=self.with_ip,
845
- ip_dim=self.ip_dim,
846
- ip_weight=self.ip_weight
847
- )
848
- )
849
- if level and i == self.num_res_blocks[level]:
850
- out_ch = ch
851
- layers.append(
852
- ResBlock(
853
- ch,
854
- time_embed_dim,
855
- dropout,
856
- out_channels=out_ch,
857
- dims=dims,
858
- use_checkpoint=use_checkpoint,
859
- use_scale_shift_norm=use_scale_shift_norm,
860
- up=True,
861
- )
862
- if resblock_updown
863
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
864
- )
865
- ds //= 2
866
- self.output_blocks.append(TimestepEmbedSequential(*layers))
867
- self._feature_size += ch
868
-
869
- self.out = nn.Sequential(
870
- normalization(ch),
871
- nn.SiLU(),
872
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
873
- )
874
- if self.predict_codebook_ids:
875
- self.id_predictor = nn.Sequential(
876
- normalization(ch),
877
- conv_nd(dims, model_channels, n_embed, 1),
878
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
879
- )
880
-
881
- def convert_to_fp16(self):
882
- """
883
- Convert the torso of the model to float16.
884
- """
885
- self.input_blocks.apply(convert_module_to_f16)
886
- self.middle_block.apply(convert_module_to_f16)
887
- self.output_blocks.apply(convert_module_to_f16)
888
-
889
- def convert_to_fp32(self):
890
- """
891
- Convert the torso of the model to float32.
892
- """
893
- self.input_blocks.apply(convert_module_to_f32)
894
- self.middle_block.apply(convert_module_to_f32)
895
- self.output_blocks.apply(convert_module_to_f32)
896
-
897
- def forward(
898
- self,
899
- x,
900
- timesteps=None,
901
- context=None,
902
- y=None,
903
- camera=None,
904
- num_frames=1,
905
- **kwargs,
906
- ):
907
- """
908
- Apply the model to an input batch.
909
- :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
910
- :param timesteps: a 1-D batch of timesteps.
911
- :param context: a dict conditioning plugged in via crossattn
912
- :param y: an [N] Tensor of labels, if class-conditional, default None.
913
- :param num_frames: a integer indicating number of frames for tensor reshaping.
914
- :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
915
- """
916
- assert (
917
- x.shape[0] % num_frames == 0
918
- ), "[UNet] input batch size must be dividable by num_frames!"
919
- assert (y is not None) == (
920
- self.num_classes is not None
921
- ), "must specify y if and only if the model is class-conditional"
922
-
923
- hs = []
924
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
925
- emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
926
-
927
- if self.num_classes is not None:
928
- assert y.shape[0] == x.shape[0]
929
- emb = emb + self.label_emb(y)
930
-
931
- # Add camera embeddings
932
- if camera is not None:
933
- assert camera.shape[0] == emb.shape[0]
934
- # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
935
- emb = emb + self.camera_embed(camera)
936
- ip = kwargs.get("ip", None)
937
- ip_img = kwargs.get("ip_img", None)
938
-
939
- if ip_img is not None:
940
- x[(num_frames-1)::num_frames, :, :, :] = ip_img
941
-
942
- if ip is not None:
943
- ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
944
- context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
945
-
946
- h = x.type(self.dtype)
947
- for module in self.input_blocks:
948
- h = module(h, emb, context, num_frames=num_frames)
949
- hs.append(h)
950
- h = self.middle_block(h, emb, context, num_frames=num_frames)
951
- for module in self.output_blocks:
952
- h = th.cat([h, hs.pop()], dim=1)
953
- h = module(h, emb, context, num_frames=num_frames)
954
- h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
955
- if self.predict_codebook_ids: # False
956
- return self.id_predictor(h)
957
- else:
958
- return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
959
-
960
-
961
-
962
-
963
- class MultiViewUNetModelStage2(MultiViewUNetModel):
964
- """
965
- The full multi-view UNet model with attention, timestep embedding and camera embedding.
966
- :param in_channels: channels in the input Tensor.
967
- :param model_channels: base channel count for the model.
968
- :param out_channels: channels in the output Tensor.
969
- :param num_res_blocks: number of residual blocks per downsample.
970
- :param attention_resolutions: a collection of downsample rates at which
971
- attention will take place. May be a set, list, or tuple.
972
- For example, if this contains 4, then at 4x downsampling, attention
973
- will be used.
974
- :param dropout: the dropout probability.
975
- :param channel_mult: channel multiplier for each level of the UNet.
976
- :param conv_resample: if True, use learned convolutions for upsampling and
977
- downsampling.
978
- :param dims: determines if the signal is 1D, 2D, or 3D.
979
- :param num_classes: if specified (as an int), then this model will be
980
- class-conditional with `num_classes` classes.
981
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
982
- :param num_heads: the number of attention heads in each attention layer.
983
- :param num_heads_channels: if specified, ignore num_heads and instead use
984
- a fixed channel width per attention head.
985
- :param num_heads_upsample: works with num_heads to set a different number
986
- of heads for upsampling. Deprecated.
987
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
988
- :param resblock_updown: use residual blocks for up/downsampling.
989
- :param use_new_attention_order: use a different attention pattern for potentially
990
- increased efficiency.
991
- :param camera_dim: dimensionality of camera input.
992
- """
993
-
994
- def __init__(
995
- self,
996
- image_size,
997
- in_channels,
998
- model_channels,
999
- out_channels,
1000
- num_res_blocks,
1001
- attention_resolutions,
1002
- dropout=0,
1003
- channel_mult=(1, 2, 4, 8),
1004
- conv_resample=True,
1005
- dims=2,
1006
- num_classes=None,
1007
- use_checkpoint=False,
1008
- use_fp16=False,
1009
- use_bf16=False,
1010
- num_heads=-1,
1011
- num_head_channels=-1,
1012
- num_heads_upsample=-1,
1013
- use_scale_shift_norm=False,
1014
- resblock_updown=False,
1015
- use_new_attention_order=False,
1016
- use_spatial_transformer=False, # custom transformer support
1017
- transformer_depth=1, # custom transformer support
1018
- context_dim=None, # custom transformer support
1019
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1020
- legacy=True,
1021
- disable_self_attentions=None,
1022
- num_attention_blocks=None,
1023
- disable_middle_self_attn=False,
1024
- use_linear_in_transformer=False,
1025
- adm_in_channels=None,
1026
- camera_dim=None,
1027
- with_ip=False, # wether add image prompt images
1028
- ip_dim=0, # number of extra token, 4 for global 16 for local
1029
- ip_weight=1.0, # weight for image prompt context
1030
- ip_mode="local_resample", # which mode of adaptor, global or local
1031
- ):
1032
- super().__init__(
1033
- image_size,
1034
- in_channels,
1035
- model_channels,
1036
- out_channels,
1037
- num_res_blocks,
1038
- attention_resolutions,
1039
- dropout,
1040
- channel_mult,
1041
- conv_resample,
1042
- dims,
1043
- num_classes,
1044
- use_checkpoint,
1045
- use_fp16,
1046
- use_bf16,
1047
- num_heads,
1048
- num_head_channels,
1049
- num_heads_upsample,
1050
- use_scale_shift_norm,
1051
- resblock_updown,
1052
- use_new_attention_order,
1053
- use_spatial_transformer,
1054
- transformer_depth,
1055
- context_dim,
1056
- n_embed,
1057
- legacy,
1058
- disable_self_attentions,
1059
- num_attention_blocks,
1060
- disable_middle_self_attn,
1061
- use_linear_in_transformer,
1062
- adm_in_channels,
1063
- camera_dim,
1064
- with_ip,
1065
- ip_dim,
1066
- ip_weight,
1067
- ip_mode,
1068
- )
1069
-
1070
- def forward(
1071
- self,
1072
- x,
1073
- timesteps=None,
1074
- context=None,
1075
- y=None,
1076
- camera=None,
1077
- num_frames=1,
1078
- **kwargs,
1079
- ):
1080
- """
1081
- Apply the model to an input batch.
1082
- :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
1083
- :param timesteps: a 1-D batch of timesteps.
1084
- :param context: a dict conditioning plugged in via crossattn
1085
- :param y: an [N] Tensor of labels, if class-conditional, default None.
1086
- :param num_frames: a integer indicating number of frames for tensor reshaping.
1087
- :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
1088
- """
1089
- assert (
1090
- x.shape[0] % num_frames == 0
1091
- ), "[UNet] input batch size must be dividable by num_frames!"
1092
- assert (y is not None) == (
1093
- self.num_classes is not None
1094
- ), "must specify y if and only if the model is class-conditional"
1095
-
1096
- hs = []
1097
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
1098
- emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
1099
-
1100
- if self.num_classes is not None:
1101
- assert y.shape[0] == x.shape[0]
1102
- emb = emb + self.label_emb(y)
1103
-
1104
- # Add camera embeddings
1105
- if camera is not None:
1106
- assert camera.shape[0] == emb.shape[0]
1107
- # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
1108
- emb = emb + self.camera_embed(camera)
1109
- ip = kwargs.get("ip", None)
1110
- ip_img = kwargs.get("ip_img", None)
1111
- pixel_images = kwargs.get("pixel_images", None)
1112
-
1113
- if ip_img is not None:
1114
- x[(num_frames-1)::num_frames, :, :, :] = ip_img
1115
-
1116
- x = torch.cat((x, pixel_images), dim=1)
1117
-
1118
- if ip is not None:
1119
- ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
1120
- context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
1121
-
1122
- h = x.type(self.dtype)
1123
- for module in self.input_blocks:
1124
- h = module(h, emb, context, num_frames=num_frames)
1125
- hs.append(h)
1126
- h = self.middle_block(h, emb, context, num_frames=num_frames)
1127
- for module in self.output_blocks:
1128
- h = th.cat([h, hs.pop()], dim=1)
1129
- h = module(h, emb, context, num_frames=num_frames)
1130
- h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
1131
- if self.predict_codebook_ids: # False
1132
- return self.id_predictor(h)
1133
- else:
1134
- return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
1135
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from imagedream.ldm.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ convert_module_to_f16,
20
+ convert_module_to_f32
21
+ )
22
+ from imagedream.ldm.modules.attention import (
23
+ SpatialTransformer,
24
+ SpatialTransformer3D,
25
+ exists
26
+ )
27
+ from imagedream.ldm.modules.diffusionmodules.adaptors import (
28
+ Resampler,
29
+ ImageProjModel
30
+ )
31
+
32
+ ## go
33
+ class AttentionPool2d(nn.Module):
34
+ """
35
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ spacial_dim: int,
41
+ embed_dim: int,
42
+ num_heads_channels: int,
43
+ output_dim: int = None,
44
+ ):
45
+ super().__init__()
46
+ self.positional_embedding = nn.Parameter(
47
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
48
+ )
49
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
50
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
51
+ self.num_heads = embed_dim // num_heads_channels
52
+ self.attention = QKVAttention(self.num_heads)
53
+
54
+ def forward(self, x):
55
+ b, c, *_spatial = x.shape
56
+ x = x.reshape(b, c, -1) # NC(HW)
57
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
58
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
59
+ x = self.qkv_proj(x)
60
+ x = self.attention(x)
61
+ x = self.c_proj(x)
62
+ return x[:, :, 0]
63
+
64
+
65
+ class TimestepBlock(nn.Module):
66
+ """
67
+ Any module where forward() takes timestep embeddings as a second argument.
68
+ """
69
+
70
+ @abstractmethod
71
+ def forward(self, x, emb):
72
+ """
73
+ Apply the module to `x` given `emb` timestep embeddings.
74
+ """
75
+
76
+
77
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
78
+ """
79
+ A sequential module that passes timestep embeddings to the children that
80
+ support it as an extra input.
81
+ """
82
+
83
+ def forward(self, x, emb, context=None, num_frames=1):
84
+ for layer in self:
85
+ if isinstance(layer, TimestepBlock):
86
+ x = layer(x, emb)
87
+ elif isinstance(layer, SpatialTransformer3D):
88
+ x = layer(x, context, num_frames=num_frames)
89
+ elif isinstance(layer, SpatialTransformer):
90
+ x = layer(x, context)
91
+ else:
92
+ x = layer(x)
93
+ return x
94
+
95
+
96
+ class Upsample(nn.Module):
97
+ """
98
+ An upsampling layer with an optional convolution.
99
+ :param channels: channels in the inputs and outputs.
100
+ :param use_conv: a bool determining if a convolution is applied.
101
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
102
+ upsampling occurs in the inner-two dimensions.
103
+ """
104
+
105
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
106
+ super().__init__()
107
+ self.channels = channels
108
+ self.out_channels = out_channels or channels
109
+ self.use_conv = use_conv
110
+ self.dims = dims
111
+ if use_conv:
112
+ self.conv = conv_nd(
113
+ dims, self.channels, self.out_channels, 3, padding=padding
114
+ )
115
+
116
+ def forward(self, x):
117
+ assert x.shape[1] == self.channels
118
+ if self.dims == 3:
119
+ x = F.interpolate(
120
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
121
+ )
122
+ else:
123
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
124
+ if self.use_conv:
125
+ x = self.conv(x)
126
+ return x
127
+
128
+
129
+ class TransposedUpsample(nn.Module):
130
+ "Learned 2x upsampling without padding"
131
+
132
+ def __init__(self, channels, out_channels=None, ks=5):
133
+ super().__init__()
134
+ self.channels = channels
135
+ self.out_channels = out_channels or channels
136
+
137
+ self.up = nn.ConvTranspose2d(
138
+ self.channels, self.out_channels, kernel_size=ks, stride=2
139
+ )
140
+
141
+ def forward(self, x):
142
+ return self.up(x)
143
+
144
+
145
+ class Downsample(nn.Module):
146
+ """
147
+ A downsampling layer with an optional convolution.
148
+ :param channels: channels in the inputs and outputs.
149
+ :param use_conv: a bool determining if a convolution is applied.
150
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
151
+ downsampling occurs in the inner-two dimensions.
152
+ """
153
+
154
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
155
+ super().__init__()
156
+ self.channels = channels
157
+ self.out_channels = out_channels or channels
158
+ self.use_conv = use_conv
159
+ self.dims = dims
160
+ stride = 2 if dims != 3 else (1, 2, 2)
161
+ if use_conv:
162
+ self.op = conv_nd(
163
+ dims,
164
+ self.channels,
165
+ self.out_channels,
166
+ 3,
167
+ stride=stride,
168
+ padding=padding,
169
+ )
170
+ else:
171
+ assert self.channels == self.out_channels
172
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
173
+
174
+ def forward(self, x):
175
+ assert x.shape[1] == self.channels
176
+ return self.op(x)
177
+
178
+
179
+ class ResBlock(TimestepBlock):
180
+ """
181
+ A residual block that can optionally change the number of channels.
182
+ :param channels: the number of input channels.
183
+ :param emb_channels: the number of timestep embedding channels.
184
+ :param dropout: the rate of dropout.
185
+ :param out_channels: if specified, the number of out channels.
186
+ :param use_conv: if True and out_channels is specified, use a spatial
187
+ convolution instead of a smaller 1x1 convolution to change the
188
+ channels in the skip connection.
189
+ :param dims: determines if the signal is 1D, 2D, or 3D.
190
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
191
+ :param up: if True, use this block for upsampling.
192
+ :param down: if True, use this block for downsampling.
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ channels,
198
+ emb_channels,
199
+ dropout,
200
+ out_channels=None,
201
+ use_conv=False,
202
+ use_scale_shift_norm=False,
203
+ dims=2,
204
+ use_checkpoint=False,
205
+ up=False,
206
+ down=False,
207
+ ):
208
+ super().__init__()
209
+ self.channels = channels
210
+ self.emb_channels = emb_channels
211
+ self.dropout = dropout
212
+ self.out_channels = out_channels or channels
213
+ self.use_conv = use_conv
214
+ self.use_checkpoint = use_checkpoint
215
+ self.use_scale_shift_norm = use_scale_shift_norm
216
+
217
+ self.in_layers = nn.Sequential(
218
+ normalization(channels),
219
+ nn.SiLU(),
220
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
221
+ )
222
+
223
+ self.updown = up or down
224
+
225
+ if up:
226
+ self.h_upd = Upsample(channels, False, dims)
227
+ self.x_upd = Upsample(channels, False, dims)
228
+ elif down:
229
+ self.h_upd = Downsample(channels, False, dims)
230
+ self.x_upd = Downsample(channels, False, dims)
231
+ else:
232
+ self.h_upd = self.x_upd = nn.Identity()
233
+
234
+ self.emb_layers = nn.Sequential(
235
+ nn.SiLU(),
236
+ linear(
237
+ emb_channels,
238
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
239
+ ),
240
+ )
241
+ self.out_layers = nn.Sequential(
242
+ normalization(self.out_channels),
243
+ nn.SiLU(),
244
+ nn.Dropout(p=dropout),
245
+ zero_module(
246
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
247
+ ),
248
+ )
249
+
250
+ if self.out_channels == channels:
251
+ self.skip_connection = nn.Identity()
252
+ elif use_conv:
253
+ self.skip_connection = conv_nd(
254
+ dims, channels, self.out_channels, 3, padding=1
255
+ )
256
+ else:
257
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
258
+
259
+ def forward(self, x, emb):
260
+ """
261
+ Apply the block to a Tensor, conditioned on a timestep embedding.
262
+ :param x: an [N x C x ...] Tensor of features.
263
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
264
+ :return: an [N x C x ...] Tensor of outputs.
265
+ """
266
+ return checkpoint(
267
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
268
+ )
269
+
270
+ def _forward(self, x, emb):
271
+ if self.updown:
272
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
273
+ h = in_rest(x)
274
+ h = self.h_upd(h)
275
+ x = self.x_upd(x)
276
+ h = in_conv(h)
277
+ else:
278
+ h = self.in_layers(x)
279
+ emb_out = self.emb_layers(emb).type(h.dtype)
280
+ while len(emb_out.shape) < len(h.shape):
281
+ emb_out = emb_out[..., None]
282
+ if self.use_scale_shift_norm:
283
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
284
+ scale, shift = th.chunk(emb_out, 2, dim=1)
285
+ h = out_norm(h) * (1 + scale) + shift
286
+ h = out_rest(h)
287
+ else:
288
+ h = h + emb_out
289
+ h = self.out_layers(h)
290
+ return self.skip_connection(x) + h
291
+
292
+
293
+ class AttentionBlock(nn.Module):
294
+ """
295
+ An attention block that allows spatial positions to attend to each other.
296
+ Originally ported from here, but adapted to the N-d case.
297
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ channels,
303
+ num_heads=1,
304
+ num_head_channels=-1,
305
+ use_checkpoint=False,
306
+ use_new_attention_order=False,
307
+ ):
308
+ super().__init__()
309
+ self.channels = channels
310
+ if num_head_channels == -1:
311
+ self.num_heads = num_heads
312
+ else:
313
+ assert (
314
+ channels % num_head_channels == 0
315
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
316
+ self.num_heads = channels // num_head_channels
317
+ self.use_checkpoint = use_checkpoint
318
+ self.norm = normalization(channels)
319
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
320
+ if use_new_attention_order:
321
+ # split qkv before split heads
322
+ self.attention = QKVAttention(self.num_heads)
323
+ else:
324
+ # split heads before split qkv
325
+ self.attention = QKVAttentionLegacy(self.num_heads)
326
+
327
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
328
+
329
+ def forward(self, x):
330
+ return checkpoint(
331
+ self._forward, (x,), self.parameters(), True
332
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
333
+ # return pt_checkpoint(self._forward, x) # pytorch
334
+
335
+ def _forward(self, x):
336
+ b, c, *spatial = x.shape
337
+ x = x.reshape(b, c, -1)
338
+ qkv = self.qkv(self.norm(x))
339
+ h = self.attention(qkv)
340
+ h = self.proj_out(h)
341
+ return (x + h).reshape(b, c, *spatial)
342
+
343
+
344
+ def count_flops_attn(model, _x, y):
345
+ """
346
+ A counter for the `thop` package to count the operations in an
347
+ attention operation.
348
+ Meant to be used like:
349
+ macs, params = thop.profile(
350
+ model,
351
+ inputs=(inputs, timestamps),
352
+ custom_ops={QKVAttention: QKVAttention.count_flops},
353
+ )
354
+ """
355
+ b, c, *spatial = y[0].shape
356
+ num_spatial = int(np.prod(spatial))
357
+ # We perform two matmuls with the same number of ops.
358
+ # The first computes the weight matrix, the second computes
359
+ # the combination of the value vectors.
360
+ matmul_ops = 2 * b * (num_spatial**2) * c
361
+ model.total_ops += th.DoubleTensor([matmul_ops])
362
+
363
+
364
+ class QKVAttentionLegacy(nn.Module):
365
+ """
366
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
367
+ """
368
+
369
+ def __init__(self, n_heads):
370
+ super().__init__()
371
+ self.n_heads = n_heads
372
+
373
+ def forward(self, qkv):
374
+ """
375
+ Apply QKV attention.
376
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
377
+ :return: an [N x (H * C) x T] tensor after attention.
378
+ """
379
+ bs, width, length = qkv.shape
380
+ assert width % (3 * self.n_heads) == 0
381
+ ch = width // (3 * self.n_heads)
382
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
383
+ scale = 1 / math.sqrt(math.sqrt(ch))
384
+ weight = th.einsum(
385
+ "bct,bcs->bts", q * scale, k * scale
386
+ ) # More stable with f16 than dividing afterwards
387
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
388
+ a = th.einsum("bts,bcs->bct", weight, v)
389
+ return a.reshape(bs, -1, length)
390
+
391
+ @staticmethod
392
+ def count_flops(model, _x, y):
393
+ return count_flops_attn(model, _x, y)
394
+
395
+
396
+ class QKVAttention(nn.Module):
397
+ """
398
+ A module which performs QKV attention and splits in a different order.
399
+ """
400
+
401
+ def __init__(self, n_heads):
402
+ super().__init__()
403
+ self.n_heads = n_heads
404
+
405
+ def forward(self, qkv):
406
+ """
407
+ Apply QKV attention.
408
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
409
+ :return: an [N x (H * C) x T] tensor after attention.
410
+ """
411
+ bs, width, length = qkv.shape
412
+ assert width % (3 * self.n_heads) == 0
413
+ ch = width // (3 * self.n_heads)
414
+ q, k, v = qkv.chunk(3, dim=1)
415
+ scale = 1 / math.sqrt(math.sqrt(ch))
416
+ weight = th.einsum(
417
+ "bct,bcs->bts",
418
+ (q * scale).view(bs * self.n_heads, ch, length),
419
+ (k * scale).view(bs * self.n_heads, ch, length),
420
+ ) # More stable with f16 than dividing afterwards
421
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
422
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
423
+ return a.reshape(bs, -1, length)
424
+
425
+ @staticmethod
426
+ def count_flops(model, _x, y):
427
+ return count_flops_attn(model, _x, y)
428
+
429
+
430
+ class Timestep(nn.Module):
431
+ def __init__(self, dim):
432
+ super().__init__()
433
+ self.dim = dim
434
+
435
+ def forward(self, t):
436
+ return timestep_embedding(t, self.dim)
437
+
438
+
439
+ class MultiViewUNetModel(nn.Module):
440
+ """
441
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
442
+ :param in_channels: channels in the input Tensor.
443
+ :param model_channels: base channel count for the model.
444
+ :param out_channels: channels in the output Tensor.
445
+ :param num_res_blocks: number of residual blocks per downsample.
446
+ :param attention_resolutions: a collection of downsample rates at which
447
+ attention will take place. May be a set, list, or tuple.
448
+ For example, if this contains 4, then at 4x downsampling, attention
449
+ will be used.
450
+ :param dropout: the dropout probability.
451
+ :param channel_mult: channel multiplier for each level of the UNet.
452
+ :param conv_resample: if True, use learned convolutions for upsampling and
453
+ downsampling.
454
+ :param dims: determines if the signal is 1D, 2D, or 3D.
455
+ :param num_classes: if specified (as an int), then this model will be
456
+ class-conditional with `num_classes` classes.
457
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
458
+ :param num_heads: the number of attention heads in each attention layer.
459
+ :param num_heads_channels: if specified, ignore num_heads and instead use
460
+ a fixed channel width per attention head.
461
+ :param num_heads_upsample: works with num_heads to set a different number
462
+ of heads for upsampling. Deprecated.
463
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
464
+ :param resblock_updown: use residual blocks for up/downsampling.
465
+ :param use_new_attention_order: use a different attention pattern for potentially
466
+ increased efficiency.
467
+ :param camera_dim: dimensionality of camera input.
468
+ """
469
+
470
+ def __init__(
471
+ self,
472
+ image_size,
473
+ in_channels,
474
+ model_channels,
475
+ out_channels,
476
+ num_res_blocks,
477
+ attention_resolutions,
478
+ dropout=0,
479
+ channel_mult=(1, 2, 4, 8),
480
+ conv_resample=True,
481
+ dims=2,
482
+ num_classes=None,
483
+ use_checkpoint=False,
484
+ use_fp16=False,
485
+ use_bf16=False,
486
+ num_heads=-1,
487
+ num_head_channels=-1,
488
+ num_heads_upsample=-1,
489
+ use_scale_shift_norm=False,
490
+ resblock_updown=False,
491
+ use_new_attention_order=False,
492
+ use_spatial_transformer=False, # custom transformer support
493
+ transformer_depth=1, # custom transformer support
494
+ context_dim=None, # custom transformer support
495
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
496
+ legacy=True,
497
+ disable_self_attentions=None,
498
+ num_attention_blocks=None,
499
+ disable_middle_self_attn=False,
500
+ use_linear_in_transformer=False,
501
+ adm_in_channels=None,
502
+ camera_dim=None,
503
+ with_ip=False, # wether add image prompt images
504
+ ip_dim=0, # number of extra token, 4 for global 16 for local
505
+ ip_weight=1.0, # weight for image prompt context
506
+ ip_mode="local_resample", # which mode of adaptor, global or local
507
+ ):
508
+ super().__init__()
509
+ if use_spatial_transformer:
510
+ assert (
511
+ context_dim is not None
512
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
513
+
514
+ if context_dim is not None:
515
+ assert (
516
+ use_spatial_transformer
517
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
518
+ from omegaconf.listconfig import ListConfig
519
+
520
+ if type(context_dim) == ListConfig:
521
+ context_dim = list(context_dim)
522
+
523
+ if num_heads_upsample == -1:
524
+ num_heads_upsample = num_heads
525
+
526
+ if num_heads == -1:
527
+ assert (
528
+ num_head_channels != -1
529
+ ), "Either num_heads or num_head_channels has to be set"
530
+
531
+ if num_head_channels == -1:
532
+ assert (
533
+ num_heads != -1
534
+ ), "Either num_heads or num_head_channels has to be set"
535
+
536
+ self.image_size = image_size
537
+ self.in_channels = in_channels
538
+ self.model_channels = model_channels
539
+ self.out_channels = out_channels
540
+ if isinstance(num_res_blocks, int):
541
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
542
+ else:
543
+ if len(num_res_blocks) != len(channel_mult):
544
+ raise ValueError(
545
+ "provide num_res_blocks either as an int (globally constant) or "
546
+ "as a list/tuple (per-level) with the same length as channel_mult"
547
+ )
548
+ self.num_res_blocks = num_res_blocks
549
+ if disable_self_attentions is not None:
550
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
551
+ assert len(disable_self_attentions) == len(channel_mult)
552
+ if num_attention_blocks is not None:
553
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
554
+ assert all(
555
+ map(
556
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
557
+ range(len(num_attention_blocks)),
558
+ )
559
+ )
560
+ print(
561
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
562
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
563
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
564
+ f"attention will still not be set."
565
+ )
566
+
567
+ self.attention_resolutions = attention_resolutions
568
+ self.dropout = dropout
569
+ self.channel_mult = channel_mult
570
+ self.conv_resample = conv_resample
571
+ self.num_classes = num_classes
572
+ self.use_checkpoint = use_checkpoint
573
+ self.dtype = th.float16 if use_fp16 else th.float32
574
+ self.dtype = th.bfloat16 if use_bf16 else self.dtype
575
+ self.num_heads = num_heads
576
+ self.num_head_channels = num_head_channels
577
+ self.num_heads_upsample = num_heads_upsample
578
+ self.predict_codebook_ids = n_embed is not None
579
+
580
+ self.with_ip = with_ip # wether there is image prompt
581
+ self.ip_dim = ip_dim # num of extra token, 4 for global 16 for local
582
+ self.ip_weight = ip_weight
583
+ assert ip_mode in ["global", "local_resample"]
584
+ self.ip_mode = ip_mode # which mode of adaptor
585
+
586
+ time_embed_dim = model_channels * 4
587
+ self.time_embed = nn.Sequential(
588
+ linear(model_channels, time_embed_dim),
589
+ nn.SiLU(),
590
+ linear(time_embed_dim, time_embed_dim),
591
+ )
592
+
593
+ if camera_dim is not None:
594
+ time_embed_dim = model_channels * 4
595
+ self.camera_embed = nn.Sequential(
596
+ linear(camera_dim, time_embed_dim),
597
+ nn.SiLU(),
598
+ linear(time_embed_dim, time_embed_dim),
599
+ )
600
+
601
+ if self.num_classes is not None:
602
+ if isinstance(self.num_classes, int):
603
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
604
+ elif self.num_classes == "continuous":
605
+ print("setting up linear c_adm embedding layer")
606
+ self.label_emb = nn.Linear(1, time_embed_dim)
607
+ elif self.num_classes == "sequential":
608
+ assert adm_in_channels is not None
609
+ self.label_emb = nn.Sequential(
610
+ nn.Sequential(
611
+ linear(adm_in_channels, time_embed_dim),
612
+ nn.SiLU(),
613
+ linear(time_embed_dim, time_embed_dim),
614
+ )
615
+ )
616
+ else:
617
+ raise ValueError()
618
+
619
+ if self.with_ip and (context_dim is not None) and ip_dim > 0:
620
+ if self.ip_mode == "local_resample":
621
+ # ip-adapter-plus
622
+ hidden_dim = 1280
623
+ self.image_embed = Resampler(
624
+ dim=context_dim,
625
+ depth=4,
626
+ dim_head=64,
627
+ heads=12,
628
+ num_queries=ip_dim, # num token
629
+ embedding_dim=hidden_dim,
630
+ output_dim=context_dim,
631
+ ff_mult=4,
632
+ )
633
+ elif self.ip_mode == "global":
634
+ self.image_embed = ImageProjModel(
635
+ cross_attention_dim=context_dim,
636
+ clip_extra_context_tokens=ip_dim)
637
+ else:
638
+ raise ValueError(f"{self.ip_mode} is not supported")
639
+
640
+ self.input_blocks = nn.ModuleList(
641
+ [
642
+ TimestepEmbedSequential(
643
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
644
+ )
645
+ ]
646
+ )
647
+ self._feature_size = model_channels
648
+ input_block_chans = [model_channels]
649
+ ch = model_channels
650
+ ds = 1
651
+ for level, mult in enumerate(channel_mult):
652
+ for nr in range(self.num_res_blocks[level]):
653
+ layers = [
654
+ ResBlock(
655
+ ch,
656
+ time_embed_dim,
657
+ dropout,
658
+ out_channels=mult * model_channels,
659
+ dims=dims,
660
+ use_checkpoint=use_checkpoint,
661
+ use_scale_shift_norm=use_scale_shift_norm,
662
+ )
663
+ ]
664
+ ch = mult * model_channels
665
+ if ds in attention_resolutions:
666
+ if num_head_channels == -1:
667
+ dim_head = ch // num_heads
668
+ else:
669
+ num_heads = ch // num_head_channels
670
+ dim_head = num_head_channels
671
+ if legacy:
672
+ # num_heads = 1
673
+ dim_head = (
674
+ ch // num_heads
675
+ if use_spatial_transformer
676
+ else num_head_channels
677
+ )
678
+ if exists(disable_self_attentions):
679
+ disabled_sa = disable_self_attentions[level]
680
+ else:
681
+ disabled_sa = False
682
+
683
+ if (
684
+ not exists(num_attention_blocks)
685
+ or nr < num_attention_blocks[level]
686
+ ):
687
+ layers.append(
688
+ AttentionBlock(
689
+ ch,
690
+ use_checkpoint=use_checkpoint,
691
+ num_heads=num_heads,
692
+ num_head_channels=dim_head,
693
+ use_new_attention_order=use_new_attention_order,
694
+ )
695
+ if not use_spatial_transformer
696
+ else SpatialTransformer3D(
697
+ ch,
698
+ num_heads,
699
+ dim_head,
700
+ depth=transformer_depth,
701
+ context_dim=context_dim,
702
+ disable_self_attn=disabled_sa,
703
+ use_linear=use_linear_in_transformer,
704
+ use_checkpoint=use_checkpoint,
705
+ with_ip=self.with_ip,
706
+ ip_dim=self.ip_dim,
707
+ ip_weight=self.ip_weight
708
+ )
709
+ )
710
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
711
+ self._feature_size += ch
712
+ input_block_chans.append(ch)
713
+
714
+ if level != len(channel_mult) - 1:
715
+ out_ch = ch
716
+ self.input_blocks.append(
717
+ TimestepEmbedSequential(
718
+ ResBlock(
719
+ ch,
720
+ time_embed_dim,
721
+ dropout,
722
+ out_channels=out_ch,
723
+ dims=dims,
724
+ use_checkpoint=use_checkpoint,
725
+ use_scale_shift_norm=use_scale_shift_norm,
726
+ down=True,
727
+ )
728
+ if resblock_updown
729
+ else Downsample(
730
+ ch, conv_resample, dims=dims, out_channels=out_ch
731
+ )
732
+ )
733
+ )
734
+ ch = out_ch
735
+ input_block_chans.append(ch)
736
+ ds *= 2
737
+ self._feature_size += ch
738
+
739
+ if num_head_channels == -1:
740
+ dim_head = ch // num_heads
741
+ else:
742
+ num_heads = ch // num_head_channels
743
+ dim_head = num_head_channels
744
+ if legacy:
745
+ # num_heads = 1
746
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
747
+ self.middle_block = TimestepEmbedSequential(
748
+ ResBlock(
749
+ ch,
750
+ time_embed_dim,
751
+ dropout,
752
+ dims=dims,
753
+ use_checkpoint=use_checkpoint,
754
+ use_scale_shift_norm=use_scale_shift_norm,
755
+ ),
756
+ AttentionBlock(
757
+ ch,
758
+ use_checkpoint=use_checkpoint,
759
+ num_heads=num_heads,
760
+ num_head_channels=dim_head,
761
+ use_new_attention_order=use_new_attention_order,
762
+ )
763
+ if not use_spatial_transformer
764
+ else SpatialTransformer3D( # always uses a self-attn
765
+ ch,
766
+ num_heads,
767
+ dim_head,
768
+ depth=transformer_depth,
769
+ context_dim=context_dim,
770
+ disable_self_attn=disable_middle_self_attn,
771
+ use_linear=use_linear_in_transformer,
772
+ use_checkpoint=use_checkpoint,
773
+ with_ip=self.with_ip,
774
+ ip_dim=self.ip_dim,
775
+ ip_weight=self.ip_weight
776
+ ),
777
+ ResBlock(
778
+ ch,
779
+ time_embed_dim,
780
+ dropout,
781
+ dims=dims,
782
+ use_checkpoint=use_checkpoint,
783
+ use_scale_shift_norm=use_scale_shift_norm,
784
+ ),
785
+ )
786
+ self._feature_size += ch
787
+
788
+ self.output_blocks = nn.ModuleList([])
789
+ for level, mult in list(enumerate(channel_mult))[::-1]:
790
+ for i in range(self.num_res_blocks[level] + 1):
791
+ ich = input_block_chans.pop()
792
+ layers = [
793
+ ResBlock(
794
+ ch + ich,
795
+ time_embed_dim,
796
+ dropout,
797
+ out_channels=model_channels * mult,
798
+ dims=dims,
799
+ use_checkpoint=use_checkpoint,
800
+ use_scale_shift_norm=use_scale_shift_norm,
801
+ )
802
+ ]
803
+ ch = model_channels * mult
804
+ if ds in attention_resolutions:
805
+ if num_head_channels == -1:
806
+ dim_head = ch // num_heads
807
+ else:
808
+ num_heads = ch // num_head_channels
809
+ dim_head = num_head_channels
810
+ if legacy:
811
+ # num_heads = 1
812
+ dim_head = (
813
+ ch // num_heads
814
+ if use_spatial_transformer
815
+ else num_head_channels
816
+ )
817
+ if exists(disable_self_attentions):
818
+ disabled_sa = disable_self_attentions[level]
819
+ else:
820
+ disabled_sa = False
821
+
822
+ if (
823
+ not exists(num_attention_blocks)
824
+ or i < num_attention_blocks[level]
825
+ ):
826
+ layers.append(
827
+ AttentionBlock(
828
+ ch,
829
+ use_checkpoint=use_checkpoint,
830
+ num_heads=num_heads_upsample,
831
+ num_head_channels=dim_head,
832
+ use_new_attention_order=use_new_attention_order,
833
+ )
834
+ if not use_spatial_transformer
835
+ else SpatialTransformer3D(
836
+ ch,
837
+ num_heads,
838
+ dim_head,
839
+ depth=transformer_depth,
840
+ context_dim=context_dim,
841
+ disable_self_attn=disabled_sa,
842
+ use_linear=use_linear_in_transformer,
843
+ use_checkpoint=use_checkpoint,
844
+ with_ip=self.with_ip,
845
+ ip_dim=self.ip_dim,
846
+ ip_weight=self.ip_weight
847
+ )
848
+ )
849
+ if level and i == self.num_res_blocks[level]:
850
+ out_ch = ch
851
+ layers.append(
852
+ ResBlock(
853
+ ch,
854
+ time_embed_dim,
855
+ dropout,
856
+ out_channels=out_ch,
857
+ dims=dims,
858
+ use_checkpoint=use_checkpoint,
859
+ use_scale_shift_norm=use_scale_shift_norm,
860
+ up=True,
861
+ )
862
+ if resblock_updown
863
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
864
+ )
865
+ ds //= 2
866
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
867
+ self._feature_size += ch
868
+
869
+ self.out = nn.Sequential(
870
+ normalization(ch),
871
+ nn.SiLU(),
872
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
873
+ )
874
+ if self.predict_codebook_ids:
875
+ self.id_predictor = nn.Sequential(
876
+ normalization(ch),
877
+ conv_nd(dims, model_channels, n_embed, 1),
878
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
879
+ )
880
+
881
+ def convert_to_fp16(self):
882
+ """
883
+ Convert the torso of the model to float16.
884
+ """
885
+ self.input_blocks.apply(convert_module_to_f16)
886
+ self.middle_block.apply(convert_module_to_f16)
887
+ self.output_blocks.apply(convert_module_to_f16)
888
+
889
+ def convert_to_fp32(self):
890
+ """
891
+ Convert the torso of the model to float32.
892
+ """
893
+ self.input_blocks.apply(convert_module_to_f32)
894
+ self.middle_block.apply(convert_module_to_f32)
895
+ self.output_blocks.apply(convert_module_to_f32)
896
+
897
+ def forward(
898
+ self,
899
+ x,
900
+ timesteps=None,
901
+ context=None,
902
+ y=None,
903
+ camera=None,
904
+ num_frames=1,
905
+ **kwargs,
906
+ ):
907
+ """
908
+ Apply the model to an input batch.
909
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
910
+ :param timesteps: a 1-D batch of timesteps.
911
+ :param context: a dict conditioning plugged in via crossattn
912
+ :param y: an [N] Tensor of labels, if class-conditional, default None.
913
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
914
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
915
+ """
916
+ assert (
917
+ x.shape[0] % num_frames == 0
918
+ ), "[UNet] input batch size must be dividable by num_frames!"
919
+ assert (y is not None) == (
920
+ self.num_classes is not None
921
+ ), "must specify y if and only if the model is class-conditional"
922
+
923
+ hs = []
924
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
925
+ emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
926
+
927
+ if self.num_classes is not None:
928
+ assert y.shape[0] == x.shape[0]
929
+ emb = emb + self.label_emb(y)
930
+
931
+ # Add camera embeddings
932
+ if camera is not None:
933
+ assert camera.shape[0] == emb.shape[0]
934
+ # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
935
+ emb = emb + self.camera_embed(camera)
936
+ ip = kwargs.get("ip", None)
937
+ ip_img = kwargs.get("ip_img", None)
938
+
939
+ if ip_img is not None:
940
+ x[(num_frames-1)::num_frames, :, :, :] = ip_img
941
+
942
+ if ip is not None:
943
+ ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
944
+ context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
945
+
946
+ h = x.type(self.dtype)
947
+ for module in self.input_blocks:
948
+ h = module(h, emb, context, num_frames=num_frames)
949
+ hs.append(h)
950
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
951
+ for module in self.output_blocks:
952
+ h = th.cat([h, hs.pop()], dim=1)
953
+ h = module(h, emb, context, num_frames=num_frames)
954
+ h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
955
+ if self.predict_codebook_ids: # False
956
+ return self.id_predictor(h)
957
+ else:
958
+ return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
959
+
960
+
961
+
962
+
963
+ class MultiViewUNetModelStage2(MultiViewUNetModel):
964
+ """
965
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
966
+ :param in_channels: channels in the input Tensor.
967
+ :param model_channels: base channel count for the model.
968
+ :param out_channels: channels in the output Tensor.
969
+ :param num_res_blocks: number of residual blocks per downsample.
970
+ :param attention_resolutions: a collection of downsample rates at which
971
+ attention will take place. May be a set, list, or tuple.
972
+ For example, if this contains 4, then at 4x downsampling, attention
973
+ will be used.
974
+ :param dropout: the dropout probability.
975
+ :param channel_mult: channel multiplier for each level of the UNet.
976
+ :param conv_resample: if True, use learned convolutions for upsampling and
977
+ downsampling.
978
+ :param dims: determines if the signal is 1D, 2D, or 3D.
979
+ :param num_classes: if specified (as an int), then this model will be
980
+ class-conditional with `num_classes` classes.
981
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
982
+ :param num_heads: the number of attention heads in each attention layer.
983
+ :param num_heads_channels: if specified, ignore num_heads and instead use
984
+ a fixed channel width per attention head.
985
+ :param num_heads_upsample: works with num_heads to set a different number
986
+ of heads for upsampling. Deprecated.
987
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
988
+ :param resblock_updown: use residual blocks for up/downsampling.
989
+ :param use_new_attention_order: use a different attention pattern for potentially
990
+ increased efficiency.
991
+ :param camera_dim: dimensionality of camera input.
992
+ """
993
+
994
+ def __init__(
995
+ self,
996
+ image_size,
997
+ in_channels,
998
+ model_channels,
999
+ out_channels,
1000
+ num_res_blocks,
1001
+ attention_resolutions,
1002
+ dropout=0,
1003
+ channel_mult=(1, 2, 4, 8),
1004
+ conv_resample=True,
1005
+ dims=2,
1006
+ num_classes=None,
1007
+ use_checkpoint=False,
1008
+ use_fp16=False,
1009
+ use_bf16=False,
1010
+ num_heads=-1,
1011
+ num_head_channels=-1,
1012
+ num_heads_upsample=-1,
1013
+ use_scale_shift_norm=False,
1014
+ resblock_updown=False,
1015
+ use_new_attention_order=False,
1016
+ use_spatial_transformer=False, # custom transformer support
1017
+ transformer_depth=1, # custom transformer support
1018
+ context_dim=None, # custom transformer support
1019
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1020
+ legacy=True,
1021
+ disable_self_attentions=None,
1022
+ num_attention_blocks=None,
1023
+ disable_middle_self_attn=False,
1024
+ use_linear_in_transformer=False,
1025
+ adm_in_channels=None,
1026
+ camera_dim=None,
1027
+ with_ip=False, # wether add image prompt images
1028
+ ip_dim=0, # number of extra token, 4 for global 16 for local
1029
+ ip_weight=1.0, # weight for image prompt context
1030
+ ip_mode="local_resample", # which mode of adaptor, global or local
1031
+ ):
1032
+ super().__init__(
1033
+ image_size,
1034
+ in_channels,
1035
+ model_channels,
1036
+ out_channels,
1037
+ num_res_blocks,
1038
+ attention_resolutions,
1039
+ dropout,
1040
+ channel_mult,
1041
+ conv_resample,
1042
+ dims,
1043
+ num_classes,
1044
+ use_checkpoint,
1045
+ use_fp16,
1046
+ use_bf16,
1047
+ num_heads,
1048
+ num_head_channels,
1049
+ num_heads_upsample,
1050
+ use_scale_shift_norm,
1051
+ resblock_updown,
1052
+ use_new_attention_order,
1053
+ use_spatial_transformer,
1054
+ transformer_depth,
1055
+ context_dim,
1056
+ n_embed,
1057
+ legacy,
1058
+ disable_self_attentions,
1059
+ num_attention_blocks,
1060
+ disable_middle_self_attn,
1061
+ use_linear_in_transformer,
1062
+ adm_in_channels,
1063
+ camera_dim,
1064
+ with_ip,
1065
+ ip_dim,
1066
+ ip_weight,
1067
+ ip_mode,
1068
+ )
1069
+
1070
+ def forward(
1071
+ self,
1072
+ x,
1073
+ timesteps=None,
1074
+ context=None,
1075
+ y=None,
1076
+ camera=None,
1077
+ num_frames=1,
1078
+ **kwargs,
1079
+ ):
1080
+ """
1081
+ Apply the model to an input batch.
1082
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
1083
+ :param timesteps: a 1-D batch of timesteps.
1084
+ :param context: a dict conditioning plugged in via crossattn
1085
+ :param y: an [N] Tensor of labels, if class-conditional, default None.
1086
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
1087
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
1088
+ """
1089
+ assert (
1090
+ x.shape[0] % num_frames == 0
1091
+ ), "[UNet] input batch size must be dividable by num_frames!"
1092
+ assert (y is not None) == (
1093
+ self.num_classes is not None
1094
+ ), "must specify y if and only if the model is class-conditional"
1095
+
1096
+ hs = []
1097
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
1098
+ emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
1099
+
1100
+ if self.num_classes is not None:
1101
+ assert y.shape[0] == x.shape[0]
1102
+ emb = emb + self.label_emb(y)
1103
+
1104
+ # Add camera embeddings
1105
+ if camera is not None:
1106
+ assert camera.shape[0] == emb.shape[0]
1107
+ # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
1108
+ emb = emb + self.camera_embed(camera)
1109
+ ip = kwargs.get("ip", None)
1110
+ ip_img = kwargs.get("ip_img", None)
1111
+ pixel_images = kwargs.get("pixel_images", None)
1112
+
1113
+ if ip_img is not None:
1114
+ x[(num_frames-1)::num_frames, :, :, :] = ip_img
1115
+
1116
+ x = torch.cat((x, pixel_images), dim=1)
1117
+
1118
+ if ip is not None:
1119
+ ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
1120
+ context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
1121
+
1122
+ h = x.type(self.dtype)
1123
+ for module in self.input_blocks:
1124
+ h = module(h, emb, context, num_frames=num_frames)
1125
+ hs.append(h)
1126
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1127
+ for module in self.output_blocks:
1128
+ h = th.cat([h, hs.pop()], dim=1)
1129
+ h = module(h, emb, context, num_frames=num_frames)
1130
+ h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
1131
+ if self.predict_codebook_ids: # False
1132
+ return self.id_predictor(h)
1133
+ else:
1134
+ return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
1135