File size: 12,712 Bytes
b41a54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import math

from einops import rearrange, repeat

import torch
from torch import nn, einsum
import torch.nn.functional as F

from axial_positional_embedding import AxialPositionalEmbedding
from text2punks.transformer import Transformer


# helpers fns

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def set_requires_grad(model, value):
    for param in model.parameters():
        param.requires_grad = value

def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# sampling helpers fn

def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# main CLIP class

class CLIP(nn.Module):
    def __init__(
        self,
        *,
        dim_text = 512,
        dim_image = 512,
        dim_latent = 512,
        num_text_tokens = 10000,
        text_enc_depth = 6,
        text_seq_len = 256,
        text_heads = 8,
        num_visual_tokens = 256,
        visual_enc_depth = 6,
        visual_image_seq_len = 256,
        visual_image_size = 24,
        visual_heads = 8,
        attn_pdrop = 0.1,
        resid_pdrop = 0.1,
        embd_pdrop = 0.1,
        ff_dropout = 0.1,
        attn_types = None
    ):
        super().__init__()

        # Texts

        self.text_emb = nn.Embedding(num_text_tokens, dim_text)
        self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)

        self.text_transformer = Transformer(
            dim = dim_text,
            causal = False,
            seq_len = text_seq_len,
            depth = text_enc_depth,
            heads = text_heads,
            dim_head = dim_text // text_heads,
            attn_dropout = attn_pdrop,
            resid_dropout = resid_pdrop,
            embd_dropout = embd_pdrop,
            ff_dropout = ff_dropout,
            attn_types = attn_types
        )

        self.text_ln = nn.LayerNorm(dim_text)
        self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)

        # Images

        self.image_emb = nn.Embedding(num_visual_tokens, dim_image)
        self.image_pos_emb = nn.Embedding(visual_image_seq_len, dim_image)

        self.visual_transformer = Transformer(
            dim = dim_image,
            causal = False,
            seq_len = visual_image_seq_len,
            depth = visual_enc_depth,
            heads = visual_heads,
            dim_head = dim_image // visual_heads,
            attn_dropout = attn_pdrop,
            resid_dropout = resid_pdrop,
            embd_dropout = embd_pdrop,
            ff_dropout = ff_dropout,
            attn_types = attn_types,
            image_size = visual_image_size,
        )

        self.image_ln = nn.LayerNorm(dim_image)
        self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)

        self.temperature = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
        

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(
        self,
        text,
        image,
        return_loss = False
    ):
        b, device= text.shape[0], text.device

        text_emb = self.text_emb(text)
        text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))

        image_emb = self.image_emb(image)
        image_emb += self.image_pos_emb(torch.arange(image.shape[1], device = device))

        enc_text = self.text_transformer(text_emb)
        enc_image = self.visual_transformer(image_emb)

        text_latents = enc_text.mean(dim = 1)
        image_latents = enc_image.mean(dim = 1)

        text_latents = self.text_ln(text_latents)
        image_latents = self.image_ln(image_latents)

        text_latents = self.to_text_latent(text_latents)
        image_latents = self.to_visual_latent(image_latents)

        text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents))

        temp = self.temperature.exp()

        if not return_loss:
            sim = einsum('n d, n d -> n', text_latents, image_latents) * temp
            return sim

        sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp
        labels = torch.arange(b, device = device)
        loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
        return loss

# main Text2Punks class

