Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,693 Bytes
b43e862 7252f98 b43e862 7252f98 b43e862 0af2920 7252f98 0af2920 b43e862 0af2920 7252f98 0af2920 b43e862 0af2920 22370b2 b43e862 0af2920 a721355 1723639 0af2920 b43e862 0af2920 7252f98 0af2920 b43e862 0af2920 7252f98 b43e862 0af2920 b43e862 0af2920 7252f98 0af2920 7252f98 0af2920 b43e862 0af2920 7252f98 0af2920 b43e862 7252f98 b43e862 7d7b6d7 7252f98 0af2920 b43e862 7252f98 b43e862 7252f98 7141e39 7252f98 b43e862 7252f98 b43e862 7141e39 f7efac8 7252f98 b43e862 7141e39 7d7b6d7 238c8f8 b43e862 238c8f8 b43e862 238c8f8 b43e862 8851563 b43e862 b6cb410 b43e862 7252f98 b43e862 7252f98 b43e862 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
import torch.nn as nn
from torch.amp import autocast
from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from peft import LoraConfig, get_peft_model
import os
from typing import Optional, Tuple
hf_token = os.getenv("HF_TOKEN")
class BidirectionalLlamaAttention(LlamaAttention):
def __init__(self, original_layer, masking='unidirectional'):
super().__init__(original_layer.config, layer_idx=original_layer.layer_idx)
self.masking = masking
self.q_proj.weight = original_layer.q_proj.weight
self.k_proj.weight = original_layer.k_proj.weight
self.v_proj.weight = original_layer.v_proj.weight
self.o_proj.weight = original_layer.o_proj.weight
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(self, module: nn.Module, query, key, value, attention_mask, scaling, dropout=0.0, **kwargs):
key_states = self.repeat_kv(key, module.num_key_value_groups)
value_states = self.repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_mask = (1.0 - attention_mask) * float('-inf')
attn_mask = attn_mask.to(dtype=query.dtype)
attn_weights = attn_weights + attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
return attn_output, attn_weights
def rotate_half(self, x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_output, attn_weights = self.eager_attention_forward(
self, query_states, key_states, value_states, attention_mask,
dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return self.o_proj(attn_output), attn_weights
class CustomTransformerConfig(PretrainedConfig):
def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.prediction_chunk = prediction_chunk
self.max_position_embeddings = max_position_embeddings
self.input_size = prediction_chunk
self.masking_type = masking_type
class CustomTransformerModel(PreTrainedModel):
config_class = CustomTransformerConfig
def __init__(self, config):
super().__init__(config)
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
self.llama.resize_token_embeddings(config.vocab_size)
# for i, layer in enumerate(self.llama.model.layers):
# layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
for param in self.llama.parameters():
param.requires_grad = False
for param in self.llama.lm_head.parameters():
param.requires_grad = True
lora_config = LoraConfig(
r=512, lora_alpha=512, lora_dropout=0.0,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
bias="none", task_type=None
)
self.llama = get_peft_model(self.llama, lora_config)
self.llama.print_trainable_parameters()
# self.llama = self.llama.to(torch.float16)
def forward(self, input_ids, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
assert seq_len == self.config.prediction_chunk, f"Expected input length {self.config.prediction_chunk}, got {seq_len}"
# Build attention mask
device = input_ids.device
masking_type = getattr(self.config, "masking_type", "bidirectional")
if masking_type == 'bidirectional':
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
elif masking_type == 'bidirectional_masked':
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
base_mask.fill_diagonal_(False)
elif masking_type == 'unidirectional':
base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
else:
raise ValueError(f"Unknown masking type: {self.config.masking_type}")
attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
attention_mask = attention_mask.to(dtype=torch.float32) # required for SDPA and Flash attention
with autocast("cuda", dtype=torch.float16):
outputs = self.llama(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
**kwargs
)
logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)
loss = None
if labels is not None:
assert labels.shape == (batch_size, seq_len), f"Labels shape mismatch: expected ({batch_size}, {seq_len}), got {labels.shape}"
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
def disable_dropout(model):
for name, module in model.named_modules():
if isinstance(module, nn.Dropout):
setattr(model, name, nn.Identity())
return model |