Spaces:
Sleeping
Sleeping
File size: 2,471 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.mistral.modeling_mistral import (
MistralForCausalLM,
MistralForSequenceClassification,
MistralModel,
MistralPreTrainedModel,
)
from transformers.utils import logging
from sdlm.models.mixins.modeling_mixin import (
CausalLMForSeq2SeqMixin,
CDCDDiffusionModelMixin,
DiffusionModelMixin,
PaddingIncludedSequenceClassificationMixin,
)
logger = logging.get_logger(__name__)
class Sin(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.sin(input)
class MistralForDiffusionLM(DiffusionModelMixin, MistralPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.model = MistralModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if not self.config.disable_timestep_embed:
# self.timestep_embed = nn.Sequential(
# nn.Linear(1, config.hidden_size, bias=False),
# Sin(),
# nn.Linear(config.hidden_size, config.hidden_size, bias=False),
# )
self.timestep_embed = nn.Linear(1, config.hidden_size, bias=False)
self.post_init()
def post_init(self):
super().post_init()
# (un)toggle causal attention
for decoder_layer in self.model.layers:
decoder_layer.self_attn.is_causal = self.config.is_causal
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def vocab_to_hidden_dim_embed(self, input_data):
return F.linear(input_data, self.get_input_embeddings().weight.data.T)
class CDCDMistralForDiffusionLM(MistralForDiffusionLM, CDCDDiffusionModelMixin):
pass
class MistralForSeq2SeqLM(CausalLMForSeq2SeqMixin, MistralForCausalLM):
pass
class MistralforSequenceClassificationWithPadding(
PaddingIncludedSequenceClassificationMixin, MistralForSequenceClassification
):
pass
|