|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from cube3d.model.transformers.cache import Cache |
|
from cube3d.model.transformers.norm import LayerNorm, RMSNorm |
|
from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb |
|
|
|
|
|
class SwiGLUMLP(nn.Module): |
|
def __init__(self, embed_dim, hidden_dim, bias=True, **kwargs): |
|
""" |
|
A PyTorch implementation of the SwiGLU (Swish-Gated Linear Unit) MLP layer. |
|
This module consists of three linear projections: `gate_proj`, `up_proj`, and `down_proj`. |
|
It applies the SwiGLU activation function, which combines the Swish activation with a gating mechanism, |
|
followed by a projection back to the original embedding dimension. |
|
Args: |
|
embed_dim (int): The dimensionality of the input embeddings. |
|
hidden_dim (int): The dimensionality of the hidden layer. |
|
bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True. |
|
**kwargs: Additional keyword arguments (currently unused). |
|
""" |
|
super().__init__() |
|
self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=bias) |
|
self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=bias) |
|
self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=bias) |
|
|
|
|
|
def forward(self, x): |
|
""" |
|
Applies a forward pass. |
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
Returns: |
|
torch.Tensor: The output tensor after applying the forward pass. |
|
""" |
|
|
|
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
return down_proj |
|
|
|
|
|
class SelfAttentionWithRotaryEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
bias: bool = True, |
|
eps: float = 1e-6, |
|
): |
|
""" |
|
A PyTorch module implementing self-attention with rotary embeddings. |
|
|
|
Args: |
|
embed_dim (int): The dimensionality of the input embeddings. |
|
num_heads (int): The number of attention heads. |
|
bias (bool, optional): Whether to include bias terms in the linear projections. Defaults to True. |
|
eps (float, optional): A small value added for numerical stability in normalization. Defaults to 1e-6. |
|
""" |
|
super().__init__() |
|
assert embed_dim % num_heads == 0 |
|
self.num_heads = num_heads |
|
|
|
self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False) |
|
self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
|
|
head_dim = embed_dim // num_heads |
|
self.q_norm = RMSNorm(head_dim) |
|
self.k_norm = RMSNorm(head_dim) |
|
|
|
def forward( |
|
self, |
|
x, |
|
freqs_cis: torch.Tensor, |
|
attn_mask=None, |
|
is_causal: bool = False, |
|
kv_cache: Optional[Cache] = None, |
|
curr_pos_id: Optional[torch.Tensor] = None, |
|
decode: bool = False, |
|
): |
|
""" |
|
Forward pass for the SelfAttentionWithRotaryEmbedding instance. |
|
Args: |
|
x (torch.Tensor): Input tensor. |
|
freqs_cis (torch.Tensor): Precomputed rotary positional embeddings. |
|
attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention. Defaults to None. |
|
is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding. Defaults to False. |
|
kv_cache (Optional[Cache], optional): Cache object for storing key and value states for decoding. Defaults to None. |
|
curr_pos_id (Optional[torch.Tensor], optional): Current position indices for decoding. Required if `decode` is True. Defaults to None. |
|
decode (bool, optional): Whether the model is in decoding mode. Defaults to False. |
|
Returns: |
|
torch.Tensor: Output tensor after applying self-attention and projection. |
|
""" |
|
|
|
b, l, d = x.shape |
|
|
|
|
|
q, k = self.c_qk(x).chunk(2, dim=-1) |
|
v = self.c_v(x) |
|
|
|
|
|
q = q.view(b, l, self.num_heads, -1).transpose(1, 2) |
|
k = k.view(b, l, self.num_heads, -1).transpose(1, 2) |
|
v = v.view(b, l, self.num_heads, -1).transpose(1, 2) |
|
|
|
q = self.q_norm(q) |
|
k = self.k_norm(k) |
|
|
|
if kv_cache is not None: |
|
if not decode: |
|
kv_cache.key_states[:, :, : k.shape[2], :].copy_(k) |
|
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v) |
|
else: |
|
assert curr_pos_id is not None |
|
kv_cache.key_states.index_copy_(2, curr_pos_id, k) |
|
kv_cache.value_states.index_copy_(2, curr_pos_id, v) |
|
k = kv_cache.key_states |
|
v = kv_cache.value_states |
|
|
|
|
|
|
|
y = scaled_dot_product_attention_with_rotary_emb( |
|
q, |
|
k, |
|
v, |
|
freqs_cis=freqs_cis, |
|
attn_mask=attn_mask, |
|
curr_pos_id=curr_pos_id if decode else None, |
|
is_causal=is_causal, |
|
) |
|
|
|
y = ( |
|
y.transpose(1, 2).contiguous().view(b, l, d) |
|
) |
|
|
|
|
|
y = self.c_proj(y) |
|
return y |
|
|
|
|
|
class DecoderLayerWithRotaryEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
bias: bool = True, |
|
eps: float = 1e-6, |
|
) -> None: |
|
""" |
|
Initializes the transformer model with rotary embeddings. |
|
Args: |
|
embed_dim (int): The dimensionality of the embedding space. |
|
num_heads (int): The number of attention heads. |
|
bias (bool, optional): Whether to include bias terms in the layers. Defaults to True. |
|
eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1e-6. |
|
""" |
|
super().__init__() |
|
|
|
self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) |
|
self.attn = SelfAttentionWithRotaryEmbedding( |
|
embed_dim, num_heads=num_heads, bias=bias, eps=eps |
|
) |
|
self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps) |
|
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias) |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
""" |
|
Create an instance of the class using the provided configuration. |
|
Args: |
|
cfg: A configuration object containing the following attributes: |
|
- n_embd (int): The size of the embedding dimension. |
|
- n_head (int): The number of attention heads. |
|
- bias (bool): Whether to include a bias term. |
|
- eps (float): A small value added for numerical stability. |
|
Returns: |
|
An instance of the class initialized with the specified configuration. |
|
""" |
|
|
|
return cls( |
|
cfg.n_embd, |
|
num_heads=cfg.n_head, |
|
bias=cfg.bias, |
|
eps=cfg.eps, |
|
) |
|
|
|
def forward( |
|
self, |
|
x, |
|
freqs_cis: torch.Tensor, |
|
attn_mask=None, |
|
is_causal: bool = True, |
|
kv_cache: Optional[Cache] = None, |
|
curr_pos_id: Optional[torch.Tensor] = None, |
|
decode: bool = False, |
|
): |
|
""" |
|
Forward pass for the transformer model. |
|
Args: |
|
x (torch.Tensor): Input tensor. |
|
freqs_cis (torch.Tensor): Precomputed sinusoidal positional encodings. |
|
attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention. |
|
Defaults to None. |
|
is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding. |
|
Defaults to True. |
|
kv_cache (Optional[Cache], optional): Key-value cache for efficient decoding. |
|
Defaults to None. |
|
curr_pos_id (Optional[torch.Tensor], optional): Current position IDs for decoding. |
|
Defaults to None. |
|
decode (bool, optional): Whether the model is in decoding mode. |
|
Defaults to False. |
|
Returns: |
|
torch.Tensor: Output tensor. |
|
""" |
|
out = self.attn( |
|
self.ln_1(x), |
|
freqs_cis=freqs_cis, |
|
attn_mask=attn_mask, |
|
is_causal=is_causal, |
|
kv_cache=kv_cache, |
|
curr_pos_id=curr_pos_id, |
|
decode=decode, |
|
) |
|
x = x + out |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|