Spaces:
Build error
Build error
File size: 5,588 Bytes
b41a54a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
# helpers
def exists(val):
return val is not None
def max_neg_value(t):
return -torch.finfo(t.dtype).max
# classes
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head ** -0.5
self.causal = causal
self.attn_drop = nn.Dropout(attn_dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(resid_dropout)
)
def forward(self, x):
h, device = self.heads, x.device
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = q * self.scale
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)
if self.causal:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
attn = torch.softmax(dots, dim=-1)
attn = self.attn_drop(attn)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
# sparse axial causal attention
class SparseAxialCausalAttention(nn.Module):
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.):
super().__init__()
assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)'
self.axis = axis
inner_dim = dim_head * heads
self.seq_len = seq_len
self.heads = heads
self.scale = dim_head ** -0.5
self.image_size = image_size
self.attn_drop = nn.Dropout(attn_dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(resid_dropout)
)
def forward(self, x):
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
img_seq_len = img_size ** 2
text_len = seq_len + 1 - img_seq_len
# padding
padding = seq_len - n + 1
mask = torch.ones(b, text_len, device = device).bool()
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = mask[:, :text_len]
# derive queries / keys / values
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
# print(self.scale)
q = q * self.scale
((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v))
# text attention
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
mask_value = max_neg_value(dots_text)
i, j = dots_text.shape[-2:]
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots_text.masked_fill_(text_causal_mask, mask_value)
attn_text = torch.softmax(dots_text, dim = -1)
# attention dropout
attn_text = self.attn_drop(attn_text)
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
# image attention
split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c'
merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d'
# split out axis
q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img))
# similarity
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
# mask so image has full attention to text, but causal along axis
bh, x, i, j = dots.shape
causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool()
causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x)
mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i)
mask = torch.cat((~mask, causal_mask), dim = -1)
dots.masked_fill_(mask, mask_value)
# attention.
attn = torch.softmax(dots, dim = -1)
# attention dropout
attn = self.attn_drop(attn)
# aggregate
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
out_image = out_image_to_image + out_image_to_text
# merge back axis
out_image = rearrange(out_image, merge_axis_einops, x = img_size)
# combine attended values for both text and image
out = torch.cat((out_text, out_image), dim = 1)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return out[:, :n]
|