Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- app.py +87 -0
- image_captioning_model.pt +3 -0
- model.py +442 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from PIL import Image
|
4 |
+
import gradio as gr
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from model import CaptioningTransformer
|
7 |
+
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
image_size = 128
|
10 |
+
patch_size = 8
|
11 |
+
d_model = 192
|
12 |
+
n_layers = 6
|
13 |
+
n_heads = 8
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
16 |
+
|
17 |
+
transform = transforms.Compose(
|
18 |
+
[
|
19 |
+
transforms.Resize(image_size),
|
20 |
+
transforms.CenterCrop(image_size),
|
21 |
+
transforms.ToTensor(),
|
22 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
23 |
+
]
|
24 |
+
)
|
25 |
+
|
26 |
+
# Instantiate your model
|
27 |
+
model = CaptioningTransformer(
|
28 |
+
image_size=image_size,
|
29 |
+
in_channels=3, # RGB images
|
30 |
+
vocab_size=tokenizer.vocab_size,
|
31 |
+
device=device,
|
32 |
+
patch_size=patch_size,
|
33 |
+
n_layers=n_layers,
|
34 |
+
d_model=d_model,
|
35 |
+
n_heads=n_heads,
|
36 |
+
).to(device)
|
37 |
+
|
38 |
+
# Load your pre-trained weights (make sure the .pt file is in your repo)
|
39 |
+
model_path = "image_captioning_model.pt"
|
40 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
41 |
+
model.eval()
|
42 |
+
|
43 |
+
|
44 |
+
# This is your existing inference function (you can modify as needed)
|
45 |
+
def make_prediction(model, sos_token, eos_token, image, max_len, temp, device):
|
46 |
+
log_tokens = [sos_token] # Start with the start-of-sequence token
|
47 |
+
with torch.inference_mode():
|
48 |
+
# Get image embeddings from the encoder
|
49 |
+
image_embedding = model.encoder(image.to(device))
|
50 |
+
for _ in range(max_len):
|
51 |
+
input_tokens = torch.cat(log_tokens, dim=1)
|
52 |
+
data_pred = model.decoder(input_tokens.to(device), image_embedding)
|
53 |
+
# Get the logits for the most recent token only
|
54 |
+
dist = torch.distributions.Categorical(logits=data_pred[:, -1] / temp)
|
55 |
+
next_tokens = dist.sample().reshape(1, 1)
|
56 |
+
log_tokens.append(next_tokens.cpu())
|
57 |
+
if next_tokens.item() == 102: # Assuming 102 is your [SEP] token
|
58 |
+
break
|
59 |
+
return torch.cat(log_tokens, dim=1)
|
60 |
+
|
61 |
+
|
62 |
+
# Define the Gradio prediction function
|
63 |
+
def predict(image: Image.Image):
|
64 |
+
# Preprocess the image
|
65 |
+
img_tensor = transform(image).unsqueeze(0) # Shape: (1, 3, image_size, image_size)
|
66 |
+
# Create a start-of-sequence token (assuming 101 is your [CLS] token)
|
67 |
+
sos_token = 101 * torch.ones(1, 1).long().to(device)
|
68 |
+
# Generate caption tokens using your inference function
|
69 |
+
tokens = make_prediction(
|
70 |
+
model, sos_token, 102, img_tensor, max_len=50, temp=0.5, device=device
|
71 |
+
)
|
72 |
+
# Decode tokens to text (skipping special tokens)
|
73 |
+
caption = tokenizer.decode(tokens[0], skip_special_tokens=True)
|
74 |
+
return caption
|
75 |
+
|
76 |
+
|
77 |
+
# Create a Gradio interface
|
78 |
+
iface = gr.Interface(
|
79 |
+
fn=predict,
|
80 |
+
inputs=gr.Image(type="pil"),
|
81 |
+
outputs="text",
|
82 |
+
title="Image Captioning Model",
|
83 |
+
description="Upload an image and get a caption generated by the model.",
|
84 |
+
)
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
iface.launch()
|
image_captioning_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:526b7cb3a1a70d6bb5503629b69e9d664efd0ba8f22a7cc1d035b9a42f6abc24
|
3 |
+
size 72371272
|
model.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
class PatchEmbedding(nn.Module):
|
7 |
+
def __init__(self, in_channels: int = 3, patch_size: int = 16, d_model: int = 128):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.patch_size = patch_size
|
11 |
+
self.d_model = d_model
|
12 |
+
|
13 |
+
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
14 |
+
self.proj = nn.Linear(in_channels * patch_size * patch_size, d_model)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
batch_size, c, h, w = x.shape
|
18 |
+
|
19 |
+
# Unfold to extract patches: shape becomes (batch_size, in_channels * patch_size * patch_size, num_patches)
|
20 |
+
# num_patches = (H / patch_size) * (W / patch_size)
|
21 |
+
patches = self.unfold(x)
|
22 |
+
|
23 |
+
# Transpose to (batch_size, num_patches, in_channels * patch_size * patch_size)
|
24 |
+
patches = patches.transpose(1, 2)
|
25 |
+
|
26 |
+
# Apply linear projection to each patch: (batch_size, num_patches, in_channels * patch_size * patch_size) -> (batch_size, num_patches, d_model)
|
27 |
+
return self.proj(patches)
|
28 |
+
|
29 |
+
|
30 |
+
# Positional Encoding
|
31 |
+
class PositionalEncoding(nn.Module):
|
32 |
+
def __init__(self, d_model: int):
|
33 |
+
"""
|
34 |
+
d_model: dimensions of the embeddings (number of values in each embedding vector)
|
35 |
+
"""
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
# Intead of precomputing fixed values, we will compute in the forward pass based off of the sinusodiual encoding formula
|
39 |
+
self.d_model = d_model
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
device = x.device
|
43 |
+
half_dim = self.d_model // 2 # Use half for sin and half for cos
|
44 |
+
emb = math.log(10000.0) / (half_dim - 1)
|
45 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
46 |
+
emb = x[:, None] * emb[None, :] # (batch_size, half_dim)
|
47 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
48 |
+
return emb
|
49 |
+
|
50 |
+
|
51 |
+
# Multi-Head Self-Attention
|
52 |
+
class MultiHeadAttention(nn.Module):
|
53 |
+
def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
|
54 |
+
"""
|
55 |
+
d_model: dimensions of the embeddings (number of values in each embedding vector)
|
56 |
+
n_heads: number of self attention heads per sequence
|
57 |
+
dropout: probability of dropout
|
58 |
+
"""
|
59 |
+
super().__init__()
|
60 |
+
assert (
|
61 |
+
d_model % n_heads == 0
|
62 |
+
) # We want to make sure that the dimensions are split evenly among the attention heads.
|
63 |
+
self.d_model = d_model
|
64 |
+
self.n_heads = n_heads
|
65 |
+
self.d_key = d_model // n_heads
|
66 |
+
|
67 |
+
self.Wq = nn.Linear(d_model, d_model) # Learnable weights for query
|
68 |
+
self.Wk = nn.Linear(d_model, d_model) # Learnable weights for key
|
69 |
+
self.Wv = nn.Linear(d_model, d_model) # Learnable weights for value
|
70 |
+
self.Wo = nn.Linear(d_model, d_model) # Learnable weights for output
|
71 |
+
|
72 |
+
self.dropout = nn.Dropout(p=dropout)
|
73 |
+
|
74 |
+
def forward(self, query, key, value, mask=None):
|
75 |
+
"""
|
76 |
+
query: (batch_size, q_length, d_model)
|
77 |
+
key: (batch_size, k_length, d_model)
|
78 |
+
value: (batch_size, s_length, d_model)
|
79 |
+
"""
|
80 |
+
batch_size = key.size(0)
|
81 |
+
|
82 |
+
# Matrix multiplication for Q, K, and V tensors
|
83 |
+
Q = self.Wq(query)
|
84 |
+
K = self.Wk(key)
|
85 |
+
V = self.Wv(value)
|
86 |
+
|
87 |
+
# Split each tensor into heads
|
88 |
+
Q = Q.view(batch_size, -1, self.n_heads, self.d_key).permute(
|
89 |
+
0, 2, 1, 3
|
90 |
+
) # (batch_size, n_heads, q_length, d_key)
|
91 |
+
K = K.view(batch_size, -1, self.n_heads, self.d_key).permute(
|
92 |
+
0, 2, 1, 3
|
93 |
+
) # (batch_size, n_heads, k_length, d_key)
|
94 |
+
V = V.view(batch_size, -1, self.n_heads, self.d_key).permute(
|
95 |
+
0, 2, 1, 3
|
96 |
+
) # (batch_size, n_heads, v_length, d_key)
|
97 |
+
|
98 |
+
# Scaled dot product
|
99 |
+
# K^T becomees (batch_size, n_heads, d_key, k_length)
|
100 |
+
scaled_dot_product = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(
|
101 |
+
self.d_key
|
102 |
+
) # (batch_size, n_heads, q_length, k_length)
|
103 |
+
|
104 |
+
if mask is not None:
|
105 |
+
scaled_dot_product = scaled_dot_product.masked_fill(
|
106 |
+
mask == 0, -float("inf")
|
107 |
+
) # Filling it with 0 would result in 1 after the mask because e^0 = 1. Intead we fill it with an infinitley large negative number
|
108 |
+
|
109 |
+
# Softmax function for attention probabilities
|
110 |
+
attention_probs = torch.softmax(scaled_dot_product, dim=-1)
|
111 |
+
|
112 |
+
# Multiply by V to get attention with respect to the values
|
113 |
+
A = torch.matmul(self.dropout(attention_probs), V)
|
114 |
+
|
115 |
+
# Reshape attention back to (batch_size, q_length, d_model)
|
116 |
+
A = (
|
117 |
+
A.permute(0, 2, 1, 3)
|
118 |
+
.contiguous()
|
119 |
+
.view(batch_size, -1, self.n_heads * self.d_key)
|
120 |
+
)
|
121 |
+
|
122 |
+
# Pass through the final linear layer
|
123 |
+
output = self.Wo(A)
|
124 |
+
|
125 |
+
return (
|
126 |
+
output,
|
127 |
+
attention_probs,
|
128 |
+
) # Output shape: (batch_size, q_length, d_model), Attention probs shape: (batch_size, n_heads, q_length, k_length)
|
129 |
+
|
130 |
+
|
131 |
+
# Position-Wise Feed Forward Network (FFN)
|
132 |
+
class PositionwiseFeedForward(nn.Module):
|
133 |
+
def __init__(self, d_model: int, dropout: float = 0.1):
|
134 |
+
"""
|
135 |
+
d_model: dimensions of the embeddings (number of values in each embedding vector)
|
136 |
+
dropout: probability of dropout
|
137 |
+
"""
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
self.ffn = nn.Sequential(
|
141 |
+
nn.Linear(in_features=d_model, out_features=(d_model * 4)),
|
142 |
+
nn.GELU(),
|
143 |
+
nn.Linear(in_features=(d_model * 4), out_features=d_model),
|
144 |
+
nn.Dropout(p=dropout),
|
145 |
+
)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
return self.ffn(x)
|
149 |
+
|
150 |
+
|
151 |
+
# Encoder Layer
|
152 |
+
class EncoderLayer(nn.Module):
|
153 |
+
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
|
154 |
+
"""
|
155 |
+
d_model: dimensions of the embeddings (number of values in each embedding vector)
|
156 |
+
n_heads: number of self attention heads per sequence
|
157 |
+
dropout: probability of dropout
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
|
161 |
+
# Multi-Head Self-Attention sublayer
|
162 |
+
self.attention = MultiHeadAttention(
|
163 |
+
d_model=d_model, n_heads=n_heads, dropout=dropout
|
164 |
+
)
|
165 |
+
self.attention_layer_norm = nn.LayerNorm(d_model) # Layer normalization
|
166 |
+
|
167 |
+
# Position-wise Feed-forward Network
|
168 |
+
self.position_wise_ffn = PositionwiseFeedForward(
|
169 |
+
d_model=d_model, dropout=dropout
|
170 |
+
)
|
171 |
+
self.ffn_layer_norm = nn.LayerNorm(d_model) # Layer normalization
|
172 |
+
|
173 |
+
self.dropout = nn.Dropout(p=dropout)
|
174 |
+
|
175 |
+
def forward(self, src):
|
176 |
+
"""
|
177 |
+
src: embedded sequences (batch_size, seq_length, d_model)
|
178 |
+
"""
|
179 |
+
# Multi-Head Attention
|
180 |
+
|
181 |
+
_src, attention_probs = self.attention(
|
182 |
+
src, src, src, None
|
183 |
+
) # Q, K, V, src_mask: we don't need a source mask because all images are the same dimension
|
184 |
+
|
185 |
+
# Residual Addition and Layer Normalization
|
186 |
+
src = self.attention_layer_norm(
|
187 |
+
src + self.dropout(_src)
|
188 |
+
) # We do residual addition by adding back the src (the embeddings) to the output of Self-Attention
|
189 |
+
|
190 |
+
# Position-wise Feed-forward Network
|
191 |
+
_src = self.position_wise_ffn(src)
|
192 |
+
|
193 |
+
# Residual Addition and Layer Normalization
|
194 |
+
src = self.ffn_layer_norm(src + self.dropout(_src))
|
195 |
+
|
196 |
+
return src, attention_probs
|
197 |
+
|
198 |
+
|
199 |
+
# The Encoder that takes in images and returns the encoding to be passed into the decoder
|
200 |
+
class Encoder(nn.Module):
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
image_size: int,
|
204 |
+
in_channels: int,
|
205 |
+
patch_size: int = 16,
|
206 |
+
d_model: int = 128,
|
207 |
+
n_layers: int = 3,
|
208 |
+
n_heads: int = 4,
|
209 |
+
dropout: float = 0.1,
|
210 |
+
):
|
211 |
+
"""
|
212 |
+
d_model: dimensions of the embeddings (number of values in each embedding vector)
|
213 |
+
n_layers: number of encoder layers in the encoder block
|
214 |
+
n_heads: number of self attention heads per sequence
|
215 |
+
dropout: probability of dropout
|
216 |
+
"""
|
217 |
+
super().__init__()
|
218 |
+
|
219 |
+
self.patch_size = patch_size
|
220 |
+
|
221 |
+
self.patch_emb = PatchEmbedding(
|
222 |
+
patch_size=patch_size, in_channels=in_channels, d_model=d_model
|
223 |
+
)
|
224 |
+
|
225 |
+
seq_length = (image_size // patch_size) ** 2
|
226 |
+
|
227 |
+
# Image src is going to use a learnable positional encoding
|
228 |
+
self.pos_embedding = nn.Parameter(
|
229 |
+
torch.empty(1, seq_length, d_model).normal_(std=0.02)
|
230 |
+
)
|
231 |
+
|
232 |
+
# Create n_layers encoders
|
233 |
+
self.layers = nn.ModuleList(
|
234 |
+
[
|
235 |
+
EncoderLayer(d_model=d_model, n_heads=n_heads, dropout=dropout)
|
236 |
+
for layer in range(n_layers)
|
237 |
+
]
|
238 |
+
)
|
239 |
+
self.dropout = nn.Dropout(p=dropout)
|
240 |
+
|
241 |
+
def forward(self, src):
|
242 |
+
"""
|
243 |
+
src: embedded sequences (batch_size, seq_length, d_model)
|
244 |
+
"""
|
245 |
+
|
246 |
+
# Extract the patches and apply a linear layer
|
247 |
+
batch_size = src.shape[0]
|
248 |
+
src = self.patch_emb(src)
|
249 |
+
|
250 |
+
# Add the learned positional embedding
|
251 |
+
src = src + self.pos_embedding
|
252 |
+
|
253 |
+
# Pass the sequences through each encoder layer
|
254 |
+
for layer in self.layers:
|
255 |
+
src, attention_probs = layer(src)
|
256 |
+
|
257 |
+
self.attention_probs = attention_probs
|
258 |
+
|
259 |
+
return src
|
260 |
+
|
261 |
+
|
262 |
+
# Decoder Layer
|
263 |
+
class DecoderLayer(nn.Module):
|
264 |
+
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
|
265 |
+
"""
|
266 |
+
d_model: dimensions of the embeddings (number of values in each embedding vector)
|
267 |
+
n_heads: number of self attention heads per sequence
|
268 |
+
dropout: probability of dropout
|
269 |
+
"""
|
270 |
+
super().__init__()
|
271 |
+
|
272 |
+
# Masked Multi-Head Self-Attention sublayer
|
273 |
+
self.masked_attention = MultiHeadAttention(
|
274 |
+
d_model=d_model, n_heads=n_heads, dropout=dropout
|
275 |
+
)
|
276 |
+
self.masked_attention_layer_norm = nn.LayerNorm(d_model) # Layer normalization
|
277 |
+
|
278 |
+
# Multi-Head Self-Attention sublayer
|
279 |
+
self.attention = MultiHeadAttention(
|
280 |
+
d_model=d_model, n_heads=n_heads, dropout=dropout
|
281 |
+
)
|
282 |
+
self.attention_layer_norm = nn.LayerNorm(d_model) # Layer normalization
|
283 |
+
|
284 |
+
# Position-wise Feed-forward Network
|
285 |
+
self.position_wise_ffn = PositionwiseFeedForward(
|
286 |
+
d_model=d_model, dropout=dropout
|
287 |
+
)
|
288 |
+
self.ffn_layer_norm = nn.LayerNorm(d_model) # Layer normalization
|
289 |
+
|
290 |
+
self.dropout = nn.Dropout(p=dropout)
|
291 |
+
|
292 |
+
def forward(self, trg, src, trg_mask):
|
293 |
+
"""
|
294 |
+
trg: embedded captions (batch_size, trg_seq_length, d_model)
|
295 |
+
src: embedded images (batch_size, src_seq_length, d_model)
|
296 |
+
trg_mask: mask for the captions preventing peeking at future tokens (batch_size, 1, trg_seq_length, trg_seq_length)
|
297 |
+
"""
|
298 |
+
|
299 |
+
# Masked Multi-Head Attention
|
300 |
+
|
301 |
+
# The target mask is used to prevent the model from seeing future tokens. This ensures that the prediction is made solely based on past and present tokens.
|
302 |
+
_trg, masked_attention_probs = self.masked_attention(
|
303 |
+
trg, trg, trg, trg_mask
|
304 |
+
) # Q, K, V, mask
|
305 |
+
|
306 |
+
# Residual Addition and Layer Normalization
|
307 |
+
trg = self.masked_attention_layer_norm(trg + self.dropout(_trg))
|
308 |
+
|
309 |
+
# Multi-Head Attention - This time, we also pass in the output of the encoder layers as src.
|
310 |
+
# This is important because this allows us to keep track of and learn relationships between the input and output tokens.
|
311 |
+
_trg, attention_probs = self.attention(trg, src, src, None) # Q, K, V, mask
|
312 |
+
# Residual Addition and Layer Normalization
|
313 |
+
trg = self.attention_layer_norm(trg + self.dropout(_trg))
|
314 |
+
|
315 |
+
# Position-wise Feed-forward Network
|
316 |
+
_trg = self.position_wise_ffn(trg)
|
317 |
+
# Residual Addition and Layer Normalization
|
318 |
+
trg = self.ffn_layer_norm(trg + self.dropout(_trg))
|
319 |
+
|
320 |
+
return trg, attention_probs, masked_attention_probs
|
321 |
+
|
322 |
+
|
323 |
+
# The Decoder Module that takes the encoded images from the encoder and generates captions
|
324 |
+
class Decoder(nn.Module):
|
325 |
+
def __init__(
|
326 |
+
self,
|
327 |
+
vocab_size: int,
|
328 |
+
d_model: int = 128,
|
329 |
+
n_layers: int = 3,
|
330 |
+
n_heads: int = 4,
|
331 |
+
dropout: float = 0.1,
|
332 |
+
):
|
333 |
+
"""
|
334 |
+
vocab_size: size of the target vocabulary
|
335 |
+
d_model: dimensions of the embeddings (number of values in each embedding vector)
|
336 |
+
n_layers: number of encoder layers in the encoder block
|
337 |
+
n_heads: number of self attention heads per sequence
|
338 |
+
dropout: probability of dropout
|
339 |
+
"""
|
340 |
+
super().__init__()
|
341 |
+
|
342 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
343 |
+
|
344 |
+
self.embedding.weight.data = 0.001 * self.embedding.weight.data
|
345 |
+
|
346 |
+
# Initialize sinusoidal positional embeddings
|
347 |
+
self.pos_emb = PositionalEncoding(d_model=d_model)
|
348 |
+
|
349 |
+
# Create n_layers decoders
|
350 |
+
self.layers = nn.ModuleList(
|
351 |
+
[
|
352 |
+
DecoderLayer(d_model=d_model, n_heads=n_heads, dropout=dropout)
|
353 |
+
for layer in range(n_layers)
|
354 |
+
]
|
355 |
+
)
|
356 |
+
self.dropout = nn.Dropout(p=dropout)
|
357 |
+
|
358 |
+
# Output layer
|
359 |
+
self.Wo = nn.Linear(in_features=d_model, out_features=vocab_size)
|
360 |
+
|
361 |
+
def make_trg_mask(self, trg):
|
362 |
+
seq_length = trg.shape[1]
|
363 |
+
|
364 |
+
trg_mask = torch.tril(
|
365 |
+
torch.ones((seq_length, seq_length), device=trg.device)
|
366 |
+
).bool()
|
367 |
+
|
368 |
+
return trg_mask.unsqueeze(0).unsqueeze(
|
369 |
+
0
|
370 |
+
) # (batch_size=1, n_heads=1, seq_length, seq_length)
|
371 |
+
|
372 |
+
def forward(self, trg, src):
|
373 |
+
"""
|
374 |
+
trg: target sequences (batch_size, trg_seq_length, d_model)
|
375 |
+
src: embedding images (batch_size, src_seq_length, d_model)
|
376 |
+
"""
|
377 |
+
|
378 |
+
# Embed the target captions
|
379 |
+
trg = self.embedding(trg)
|
380 |
+
batch_size, l, h = trg.shape
|
381 |
+
|
382 |
+
trg_index = torch.arange(l, device=trg.device)
|
383 |
+
pos_emb = self.pos_emb(trg_index).reshape(1, l, h).expand(batch_size, l, h)
|
384 |
+
# Add the fixed sinusodial positional embedding
|
385 |
+
trg += pos_emb
|
386 |
+
|
387 |
+
# Create a target mask for the target captions to prevent the model from peeking at future tokens
|
388 |
+
trg_mask = self.make_trg_mask(
|
389 |
+
trg
|
390 |
+
) # (batch_size, 1, trg_seq_length, trg_seq_length)
|
391 |
+
|
392 |
+
# Pass the sequences through each decoder layer
|
393 |
+
for layer in self.layers:
|
394 |
+
trg, attention_probs, masked_attention_probs = layer(trg, src, trg_mask)
|
395 |
+
|
396 |
+
self.attention_probs = attention_probs
|
397 |
+
self.masked_attention_probs = masked_attention_probs # (batch_size, n_heads, trg_seq_len, src_seq_len) trg_seq_len: length of the target caption \ src_seq_len: number of patches from the encoder
|
398 |
+
|
399 |
+
# Final linear output layer
|
400 |
+
return self.Wo(trg)
|
401 |
+
|
402 |
+
|
403 |
+
class CaptioningTransformer(nn.Module):
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
image_size: int,
|
407 |
+
in_channels: int,
|
408 |
+
vocab_size: int,
|
409 |
+
device,
|
410 |
+
patch_size: int = 16,
|
411 |
+
d_model: int = 128,
|
412 |
+
n_layers: int = 3,
|
413 |
+
n_heads: int = 4,
|
414 |
+
):
|
415 |
+
super().__init__()
|
416 |
+
|
417 |
+
self.device = device
|
418 |
+
|
419 |
+
# Create an encoder and decoder with specified parameters
|
420 |
+
self.encoder = Encoder(
|
421 |
+
image_size=image_size,
|
422 |
+
in_channels=in_channels,
|
423 |
+
patch_size=patch_size,
|
424 |
+
d_model=d_model,
|
425 |
+
n_layers=n_layers,
|
426 |
+
n_heads=n_heads,
|
427 |
+
)
|
428 |
+
|
429 |
+
self.decoder = Decoder(
|
430 |
+
vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_heads=n_heads
|
431 |
+
)
|
432 |
+
|
433 |
+
def forward(self, src, trg):
|
434 |
+
# Encoder layers
|
435 |
+
src = self.encoder(src) # (batch_size, src_seq_length, d_model)
|
436 |
+
|
437 |
+
# Decoder layers
|
438 |
+
output = self.decoder(
|
439 |
+
trg, src
|
440 |
+
) # Pass in both the target (for Masked Multi-Head Self-Attention) and source for (Cross-Attention)
|
441 |
+
|
442 |
+
return output
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
transformers
|
4 |
+
gradio
|
5 |
+
Pillow
|