File size: 12,357 Bytes
587665f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
class GroupNorm(nn.Module):
def __init__(self, in_channels: int, num_groups: int = 32):
super(GroupNorm, self).__init__()
self.gn = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.gn(x)
class AdaLayerNorm(nn.Module):
def __init__(self, channels: int, cond_channels: int = 0, return_scale_shift: bool = True):
super(AdaLayerNorm, self).__init__()
self.norm = nn.LayerNorm(channels)
self.return_scale_shift = return_scale_shift
if cond_channels != 0:
if return_scale_shift:
self.proj = nn.Linear(cond_channels, channels * 3, bias=False)
else:
self.proj = nn.Linear(cond_channels, channels * 2, bias=False)
nn.init.xavier_uniform_(self.proj.weight)
def expand_dims(self, tensor: torch.Tensor, dims: list[int]) -> torch.Tensor:
for dim in dims:
tensor = tensor.unsqueeze(dim)
return tensor
def forward(self, x: torch.Tensor, cond: torch.Tensor | None = None) -> torch.Tensor:
x = self.norm(x)
if cond is None:
return x
dims = list(range(1, len(x.shape) - 1))
if self.return_scale_shift:
gamma, beta, sigma = self.proj(cond).chunk(3, dim=-1)
gamma, beta, sigma = [self.expand_dims(t, dims) for t in (gamma, beta, sigma)]
return x * (1 + gamma) + beta, sigma
else:
gamma, beta = self.proj(cond).chunk(2, dim=-1)
gamma, beta = [self.expand_dims(t, dims) for t in (gamma, beta)]
return x * (1 + gamma) + beta
class SinusoidalPositionalEmbedding(nn.Module):
def __init__(self, emb_dim: int = 256):
super(SinusoidalPositionalEmbedding, self).__init__()
self.channels = emb_dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
inv_freq = 1.0 / (
10000
** (torch.arange(0, self.channels, 2, device=t.device).float() / self.channels)
)
pos_enc_a = torch.sin(t.repeat(1, self.channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, self.channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc
class GatedConv2d(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
padding: int = 1,
bias: bool = False):
super(GatedConv2d, self).__init__()
self.gate_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.feature_conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
padding=padding,
bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = torch.sigmoid(self.gate_conv(x))
feature = F.silu(self.feature_conv(x))
return gate * feature
class ResGatedBlock(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
mid_channels: int | None = None,
num_groups: int = 32,
residual: bool = True,
emb_channels: int | None = None,
gated_conv: bool = False):
super().__init__()
self.residual = residual
self.emb_channels = emb_channels
if not mid_channels:
mid_channels = out_channels
if gated_conv: conv2d = GatedConv2d
else: conv2d = nn.Conv2d
self.conv1 = conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
self.norm1 = GroupNorm(mid_channels, num_groups=num_groups)
self.nonlienrity = nn.SiLU()
if emb_channels:
self.emb_proj = nn.Linear(emb_channels, mid_channels)
self.conv2 = conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.norm2 = GroupNorm(out_channels, num_groups=num_groups)
if in_channels != out_channels:
self.skip = conv2d(in_channels, out_channels, kernel_size=1, padding=0)
def double_conv(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
x = self.conv1(x)
x = self.norm1(x)
x = self.nonlienrity(x)
if emb is not None and self.emb_channels is not None:
x = x + self.emb_proj(emb)[:,:,None,None]
x = self.conv2(x)
return self.norm2(x)
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
if self.residual:
if hasattr(self, 'skip'):
return F.silu(self.skip(x) + self.double_conv(x, emb))
return F.silu(x + self.double_conv(x, emb))
else:
return self.double_conv(x, emb)
class Downsample(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
use_conv: bool=True):
super().__init__()
self.use_conv = use_conv
if use_conv:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
else:
assert in_channels == out_channels
self.conv = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad = (0, 1, 0, 1)
hidden_states = F.pad(x, pad, mode="constant", value=0)
return self.conv(hidden_states) if self.use_conv else self.conv(x)
class Upsample(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
use_conv: bool=True):
super().__init__()
self.use_conv = use_conv
if use_conv:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x,
scale_factor = (2, 2) if x.dim() == 4 else (1, 2, 2),
mode='nearest')
return self.conv(x) if self.use_conv else x
class FeedForward(nn.Module):
def __init__(self,
dim: int,
emb_channels: int,
expansion_rate: int = 4,
dropout: float = 0.0):
super().__init__()
inner_dim = int(dim * expansion_rate)
self.norm = AdaLayerNorm(dim, emb_channels)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
self.__init_weights()
def __init_weights(self):
nn.init.xavier_uniform_(self.net[0].weight)
nn.init.xavier_uniform_(self.net[3].weight)
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
x, sigma = self.norm(x, emb)
return self.net(x) * sigma
class Attention(nn.Module):
def __init__(
self,
dim: int,
emb_channels: int = 512,
dim_head: int = 32,
dropout: float = 0.,
window_size: int = 7
):
super().__init__()
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.norm = AdaLayerNorm(dim, emb_channels)
self.to_q = nn.Linear(dim, dim, bias = False)
self.to_k = nn.Linear(dim, dim, bias = False)
self.to_v = nn.Linear(dim, dim, bias = False)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
)
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias = False),
nn.Dropout(dropout)
)
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
x, sigma = self.norm(x, emb)
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # split heads
q = q * self.scale
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) # sim
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, 'i j h -> h i j')# add positional bias
attn = self.attend(sim) # attention
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) # aggregate
out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width) # merge heads
out = self.to_out(out) # combine heads out
return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width) * sigma
class MaxViTBlock(nn.Module):
def __init__(
self,
channels: int,
emb_channels: int = 512,
heads: int = 1,
window_size: int = 8,
window_attn: bool = True,
grid_attn: bool = True,
expansion_rate: int = 4,
dropout: float = 0.0,
):
super(MaxViTBlock, self).__init__()
dim_head = channels // heads
layer_dim = dim_head * heads
w = window_size
self.window_attn = window_attn
self.grid_attn = grid_attn
if window_attn:
self.wind_rearrange_forward = Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w) # block-like attention
self.wind_attn = Attention(
dim = layer_dim,
emb_channels = emb_channels,
dim_head = dim_head,
dropout = dropout,
window_size = w
)
self.wind_ff = FeedForward(dim = layer_dim,
emb_channels = emb_channels,
expansion_rate = expansion_rate,
dropout = dropout)
self.wind_rearrange_backward = Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)')
if grid_attn:
self.grid_rearrange_forward = Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w) # grid-like attention
self.grid_attn = Attention(
dim = layer_dim,
emb_channels = emb_channels,
dim_head = dim_head,
dropout = dropout,
window_size = w
)
self.grid_ff = FeedForward(dim = layer_dim,
emb_channels = emb_channels,
expansion_rate = expansion_rate,
dropout = dropout)
self.grid_rearrange_backward = Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)')
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
if self.window_attn:
x = self.wind_rearrange_forward(x)
x = x + self.wind_attn(x, emb = emb)
x = x + self.wind_ff(x, emb = emb)
x = self.wind_rearrange_backward(x)
if self.grid_attn:
x = self.grid_rearrange_forward(x)
x = x + self.grid_attn(x, emb = emb)
x = x + self.grid_ff(x, emb = emb)
x = self.grid_rearrange_backward(x)
return x
|