Spaces:
Running on Zero

Ruurd commited on
Commit
b43e862
·
verified ·
1 Parent(s): 22370b2

Overhaul code for appropriate masking for full model instead of just attention layers

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +54 -154
llama_diffusion_model.py CHANGED
@@ -1,115 +1,57 @@
 
1
  import torch.nn as nn
2
- from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
3
- from transformers.models.llama.modeling_llama import LlamaAttention
4
  from torch.amp import autocast
 
 
5
  from peft import LoraConfig, get_peft_model
6
- from typing import Optional, Tuple
7
- import torch
8
  import os
9
-
10
-
11
 
12
  hf_token = os.getenv("HF_TOKEN")
13
 
14
  class BidirectionalLlamaAttention(LlamaAttention):
15
- def __init__(self, original_layer, masking = 'unidirectional'):
16
  super().__init__(original_layer.config, layer_idx=original_layer.layer_idx)
17
  self.masking = masking
18
-
19
- # Copy weights from original layer
20
  self.q_proj.weight = original_layer.q_proj.weight
21
  self.k_proj.weight = original_layer.k_proj.weight
22
  self.v_proj.weight = original_layer.v_proj.weight
23
  self.o_proj.weight = original_layer.o_proj.weight
24
 
25
  def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
26
- """
27
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
28
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
29
- """
30
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
31
  if n_rep == 1:
32
  return hidden_states
33
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
34
-
35
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
36
 
37
- def eager_attention_forward(self,
38
- module: nn.Module,
39
- query: torch.Tensor,
40
- key: torch.Tensor,
41
- value: torch.Tensor,
42
- attention_mask: Optional[torch.Tensor],
43
- scaling: float,
44
- dropout: float = 0.0,
45
- **kwargs,
46
- ):
47
  key_states = self.repeat_kv(key, module.num_key_value_groups)
48
  value_states = self.repeat_kv(value, module.num_key_value_groups)
49
-
50
- # attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
51
- # if attention_mask is not None:
52
- # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
53
- # attn_weights = attn_weights + causal_mask
54
-
55
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
 
56
  if attention_mask is not None:
57
- # Convert bool -> float with -inf where masked
58
  attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)
59
  attn_weights = attn_weights + attn_mask
60
 
61
-
62
  attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
63
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
64
- attn_output = torch.matmul(attn_weights, value_states)
65
- attn_output = attn_output.transpose(1, 2).contiguous()
66
-
67
  return attn_output, attn_weights
68
 
69
  def rotate_half(self, x):
70
- """Rotates half the hidden dims of the input."""
71
  x1 = x[..., : x.shape[-1] // 2]
72
- x2 = x[..., x.shape[-1] // 2 :]
73
-
74
  return torch.cat((-x2, x1), dim=-1)
75
 
76
-
77
- def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
78
- """Applies Rotary Position Embedding to the query and key tensors.
79
-
80
- Args:
81
- q (`torch.Tensor`): The query tensor.
82
- k (`torch.Tensor`): The key tensor.
83
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
84
- sin (`torch.Tensor`): The sine part of the rotary embedding.
85
- position_ids (`torch.Tensor`, *optional*):
86
- Deprecated and unused.
87
- unsqueeze_dim (`int`, *optional*, defaults to 1):
88
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
89
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
90
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
91
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
92
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
93
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
94
- Returns:
95
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
96
- """
97
  cos = cos.unsqueeze(unsqueeze_dim)
98
  sin = sin.unsqueeze(unsqueeze_dim)
99
  q_embed = (q * cos) + (self.rotate_half(q) * sin)
100
  k_embed = (k * cos) + (self.rotate_half(k) * sin)
101
-
102
  return q_embed, k_embed
103
 
104
- def forward(
105
- self,
106
- hidden_states: torch.Tensor,
107
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
108
- attention_mask: Optional[torch.Tensor],
109
- past_key_value: Optional[torch.Tensor] = None,
110
- cache_position: Optional[torch.LongTensor] = None,
111
- **kwargs,
112
- ):
113
  input_shape = hidden_states.shape[:-1]
114
  hidden_shape = (*input_shape, -1, self.head_dim)
115
 
@@ -117,7 +59,6 @@ class BidirectionalLlamaAttention(LlamaAttention):
117
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
118
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
119
 
120
- # Apply rotary embeddings
121
  cos, sin = position_embeddings
122
  query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
123
 
@@ -125,58 +66,17 @@ class BidirectionalLlamaAttention(LlamaAttention):
125
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
126
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
127
 
128
- # 🔄 **Modify the Attention Mask**
129
- seq_len = hidden_states.shape[1]
130
- batch_size = hidden_states.shape[0]
131
- if self.masking == 'bidirectional':
132
- base_mask = torch.ones((seq_len, seq_len), device=hidden_states.device, dtype=torch.bool)
133
- attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
134
- elif self.masking == 'bidirectional_masked':
135
- base_mask = torch.ones((seq_len, seq_len), device=hidden_states.device, dtype=torch.bool)
136
- base_mask[:, :].fill_diagonal_(False) # ✅ Apply diagonal masking only in 2D
137
- attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
138
- else: # unidirectional
139
- # 🚀 Standard autoregressive (causal) mask
140
- attn_mask = torch.tril(torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool))
141
- attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
142
-
143
-
144
- # Call the default attention function
145
  attn_output, attn_weights = self.eager_attention_forward(
146
- self,
147
- query_states,
148
- key_states,
149
- value_states,
150
- attn_mask, # ✅ Custom mask is applied here
151
- dropout=0.0 if not self.training else self.attention_dropout,
152
- scaling=self.scaling,
153
- **kwargs,
154
  )
155
 
156
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
157
- attn_output = self.o_proj(attn_output)
158
-
159
- return attn_output, attn_weights
160
-
161
-
162
- def _split_heads(self, tensor, num_heads, attn_head_size):
163
- """
164
- Splits hidden_size dim into attn_head_size and num_heads
165
- """
166
- new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
167
- tensor = tensor.view(*new_shape)
168
- return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
169
-
170
- def _merge_heads(self, tensor, num_heads, attn_head_size):
171
- """
172
- Merges attn_head_size dim and num_attn_heads dim into hidden_size
173
- """
174
- tensor = tensor.permute(0, 2, 1, 3).contiguous()
175
- new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
176
- return tensor.view(new_shape)
177
 
178
  class CustomTransformerConfig(PretrainedConfig):
179
- def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0, max_position_embeddings=4096, **kwargs):
 
