import torch from torch import nn from typing import Optional from .language_config import LanguageModelConfig from .language_components import DecoderLayer, RMSNorm, KVCache class LanguageModel(nn.Module): def __init__(self, config: LanguageModelConfig): super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self): return self.embed_tokens # Ignore copy def forward( self, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, kv_cache: Optional[KVCache] = None, ) -> torch.FloatTensor: hidden_states = inputs_embeds normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer for decoder_layer in self.layers: hidden_states = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, kv_cache=kv_cache, ) hidden_states = self.norm(hidden_states) return hidden_states