Mariam-Elz commited on
Commit
5764704
·
verified ·
1 Parent(s): 11f10ee

Upload imagedream/ldm/modules/diffusionmodules/adaptors.py with huggingface_hub

Browse files
imagedream/ldm/modules/diffusionmodules/adaptors.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # FFN
9
+ def FeedForward(dim, mult=4):
10
+ inner_dim = int(dim * mult)
11
+ return nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Linear(dim, inner_dim, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(inner_dim, dim, bias=False),
16
+ )
17
+
18
+
19
+ def reshape_tensor(x, heads):
20
+ bs, length, width = x.shape
21
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
+ x = x.view(bs, length, heads, -1)
23
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
+ x = x.transpose(1, 2)
25
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
+ x = x.reshape(bs, heads, length, -1)
27
+ return x
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.dim_head = dim_head
35
+ self.heads = heads
36
+ inner_dim = dim_head * heads
37
+
38
+ self.norm1 = nn.LayerNorm(dim)
39
+ self.norm2 = nn.LayerNorm(dim)
40
+
41
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
+
45
+
46
+ def forward(self, x, latents):
47
+ """
48
+ Args:
49
+ x (torch.Tensor): image features
50
+ shape (b, n1, D)
51
+ latent (torch.Tensor): latent features
52
+ shape (b, n2, D)
53
+ """
54
+ x = self.norm1(x)
55
+ latents = self.norm2(latents)
56
+
57
+ b, l, _ = latents.shape
58
+
59
+ q = self.to_q(latents)
60
+ kv_input = torch.cat((x, latents), dim=-2)
61
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
+
63
+ q = reshape_tensor(q, self.heads)
64
+ k = reshape_tensor(k, self.heads)
65
+ v = reshape_tensor(v, self.heads)
66
+
67
+ # attention
68
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ out = weight @ v
72
+
73
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
+
75
+ return self.to_out(out)
76
+
77
+
78
+ class ImageProjModel(torch.nn.Module):
79
+ """Projection Model"""
80
+ def __init__(self,
81
+ cross_attention_dim=1024,
82
+ clip_embeddings_dim=1024,
83
+ clip_extra_context_tokens=4):
84
+ super().__init__()
85
+ self.cross_attention_dim = cross_attention_dim
86
+ self.clip_extra_context_tokens = clip_extra_context_tokens
87
+
88
+ # from 1024 -> 4 * 1024
89
+ self.proj = torch.nn.Linear(
90
+ clip_embeddings_dim,
91
+ self.clip_extra_context_tokens * cross_attention_dim)
92
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
93
+
94
+ def forward(self, image_embeds):
95
+ embeds = image_embeds
96
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
97
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
98
+ return clip_extra_context_tokens
99
+
100
+
101
+ class SimpleReSampler(nn.Module):
102
+ def __init__(self, embedding_dim=1280, output_dim=1024):
103
+ super().__init__()
104
+ self.proj_out = nn.Linear(embedding_dim, output_dim)
105
+ self.norm_out = nn.LayerNorm(output_dim)
106
+
107
+ def forward(self, latents):
108
+ """
109
+ latents: B 256 N
110
+ """
111
+ latents = self.proj_out(latents)
112
+ return self.norm_out(latents)
113
+
114
+
115
+ class Resampler(nn.Module):
116
+ def __init__(
117
+ self,
118
+ dim=1024,
119
+ depth=8,
120
+ dim_head=64,
121
+ heads=16,
122
+ num_queries=8,
123
+ embedding_dim=768,
124
+ output_dim=1024,
125
+ ff_mult=4,
126
+ ):
127
+ super().__init__()
128
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
129
+ self.proj_in = nn.Linear(embedding_dim, dim)
130
+ self.proj_out = nn.Linear(dim, output_dim)
131
+ self.norm_out = nn.LayerNorm(output_dim)
132
+
133
+ self.layers = nn.ModuleList([])
134
+ for _ in range(depth):
135
+ self.layers.append(
136
+ nn.ModuleList(
137
+ [
138
+ PerceiverAttention(dim=dim,
139
+ dim_head=dim_head,
140
+ heads=heads),
141
+ FeedForward(dim=dim, mult=ff_mult),
142
+ ]
143
+ )
144
+ )
145
+
146
+ def forward(self, x):
147
+ latents = self.latents.repeat(x.size(0), 1, 1)
148
+ x = self.proj_in(x)
149
+ for attn, ff in self.layers:
150
+ latents = attn(x, latents) + latents
151
+ latents = ff(latents) + latents
152
+
153
+ latents = self.proj_out(latents)
154
+ return self.norm_out(latents)
155
+
156
+
157
+ if __name__ == '__main__':
158
+ resampler = Resampler(embedding_dim=1280)
159
+ resampler = SimpleReSampler(embedding_dim=1280)
160
+ tensor = torch.rand(4, 257, 1280)
161
+ embed = resampler(tensor)
162
+ # embed = (tensor)
163
+ print(embed.shape)