180
  super().__init__(**kwargs)
181
  self.vocab_size = vocab_size
182
  self.hidden_size = hidden_size
@@ -186,72 +86,72 @@ class CustomTransformerConfig(PretrainedConfig):
186
  self.prediction_chunk = prediction_chunk
187
  self.max_position_embeddings = max_position_embeddings
188
  self.input_size = prediction_chunk
 
189
 
190
  class CustomTransformerModel(PreTrainedModel):
191
  config_class = CustomTransformerConfig
192
 
193
  def __init__(self, config):
194
  super().__init__(config)
195
-
196
- # Load pre-trained Llama model (excluding its original lm_head)
197
- self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token = hf_token)
198
-
199
  self.llama.resize_token_embeddings(config.vocab_size)
200
 
201
  for i, layer in enumerate(self.llama.model.layers):
202
- layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='bidirectional_masked')
203
 
204
- # Freeze Llama to retain pre-trained knowledge
205
  for param in self.llama.parameters():
206
  param.requires_grad = False
207
-
208
  for param in self.llama.lm_head.parameters():
209
  param.requires_grad = True
210
 
211
  lora_config = LoraConfig(
212
- r=512,
213
- lora_alpha=512,
214
- lora_dropout=0.0,
215
- target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Llama-3 uses these attention modules
216
- bias="none",
217
- task_type=None
218
  )
219
 
220
  self.llama = get_peft_model(self.llama, lora_config)
221
- self.llama.print_trainable_parameters() # Print number of trainable parameters
222
  self.llama = self.llama.to(torch.float16)
223
 
224
-
225
  def forward(self, input_ids, labels=None, **kwargs):
226
- batch_size, seq_length = input_ids.shape
227
- assert seq_length == 256, f"Expected input length input_size, got {seq_length}"
228
-
229
- with autocast("cuda", dtype=torch.float16): # Correct future-proof usage
230
-
231
-
232
- outputs = self.llama(input_ids, output_hidden_states=True, **kwargs)
233
-
234
- logits = outputs.logits[:,:,:self.config.vocab_size]
235
-
236
- # Reshape logits to (batch, input_size, vocab_size)
237
- logits = logits.view(batch_size, self.config.prediction_chunk, self.config.vocab_size)
238
-
239
- loss = None
240
-
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  if labels is not None:
242
- assert labels.shape == (batch_size, 256), f"Labels shape mismatch: expected (batch, input_size), got {labels.shape}"
243
-
244
- # Compute loss
245
- loss_fct = torch.nn.CrossEntropyLoss()
246
-
247
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
248
 
249
-
250
  return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
251
 
252
-
253
  def disable_dropout(model):
254
  for name, module in model.named_modules():
255
  if isinstance(module, nn.Dropout):
256
  setattr(model, name, nn.Identity())
257
- return model
 
1
+ import torch
2
  import torch.nn as nn
 
 
3
  from torch.amp import autocast
4
+ from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
5
+ from transformers.models.llama.modeling_llama import LlamaAttention
6
  from peft import LoraConfig, get_peft_model
 
 
7
  import os
8
+ from typing import Optional, Tuple
 
9
 
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
  class BidirectionalLlamaAttention(LlamaAttention):
13
+ def __init__(self, original_layer, masking='unidirectional'):
14
  super().__init__(original_layer.config, layer_idx=original_layer.layer_idx)
15
  self.masking = masking
 
 
16
  self.q_proj.weight = original_layer.q_proj.weight
17
  self.k_proj.weight = original_layer.k_proj.weight
18
  self.v_proj.weight = original_layer.v_proj.weight
