makeavid-sd-jax / makeavid_sd /flax_impl /flax_attention_pseudo3d.py
lopho's picture
forgot about the nested package structure
b2f876f
raw
history blame contribute delete
8.34 kB
from typing import Optional
import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
#from flax_memory_efficient_attention import jax_memory_efficient_attention
#from flax_attention import FlaxAttention
from diffusers.models.attention_flax import FlaxAttention
class TransformerPseudo3DModel(nn.Module):
in_channels: int
num_attention_heads: int
attention_head_dim: int
num_layers: int = 1
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
inner_dim = self.num_attention_heads * self.attention_head_dim
self.norm = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.proj_in = nn.Conv(
inner_dim,
kernel_size = (1, 1),
strides = (1, 1),
padding = 'VALID',
dtype = self.dtype
)
transformer_blocks = []
#CheckpointTransformerBlock = nn.checkpoint(
# BasicTransformerBlockPseudo3D,
# static_argnums = (2,3,4)
# #prevent_cse = False
#)
CheckpointTransformerBlock = BasicTransformerBlockPseudo3D
for _ in range(self.num_layers):
transformer_blocks.append(CheckpointTransformerBlock(
dim = inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
))
self.transformer_blocks = transformer_blocks
self.proj_out = nn.Conv(
inner_dim,
kernel_size = (1, 1),
strides = (1, 1),
padding = 'VALID',
dtype = self.dtype
)
def __call__(self,
hidden_states: jax.Array,
encoder_hidden_states: Optional[jax.Array] = None
) -> jax.Array:
is_video = hidden_states.ndim == 5
f: Optional[int] = None
if is_video:
# jax is channels last
# b,c,f,h,w WRONG
# b,f,h,w,c CORRECT
# b, c, f, h, w = hidden_states.shape
#hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w')
b, f, h, w, c = hidden_states.shape
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
batch, height, width, channels = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
f,
height,
width
)
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
if is_video:
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
return hidden_states
class BasicTransformerBlockPseudo3D(nn.Module):
dim: int
num_attention_heads: int
attention_head_dim: int
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.attn1 = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.ff = FeedForward(dim = self.dim, dtype = self.dtype)
self.attn2 = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.attn_temporal = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
def __call__(self,
hidden_states: jax.Array,
context: Optional[jax.Array] = None,
frames_length: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None
) -> jax.Array:
if context is not None and frames_length is not None:
context = context.repeat(frames_length, axis = 0)
# self attention
norm_hidden_states = self.norm1(hidden_states)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# cross attention
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn2(
norm_hidden_states,
context = context
) + hidden_states
# temporal attention
if frames_length is not None:
#bf, hw, c = hidden_states.shape
# (b f) (h w) c -> b f (h w) c
#hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c)
#b, f, hw, c = hidden_states.shape
# b f (h w) c -> b (h w) f c
#hidden_states = hidden_states.transpose(0, 2, 1, 3)
# b (h w) f c -> (b h w) f c
#hidden_states = hidden_states.reshape(b * hw, frames_length, c)
hidden_states = einops.rearrange(
hidden_states,
'(b f) (h w) c -> (b h w) f c',
f = frames_length,
h = height,
w = width
)
norm_hidden_states = self.norm_temporal(hidden_states)
hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
# (b h w) f c -> b (h w) f c
#hidden_states = hidden_states.reshape(b, hw, f, c)
# b (h w) f c -> b f (h w) c
#hidden_states = hidden_states.transpose(0, 2, 1, 3)
# b f h w c -> (b f) (h w) c
#hidden_states = hidden_states.reshape(bf, hw, c)
hidden_states = einops.rearrange(
hidden_states,
'(b h w) f c -> (b f) (h w) c',
f = frames_length,
h = height,
w = width
)
norm_hidden_states = self.norm3(hidden_states)
hidden_states = self.ff(norm_hidden_states) + hidden_states
return hidden_states
class FeedForward(nn.Module):
dim: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.net_0 = GEGLU(self.dim, self.dtype)
self.net_2 = nn.Dense(self.dim, dtype = self.dtype)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_2(hidden_states)
return hidden_states
class GEGLU(nn.Module):
dim: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2)
return hidden_linear * nn.gelu(hidden_gelu)