|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from torch import nn |
|
from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
|
|
|
class GEGLU(nn.Module): |
|
def __init__(self, dim_in: int, dim_out: int): |
|
super().__init__() |
|
self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x, gate = self.proj(x).chunk(2, dim=-1) |
|
return x * F.gelu(gate) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: int | None = None, |
|
mult: int = 4, |
|
dropout: float = 0.0, |
|
): |
|
super().__init__() |
|
inner_dim = int(dim * mult) |
|
dim_out = dim_out or dim |
|
self.net = nn.Sequential( |
|
GEGLU(dim, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.net(x) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
query_dim: int, |
|
context_dim: int | None = None, |
|
heads: int = 8, |
|
dim_head: int = 64, |
|
dropout: float = 0.0, |
|
): |
|
super().__init__() |
|
self.heads = heads |
|
self.dim_head = dim_head |
|
inner_dim = dim_head * heads |
|
context_dim = context_dim or query_dim |
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) |
|
) |
|
|
|
def forward( |
|
self, x: torch.Tensor, context: torch.Tensor | None = None |
|
) -> torch.Tensor: |
|
q = self.to_q(x) |
|
context = context if context is not None else x |
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
q, k, v = map( |
|
lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads), |
|
(q, k, v), |
|
) |
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): |
|
out = F.scaled_dot_product_attention(q, k, v) |
|
out = rearrange(out, "b h l d -> b l (h d)") |
|
out = self.to_out(out) |
|
return out |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
n_heads: int, |
|
d_head: int, |
|
context_dim: int, |
|
dropout: float = 0.0, |
|
): |
|
super().__init__() |
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
context_dim=None, |
|
heads=n_heads, |
|
dim_head=d_head, |
|
dropout=dropout, |
|
) |
|
self.ff = FeedForward(dim, dropout=dropout) |
|
self.attn2 = Attention( |
|
query_dim=dim, |
|
context_dim=context_dim, |
|
heads=n_heads, |
|
dim_head=d_head, |
|
dropout=dropout, |
|
) |
|
self.norm1 = nn.LayerNorm(dim) |
|
self.norm2 = nn.LayerNorm(dim) |
|
self.norm3 = nn.LayerNorm(dim) |
|
|
|
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: |
|
x = self.attn1(self.norm1(x)) + x |
|
x = self.attn2(self.norm2(x), context=context) + x |
|
x = self.ff(self.norm3(x)) + x |
|
return x |
|
|
|
|
|
class TransformerBlockTimeMix(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
n_heads: int, |
|
d_head: int, |
|
context_dim: int, |
|
dropout: float = 0.0, |
|
): |
|
super().__init__() |
|
inner_dim = n_heads * d_head |
|
self.norm_in = nn.LayerNorm(dim) |
|
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout) |
|
self.attn1 = Attention( |
|
query_dim=inner_dim, |
|
context_dim=None, |
|
heads=n_heads, |
|
dim_head=d_head, |
|
dropout=dropout, |
|
) |
|
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout) |
|
self.attn2 = Attention( |
|
query_dim=inner_dim, |
|
context_dim=context_dim, |
|
heads=n_heads, |
|
dim_head=d_head, |
|
dropout=dropout, |
|
) |
|
self.norm1 = nn.LayerNorm(inner_dim) |
|
self.norm2 = nn.LayerNorm(inner_dim) |
|
self.norm3 = nn.LayerNorm(inner_dim) |
|
|
|
def forward( |
|
self, x: torch.Tensor, context: torch.Tensor, num_frames: int |
|
) -> torch.Tensor: |
|
_, s, _ = x.shape |
|
x = rearrange(x, "(b t) s c -> (b s) t c", t=num_frames) |
|
x = self.ff_in(self.norm_in(x)) + x |
|
x = self.attn1(self.norm1(x), context=None) + x |
|
x = self.attn2(self.norm2(x), context=context) + x |
|
x = self.ff(self.norm3(x)) |
|
x = rearrange(x, "(b s) t c -> (b t) s c", s=s) |
|
return x |
|
|
|
|
|
class SkipConnect(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward( |
|
self, x_spatial: torch.Tensor, x_temporal: torch.Tensor |
|
) -> torch.Tensor: |
|
return x_spatial + x_temporal |
|
|
|
|
|
class MultiviewTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
n_heads: int, |
|
d_head: int, |
|
name: str, |
|
unflatten_names: list[str] = [], |
|
depth: int = 1, |
|
context_dim: int = 1024, |
|
dropout: float = 0.0, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.name = name |
|
self.unflatten_names = unflatten_names |
|
|
|
inner_dim = n_heads * d_head |
|
self.norm = nn.GroupNorm(32, in_channels, eps=1e-6) |
|
self.proj_in = nn.Linear(in_channels, inner_dim) |
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
TransformerBlock( |
|
inner_dim, |
|
n_heads, |
|
d_head, |
|
context_dim=context_dim, |
|
dropout=dropout, |
|
) |
|
for _ in range(depth) |
|
] |
|
) |
|
self.proj_out = nn.Linear(inner_dim, in_channels) |
|
self.time_mixer = SkipConnect() |
|
self.time_mix_blocks = nn.ModuleList( |
|
[ |
|
TransformerBlockTimeMix( |
|
inner_dim, |
|
n_heads, |
|
d_head, |
|
context_dim=context_dim, |
|
dropout=dropout, |
|
) |
|
for _ in range(depth) |
|
] |
|
) |
|
|
|
def forward( |
|
self, x: torch.Tensor, context: torch.Tensor, num_frames: int |
|
) -> torch.Tensor: |
|
assert context.ndim == 3 |
|
_, _, h, w = x.shape |
|
x_in = x |
|
|
|
time_context = context |
|
time_context_first_timestep = time_context[::num_frames] |
|
time_context = repeat( |
|
time_context_first_timestep, "b ... -> (b n) ...", n=h * w |
|
) |
|
|
|
if self.name in self.unflatten_names: |
|
context = context[::num_frames] |
|
|
|
x = self.norm(x) |
|
x = rearrange(x, "b c h w -> b (h w) c") |
|
x = self.proj_in(x) |
|
|
|
for block, mix_block in zip(self.transformer_blocks, self.time_mix_blocks): |
|
if self.name in self.unflatten_names: |
|
x = rearrange(x, "(b t) (h w) c -> b (t h w) c", t=num_frames, h=h, w=w) |
|
x = block(x, context=context) |
|
if self.name in self.unflatten_names: |
|
x = rearrange(x, "b (t h w) c -> (b t) (h w) c", t=num_frames, h=h, w=w) |
|
x_mix = mix_block(x, context=time_context, num_frames=num_frames) |
|
x = self.time_mixer(x_spatial=x, x_temporal=x_mix) |
|
|
|
x = self.proj_out(x) |
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) |
|
out = x + x_in |
|
return out |
|
|