sesame_openai / app /custom_transformer.py
karumati's picture
yo
01115c6
"""Custom transformer implementation for fallback."""
import torch
import torch.nn as nn
import math
import logging
# Set up logging
logger = logging.getLogger(__name__)
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# Calculate RMS
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return self.weight * rms * x
class RotaryEmbedding(nn.Module):
"""Rotary positional embedding."""
def __init__(self, dim, max_seq_len=2048, base=10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Generate frequency tensor
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# Generate cos and sin cache
self._update_cos_sin_cache(max_seq_len)
def _update_cos_sin_cache(self, max_seq_len):
"""Update the cache of cos and sin values."""
self.max_seq_len = max_seq_len
t = torch.arange(max_seq_len, device=self.inv_freq.device)
# Compute cos and sin at each position
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
cos = freqs.cos()
sin = freqs.sin()
self.register_buffer("cos_cache", cos, persistent=False)
self.register_buffer("sin_cache", sin, persistent=False)
def forward(self, x, seq_len=None, pos=None):
# Get appropriate parts of the cache
if pos is not None:
# Handle arbitrary positions
cos = self.cos_cache[pos]
sin = self.sin_cache[pos]
else:
# Handle sequential positions
seq_len = x.shape[1] if seq_len is None else seq_len
cos = self.cos_cache[:seq_len]
sin = self.sin_cache[:seq_len]
return cos, sin
def rotate_half(x):
"""Rotate half the dimensions of the input."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
"""Apply rotary position embedding to q and k."""
if position_ids is not None:
# Handle arbitrary positions
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
else:
# Handle sequential positions
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
# Apply rotation
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class CustomAttention(nn.Module):
"""Multi-head attention with support for KV caching."""
def __init__(self, dim, num_heads, num_kv_heads=None, dropout=0.0):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
# Attention projections
self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False)
# Rotary embedding
self.rope = RotaryEmbedding(self.head_dim)
# Dropout
self.dropout = nn.Dropout(dropout)
def _repeat_kv(self, x):
"""Repeat KV heads to match the number of query heads."""
if self.num_kv_heads == self.num_heads:
return x
b, s, n_kv_head, head_dim = x.shape
# Repeat the KV heads to match the number of query heads
repeats = self.num_heads // self.num_kv_heads
x = x.repeat_interleave(repeats, dim=2)
return x
def forward(self, x, mask=None, input_pos=None, kv_cache=None):
batch_size, seq_len, _ = x.shape
# Project to q, k, v
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, nh, s, hd]
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, nkh, s, hd]
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, nkh, s, hd]
# Apply rotary embeddings
cos, sin = self.rope.forward(x, seq_len=seq_len, pos=input_pos)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=input_pos)
# Handle KV cache
if kv_cache is not None:
k_cache, v_cache = kv_cache
if input_pos is not None:
# Update cache at specific positions
k_cache.index_copy_(2, input_pos, k)
v_cache.index_copy_(2, input_pos, v)
# Use the entire cache
k, v = k_cache, v_cache
# Repeat KV if needed
k = self._repeat_kv(k)
v = self._repeat_kv(v)
# Calculate attention scores
attention_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# Apply mask if provided
if mask is not None:
attention_scores = attention_scores.masked_fill(mask == 0, -10000.0)
# Apply softmax and dropout
attention_probs = self.dropout(torch.softmax(attention_scores, dim=-1))
# Get context vector
context = torch.matmul(attention_probs, v)
# Reshape and project back
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
output = self.o_proj(context)
return output
class FeedForward(nn.Module):
"""Feed-forward network with GELU activation."""
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
def forward(self, x):
x = self.w1(x)
x = self.act(x)
x = self.dropout(x)
x = self.w2(x)
return x
class TransformerLayer(nn.Module):
"""A single transformer layer."""
def __init__(
self,
dim,
num_heads,
num_kv_heads=None,
ffn_dim=None,
dropout=0.0,
norm_eps=1e-5
):
super().__init__()
self.norm1 = RMSNorm(dim, eps=norm_eps)
self.attn = CustomAttention(dim, num_heads, num_kv_heads, dropout)
self.norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn = FeedForward(
dim,
ffn_dim or 4 * dim,
dropout
)
def forward(self, x, mask=None, input_pos=None, kv_cache=None):
# Self-attention with residual
h = self.norm1(x)
h = self.attn(h, mask=mask, input_pos=input_pos, kv_cache=kv_cache)
x = x + h
# FFN with residual
h = self.norm2(x)
h = self.ffn(h)
x = x + h
return x
class CustomTransformerDecoder(nn.Module):
"""Custom transformer decoder that mimics Llama architecture."""
def __init__(
self,
vocab_size,
num_layers,
num_heads,
num_kv_heads,
embed_dim,
max_seq_len,
intermediate_dim,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=10000,
):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.embed_dim = embed_dim
# Token embeddings
self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)
# Transformer layers
self.layers = nn.ModuleList([
TransformerLayer(
embed_dim,
num_heads,
num_kv_heads,
intermediate_dim,
attn_dropout,
norm_eps
)
for _ in range(num_layers)
])
# Final normalization and output projection
self.norm = RMSNorm(embed_dim, eps=norm_eps)
self.output = nn.Linear(embed_dim, vocab_size, bias=False)
# Initialize the KV cache
self._kv_cache = None
self._has_cache = False
logger.info(f"Initialized CustomTransformerDecoder with {num_layers} layers, {num_heads} heads, {embed_dim} dim")
def setup_caches(self, batch_size, dtype, decoder_max_seq_len=None):
"""Set up KV caches for inference."""
max_seq_len = decoder_max_seq_len or self.max_seq_len
device = next(self.parameters()).device
self._kv_cache = []
for i, layer in enumerate(self.layers):
# Create a KV cache for each layer
k_cache = torch.zeros(
batch_size,
layer.attn.num_kv_heads,
max_seq_len,
layer.attn.head_dim,
device=device,
dtype=dtype
)
v_cache = torch.zeros(
batch_size,
layer.attn.num_kv_heads,
max_seq_len,
layer.attn.head_dim,
device=device,
dtype=dtype
)
self._kv_cache.append((k_cache, v_cache))
self._has_cache = True
logger.info(f"KV caches set up for {batch_size} batches, {max_seq_len} seq length")
def caches_are_enabled(self):
"""Check if caches are enabled."""
return self._has_cache
def reset_caches(self):
"""Reset the KV cache to zeros."""
if self._has_cache and self._kv_cache:
for k_cache, v_cache in self._kv_cache:
k_cache.zero_()
v_cache.zero_()
def forward(self, x, mask=None, input_pos=None):
batch_size, seq_len = x.shape[:2]
# Apply embedding if input is token IDs
if x.dim() == 2:
x = self.tok_embeddings(x)
# Apply transformer layers
for i, layer in enumerate(self.layers):
layer_cache = self._kv_cache[i] if self._has_cache else None
x = layer(x, mask=mask, input_pos=input_pos, kv_cache=layer_cache)
# Apply final norm
x = self.norm(x)
# Skip output projection if using Identity
if isinstance(self.output, nn.Identity):
return x
# Apply output projection
logits = self.output(x)
return logits