19
  self.o_proj.weight = original_layer.o_proj.weight
20
 
21
  def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
 
 
 
22
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
23
  if n_rep == 1:
24
  return hidden_states
25
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
26
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
27
 
28
+ def eager_attention_forward(self, module: nn.Module, query, key, value, attention_mask, scaling, dropout=0.0, **kwargs):
 
 
 
 
 
 
 
 
 
29
  key_states = self.repeat_kv(key, module.num_key_value_groups)
30
  value_states = self.repeat_kv(value, module.num_key_value_groups)
 
 
 
 
 
 
31
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
32
+
33
  if attention_mask is not None:
 
34
  attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)
35
  attn_weights = attn_weights + attn_mask
36
 
 
37
  attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
38
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
39
+ attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
 
 
40
  return attn_output, attn_weights
41
 
42
  def rotate_half(self, x):
 
43
  x1 = x[..., : x.shape[-1] // 2]
44
+ x2 = x[..., x.shape[-1] // 2:]
 
45
  return torch.cat((-x2, x1), dim=-1)
46
 
47
+ def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  cos = cos.unsqueeze(unsqueeze_dim)
49
  sin = sin.unsqueeze(unsqueeze_dim)
50
  q_embed = (q * cos) + (self.rotate_half(q) * sin)
51
  k_embed = (k * cos) + (self.rotate_half(k) * sin)
 
52
  return q_embed, k_embed
53
 
54
+ def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
 
 
 
 
 
 
 
 
55
  input_shape = hidden_states.shape[:-1]
56
  hidden_shape = (*input_shape, -1, self.head_dim)
57
 
 
59
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
60
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
61
 
 
62
  cos, sin = position_embeddings
63
  query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
64
 
 
66
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
67
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  attn_output, attn_weights = self.eager_attention_forward(
70
+ self, query_states, key_states, value_states, attention_mask,
71
+ dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs
 
 
 
 
 
 
72
  )
73
 
74
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
75
+ return self.o_proj(attn_output), attn_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  class CustomTransformerConfig(PretrainedConfig):
78
+ def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
79
+ max_position_embeddings=4096, masking_type="bidirectional_masked", **kwargs):
80
  super().__init__(**kwargs)
81
  self.vocab_size = vocab_size
82
  self.hidden_size = hidden_size
 
86
  self.prediction_chunk = prediction_chunk
87
  self.max_position_embeddings = max_position_embeddings
88
  self.input_size = prediction_chunk
89
+ self.masking_type = masking_type
90
 
91
  class CustomTransformerModel(PreTrainedModel):
92
  config_class = CustomTransformerConfig
93
 
94
  def __init__(self, config):
95
  super().__init__(config)
96
+ self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
 
 
 
97
  self.llama.resize_token_embeddings(config.vocab_size)
98
 
99
  for i, layer in enumerate(self.llama.model.layers):
100
+ layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
101
 
 
102
  for param in self.llama.parameters():
103
  param.requires_grad = False
 
104
  for param in self.llama.lm_head.parameters():
105
  param.requires_grad = True
106
 
107
  lora_config = LoraConfig(
108
+ r=512, lora_alpha=512, lora_dropout=0.0,
109
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
110
+ bias="none", task_type=None
 
 
 
111
  )
112
 
113
  self.llama = get_peft_model(self.llama, lora_config)
114
+ self.llama.print_trainable_parameters()
115
  self.llama = self.llama.to(torch.float16)
116
 
 
117
  def forward(self, input_ids, labels=None, **kwargs):
118
+ batch_size, seq_len = input_ids.shape
119
+ assert seq_len == self.config.prediction_chunk, f"Expected input length {self.config.prediction_chunk}, got {seq_len}"
120
+
121
+ # Build attention mask
122
+ device = input_ids.device
123
+ if self.config.masking_type == 'bidirectional':
124
+ base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
125
+ elif self.config.masking_type == 'bidirectional_masked':
126
+ base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
127
+ base_mask.fill_diagonal_(False)
128
+ elif self.config.masking_type == 'unidirectional':
129
+ base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
130
+ else:
131
+ raise ValueError(f"Unknown masking type: {self.config.masking_type}")
132
+
133
+ attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
134
+
135
+ with autocast("cuda", dtype=torch.float16):
136
+ outputs = self.llama(
137
+ input_ids,
138
+ attention_mask=attention_mask,
139
+ output_hidden_states=True,
140
+ **kwargs
141
+ )
142
+
143
+ logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)
144
+
145
+ loss = None
146
  if labels is not None:
147
+ assert labels.shape == (batch_size, seq_len), f"Labels shape mismatch: expected ({batch_size}, {seq_len}), got {labels.shape}"
148
+ loss_fct = nn.CrossEntropyLoss()
 
 
 
149
  loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
150
 
 
151
  return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
152
 
 
153
  def disable_dropout(model):
154
  for name, module in model.named_modules():
155
  if isinstance(module, nn.Dropout):
156
  setattr(model, name, nn.Identity())
157
+ return model