import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.mistral.modeling_mistral import ( MistralForCausalLM, MistralForSequenceClassification, MistralModel, MistralPreTrainedModel, ) from transformers.utils import logging from sdlm.models.mixins.modeling_mixin import ( CausalLMForSeq2SeqMixin, CDCDDiffusionModelMixin, DiffusionModelMixin, PaddingIncludedSequenceClassificationMixin, ) logger = logging.get_logger(__name__) class Sin(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.sin(input) class MistralForDiffusionLM(DiffusionModelMixin, MistralPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) self.model = MistralModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if not self.config.disable_timestep_embed: # self.timestep_embed = nn.Sequential( # nn.Linear(1, config.hidden_size, bias=False), # Sin(), # nn.Linear(config.hidden_size, config.hidden_size, bias=False), # ) self.timestep_embed = nn.Linear(1, config.hidden_size, bias=False) self.post_init() def post_init(self): super().post_init() # (un)toggle causal attention for decoder_layer in self.model.layers: decoder_layer.self_attn.is_causal = self.config.is_causal def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def vocab_to_hidden_dim_embed(self, input_data): return F.linear(input_data, self.get_input_embeddings().weight.data.T) class CDCDMistralForDiffusionLM(MistralForDiffusionLM, CDCDDiffusionModelMixin): pass class MistralForSeq2SeqLM(CausalLMForSeq2SeqMixin, MistralForCausalLM): pass class MistralforSequenceClassificationWithPadding( PaddingIncludedSequenceClassificationMixin, MistralForSequenceClassification ): pass