File size: 3,382 Bytes
b41a54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from functools import partial
from itertools import islice, cycle

from torch import nn

from text2punks.attention import Attention, SparseAxialCausalAttention

# helpers

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cast_tuple(val, depth = 1):
    if isinstance(val, list):
        val = tuple(val)
    return val if isinstance(val, tuple) else (val,) * depth

# classes

class SequentialSequence(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = layers

    def forward(self, x):
        for (f, g) in list(self.layers):
            x = x + f(x)
            x = x + g(x)
        return x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * 4, dim)
        )
        # the order of dropout nn.Linear(4 * n_embd, n_embd) vs nn.Dropout(resid_pdrop)

    def forward(self, x):
        return self.net(x)


class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        seq_len,
        causal = True,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        resid_dropout = 0.,
        embd_dropout = 0.,
        ff_dropout = 0.,
        image_size = 24,
        attn_types = None,
    ):
        super().__init__()
        layers = nn.ModuleList([])

        attn_types = default(attn_types, ('full',))
        attn_types = cast_tuple(attn_types)
        attn_type_layer = islice(cycle(attn_types), depth)

        for attn_type in attn_type_layer:
            if attn_type == 'full':
                attn_class = partial(Attention, causal = causal)
            elif attn_type == 'axial_row':
                attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_size)
            elif attn_type == 'axial_col':
                attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_size)
            else:
                raise ValueError(f'attention type "{attn_type}" is not valid')

            attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout)

            layers.append(nn.ModuleList([
                PreNorm(dim, attn),
                PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
            ]))

        # full attention in the last layer

        attn_class = partial(Attention, causal = causal)
        attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout)

        layers.append(nn.ModuleList([
            PreNorm(dim, attn),
            PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
        ]))

        self.layers = SequentialSequence(layers)
        self.embd_drop = nn.Dropout(embd_dropout)

    def forward(self, x):
        x = self.embd_drop(x)
        return self.layers(x)