Spaces:
Running on Zero

File size: 7,693 Bytes
b43e862
7252f98
 
b43e862
 
7252f98
 
b43e862
0af2920
7252f98
 
0af2920
b43e862
0af2920
7252f98
0af2920
 
 
 
 
 
 
 
 
 
 
 
b43e862
0af2920
 
22370b2
b43e862
0af2920
a721355
 
1723639
 
0af2920
 
b43e862
0af2920
7252f98
0af2920
 
b43e862
0af2920
7252f98
b43e862
0af2920
 
 
 
 
 
b43e862
0af2920
 
 
 
 
 
 
7252f98
0af2920
7252f98
0af2920
 
 
 
 
b43e862
 
0af2920
7252f98
0af2920
b43e862
7252f98
 
b43e862
7d7b6d7
7252f98
 
 
 
 
 
 
 
0af2920
b43e862
7252f98
 
 
 
 
 
b43e862
7252f98
 
7141e39
 
7252f98
 
 
 
 
 
 
b43e862
 
 
7252f98
 
 
b43e862
7141e39
f7efac8
7252f98
b43e862
 
 
 
 
7141e39
7d7b6d7
238c8f8
b43e862
238c8f8
b43e862
 
238c8f8
b43e862
 
 
 
 
8851563
 
b43e862
b6cb410
 
 
 
 
 
 
 
b43e862
 
 
 
7252f98
b43e862
 
7252f98
 
 
 
 
 
 
 
b43e862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import torch.nn as nn
from torch.amp import autocast
from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from peft import LoraConfig, get_peft_model
import os
from typing import Optional, Tuple

hf_token = os.getenv("HF_TOKEN")

class BidirectionalLlamaAttention(LlamaAttention):
    def __init__(self, original_layer, masking='unidirectional'):
        super().__init__(original_layer.config, layer_idx=original_layer.layer_idx)
        self.masking = masking
        self.q_proj.weight = original_layer.q_proj.weight
        self.k_proj.weight = original_layer.k_proj.weight
        self.v_proj.weight = original_layer.v_proj.weight
        self.o_proj.weight = original_layer.o_proj.weight

    def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

    def eager_attention_forward(self, module: nn.Module, query, key, value, attention_mask, scaling, dropout=0.0, **kwargs):
        key_states = self.repeat_kv(key, module.num_key_value_groups)
        value_states = self.repeat_kv(value, module.num_key_value_groups)
        attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

        if attention_mask is not None:
            attn_mask = (1.0 - attention_mask) * float('-inf')
            attn_mask = attn_mask.to(dtype=query.dtype)
            attn_weights = attn_weights + attn_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
        attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
        return attn_output, attn_weights

    def rotate_half(self, x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat((-x2, x1), dim=-1)

    def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)
        return q_embed, k_embed

    def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_output, attn_weights = self.eager_attention_forward(
            self, query_states, key_states, value_states, attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        return self.o_proj(attn_output), attn_weights

class CustomTransformerConfig(PretrainedConfig):
    def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
                 max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.prediction_chunk = prediction_chunk
        self.max_position_embeddings = max_position_embeddings
        self.input_size = prediction_chunk
        self.masking_type = masking_type

class CustomTransformerModel(PreTrainedModel):
    config_class = CustomTransformerConfig

    def __init__(self, config):
        super().__init__(config)
        self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
        self.llama.resize_token_embeddings(config.vocab_size)

        # for i, layer in enumerate(self.llama.model.layers):
        #     layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)

        for param in self.llama.parameters():
            param.requires_grad = False
        for param in self.llama.lm_head.parameters():
            param.requires_grad = True

        lora_config = LoraConfig(
            r=512, lora_alpha=512, lora_dropout=0.0,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
            bias="none", task_type=None
        )

        self.llama = get_peft_model(self.llama, lora_config)
        self.llama.print_trainable_parameters()
        # self.llama = self.llama.to(torch.float16)

    def forward(self, input_ids, labels=None, **kwargs):
        batch_size, seq_len = input_ids.shape
        assert seq_len == self.config.prediction_chunk, f"Expected input length {self.config.prediction_chunk}, got {seq_len}"

        # Build attention mask
        device = input_ids.device

        masking_type = getattr(self.config, "masking_type", "bidirectional")
        if masking_type == 'bidirectional':
            base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
        elif masking_type == 'bidirectional_masked':
            base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
            base_mask.fill_diagonal_(False)
        elif masking_type == 'unidirectional':
            base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
        else:
            raise ValueError(f"Unknown masking type: {self.config.masking_type}")

        attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
        attention_mask = attention_mask.to(dtype=torch.float32)  # required for SDPA and Flash attention


        with autocast("cuda", dtype=torch.float16):
            outputs = self.llama(
                input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                use_cache=False,
                **kwargs
            )

        logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)

        loss = None
        if labels is not None:
            assert labels.shape == (batch_size, seq_len), f"Labels shape mismatch: expected ({batch_size}, {seq_len}), got {labels.shape}"
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}

def disable_dropout(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Dropout):
            setattr(model, name, nn.Identity())
    return model