File size: 15,425 Bytes
10c688c
 
 
 
 
42296ae
 
10c688c
 
 
 
 
 
 
 
42296ae
 
 
 
 
 
 
10c688c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42296ae
10c688c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42296ae
 
10c688c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42296ae
10c688c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
import torch
import torch.nn as nn
import math


class ExtractPatches(nn.Module):
    def __init__(self, patch_size: int = 16):
        super().__init__()

        self.patch_size = patch_size
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        batch_size, c, h, w = x.shape

        # Unfold applies a slding window to generate patches
        # The transpose and reshape change the shape to (batch_size, num_patches, 3 * patch_size * patch_size), flattening the patches
        return (
            self.unfold(x)
            .transpose(1, 2)
            .reshape(batch_size, -1, c * self.patch_size * self.patch_size)
        )


# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int):
        """
        d_model: dimensions of the embeddings (number of values in each embedding vector)
        """
        super().__init__()

        # Intead of precomputing fixed values, we will compute in the forward pass based off of the sinusodiual encoding formula
        self.d_model = d_model

    def forward(self, x):
        device = x.device
        half_dim = self.d_model // 2  # Use half for sin and half for cos
        emb = math.log(10000.0) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]  # (batch_size, half_dim)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


# Multi-Head Self-Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
        """
        d_model: dimensions of the embeddings (number of values in each embedding vector)
        n_heads: number of self attention heads per sequence
        dropout: probability of dropout
        """
        super().__init__()
        assert (
            d_model % n_heads == 0
        )  # We want to make sure that the dimensions are split evenly among the attention heads.
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_key = d_model // n_heads

        self.Wq = nn.Linear(d_model, d_model)  # Learnable weights for query
        self.Wk = nn.Linear(d_model, d_model)  # Learnable weights for key
        self.Wv = nn.Linear(d_model, d_model)  # Learnable weights for value
        self.Wo = nn.Linear(d_model, d_model)  # Learnable weights for output

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        """
        query: (batch_size, q_length, d_model)
        key: (batch_size, k_length, d_model)
        value: (batch_size, s_length, d_model)
        """
        batch_size = key.size(0)

        # Matrix multiplication for Q, K, and V tensors
        Q = self.Wq(query)
        K = self.Wk(key)
        V = self.Wv(value)

        # Split each tensor into heads
        Q = Q.view(batch_size, -1, self.n_heads, self.d_key).permute(
            0, 2, 1, 3
        )  # (batch_size, n_heads, q_length, d_key)
        K = K.view(batch_size, -1, self.n_heads, self.d_key).permute(
            0, 2, 1, 3
        )  # (batch_size, n_heads, k_length, d_key)
        V = V.view(batch_size, -1, self.n_heads, self.d_key).permute(
            0, 2, 1, 3
        )  # (batch_size, n_heads, v_length, d_key)

        # Scaled dot product
        # K^T becomees (batch_size, n_heads, d_key, k_length)
        scaled_dot_product = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(
            self.d_key
        )  # (batch_size, n_heads, q_length, k_length)

        if mask is not None:
            scaled_dot_product = scaled_dot_product.masked_fill(
                mask == 0, -float("inf")
            )  # Filling it with 0 would result in 1 after the mask because e^0 = 1. Intead we fill it with an infinitley large negative number

        # Softmax function for attention probabilities
        attention_probs = torch.softmax(scaled_dot_product, dim=-1)

        # Multiply by V to get attention with respect to the values
        A = torch.matmul(self.dropout(attention_probs), V)

        # Reshape attention back to (batch_size, q_length, d_model)
        A = (
            A.permute(0, 2, 1, 3)
            .contiguous()
            .view(batch_size, -1, self.n_heads * self.d_key)
        )

        # Pass through the final linear layer
        output = self.Wo(A)

        return (
            output,
            attention_probs,
        )  # Output shape: (batch_size, q_length, d_model), Attention probs shape: (batch_size, n_heads, q_length, k_length)


# Position-Wise Feed Forward Network (FFN)
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1):
        """
        d_model: dimensions of the embeddings (number of values in each embedding vector)
        dropout: probability of dropout
        """
        super().__init__()

        self.ffn = nn.Sequential(
            nn.Linear(in_features=d_model, out_features=(d_model * 4)),
            nn.ReLU(),
            nn.Linear(in_features=(d_model * 4), out_features=d_model),
            nn.Dropout(p=dropout),
        )

    def forward(self, x):
        return self.ffn(x)


# Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        """
        d_model: dimensions of the embeddings (number of values in each embedding vector)
        n_heads: number of self attention heads per sequence
        dropout: probability of dropout
        """
        super().__init__()

        # Multi-Head Self-Attention sublayer
        self.attention = MultiHeadAttention(
            d_model=d_model, n_heads=n_heads, dropout=dropout
        )
        self.attention_layer_norm = nn.LayerNorm(d_model)  # Layer normalization

        # Position-wise Feed-forward Network
        self.position_wise_ffn = PositionwiseFeedForward(
            d_model=d_model, dropout=dropout
        )
        self.ffn_layer_norm = nn.LayerNorm(d_model)  # Layer normalization

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, src):
        """
        src: embedded sequences (batch_size, seq_length, d_model)
        """
        # Multi-Head Attention

        _src, attention_probs = self.attention(
            src, src, src, None
        )  # Q, K, V, src_mask: we don't need a source mask because all images are the same dimension

        # Residual Addition and Layer Normalization
        src = self.attention_layer_norm(
            src + self.dropout(_src)
        )  # We do residual addition by adding back the src (the embeddings) to the output of Self-Attention

        # Position-wise Feed-forward Network
        _src = self.position_wise_ffn(src)

        # Residual Addition and Layer Normalization
        src = self.ffn_layer_norm(src + self.dropout(_src))

        return src, attention_probs


# The Encoder that takes in images and returns the encoding to be passed into the decoder
class Encoder(nn.Module):
    def __init__(
        self,
        image_size: int,
        in_channels: int,
        patch_size: int = 16,
        d_model: int = 128,
        n_layers: int = 3,
        n_heads: int = 4,
        dropout: float = 0.1,
    ):
        """
        d_model: dimensions of the embeddings (number of values in each embedding vector)
        n_layers: number of encoder layers in the encoder block
        n_heads: number of self attention heads per sequence
        dropout: probability of dropout
        """
        super().__init__()

        self.patch_size = patch_size

        self.extract_patches = ExtractPatches(patch_size=patch_size)
        self.fc_in = nn.Linear(in_channels * patch_size * patch_size, d_model)

        seq_length = (image_size // patch_size) ** 2

        # Image src is going to use a learnable positional encoding
        self.pos_embedding = nn.Parameter(
            torch.empty(1, seq_length, d_model).normal_(std=0.02)
        )

        # Create n_layers encoders
        self.layers = nn.ModuleList(
            [
                EncoderLayer(d_model=d_model, n_heads=n_heads, dropout=dropout)
                for layer in range(n_layers)
            ]
        )
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, src):
        """
        src: embedded sequences (batch_size, seq_length, d_model)
        """

        # Extract the patches and apply a linear layer
        batch_size = src.shape[0]
        src = self.fc_in(self.extract_patches(src))

        # Add the learned positional embedding
        src = src + self.pos_embedding

        # Pass the sequences through each encoder layer
        for layer in self.layers:
            src, attention_probs = layer(src)

        self.attention_probs = attention_probs

        return src


# Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        """
        d_model: dimensions of the embeddings (number of values in each embedding vector)
        n_heads: number of self attention heads per sequence
        dropout: probability of dropout
        """
        super().__init__()

        # Masked Multi-Head Self-Attention sublayer
        self.masked_attention = MultiHeadAttention(
            d_model=d_model, n_heads=n_heads, dropout=dropout
        )
        self.masked_attention_layer_norm = nn.LayerNorm(d_model)  # Layer normalization

        # Multi-Head Self-Attention sublayer
        self.attention = MultiHeadAttention(
            d_model=d_model, n_heads=n_heads, dropout=dropout
        )
        self.attention_layer_norm = nn.LayerNorm(d_model)  # Layer normalization

        # Position-wise Feed-forward Network
        self.position_wise_ffn = PositionwiseFeedForward(
            d_model=d_model, dropout=dropout
        )
        self.ffn_layer_norm = nn.LayerNorm(d_model)  # Layer normalization

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, trg, src, trg_mask):
        """
        trg: embedded captions (batch_size, trg_seq_length, d_model)
        src: embedded images (batch_size, src_seq_length, d_model)
        trg_mask: mask for the captions preventing peeking at future tokens (batch_size, 1, trg_seq_length, trg_seq_length)
        """

        # Masked Multi-Head Attention

        # The target mask is used to prevent the model from seeing future tokens. This ensures that the prediction is made solely based on past and present tokens.
        _trg, masked_attention_probs = self.masked_attention(
            trg, trg, trg, trg_mask
        )  # Q, K, V, mask

        # Residual Addition and Layer Normalization
        trg = self.masked_attention_layer_norm(trg + self.dropout(_trg))

        # Multi-Head Attention - This time, we also pass in the output of the encoder layers as src.
        # This is important because this allows us to keep track of and learn relationships between the input and output tokens.
        _trg, attention_probs = self.attention(trg, src, src, None)  # Q, K, V, mask
        # Residual Addition and Layer Normalization
        trg = self.attention_layer_norm(trg + self.dropout(_trg))

        # Position-wise Feed-forward Network
        _trg = self.position_wise_ffn(trg)
        # Residual Addition and Layer Normalization
        trg = self.ffn_layer_norm(trg + self.dropout(_trg))

        return trg, attention_probs, masked_attention_probs


# The Decoder Module that takes the encoded images from the encoder and generates captions
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        n_layers: int = 3,
        n_heads: int = 4,
        dropout: float = 0.1,
    ):
        """
        vocab_size: size of the target vocabulary
        d_model: dimensions of the embeddings (number of values in each embedding vector)
        n_layers: number of encoder layers in the encoder block
        n_heads: number of self attention heads per sequence
        dropout: probability of dropout
        """
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, d_model)

        self.embedding.weight.data = 0.001 * self.embedding.weight.data

        # Initialize sinusoidal positional embeddings
        self.pos_emb = PositionalEncoding(d_model=d_model)

        # Create n_layers decoders
        self.layers = nn.ModuleList(
            [
                DecoderLayer(d_model=d_model, n_heads=n_heads, dropout=dropout)
                for layer in range(n_layers)
            ]
        )
        self.dropout = nn.Dropout(p=dropout)

        # Output layer
        self.Wo = nn.Linear(in_features=d_model, out_features=vocab_size)

    def make_trg_mask(self, trg):
        seq_length = trg.shape[1]

        trg_mask = torch.tril(
            torch.ones((seq_length, seq_length), device=trg.device)
        ).bool()

        return trg_mask.unsqueeze(0).unsqueeze(
            0
        )  # (batch_size=1, n_heads=1, seq_length, seq_length)

    def forward(self, trg, src):
        """
        trg: target sequences (batch_size, trg_seq_length, d_model)
        src: embedding images (batch_size, src_seq_length, d_model)
        """

        # Embed the target captions
        trg = self.embedding(trg)
        batch_size, l, h = trg.shape

        trg_index = torch.arange(l, device=trg.device)
        pos_emb = self.pos_emb(trg_index).reshape(1, l, h).expand(batch_size, l, h)
        # Add the fixed sinusodial positional embedding
        trg += pos_emb

        # Create a target mask for the target captions to prevent the model from peeking at future tokens
        trg_mask = self.make_trg_mask(
            trg
        )  # (batch_size, 1, trg_seq_length, trg_seq_length)

        # Pass the sequences through each decoder layer
        for layer in self.layers:
            trg, attention_probs, masked_attention_probs = layer(trg, src, trg_mask)

        self.attention_probs = attention_probs
        self.masked_attention_probs = masked_attention_probs  # (batch_size, n_heads, trg_seq_len, src_seq_len) trg_seq_len: length of the target caption \ src_seq_len: number of patches from the encoder

        # Final linear output layer
        return self.Wo(trg)


class CaptioningTransformer(nn.Module):
    def __init__(
        self,
        image_size: int,
        in_channels: int,
        vocab_size: int,
        device,
        patch_size: int = 16,
        d_model: int = 128,
        n_layers: int = 3,
        n_heads: int = 4,
    ):
        super().__init__()

        self.device = device

        # Create an encoder and decoder with specified parameters
        self.encoder = Encoder(
            image_size=image_size,
            in_channels=in_channels,
            patch_size=patch_size,
            d_model=d_model,
            n_layers=n_layers,
            n_heads=n_heads,
        )

        self.decoder = Decoder(
            vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_heads=n_heads
        )

    def forward(self, src, trg):
        # Encoder layers
        src = self.encoder(src)  # (batch_size, src_seq_length, d_model)

        # Decoder layers
        output = self.decoder(
            trg, src
        )  # Pass in both the target (for Masked Multi-Head Self-Attention) and source for (Cross-Attention)

        return output