class Text2Punks(nn.Module):
    def __init__(
        self,
        *,
        n_embd,
        n_layer = 12,
        n_head = 12,
        d_head = 64,
        num_text_tokens = 10000,
        text_seq_len = 256,
        num_image_tokens = 222,
        image_seq_len = 576,
        image_size = 24,
        attn_pdrop = 0.1,
        resid_pdrop = 0.1,
        embd_pdrop = 0.1,
        ff_dropout = 0.1,
        attn_types = None,
        loss_img_weight = 7,
        loss_txt_weight = 7,
    ):
        super().__init__()

        num_text_tokens = num_text_tokens + text_seq_len  # reserve unique padding tokens for each position (text seq len)

        self.text_emb = nn.Embedding(num_text_tokens, n_embd)
        self.image_emb = nn.Embedding(num_image_tokens, n_embd)

        self.text_pos_emb = nn.Embedding(text_seq_len + 1, n_embd) # +1 for <bos> a.k.a <sos>
        # self.image_pos_emb = nn.Embedding(image_seq_len, n_embd)
        self.image_pos_emb = nn.Parameter(torch.zeros(1, image_seq_len, n_embd))
        # self.image_pos_emb = AxialPositionalEmbedding(n_embd, axial_shape=(image_size, image_size))

        self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
        self.num_image_tokens = num_image_tokens
        self.text_seq_len = text_seq_len
        self.image_seq_len = image_seq_len

        seq_len = text_seq_len + image_seq_len
        total_tokens = num_text_tokens + num_image_tokens
        self.total_seq_len = seq_len
        self.total_tokens = total_tokens

        self.transformer = Transformer(
            dim = n_embd,
            causal = True,
            seq_len = seq_len,
            depth = n_layer,
            heads = n_head,
            dim_head = d_head,
            attn_dropout = attn_pdrop,
            resid_dropout = resid_pdrop,
            embd_dropout = embd_pdrop,
            ff_dropout = ff_dropout,
            attn_types = attn_types,
            image_size = image_size,
        )

        self.to_logits = nn.Sequential(
            nn.LayerNorm(n_embd),
            nn.Linear(n_embd, self.total_tokens),
        )

        seq_range = torch.arange(seq_len)
        logits_range = torch.arange(total_tokens)

        seq_range = rearrange(seq_range, 'n -> () n ()')
        logits_range = rearrange(logits_range, 'd -> () () d')

        logits_mask = (
            ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
            ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
        )

        self.register_buffer('logits_mask', logits_mask, persistent=False)
        self.loss_img_weight = loss_img_weight
        self.loss_txt_weight = loss_txt_weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    @torch.no_grad()
    @eval_decorator
    def generate_images(
        self,
        text,
        decoder,
        *,
        clip = None,
        filter_thres = 0.5,
        temperature = 1.,
        img = None,
        num_init_img_tokens = None
    ):
        text_seq_len, image_seq_len, num_text_tokens = self.text_seq_len, self.image_seq_len, self.num_text_tokens
        total_len = text_seq_len + image_seq_len

        batch = text.shape[0]
        text = text[:, :text_seq_len] # make sure text is within bounds
        out = text

        if exists(img):
            assert img.shape[1] == image_seq_len, f'input image must have the correct image size {image_seq_len}'

            num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len))  # OpenAI used 14 * 32 initial tokens to prime
            assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length'

            trunc_img = img[:, :num_img_tokens]
            out = torch.cat((out, trunc_img), dim = -1)

        for cur_len in range(out.shape[1], total_len):
            is_image = cur_len >= text_seq_len

            text, image = out[:, :text_seq_len], out[:, text_seq_len:]

            logits = self(text, image)[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim = -1)
            sample = torch.multinomial(probs, 1)

            sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
            out = torch.cat((out, sample), dim=-1)

        text_seq = out[:, :text_seq_len]
        img_seq = out[:, -image_seq_len:]

        scores = None
        if exists(clip):
            scores = clip(text_seq, img_seq, return_loss = False)

        img_seq = repeat(img_seq, 'b p -> b p c', c=3)
        decoder = repeat(decoder, 'p c -> b p c', b=batch)
        images = torch.gather(decoder, 1, img_seq)
        images = rearrange(images, 'b (h w) c-> b c h w', h=24, w =24)
        images = images.float()

        return images, scores

    def forward(
        self,
        text,
        image = None,
        return_loss = False
    ):
        assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
        device, total_seq_len = text.device, self.total_seq_len

        text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len)
        text = torch.where(text == 0, text_range, text)

        text = F.pad(text, (1, 0), value = 0) # add <bos>

        tokens = self.text_emb(text)
        tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))

        seq_len = tokens.shape[1]
        
        image_len = image.shape[1]
        image_emb = self.image_emb(image)
        # image_emb += self.image_pos_emb(torch.arange(image_len, device = device))
        image_emb += self.image_pos_emb[:, :image_len, :]

        # image_emb += self.image_pos_emb(image_emb)

        tokens = torch.cat((tokens, image_emb), dim = 1)

        seq_len += image_len

        # when training, if the length exceeds the total text + image length
        # remove the last token, since it needs not to be trained

        if tokens.shape[1] > total_seq_len:
            seq_len -= 1
            tokens = tokens[:, :-1]

        out = self.transformer(tokens)
        logits = self.to_logits(out)

        # mask logits to make sure text predicts text (except last token), and image predicts image

        logits_mask = self.logits_mask[:, :seq_len]
        max_neg_value = -torch.finfo(logits.dtype).max
        logits.masked_fill_(logits_mask, max_neg_value)

        if not return_loss:
            return logits

        assert exists(image), 'when training, image must be supplied'

        offsetted_image = image + self.num_text_tokens
        labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)

        logits = rearrange(logits, 'b n c -> b c n')

        loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len])
        loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:])

        loss = (self.loss_txt_weight * loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + self.loss_txt_weight)
        return loss, loss_text, loss_img