"""Custom transformer implementation for fallback.""" import torch import torch.nn as nn import math import logging # Set up logging logger = logging.getLogger(__name__) class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): # Calculate RMS rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return self.weight * rms * x class RotaryEmbedding(nn.Module): """Rotary positional embedding.""" def __init__(self, dim, max_seq_len=2048, base=10000): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.base = base # Generate frequency tensor inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) # Generate cos and sin cache self._update_cos_sin_cache(max_seq_len) def _update_cos_sin_cache(self, max_seq_len): """Update the cache of cos and sin values.""" self.max_seq_len = max_seq_len t = torch.arange(max_seq_len, device=self.inv_freq.device) # Compute cos and sin at each position freqs = torch.einsum('i,j->ij', t, self.inv_freq) cos = freqs.cos() sin = freqs.sin() self.register_buffer("cos_cache", cos, persistent=False) self.register_buffer("sin_cache", sin, persistent=False) def forward(self, x, seq_len=None, pos=None): # Get appropriate parts of the cache if pos is not None: # Handle arbitrary positions cos = self.cos_cache[pos] sin = self.sin_cache[pos] else: # Handle sequential positions seq_len = x.shape[1] if seq_len is None else seq_len cos = self.cos_cache[:seq_len] sin = self.sin_cache[:seq_len] return cos, sin def rotate_half(x): """Rotate half the dimensions of the input.""" x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): """Apply rotary position embedding to q and k.""" if position_ids is not None: # Handle arbitrary positions cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] else: # Handle sequential positions cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim] sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim] # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class CustomAttention(nn.Module): """Multi-head attention with support for KV caching.""" def __init__(self, dim, num_heads, num_kv_heads=None, dropout=0.0): super().__init__() self.dim = dim self.num_heads = num_heads self.num_kv_heads = num_kv_heads or num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 # Attention projections self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False) # Rotary embedding self.rope = RotaryEmbedding(self.head_dim) # Dropout self.dropout = nn.Dropout(dropout) def _repeat_kv(self, x): """Repeat KV heads to match the number of query heads.""" if self.num_kv_heads == self.num_heads: return x b, s, n_kv_head, head_dim = x.shape # Repeat the KV heads to match the number of query heads repeats = self.num_heads // self.num_kv_heads x = x.repeat_interleave(repeats, dim=2) return x def forward(self, x, mask=None, input_pos=None, kv_cache=None): batch_size, seq_len, _ = x.shape # Project to q, k, v q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, nh, s, hd] k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, nkh, s, hd] v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, nkh, s, hd] # Apply rotary embeddings cos, sin = self.rope.forward(x, seq_len=seq_len, pos=input_pos) q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=input_pos) # Handle KV cache if kv_cache is not None: k_cache, v_cache = kv_cache if input_pos is not None: # Update cache at specific positions k_cache.index_copy_(2, input_pos, k) v_cache.index_copy_(2, input_pos, v) # Use the entire cache k, v = k_cache, v_cache # Repeat KV if needed k = self._repeat_kv(k) v = self._repeat_kv(v) # Calculate attention scores attention_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # Apply mask if provided if mask is not None: attention_scores = attention_scores.masked_fill(mask == 0, -10000.0) # Apply softmax and dropout attention_probs = self.dropout(torch.softmax(attention_scores, dim=-1)) # Get context vector context = torch.matmul(attention_probs, v) # Reshape and project back context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) output = self.o_proj(context) return output class FeedForward(nn.Module): """Feed-forward network with GELU activation.""" def __init__(self, dim, hidden_dim, dropout=0.0): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.dropout = nn.Dropout(dropout) self.act = nn.GELU() def forward(self, x): x = self.w1(x) x = self.act(x) x = self.dropout(x) x = self.w2(x) return x class TransformerLayer(nn.Module): """A single transformer layer.""" def __init__( self, dim, num_heads, num_kv_heads=None, ffn_dim=None, dropout=0.0, norm_eps=1e-5 ): super().__init__() self.norm1 = RMSNorm(dim, eps=norm_eps) self.attn = CustomAttention(dim, num_heads, num_kv_heads, dropout) self.norm2 = RMSNorm(dim, eps=norm_eps) self.ffn = FeedForward( dim, ffn_dim or 4 * dim, dropout ) def forward(self, x, mask=None, input_pos=None, kv_cache=None): # Self-attention with residual h = self.norm1(x) h = self.attn(h, mask=mask, input_pos=input_pos, kv_cache=kv_cache) x = x + h # FFN with residual h = self.norm2(x) h = self.ffn(h) x = x + h return x class CustomTransformerDecoder(nn.Module): """Custom transformer decoder that mimics Llama architecture.""" def __init__( self, vocab_size, num_layers, num_heads, num_kv_heads, embed_dim, max_seq_len, intermediate_dim, attn_dropout=0.0, norm_eps=1e-5, rope_base=10000, ): super().__init__() self.vocab_size = vocab_size self.max_seq_len = max_seq_len self.embed_dim = embed_dim # Token embeddings self.tok_embeddings = nn.Embedding(vocab_size, embed_dim) # Transformer layers self.layers = nn.ModuleList([ TransformerLayer( embed_dim, num_heads, num_kv_heads, intermediate_dim, attn_dropout, norm_eps ) for _ in range(num_layers) ]) # Final normalization and output projection self.norm = RMSNorm(embed_dim, eps=norm_eps) self.output = nn.Linear(embed_dim, vocab_size, bias=False) # Initialize the KV cache self._kv_cache = None self._has_cache = False logger.info(f"Initialized CustomTransformerDecoder with {num_layers} layers, {num_heads} heads, {embed_dim} dim") def setup_caches(self, batch_size, dtype, decoder_max_seq_len=None): """Set up KV caches for inference.""" max_seq_len = decoder_max_seq_len or self.max_seq_len device = next(self.parameters()).device self._kv_cache = [] for i, layer in enumerate(self.layers): # Create a KV cache for each layer k_cache = torch.zeros( batch_size, layer.attn.num_kv_heads, max_seq_len, layer.attn.head_dim, device=device, dtype=dtype ) v_cache = torch.zeros( batch_size, layer.attn.num_kv_heads, max_seq_len, layer.attn.head_dim, device=device, dtype=dtype ) self._kv_cache.append((k_cache, v_cache)) self._has_cache = True logger.info(f"KV caches set up for {batch_size} batches, {max_seq_len} seq length") def caches_are_enabled(self): """Check if caches are enabled.""" return self._has_cache def reset_caches(self): """Reset the KV cache to zeros.""" if self._has_cache and self._kv_cache: for k_cache, v_cache in self._kv_cache: k_cache.zero_() v_cache.zero_() def forward(self, x, mask=None, input_pos=None): batch_size, seq_len = x.shape[:2] # Apply embedding if input is token IDs if x.dim() == 2: x = self.tok_embeddings(x) # Apply transformer layers for i, layer in enumerate(self.layers): layer_cache = self._kv_cache[i] if self._has_cache else None x = layer(x, mask=mask, input_pos=input_pos, kv_cache=layer_cache) # Apply final norm x = self.norm(x) # Skip output projection if using Identity if isinstance(self.output, nn.Identity): return x # Apply output projection logits = self.output(x) return logits