Spaces:
Paused
Paused
Upload 12 files
Browse files- .gitignore +14 -0
- README.md +187 -10
- config.py +41 -0
- data.py +117 -0
- download_model_weight.py +131 -0
- fine_tune.py +1282 -0
- inference.py +84 -0
- llama_torchrun.py +1435 -0
- metric.py +28 -0
- model.py +489 -0
- tokenizer.py +21 -0
- trainer.py +469 -0
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
snapshot.pt
|
2 |
+
snapshot2.pt
|
3 |
+
llama.py
|
4 |
+
snapshot_3.pt
|
5 |
+
metric.py
|
6 |
+
weights/
|
7 |
+
gpt4all.json
|
8 |
+
fine_tune.py
|
9 |
+
old_files/
|
10 |
+
snapshot_4650.pt
|
11 |
+
snapshot (1).pt
|
12 |
+
|
13 |
+
|
14 |
+
|
README.md
CHANGED
@@ -1,13 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Introducing StoryLlama - A Smaller Language Model for Bedtime Stories!
|
3 |
+
|
4 |
+
- So, I trained a Llama a 88M architecture I coded from ground up to build a small instruct model, going through the below-mentioned stages from scratch.
|
5 |
+
- Trained on TiyStories dataset form HuggingFace consisting of 4B tokens for a total of 5000 steps
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
### Pretraining
|
10 |
+
|
11 |
+
#### Dataset
|
12 |
+
|
13 |
+
- I used the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset from HuggingFace.
|
14 |
+
|
15 |
+
1) Train dataset - 2 M records approx
|
16 |
+
2) Val dataset - 26K records approx
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
---
|
21 |
+
|
22 |
+
#### ModelArgs (Hyperparameters)
|
23 |
+
|
24 |
+
|
25 |
+
Below is a table summarizing the configuration parameters for the model:
|
26 |
+
|
27 |
+
| Parameter | Description | Default Value | Type |
|
28 |
+
|--------------------------------|-----------------------------------------------------------------------------|-----------------------------------|-----------|
|
29 |
+
| `epochs` | Number of training epochs | `4` | `int` |
|
30 |
+
| `block_size` | Size of each block (context length) | `512` | `int` |
|
31 |
+
| `batch_size` | Batch size for training | `64` | `int` |
|
32 |
+
| `inference` | Inference mode (not specified) | `None` | `None` |
|
33 |
+
| `embeddings_dims` | Dimensionality of embeddings | `512` | `int` |
|
34 |
+
| `attn_dropout` | Dropout rate for attention layers | `0.1` | `float` |
|
35 |
+
| `no_of_heads` | Number of attention heads | `8` | `int` |
|
36 |
+
| `dropout` | Dropout rate for the model | `0.1` | `float` |
|
37 |
+
| `val_epochs` | Number of validation epochs | `2` | `int` |
|
38 |
+
| `max_lr` | Maximum learning rate | `6e-4` | `float` |
|
39 |
+
| `no_of_decoder_layers` | Number of decoder layers | `8` | `int` |
|
40 |
+
| `weight_decay_optim` | Weight decay for the optimizer | `0.1` | `float` |
|
41 |
+
| `beta_1` | Beta 1 for Adam optimizer | `0.9` | `float` |
|
42 |
+
| `beta_2` | Beta 2 for Adam optimizer | `0.95` | `float` |
|
43 |
+
| `clip` | Gradient clipping value | `1.0` | `float` |
|
44 |
+
| `device` | Device to run the model (`cuda` or `cpu`) | `'cuda'` | `str` |
|
45 |
+
| `no_kv_heads` | Number of key-value heads | `2` | `int` |
|
46 |
+
| `vocab_size` | Size of the vocabulary | `50304` | `int` |
|
47 |
+
| `eps` | Epsilon value for numerical stability | `1e-5` | `float` |
|
48 |
+
| `dtype` | Data type for tensors (`bfloat16` if supported, else `float16`) | `'bfloat16'` or `'float16'` | `str` |
|
49 |
+
| `save_checkpoint_dir` | Directory to save model checkpoints | `"checkpoints"` | `str` |
|
50 |
+
| `prompt` | Default prompt for inference | `"Once upon a time"` | `str` |
|
51 |
+
| `save_checkpoint_iter` | Save checkpoint every N iterations | `50` | `int` |
|
52 |
+
| `total_iters` | Total number of training iterations | `10000` | `int` |
|
53 |
+
| `eval_iters` | Evaluate model every N iterations | `50` | `int` |
|
54 |
+
| `eval_check` | Check evaluation metrics every N iterations | `100` | `int` |
|
55 |
+
| `warmup_iters` | Number of warmup iterations for learning rate scheduling | `700` | `int` |
|
56 |
+
| `min_lr` | Minimum learning rate (10% of `max_lr`) | `0.1 * max_lr` | `float` |
|
57 |
+
| `lr_decay_iters` | Number of iterations for learning rate decay | `10000` | `int` |
|
58 |
+
| `total_batch_size` | Total batch size across all devices | `524288` | `int` |
|
59 |
+
| `micro_batch_size` | Micro batch size per device | `batch_size` | `int` |
|
60 |
+
| `gradient_accumulation_steps` | Gradient accumulation steps | 524288 | `int` |
|
61 |
+
---
|
62 |
+
#### Hardware Setup
|
63 |
+
|
64 |
+
- Used DPP using Pytorch torchrun consisting of 2x GeForce RTX A100 AXM (80gb VRAM each) rented on runpod.io
|
65 |
+
- The model is a 0.768GB in size but needs around 4 GB of VRAM when loaded in fp32 precision
|
66 |
+
---
|
67 |
+
|
68 |
+
#### Frameworks:
|
69 |
+
**Pytorch**
|
70 |
+
|
71 |
+
|
72 |
+
---
|
73 |
+
|
74 |
+
#### Epochs/Steps
|
75 |
+
- Iterations (train) = 5k
|
76 |
+
|
77 |
+
- Val iterations = every 50 steps
|
78 |
+
---
|
79 |
+
|
80 |
+
#### Losses
|
81 |
+
- Train loss - 1.43
|
82 |
+
|
83 |
+
- Val loss - 1.45
|
84 |
+
|
85 |
+
---
|
86 |
+
|
87 |
+
#### Screenshots of the loss curves
|
88 |
+
|
89 |
+
- Loss Curves (Train and Val)
|
90 |
+
|
91 |
+

|
92 |
+
|
93 |
+
---
|
94 |
+
#### Output
|
95 |
+
|
96 |
+
- Prompt: Once upon a time
|
97 |
+
|
98 |
+

|
99 |
+
|
100 |
---
|
101 |
+
|
102 |
+
### Local setup
|
103 |
+
|
104 |
+
|
105 |
+
### Requirements
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
```python
|
110 |
+
git [clone the repo](https://github.com/YuvrajSingh-mist/StoryLlama.git)
|
111 |
+
cd StoryLlama
|
112 |
+
bash ./install.sh
|
113 |
+
|
114 |
+
```
|
115 |
+
- A wandb.ai account for plotting graphs for your loss curves
|
116 |
+
|
117 |
+
- On your terminal run
|
118 |
+
```python
|
119 |
+
wandb login
|
120 |
+
```
|
121 |
+
|
122 |
+
- Enter the api key and follow the instructions and once you are succesfully logged in follow the given steps
|
123 |
+
|
124 |
+
|
125 |
+
- Download the model
|
126 |
+
|
127 |
+
```python
|
128 |
+
python download_model_weight.py
|
129 |
+
```
|
130 |
+
|
131 |
+
|
132 |
---
|
133 |
|
134 |
+
### Running
|
135 |
+
|
136 |
+
|
137 |
+
#### Training a model
|
138 |
+
|
139 |
+
- Kindly change 'device' to any of your available cuda gpus.
|
140 |
+
|
141 |
+
To run:
|
142 |
+
|
143 |
+
```python
|
144 |
+
bash ./install.sh
|
145 |
+
```
|
146 |
+
|
147 |
+
```python
|
148 |
+
torchrun --standalone --nproc_per_node=gpu trainer.py \
|
149 |
+
--epochs 10 \
|
150 |
+
--block_size 256 \
|
151 |
+
--batch_size 128 \
|
152 |
+
--embeddings_dims 768 \
|
153 |
+
--attn_dropout 0.2 \
|
154 |
+
--no_of_heads 12 \
|
155 |
+
--dropout 0.2 \
|
156 |
+
--val_epochs 3 \
|
157 |
+
--max_lr 5e-4 \
|
158 |
+
--no_of_decoder_layers 6 \
|
159 |
+
--weight_decay_optim 0.01 \
|
160 |
+
--beta_1 0.85 \
|
161 |
+
--beta_2 0.99 \
|
162 |
+
--clip 0.5 \
|
163 |
+
--device "cuda" \
|
164 |
+
--no_kv_heads 4 \
|
165 |
+
--vocab_size 50257 \
|
166 |
+
--eps 1e-6 \
|
167 |
+
--dtype "float16" \
|
168 |
+
--save_checkpoint_dir "model_checkpoints" \
|
169 |
+
--prompt "Once upon a time" \
|
170 |
+
--save_checkpoint_iter 100 \
|
171 |
+
--total_iters 5000 \
|
172 |
+
--eval_iters 200 \
|
173 |
+
--eval_check 500 \
|
174 |
+
--warmup_iters 1000 \
|
175 |
+
--min_lr 1e-5 \
|
176 |
+
--lr_decay_iters 2000 \
|
177 |
+
--total_batch_size 262144 \
|
178 |
+
--micro_batch_size 128 \
|
179 |
+
--gradient_accumulation_steps 4
|
180 |
+
|
181 |
+
```
|
182 |
+
--standalone - if all the gpu are on one server
|
183 |
+
--npro_per_node - number of gpus available and use the keyword gpu to use all
|
184 |
+
|
185 |
+
#### Inference on a model
|
186 |
+
|
187 |
+
```python
|
188 |
+
python inference.py --prompt "Once upon a time" --max_length 100 --temperature 0.8 --topk 50
|
189 |
+
```
|
190 |
+
|
config.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import torch
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ModelArgs:
|
7 |
+
|
8 |
+
epochs: int = 4
|
9 |
+
block_size: int = 512
|
10 |
+
batch_size: int = 64
|
11 |
+
inference = None
|
12 |
+
embeddings_dims: int = 512
|
13 |
+
attn_dropout: float = 0.1
|
14 |
+
no_of_heads: int = 8
|
15 |
+
dropout: float = 0.1
|
16 |
+
val_epochs: int = 2
|
17 |
+
max_lr: float = 6e-4
|
18 |
+
no_of_decoder_layers: int = 8
|
19 |
+
weight_decay_optim: float = 0.1
|
20 |
+
beta_1: float = 0.9
|
21 |
+
beta_2: float = 0.95
|
22 |
+
clip: float = 1.0
|
23 |
+
device: str = 'cuda'
|
24 |
+
no_kv_heads: int = 2
|
25 |
+
vocab_size: int = 50304
|
26 |
+
eps: float = 1e-5
|
27 |
+
dtype: str = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
|
28 |
+
save_checkpoint_dir: str = "checkpoints"
|
29 |
+
prompt: str = "Once upon a time"
|
30 |
+
|
31 |
+
|
32 |
+
save_checkpoint_iter: int = 50
|
33 |
+
total_iters: int = 10000
|
34 |
+
eval_iters: int = 50
|
35 |
+
eval_check: int = 100
|
36 |
+
warmup_iters: int = 700
|
37 |
+
min_lr: float = 0.1 * max_lr
|
38 |
+
lr_decay_iters: int = 10000
|
39 |
+
total_batch_size: int = 524288
|
40 |
+
micro_batch_size: int = batch_size
|
41 |
+
gradient_accumulation_steps: int = total_batch_size // (micro_batch_size * (block_size * torch.cuda.device_count()))
|
data.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
import torch.multiprocessing as mp
|
5 |
+
from torch.utils.data.distributed import DistributedSampler
|
6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
7 |
+
|
8 |
+
|
9 |
+
from datasets import load_dataset
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tokenizer import Tokenizer
|
12 |
+
from config import ModelArgs
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
tokenizer = Tokenizer().ready_tokenizer()
|
17 |
+
|
18 |
+
|
19 |
+
tinystories = True
|
20 |
+
fw = False
|
21 |
+
fw_train = None
|
22 |
+
fw_test = None
|
23 |
+
if(tinystories):
|
24 |
+
fw_train = load_dataset("roneneldan/TinyStories", split="train")
|
25 |
+
fw_test = load_dataset("roneneldan/TinyStories", split="validation")
|
26 |
+
print(fw_train)
|
27 |
+
print(fw_test)
|
28 |
+
if(fw):
|
29 |
+
fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
|
30 |
+
fw_train = fw_train.train_test_split(test_size=0.01)
|
31 |
+
print(fw_train)
|
32 |
+
print(fw_train)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
38 |
+
|
39 |
+
def tokenize_function(examples):
|
40 |
+
return tokenizer(
|
41 |
+
examples['text'],
|
42 |
+
max_length=ModelArgs.block_size,
|
43 |
+
padding='max_length',
|
44 |
+
truncation=True,
|
45 |
+
return_tensors='pt'
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def prepare_dataset(split, device, batch_size):
|
52 |
+
print("Device is: ", device)
|
53 |
+
|
54 |
+
def collate_fn(batch):
|
55 |
+
# Extract text data
|
56 |
+
texts = [item ["text"] for item in batch]
|
57 |
+
|
58 |
+
|
59 |
+
input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
|
60 |
+
|
61 |
+
input_encodings["labels"] = input_encodings["input_ids"].clone()
|
62 |
+
|
63 |
+
input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:]
|
64 |
+
input_encodings["labels"][:, -1] = tokenizer.eos_token_id
|
65 |
+
|
66 |
+
return input_encodings
|
67 |
+
|
68 |
+
|
69 |
+
dataloader = None
|
70 |
+
if(tinystories):
|
71 |
+
if(split == 'train'):
|
72 |
+
data_loader = DataLoader(
|
73 |
+
fw_train,
|
74 |
+
# generator=generator,
|
75 |
+
batch_size=batch_size,
|
76 |
+
|
77 |
+
sampler=DistributedSampler(fw_train, shuffle=True),
|
78 |
+
collate_fn=collate_fn,
|
79 |
+
drop_last=True,
|
80 |
+
shuffle=False
|
81 |
+
)
|
82 |
+
elif(split == 'val'):
|
83 |
+
data_loader = DataLoader(
|
84 |
+
fw_test,
|
85 |
+
|
86 |
+
|
87 |
+
batch_size=batch_size,
|
88 |
+
sampler=DistributedSampler(fw_test, shuffle=True),
|
89 |
+
collate_fn=collate_fn,
|
90 |
+
drop_last=True,
|
91 |
+
shuffle=False
|
92 |
+
)
|
93 |
+
elif(fw):
|
94 |
+
if(split == 'train'):
|
95 |
+
data_loader = DataLoader(
|
96 |
+
fw_train['train'],
|
97 |
+
batch_size=batch_size,
|
98 |
+
|
99 |
+
|
100 |
+
sampler=DistributedSampler(fw_train['train'], shuffle=True),
|
101 |
+
collate_fn=collate_fn,
|
102 |
+
drop_last=True,
|
103 |
+
shuffle=False
|
104 |
+
)
|
105 |
+
elif(split == 'val'):
|
106 |
+
data_loader = DataLoader(
|
107 |
+
fw_train['test'],
|
108 |
+
batch_size=batch_size,
|
109 |
+
# generator=generator,
|
110 |
+
sampler=DistributedSampler(fw_train["test"]),
|
111 |
+
collate_fn=collate_fn,
|
112 |
+
|
113 |
+
drop_last=True,
|
114 |
+
shuffle=False
|
115 |
+
)
|
116 |
+
return data_loader
|
117 |
+
|
download_model_weight.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import gdown
|
2 |
+
# import os
|
3 |
+
# import argparse
|
4 |
+
|
5 |
+
# def download_model(model_id, folder, filename):
|
6 |
+
# os.makedirs(folder, exist_ok=True)
|
7 |
+
# url = f"https://drive.google.com/uc?id={model_id}"
|
8 |
+
# output_path = os.path.join(folder, filename)
|
9 |
+
# print(f"Downloading model to {output_path}...")
|
10 |
+
# gdown.download(url, output_path, quiet=False)
|
11 |
+
# print("Download complete!")
|
12 |
+
|
13 |
+
# def main():
|
14 |
+
# parser = argparse.ArgumentParser(description="Download models using gdown and organize them into appropriate folders.")
|
15 |
+
# parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model")
|
16 |
+
# parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model")
|
17 |
+
# parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model")
|
18 |
+
|
19 |
+
# args = parser.parse_args()
|
20 |
+
|
21 |
+
# pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98"
|
22 |
+
# fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf"
|
23 |
+
# dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh"
|
24 |
+
|
25 |
+
# if args.pretrained:
|
26 |
+
# download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt")
|
27 |
+
# if args.sft:
|
28 |
+
# download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt")
|
29 |
+
# if args.dpo:
|
30 |
+
# download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt")
|
31 |
+
|
32 |
+
# if __name__ == "__main__":
|
33 |
+
# main()
|
34 |
+
|
35 |
+
|
36 |
+
# import os
|
37 |
+
# import argparse
|
38 |
+
|
39 |
+
# def download_model(model_id, folder, filename, access_token):
|
40 |
+
# os.makedirs(folder, exist_ok=True)
|
41 |
+
# output_path = os.path.join(folder, filename)
|
42 |
+
|
43 |
+
# url = f"https://www.googleapis.com/drive/v3/files/{model_id}?alt=media"
|
44 |
+
# command = f"curl -H \"Authorization: Bearer {access_token}\" {url} -o {output_path}"
|
45 |
+
|
46 |
+
# print(f"Downloading model to {output_path}...")
|
47 |
+
# os.system(command)
|
48 |
+
# print("Download complete!")
|
49 |
+
|
50 |
+
# def main():
|
51 |
+
# parser = argparse.ArgumentParser(description="Download models using Google Drive API and organize them into appropriate folders.")
|
52 |
+
# parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model")
|
53 |
+
# parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model")
|
54 |
+
# parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model")
|
55 |
+
# parser.add_argument("--token", type=str, required=True, help="Google Drive API Access Token")
|
56 |
+
|
57 |
+
# args = parser.parse_args()
|
58 |
+
|
59 |
+
# pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98"
|
60 |
+
# fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf"
|
61 |
+
# dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh"
|
62 |
+
|
63 |
+
# if args.pretrained:
|
64 |
+
# download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt", args.token)
|
65 |
+
# if args.sft:
|
66 |
+
# download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt", args.token)
|
67 |
+
# if args.dpo:
|
68 |
+
# download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt", args.token)
|
69 |
+
|
70 |
+
# if __name__ == "__main__":
|
71 |
+
# main()
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
# download_model_weight.py
|
76 |
+
import os
|
77 |
+
import argparse
|
78 |
+
from huggingface_hub import hf_hub_download, login
|
79 |
+
|
80 |
+
def download_model(repo_id, filename, cache_dir):
|
81 |
+
|
82 |
+
try:
|
83 |
+
model_path = hf_hub_download(
|
84 |
+
repo_id=repo_id,
|
85 |
+
filename=filename,
|
86 |
+
cache_dir=cache_dir,
|
87 |
+
resume_download=True,
|
88 |
+
force_download=False,
|
89 |
+
token=os.getenv("HF_TOKEN")
|
90 |
+
)
|
91 |
+
|
92 |
+
if os.path.exists(model_path) and os.path.getsize(model_path) > 1024*1024:
|
93 |
+
return model_path
|
94 |
+
raise ValueError("Downloaded file is too small or invalid")
|
95 |
+
except Exception as e:
|
96 |
+
print(f"Download failed: {str(e)}")
|
97 |
+
raise
|
98 |
+
|
99 |
+
def main():
|
100 |
+
parser = argparse.ArgumentParser(description="Download models from Hugging Face Hub")
|
101 |
+
parser.add_argument("--model_type",
|
102 |
+
choices=["pretrained"],
|
103 |
+
required=True,
|
104 |
+
help="Type of model to download")
|
105 |
+
|
106 |
+
args = parser.parse_args()
|
107 |
+
|
108 |
+
model_config = {
|
109 |
+
|
110 |
+
"pretrained": {
|
111 |
+
"repo_id": "YuvrajSingh9886/StoryLlama",
|
112 |
+
"filename": "snapshot_4650.pt",
|
113 |
+
"cache_dir": "weights/pretrained"
|
114 |
+
}
|
115 |
+
}
|
116 |
+
|
117 |
+
config = model_config[args.model_type]
|
118 |
+
os.makedirs(config["cache_dir"], exist_ok=True)
|
119 |
+
|
120 |
+
print(f"Downloading {args.model_type} model...")
|
121 |
+
model_path = download_model(
|
122 |
+
config["repo_id"],
|
123 |
+
config["filename"],
|
124 |
+
config["cache_dir"]
|
125 |
+
)
|
126 |
+
print(f"Successfully downloaded to: {model_path}")
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
|
130 |
+
login(token=os.getenv("HF_TOKEN"))
|
131 |
+
main()
|
fine_tune.py
ADDED
@@ -0,0 +1,1282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 185860
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
# from torchtune.modules import RMSNorm
|
9 |
+
from tokenizers import Tokenizer
|
10 |
+
from pathlib import Path
|
11 |
+
import torch.multiprocessing as mp
|
12 |
+
from torch.utils.data.distributed import DistributedSampler
|
13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
14 |
+
from torch.distributed import init_process_group, destroy_process_group
|
15 |
+
import torch
|
16 |
+
from datasets import Dataset
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from transformers.models.prophetnet.modeling_prophetnet import ProphetNetDecoderModelOutput
|
19 |
+
import wandb
|
20 |
+
from tqdm import tqdm
|
21 |
+
from functools import partial
|
22 |
+
|
23 |
+
import torch.optim as optim
|
24 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
25 |
+
|
26 |
+
|
27 |
+
# Load model directly
|
28 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
29 |
+
import os
|
30 |
+
|
31 |
+
|
32 |
+
# import wandb
|
33 |
+
# wandb.login()
|
34 |
+
|
35 |
+
|
36 |
+
# from torch.utils.tensorboard import SummaryWriter
|
37 |
+
|
38 |
+
|
39 |
+
from datasets import load_dataset, concatenate_datasets
|
40 |
+
# use name="sample-10BT" to use the 10BT sample
|
41 |
+
# fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
|
42 |
+
# print(fw_train)
|
43 |
+
# Select only 1000 rows from the dataset
|
44 |
+
# fw_train = fw_train.select(range(1000000))
|
45 |
+
# alpaca = load_dataset("yahma/alpaca-cleaned", split='train')
|
46 |
+
# dolly = load_dataset("llm-wizard/dolly-15k-instruction-alpaca-format", split='train')
|
47 |
+
# merged_dataset = concatenate_datasets([alpaca, dolly])
|
48 |
+
dataset = load_dataset("swype/instruct", split='train', trust_remote_code=True)
|
49 |
+
# print(fw_train)
|
50 |
+
# Split the dataset into training and validation sets
|
51 |
+
merged_dataset = dataset.train_test_split(test_size=0.1)
|
52 |
+
print(merged_dataset)
|
53 |
+
# fw_train = fw_train.train_test_split(test_size=0.2)
|
54 |
+
# print(fw_train)
|
55 |
+
|
56 |
+
# Access the splits
|
57 |
+
# train_dataset = train_val_split['train']
|
58 |
+
# val_dataset = train_val_split['test']
|
59 |
+
|
60 |
+
# train_dataset = fw_train.train_test_split(test_size=0.2)
|
61 |
+
|
62 |
+
|
63 |
+
def setup(rank=None, world_size=None):
|
64 |
+
# os.environ['MASTER_ADDR'] = 'localhost'
|
65 |
+
# os.environ['MASTER_PORT'] = '12355'
|
66 |
+
init_process_group("nccl")
|
67 |
+
# torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
|
68 |
+
|
69 |
+
def cleanup():
|
70 |
+
destroy_process_group()
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
@dataclass
|
75 |
+
class ModelArgs:
|
76 |
+
#Hyperparameters
|
77 |
+
|
78 |
+
epochs = 5
|
79 |
+
block_size = 128
|
80 |
+
batch_size = 64
|
81 |
+
embeddings_dims = 786
|
82 |
+
attn_dropout = 0.1
|
83 |
+
no_of_heads = 6 #IMP needs to be thoroughly calculated
|
84 |
+
dropout = 0.1
|
85 |
+
# epochs = 100
|
86 |
+
val_epochs = 2
|
87 |
+
max_lr = 2e-4
|
88 |
+
no_of_decoder_layers = 6 #IMP needs to be thoroughly calculated
|
89 |
+
weight_decay_optim = 0.1
|
90 |
+
beta_1 = 0.9
|
91 |
+
beta_2 = 0.95
|
92 |
+
clip = 1.0
|
93 |
+
device = 'cuda'
|
94 |
+
no_kv_heads = 2
|
95 |
+
vocab_size = 50258
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
from pathlib import Path
|
100 |
+
data_path = Path('data')
|
101 |
+
data_path.mkdir(exist_ok=True)
|
102 |
+
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
|
103 |
+
# !cp input.txt data/input.txt
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
#Datasets
|
108 |
+
|
109 |
+
# Using tinyshakespeare
|
110 |
+
|
111 |
+
# with open('data/input.txt', 'r', encoding='utf-8') as f:
|
112 |
+
# text = f.read()
|
113 |
+
|
114 |
+
|
115 |
+
# Load the tokenizer
|
116 |
+
# tokenizer = Tokenizer.from_file("bpe_tokenizer_30k.json")
|
117 |
+
|
118 |
+
# Encode and decode functions
|
119 |
+
# encode = lambda s: tokenizer.encode(s).ids
|
120 |
+
# decode = lambda l: tokenizer.decode(l)
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
def _save_snapshot(model, optimizer, scheduler, epoch, step):
|
125 |
+
snapshot = {
|
126 |
+
"MODEL_STATE": model.module.state_dict(),
|
127 |
+
"OPTIMIZER_STATE": optimizer.state_dict(),
|
128 |
+
"SCHEDULER_STATE": scheduler.state_dict(), # NEW: Save scheduler state
|
129 |
+
"EPOCHS_RUN": epoch,
|
130 |
+
"STEP_RUN": step
|
131 |
+
}
|
132 |
+
torch.save(snapshot, "/kaggle/working/snapshot_fine_tuned_model_with_gradient_clipping_3.pt")
|
133 |
+
print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.")
|
134 |
+
|
135 |
+
def _load_snapshot(snapshot_path, model, optimizer, scheduler):
|
136 |
+
snapshot = torch.load(snapshot_path)
|
137 |
+
model.load_state_dict(snapshot["MODEL_STATE"])
|
138 |
+
# optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"])
|
139 |
+
# scheduler.load_state_dict(snapshot["SCHEDULER_STATE"]) # Load scheduler state
|
140 |
+
epoch = snapshot["EPOCHS_RUN"]
|
141 |
+
step = snapshot["STEP_RUN"]
|
142 |
+
print(f"Resuming from Epoch {epoch}, Step {step}")
|
143 |
+
return epoch, step
|
144 |
+
|
145 |
+
#Subword level tokenization
|
146 |
+
|
147 |
+
#Loading custom trained BPE
|
148 |
+
# Load the tokenizer
|
149 |
+
# tokenizer = Tokenizer.from_file("data/bpe_tokenizer_tinyshakespeare_1k.json")
|
150 |
+
# vocab_size = tokenizer.get_vocab_size()
|
151 |
+
# Encode and decode functions
|
152 |
+
# encode = lambda s: tokenizer.encode(s).ids
|
153 |
+
# decode = lambda l: tokenizer.decode(l)
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
###############################################################################
|
160 |
+
#Character level tokenization
|
161 |
+
|
162 |
+
# # here are all the unique characters that occur in this text
|
163 |
+
# chars = sorted(list(set(text)))
|
164 |
+
# vocab_size = len(chars)
|
165 |
+
|
166 |
+
|
167 |
+
# # create a mapping from characters to integers
|
168 |
+
# stoi = { ch: i for i,ch in enumerate(chars) }
|
169 |
+
# itos = { i:ch for i,ch in enumerate(chars) }
|
170 |
+
# encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
|
171 |
+
# decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
|
172 |
+
|
173 |
+
|
174 |
+
# Convert the dataset to Hugging Face Dataset format
|
175 |
+
# train_hf_dataset = Dataset.from_dict({"text": train_dataset['train']['text']})
|
176 |
+
# val_hf_dataset = Dataset.from_dict({"text": train_dataset['test']['text']})
|
177 |
+
|
178 |
+
# Tokenize the dataset using the `map` function
|
179 |
+
|
180 |
+
|
181 |
+
# from google.colab import userdata
|
182 |
+
# HF_TOKEN = userdata.get('HF_TOKEN')
|
183 |
+
|
184 |
+
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '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')
|
185 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
186 |
+
# if tokenizer.pad_token is None:
|
187 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
188 |
+
# print("ADDED THE TOKENS: ", tokenizer.pad_token_id)
|
189 |
+
# tokenizer.bos_token = "[INST]"
|
190 |
+
# tokenizer.eos_token = "[/INST]"
|
191 |
+
# model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
192 |
+
|
193 |
+
def tokenize_function(examples):
|
194 |
+
return tokenizer(
|
195 |
+
examples['text'],
|
196 |
+
max_length=ModelArgs.block_size,
|
197 |
+
padding='max_length',
|
198 |
+
truncation=True,
|
199 |
+
return_tensors='pt'
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
## Load the tokenizer
|
204 |
+
# tokenizer = Tokenizer.from_file("bpe_tokenizer_30k.json")
|
205 |
+
|
206 |
+
# # Tokenization functions
|
207 |
+
# def encode_train(examples):
|
208 |
+
# tokens = []
|
209 |
+
# for example in examples['text']:
|
210 |
+
# out = tokenizer.encode(example).ids
|
211 |
+
# tokens.append(out) # Append the tokenized sequence (do not flatten)
|
212 |
+
# return {"tokens": tokens}
|
213 |
+
|
214 |
+
# def encode_val(examples):
|
215 |
+
# tokens = []
|
216 |
+
# for example in examples['text']:
|
217 |
+
# out = tokenizer.encode(example).ids
|
218 |
+
# tokens.append(out) # Append the tokenized sequence (do not flatten)
|
219 |
+
# return {"tokens": tokens}
|
220 |
+
|
221 |
+
# Apply tokenization with batching
|
222 |
+
# train_data = train_dataset['train'].map(tokenize_function, batched=True, batch_size=8000, remove_columns=['id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'], num_proc=8)
|
223 |
+
# val_data = train_dataset['test'].map(tokenize_function, batched=True, batch_size=8000, remove_columns=['id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'], num_proc=8)
|
224 |
+
|
225 |
+
# # # Extract tokens from the processed datasets
|
226 |
+
# # train_tokens = train_data['tokens']
|
227 |
+
# # val_tokens = val_data['tokens']
|
228 |
+
|
229 |
+
# # Flatten the tokenized data
|
230 |
+
# # train_tokens = [token_id for seq in train_data['input_ids'] for token_id in seq]
|
231 |
+
# # val_tokens = [token_id for seq in val_data['input_ids'] for token_id in seq]
|
232 |
+
|
233 |
+
# try:
|
234 |
+
# train_tensors = [torch.tensor(seq) for seq in tqdm(train_data['input_ids'], desc="Converting train_data to tensors")]
|
235 |
+
# train_data_tensor = torch.cat(train_tensors)
|
236 |
+
# except Exception as e:
|
237 |
+
# print(f"Error during tensor conversion: {e}")
|
238 |
+
|
239 |
+
# try:
|
240 |
+
# train_tensors = [torch.tensor(seq) for seq in tqdm(val_data['input_ids'], desc="Converting train_data to tensors")]
|
241 |
+
# val_data_tensor = torch.cat(train_tensors)
|
242 |
+
# except Exception as e:
|
243 |
+
# print(f"Error during tensor conversion: {e}")
|
244 |
+
# print("Train tokens count: ", train_data_tensor)
|
245 |
+
# print("Val tokens count: ", val_data_tensor)
|
246 |
+
|
247 |
+
|
248 |
+
def prepare_dataset(split, batch_size):
|
249 |
+
|
250 |
+
# alpaca_prompt = '''
|
251 |
+
|
252 |
+
|
253 |
+
# ### Instruction:
|
254 |
+
# {}
|
255 |
+
|
256 |
+
# ### Response:
|
257 |
+
# {}
|
258 |
+
# '''
|
259 |
+
# Load a subset of the C4 dataset with a glob pattern for specific training files
|
260 |
+
# dataset = load_dataset("allenai/c4", data_files=["en/c4-train.00001-of-01024.json.gz"], trust_remote_code=True)
|
261 |
+
|
262 |
+
# Initialize tokenizer
|
263 |
+
# tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
264 |
+
|
265 |
+
def collate_fn(batch):
|
266 |
+
# Extract text data
|
267 |
+
# texts = [item ["text"] for item in batch]
|
268 |
+
|
269 |
+
# Set the pad token if it isn't set already
|
270 |
+
# if tokenizer.pad_token is None:
|
271 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
272 |
+
outputs = []
|
273 |
+
texts = []
|
274 |
+
for item in batch:
|
275 |
+
instruction = item['prompt']
|
276 |
+
# input = item['input']
|
277 |
+
output = item['completion']
|
278 |
+
# out = alpaca_prompt.format(instruction, output)
|
279 |
+
texts.append(instruction)
|
280 |
+
outputs.append(output)
|
281 |
+
# Tokenize text data
|
282 |
+
input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
|
283 |
+
# output_encodings = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
|
284 |
+
input_encodings["labels"] = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
|
285 |
+
# out = {"input": input_encodings}
|
286 |
+
# input_encodings["labels"] = input_encodings["input_ids"].clone() # Use `input_ids` as labels
|
287 |
+
# input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] # Shift right
|
288 |
+
# input_encodings["labels"][:, -1] = tokenizer.pad_token_id # Ignore the last token (no target for it)
|
289 |
+
# Return tokenized input tensors
|
290 |
+
# return out
|
291 |
+
return input_encodings
|
292 |
+
|
293 |
+
# Create DistributedSampler for proper shuffling and partitioning across processes
|
294 |
+
# dist_sampler = DistributedSampler(fw_train["text"], shuffle=True)
|
295 |
+
|
296 |
+
# Create DataLoader with custom collate_fn
|
297 |
+
# print(fw_dataset)
|
298 |
+
dataloader = None
|
299 |
+
if(split == 'train'):
|
300 |
+
data_loader = DataLoader(
|
301 |
+
merged_dataset['train'],
|
302 |
+
batch_size=batch_size,
|
303 |
+
sampler=DistributedSampler(merged_dataset['train'], shuffle=True),
|
304 |
+
collate_fn=collate_fn,
|
305 |
+
drop_last=True,
|
306 |
+
shuffle=False
|
307 |
+
)
|
308 |
+
elif(split == 'val'):
|
309 |
+
data_loader = DataLoader(
|
310 |
+
merged_dataset['test'],
|
311 |
+
batch_size=batch_size,
|
312 |
+
sampler=DistributedSampler(merged_dataset["test"], shuffle=True),
|
313 |
+
collate_fn=collate_fn,
|
314 |
+
drop_last=True,
|
315 |
+
shuffle=False
|
316 |
+
)
|
317 |
+
|
318 |
+
return data_loader
|
319 |
+
# Convert to tensors
|
320 |
+
# train_data_tensor = torch.tensor(train_tokens, dtype=torch.long)
|
321 |
+
# val_data_tensor = torch.tensor(val_tokens, dtype=torch.long)
|
322 |
+
|
323 |
+
# # Debug output
|
324 |
+
# print("Number of train tokens:", len(train_data_tensor))
|
325 |
+
# print("Number of validation tokens:", len(val_data_tensor))
|
326 |
+
|
327 |
+
|
328 |
+
# def create_sequences(data, block_size):
|
329 |
+
# sequences = []
|
330 |
+
|
331 |
+
# for seq in data:
|
332 |
+
# if len(seq) < block_size:
|
333 |
+
# # while(len(sequence) < block_size):
|
334 |
+
# # sequence = data[i:i + block_size + 1]
|
335 |
+
|
336 |
+
# # Pad the sequence if it's shorter than block_size
|
337 |
+
# padding_length = block_size - len(seq)
|
338 |
+
# seq = torch.cat([seq, torch.full((padding_length,), tokenizer.pad_token_id, dtype=torch.long)])
|
339 |
+
# sequences.append(seq)
|
340 |
+
# out = torch.tensor(sequences, dtype=torch.long)
|
341 |
+
# return out
|
342 |
+
|
343 |
+
# train_data = create_sequences(train_data['input_ids'], ModelArgs.block_size)
|
344 |
+
# val_data = create_sequences(val_data['input_ids'], ModelArgs.block_size)
|
345 |
+
|
346 |
+
|
347 |
+
def get_batch(split):
|
348 |
+
# generate a small batch of data of inputs x and targets y
|
349 |
+
data = train_data if split == 'train' else val_data
|
350 |
+
ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,))
|
351 |
+
x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix])
|
352 |
+
y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix])
|
353 |
+
x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
|
354 |
+
return x, y
|
355 |
+
|
356 |
+
from torch.utils.data import Dataset
|
357 |
+
|
358 |
+
class TokenDataset(Dataset):
|
359 |
+
def __init__(self, data, block_size):
|
360 |
+
self.data = data
|
361 |
+
self.block_size = block_size
|
362 |
+
|
363 |
+
def __len__(self):
|
364 |
+
return len(self.data) - self.block_size # Ensure valid indexing
|
365 |
+
|
366 |
+
def __getitem__(self, idx):
|
367 |
+
x = self.data[idx:idx + self.block_size]
|
368 |
+
y = self.data[idx + 1:idx + self.block_size + 1]
|
369 |
+
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
|
370 |
+
|
371 |
+
|
372 |
+
# train_rows = 11895089
|
373 |
+
# encoded_data = torch.tensor(encode(fw_train['text']), dtype=torch.long)
|
374 |
+
# train_data = train_data[:train_rows]
|
375 |
+
# val_data = val_data[train_rows:]
|
376 |
+
|
377 |
+
# train_dataset = TokenDataset(train_data_tensor, ModelArgs.block_size)
|
378 |
+
# val_dataset = TokenDataset(val_data_tensor, ModelArgs.block_size)
|
379 |
+
# encoded_data = torch.tensor(encode(text), dtype=torch.long)
|
380 |
+
|
381 |
+
# print(train_data)
|
382 |
+
# print(val_data)
|
383 |
+
# train_dataset = TextDataset(train_data, ModelArgs.block_size)
|
384 |
+
# val_dataset = TextDataset(val_data, ModelArgs.block_size)
|
385 |
+
|
386 |
+
# print(train_dataset)
|
387 |
+
# print(val_dataset)
|
388 |
+
|
389 |
+
|
390 |
+
# # Convert the tokenized data into a list of sequences
|
391 |
+
# train_sequences = [train_data[i:i + ModelArgs.block_size] for i in range(0, len(train_data) - ModelArgs.block_size)]
|
392 |
+
# val_sequences = [val_data[i:i + ModelArgs.block_size] for i in range(0, len(val_data) - ModelArgs.block_size)]
|
393 |
+
|
394 |
+
# Define collate_fn
|
395 |
+
# def collate_fn(batch):
|
396 |
+
# block_size = ModelArgs.block_size
|
397 |
+
# batch_size = len(batch)
|
398 |
+
# x = torch.zeros((batch_size, block_size), dtype=torch.long)
|
399 |
+
# y = torch.zeros((batch_size, block_size), dtype=torch.long)
|
400 |
+
# for i, sequence in enumerate(batch):
|
401 |
+
# print("Shape x: ", sequence[:-1].shape)
|
402 |
+
# print("Shape of y: ", len(sequence[1:]))
|
403 |
+
# x[i] = sequence[:-1] # Input is all tokens except the last one
|
404 |
+
# y[i] = sequence[1:] # Target is all tokens except the first one
|
405 |
+
# return x, y
|
406 |
+
|
407 |
+
|
408 |
+
|
409 |
+
def create_sequences(data, block_size):
|
410 |
+
sequences = []
|
411 |
+
|
412 |
+
for seq in data:
|
413 |
+
len(seq)
|
414 |
+
if len(seq) < block_size:
|
415 |
+
# while(len(sequence) < block_size):
|
416 |
+
# sequence = data[i:i + block_size + 1]
|
417 |
+
|
418 |
+
# Pad the sequence if it's shorter than block_size
|
419 |
+
padding_length = block_size - len(seq)
|
420 |
+
seq = torch.cat([seq, torch.full((padding_length,), tokenizer.encode('[PAD]').ids[0], dtype=torch.long)])
|
421 |
+
|
422 |
+
else:
|
423 |
+
if len(seq) > block_size:
|
424 |
+
seq = seq[:block_size]
|
425 |
+
# while(len(sequence) < block_size):
|
426 |
+
# sequence = data[i:i + block_size + 1]
|
427 |
+
|
428 |
+
# Pad the sequence if it's shorter than block_size
|
429 |
+
# padding_length = block_size - len(seq)
|
430 |
+
# seq = torch.cat([seq, torch.full((padding_length,), tokenizer.encode('[PAD]').ids[0], dtype=torch.long)])
|
431 |
+
sequences.append(seq)
|
432 |
+
out = torch.tensor(sequences, dtype=torch.long)
|
433 |
+
return out
|
434 |
+
|
435 |
+
# train_data = create_sequences(train_data_flat['input_ids'], ModelArgs.block_size)
|
436 |
+
# val_data = create_sequences(val_data['input_ids'], ModelArgs.block_size)
|
437 |
+
|
438 |
+
|
439 |
+
# Define collate_fn
|
440 |
+
def collate_fn(split , batch):
|
441 |
+
block_size = ModelArgs.block_size
|
442 |
+
batch_size = len(batch)
|
443 |
+
if(split == 'train'):
|
444 |
+
data = train_data_tensor
|
445 |
+
elif(split == 'test'):
|
446 |
+
data = val_data_tensor
|
447 |
+
ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,))
|
448 |
+
x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix])
|
449 |
+
y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix])
|
450 |
+
|
451 |
+
# print("Shape of x: ", len(x))
|
452 |
+
# print("Length of y: ", len(y))
|
453 |
+
# x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
|
454 |
+
# x = torch.zeros((batch_size, block_size), dtype=torch.long)
|
455 |
+
# y = torch.zeros((batch_size, block_size), dtype=torch.long)
|
456 |
+
# for i, sequence in enumerate(batch):
|
457 |
+
# print("Seq: ", sequence)
|
458 |
+
# print("Shape x: ", sequence[:-1].shape)
|
459 |
+
# print("Shape of y: ", len(sequence[1:]))
|
460 |
+
# x[i] = sequence[:-1] # Input is all tokens except the last one
|
461 |
+
# y[i] = sequence[1:] # Target is all tokens except the first one
|
462 |
+
return x, y
|
463 |
+
|
464 |
+
|
465 |
+
class Normalization(nn.Module):
|
466 |
+
def __init__(
|
467 |
+
self,
|
468 |
+
|
469 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
470 |
+
):
|
471 |
+
super().__init__()
|
472 |
+
self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims)
|
473 |
+
|
474 |
+
|
475 |
+
def forward(self, x):
|
476 |
+
|
477 |
+
x = self.rmsnorm_layer(x)
|
478 |
+
return x
|
479 |
+
|
480 |
+
|
481 |
+
|
482 |
+
# import numpy as np
|
483 |
+
class RotaryEmbeddings(nn.Module):
|
484 |
+
def __init__(
|
485 |
+
self,
|
486 |
+
device,
|
487 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
488 |
+
block_size: int = ModelArgs.block_size,
|
489 |
+
batch_size: int = ModelArgs.batch_size
|
490 |
+
):
|
491 |
+
super().__init__()
|
492 |
+
|
493 |
+
self.embeddings_dims = embeddings_dims
|
494 |
+
self.block_size = block_size
|
495 |
+
self.batch_size = batch_size
|
496 |
+
self.theta = 0
|
497 |
+
|
498 |
+
|
499 |
+
# def init_matrix(self, seq_len):
|
500 |
+
# self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False)
|
501 |
+
# for pos in range(seq_len):
|
502 |
+
# for j in range(1, self.embeddings_dims // 2):
|
503 |
+
# self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
|
504 |
+
# self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
|
505 |
+
# self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
|
506 |
+
# self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
|
507 |
+
# self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
|
508 |
+
# return self.matrix
|
509 |
+
self.device=device
|
510 |
+
|
511 |
+
def init_matrix(self, seq_len):
|
512 |
+
self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
|
513 |
+
|
514 |
+
positions = torch.arange(seq_len, dtype=torch.float32, device = self.device).unsqueeze(1)
|
515 |
+
# dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
|
516 |
+
theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
|
517 |
+
angles = positions * theta
|
518 |
+
|
519 |
+
cos_angles = torch.cos(angles)
|
520 |
+
sin_angles = torch.sin(angles)
|
521 |
+
|
522 |
+
indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device)
|
523 |
+
# print(indices)
|
524 |
+
# print(indices.shape)
|
525 |
+
# print(indices[::2])
|
526 |
+
even_indices = indices[::2]
|
527 |
+
odd_indices = indices[1::2]
|
528 |
+
|
529 |
+
self.matrix[:, even_indices, even_indices] = cos_angles
|
530 |
+
self.matrix[:, odd_indices, odd_indices] = sin_angles
|
531 |
+
self.matrix[:, odd_indices, even_indices] = -sin_angles
|
532 |
+
self.matrix[:, even_indices, odd_indices] = cos_angles
|
533 |
+
|
534 |
+
return self.matrix
|
535 |
+
|
536 |
+
def forward(self, x):
|
537 |
+
# B,T,C = x.shape
|
538 |
+
# print("MATRIX:",x)
|
539 |
+
if(x > self.block_size or x < self.block_size):
|
540 |
+
matrix = self.init_matrix(x)
|
541 |
+
return matrix
|
542 |
+
else:
|
543 |
+
matrix = self.init_matrix(self.block_size)
|
544 |
+
|
545 |
+
return matrix
|
546 |
+
|
547 |
+
|
548 |
+
class RotaryAttentionHead(nn.Module):
|
549 |
+
def __init__(
|
550 |
+
self,
|
551 |
+
device,
|
552 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
553 |
+
no_of_heads: int = ModelArgs.no_of_heads,
|
554 |
+
attn_dropout: int = ModelArgs.attn_dropout
|
555 |
+
):
|
556 |
+
super().__init__()
|
557 |
+
self.head_size = embeddings_dims // no_of_heads
|
558 |
+
self.query = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
559 |
+
self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
560 |
+
self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
561 |
+
self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims, device = device)
|
562 |
+
self.dropout = nn.Dropout(p = attn_dropout)
|
563 |
+
self.device = device
|
564 |
+
def forward(self,x):
|
565 |
+
# print(x.shape)
|
566 |
+
batch, block_size, embeddings_dims = x.shape
|
567 |
+
query = self.query(x)
|
568 |
+
# print(query)
|
569 |
+
key = self.key(x)
|
570 |
+
values = self.value(x)
|
571 |
+
matrix = self.rotary_matrix(block_size)
|
572 |
+
|
573 |
+
# print(matrix.shape)
|
574 |
+
# print(query.shape)
|
575 |
+
masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
576 |
+
rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
577 |
+
rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
578 |
+
weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
579 |
+
weights_masked = weights.masked_fill(masked == 0, float('-inf'))
|
580 |
+
scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
|
581 |
+
scaled_weights = F.softmax(scaled_weights, dim=-1)
|
582 |
+
value = scaled_weights @ values
|
583 |
+
out = self.dropout(value)
|
584 |
+
return out
|
585 |
+
|
586 |
+
|
587 |
+
class MQA(nn.Module):
|
588 |
+
def __init__(
|
589 |
+
self,
|
590 |
+
device,
|
591 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
592 |
+
block_size: int = ModelArgs.block_size,
|
593 |
+
no_of_kv_heads: int = ModelArgs.no_of_heads,
|
594 |
+
no_of_heads: int = ModelArgs.no_of_heads,
|
595 |
+
|
596 |
+
):
|
597 |
+
super().__init__()
|
598 |
+
|
599 |
+
self.no_of_kv_heads = no_of_kv_heads
|
600 |
+
self.no_of_q_heads = no_of_heads // no_of_kv_heads
|
601 |
+
self.head_size = embeddings_dims // self.no_of_q_heads
|
602 |
+
self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims, device = device)
|
603 |
+
# self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False)
|
604 |
+
self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
|
605 |
+
self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
|
606 |
+
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
|
607 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
|
608 |
+
self.device = device
|
609 |
+
self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, device = self.device) for _ in range(self.no_of_q_heads)])
|
610 |
+
|
611 |
+
def scaled_dot_product(self, q, k, v, block_size, matrix):
|
612 |
+
|
613 |
+
# masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
614 |
+
|
615 |
+
masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
616 |
+
rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
617 |
+
rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
618 |
+
weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
619 |
+
weights_masked = weights.masked_fill(masked == 0, float('-inf'))
|
620 |
+
scaled_weights = weights_masked / (torch.sqrt(torch.tensor(k.shape[-1])))
|
621 |
+
scaled_weights = F.softmax(scaled_weights, dim=-1)
|
622 |
+
value = scaled_weights @ v
|
623 |
+
out = self.dropout(value)
|
624 |
+
return value
|
625 |
+
|
626 |
+
def forward(self,x):
|
627 |
+
# print("MQA: ", x.shape)
|
628 |
+
batch, block_size, embeddings_dims = x.shape
|
629 |
+
|
630 |
+
# query = self.query(x)
|
631 |
+
matrix = self.rotary_matrix(block_size)
|
632 |
+
|
633 |
+
|
634 |
+
key = self.key(x)
|
635 |
+
values = self.value(x)
|
636 |
+
|
637 |
+
multi_query_concat = torch.cat([self.scaled_dot_product(query(x), key, values, block_size, matrix) for query in self.multi_query], dim=-1)
|
638 |
+
|
639 |
+
|
640 |
+
linear_layer= self.linear_layer(multi_query_concat)
|
641 |
+
out = self.dropout(linear_layer)
|
642 |
+
return out
|
643 |
+
|
644 |
+
|
645 |
+
class GQA(nn.Module):
|
646 |
+
def __init__(
|
647 |
+
self,
|
648 |
+
device,
|
649 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
650 |
+
block_size: int = ModelArgs.block_size,
|
651 |
+
no_of_q_heads: int = ModelArgs.no_of_heads,
|
652 |
+
no_of_kv_heads: int = ModelArgs.no_kv_heads
|
653 |
+
):
|
654 |
+
super().__init__()
|
655 |
+
|
656 |
+
self.no_of_kv_heads = no_of_kv_heads
|
657 |
+
self.no_of_q_heads = no_of_q_heads
|
658 |
+
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
|
659 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_kv_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device)
|
660 |
+
self.device = device
|
661 |
+
self.mqa = nn.ModuleList([MQA(embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_kv_heads)])
|
662 |
+
|
663 |
+
def forward(self,x):
|
664 |
+
|
665 |
+
batch, block_size, embeddings_dims = x.shape
|
666 |
+
|
667 |
+
|
668 |
+
grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1)
|
669 |
+
|
670 |
+
linear_layer= self.linear_layer(grouped_query_concat)
|
671 |
+
out = self.dropout(linear_layer)
|
672 |
+
return out
|
673 |
+
|
674 |
+
|
675 |
+
class Swish(nn.Module):
|
676 |
+
def __init__(
|
677 |
+
self,
|
678 |
+
device,
|
679 |
+
block_size: int = ModelArgs.block_size,
|
680 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
681 |
+
):
|
682 |
+
super().__init__()
|
683 |
+
|
684 |
+
self.sig = torch.nn.Sigmoid()
|
685 |
+
|
686 |
+
|
687 |
+
def forward(self, x):
|
688 |
+
swish = x * self.sig(x)
|
689 |
+
|
690 |
+
return swish
|
691 |
+
|
692 |
+
|
693 |
+
|
694 |
+
class SWiGLU(nn.Module):
|
695 |
+
def __init__(
|
696 |
+
self,
|
697 |
+
device,
|
698 |
+
block_size: int = ModelArgs.block_size,
|
699 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
700 |
+
):
|
701 |
+
super().__init__()
|
702 |
+
|
703 |
+
self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
|
704 |
+
self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
705 |
+
self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
706 |
+
self.linear_layer3 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
707 |
+
|
708 |
+
|
709 |
+
|
710 |
+
|
711 |
+
def forward(self, x):
|
712 |
+
swish_res = self.swish(self.linear_layer1(x))
|
713 |
+
x_V = self.linear_layer2(x)
|
714 |
+
res = torch.mul(swish_res, x_V)
|
715 |
+
out = self.linear_layer3(res)
|
716 |
+
return out
|
717 |
+
|
718 |
+
|
719 |
+
|
720 |
+
class FFN(nn.Module):
|
721 |
+
def __init__(self,
|
722 |
+
device,
|
723 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
724 |
+
block_size: int = ModelArgs.block_size,
|
725 |
+
vocab_size: int = ModelArgs.vocab_size,
|
726 |
+
dropout = ModelArgs.dropout
|
727 |
+
|
728 |
+
):
|
729 |
+
super().__init__()
|
730 |
+
|
731 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
|
732 |
+
self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device)
|
733 |
+
self.dropout = nn.Dropout(p = dropout)
|
734 |
+
def forward(self, x):
|
735 |
+
|
736 |
+
x = self.swiglue(x)
|
737 |
+
x = self.linear_layer(x)
|
738 |
+
x = self.dropout(x)
|
739 |
+
return x
|
740 |
+
|
741 |
+
|
742 |
+
class DecoderLayer(nn.Module):
|
743 |
+
def __init__(self,
|
744 |
+
device,
|
745 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
746 |
+
dropout = ModelArgs.dropout,
|
747 |
+
block_size: int = ModelArgs.block_size,
|
748 |
+
vocab_size: int = ModelArgs.vocab_size,
|
749 |
+
|
750 |
+
) :
|
751 |
+
super().__init__()
|
752 |
+
|
753 |
+
|
754 |
+
self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
|
755 |
+
self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, no_of_kv_heads=ModelArgs.no_kv_heads, no_of_q_heads=ModelArgs.no_of_heads, device = device)
|
756 |
+
# self.norm = Normalization(embeddings_dims=embeddings_dims)
|
757 |
+
self.norm1 = Normalization(embeddings_dims=embeddings_dims)
|
758 |
+
self.norm2 = Normalization(embeddings_dims=embeddings_dims)
|
759 |
+
self.dropout = nn.Dropout(p = dropout)
|
760 |
+
def forward(self, x):
|
761 |
+
|
762 |
+
x = self.norm1(x + self.gqa(x))
|
763 |
+
x = self.norm2(x + self.feedforward_network(x))
|
764 |
+
return x
|
765 |
+
|
766 |
+
|
767 |
+
class Llama(nn.Module):
|
768 |
+
def __init__(self,
|
769 |
+
device,
|
770 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
771 |
+
no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
|
772 |
+
block_size: int = ModelArgs.block_size,
|
773 |
+
vocab_size: int = ModelArgs.vocab_size,
|
774 |
+
dropout = ModelArgs.dropout
|
775 |
+
|
776 |
+
) :
|
777 |
+
super().__init__()
|
778 |
+
|
779 |
+
self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
|
780 |
+
self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)])
|
781 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
|
782 |
+
self.dropout = nn.Dropout(p = dropout)
|
783 |
+
# self.norm = Normalization(embeddings_dims)
|
784 |
+
def forward(self, x):
|
785 |
+
x = self.embeddings(x)
|
786 |
+
x = self.dropout(x)
|
787 |
+
x = self.decoder(x)
|
788 |
+
# x = self.norm(x)
|
789 |
+
x = self.linear_layer(x)
|
790 |
+
# out = self.norm(x)
|
791 |
+
return x
|
792 |
+
|
793 |
+
|
794 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
795 |
+
# # device = "cpu"
|
796 |
+
# ModelArgs.device = device
|
797 |
+
# model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
|
798 |
+
# model = model.to(ModelArgs.device)
|
799 |
+
|
800 |
+
#Printing a summary of the architecture
|
801 |
+
# !pip install torchinfo
|
802 |
+
# from torchinfo import summary
|
803 |
+
# # idx, targets = get_batch('test')
|
804 |
+
# idx = torch.randint(
|
805 |
+
# low=0,
|
806 |
+
# high=ModelArgs.vocab_size,
|
807 |
+
# size=(ModelArgs.batch_size, ModelArgs.block_size),
|
808 |
+
# dtype=torch.long
|
809 |
+
# )
|
810 |
+
# # sample_idx = random.randint(range(len(train_dataset)))
|
811 |
+
# # idx, targets = train_dataset[0]
|
812 |
+
# idx = idx.to(ModelArgs.device)
|
813 |
+
# # targets = targets.to(ModelArgs.device)
|
814 |
+
# summary(model=model,
|
815 |
+
# input_data=idx,
|
816 |
+
# # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
|
817 |
+
# col_names=["input_size", "output_size", "num_params", "trainable"],
|
818 |
+
# col_width=20,
|
819 |
+
# row_settings=["var_names"])
|
820 |
+
|
821 |
+
|
822 |
+
def find_unused_parameters(model):
|
823 |
+
unused = []
|
824 |
+
for name, param in model.named_parameters():
|
825 |
+
if param.grad is None:
|
826 |
+
unused.append(name)
|
827 |
+
return unused
|
828 |
+
|
829 |
+
def greedy_decode(
|
830 |
+
model,
|
831 |
+
tokenizer,
|
832 |
+
prompt,
|
833 |
+
max_length=50,
|
834 |
+
repetition_penalty=1.2,
|
835 |
+
context_window=10,
|
836 |
+
temperature=1.0,
|
837 |
+
eos_token_id=None
|
838 |
+
):
|
839 |
+
|
840 |
+
device = next(model.parameters()).device
|
841 |
+
input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids']
|
842 |
+
generated_tokens = []
|
843 |
+
eos_token_id = eos_token_id or tokenizer.eos_token_id # Use EOS token if provided
|
844 |
+
|
845 |
+
for _ in range(max_length):
|
846 |
+
outputs = model(input_ids)
|
847 |
+
logits = outputs[:, -1, :] # Get logits for the last token
|
848 |
+
|
849 |
+
# Apply temperature scaling
|
850 |
+
if temperature != 1.0:
|
851 |
+
logits = logits / temperature
|
852 |
+
|
853 |
+
# Apply repetition penalty
|
854 |
+
if repetition_penalty != 1.0 and len(generated_tokens) > 0:
|
855 |
+
for token in set(generated_tokens[-context_window:]): # Penalize recent tokens
|
856 |
+
logits[0, token] /= repetition_penalty
|
857 |
+
|
858 |
+
# Greedy selection
|
859 |
+
next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
|
860 |
+
generated_tokens.append(next_token.item())
|
861 |
+
|
862 |
+
# Stop if EOS token is generated
|
863 |
+
if next_token.item() == eos_token_id:
|
864 |
+
break
|
865 |
+
|
866 |
+
# Append the new token to the input
|
867 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
868 |
+
|
869 |
+
# Decode the generated tokens
|
870 |
+
return tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
871 |
+
|
872 |
+
|
873 |
+
|
874 |
+
def save_to_file(text):
|
875 |
+
|
876 |
+
with open('generations.txt', 'a') as f:
|
877 |
+
f.writelines(text + "\n\n")
|
878 |
+
|
879 |
+
|
880 |
+
#Train the model
|
881 |
+
|
882 |
+
|
883 |
+
# writer = SummaryWriter(log_dir="runs/experiment")
|
884 |
+
|
885 |
+
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
|
886 |
+
|
887 |
+
# Warmup phase for 2000 steps
|
888 |
+
def warmup_fn(step):
|
889 |
+
if step < 2000:
|
890 |
+
return step / 2000 # LR gradually increases
|
891 |
+
return 1.0
|
892 |
+
|
893 |
+
|
894 |
+
from torch.optim.lr_scheduler import LambdaLR
|
895 |
+
|
896 |
+
def trapezoidal_lr_scheduler(optimizer, max_lr, total_steps, warmup_steps, plateau_steps, decay_steps):
|
897 |
+
"""
|
898 |
+
Trapezoidal learning rate scheduler:
|
899 |
+
- Increases linearly for `warmup_steps` steps.
|
900 |
+
- Remains constant for `plateau_steps` steps.
|
901 |
+
- Decreases linearly for `decay_steps` steps.
|
902 |
+
"""
|
903 |
+
def lr_lambda(step):
|
904 |
+
if step < warmup_steps:
|
905 |
+
# Linear warmup
|
906 |
+
return float(step) / float(max(1, warmup_steps))
|
907 |
+
elif step < warmup_steps + plateau_steps:
|
908 |
+
# Constant plateau
|
909 |
+
return 1.0
|
910 |
+
else:
|
911 |
+
# Linear decay
|
912 |
+
decay_step = step - (warmup_steps + plateau_steps)
|
913 |
+
return max(0.0, float(decay_steps - decay_step) / float(max(1, decay_steps)))
|
914 |
+
|
915 |
+
return LambdaLR(optimizer, lr_lambda)
|
916 |
+
|
917 |
+
|
918 |
+
torch.set_float32_matmul_precision('high')
|
919 |
+
|
920 |
+
|
921 |
+
def train():
|
922 |
+
setup()
|
923 |
+
device = int(os.environ["LOCAL_RANK"])
|
924 |
+
|
925 |
+
torch.cuda.set_device(int(device))
|
926 |
+
|
927 |
+
# train_dataloader = prepare_dataset(ModelArgs.batch_size)
|
928 |
+
# rank = torch.distributed.get_rank()
|
929 |
+
print(f"Start running DDP on rank {device}.")
|
930 |
+
# # create model and move it to GPU with id rank
|
931 |
+
# device_id = rank % torch.cuda.device_count()
|
932 |
+
# CFG = ModelArgs()
|
933 |
+
|
934 |
+
if(device == 0):
|
935 |
+
|
936 |
+
|
937 |
+
|
938 |
+
# # Initialise run
|
939 |
+
wandb.init(
|
940 |
+
# entity = 'rajceo2031',
|
941 |
+
project = 'Llama-DDP-Pretrain-10-billion-tokens',
|
942 |
+
# config = CFG,
|
943 |
+
# save_code = True,
|
944 |
+
#group = 'ANN',
|
945 |
+
#job_type = 'train'
|
946 |
+
)
|
947 |
+
|
948 |
+
model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
|
949 |
+
# Optimizer setup and scheduler steup
|
950 |
+
|
951 |
+
model = model.to(device)
|
952 |
+
|
953 |
+
print(f"Model on device {device} is ready")
|
954 |
+
# Wrap model with DDP after moving to GPU
|
955 |
+
# model = DDP(model, device_ids=[device])
|
956 |
+
optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim)
|
957 |
+
# scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=4000, T_mult=1, eta_min=1e-5)
|
958 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30000, eta_min=1e-6)
|
959 |
+
_load_snapshot('/kaggle/input/models/snapshot2.pt', model, optimizer, scheduler)
|
960 |
+
optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim)
|
961 |
+
|
962 |
+
# model = torch.compile(model)
|
963 |
+
|
964 |
+
# Define the trapezoidal learning rate scheduler
|
965 |
+
total_steps = 100000 # Total steps (40k + 20k + 40k)
|
966 |
+
warmup_steps = 40000 # Steps for warmup (increase)
|
967 |
+
plateau_steps = 20000 # Steps for plateau (constant)
|
968 |
+
decay_steps = 40000 # Steps for decay (decrease)
|
969 |
+
|
970 |
+
|
971 |
+
# new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25000, eta_min=1e-6) #with the prev optim snapshot
|
972 |
+
new_scheduler = trapezoidal_lr_scheduler(optimizer, ModelArgs.max_lr, total_steps, warmup_steps, plateau_steps, decay_steps)
|
973 |
+
|
974 |
+
# warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_fn)
|
975 |
+
# new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)
|
976 |
+
# Cosine decay after warmup
|
977 |
+
# new_scheduler = CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)
|
978 |
+
|
979 |
+
# Combine both schedulers
|
980 |
+
# scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, new_scheduler], milestones=[2000])
|
981 |
+
|
982 |
+
# Reset learning rate to 1e-4
|
983 |
+
# for param_group in optimizer.param_groups:
|
984 |
+
# param_group['lr'] = ModelArgs.max_lr
|
985 |
+
# scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1, eta_min=1e-6)
|
986 |
+
# print("Old optimizer with new lr ready")
|
987 |
+
model = DDP(model, device_ids=[device])
|
988 |
+
print(f"Model on device {device} is ready")
|
989 |
+
|
990 |
+
|
991 |
+
# optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr)
|
992 |
+
# Create DataLoader with collate_fn
|
993 |
+
# train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device))
|
994 |
+
# val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device))
|
995 |
+
# print("Loader is ready")
|
996 |
+
# print(train_loader)
|
997 |
+
# print(next(iter(train_loader)))
|
998 |
+
|
999 |
+
save_chechpoint_iter = 1000
|
1000 |
+
total_iters = 20000
|
1001 |
+
eval_iters = 200
|
1002 |
+
eval_check = 100
|
1003 |
+
# for X,y in train_loader:
|
1004 |
+
# print(X.shape)
|
1005 |
+
# print(y.shape)
|
1006 |
+
|
1007 |
+
# alpaca_prompt = '''
|
1008 |
+
|
1009 |
+
# ### Instruction:
|
1010 |
+
# {instruction}
|
1011 |
+
|
1012 |
+
# ### Input:
|
1013 |
+
# {input}
|
1014 |
+
|
1015 |
+
# ### Response:
|
1016 |
+
|
1017 |
+
# '''
|
1018 |
+
# Only create progress bar for rank 0
|
1019 |
+
# eval_epoch_iterator = range(eval_iters)
|
1020 |
+
# train_epoch_iterator = range(total_iters)
|
1021 |
+
# if device == 0:
|
1022 |
+
# train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training")
|
1023 |
+
|
1024 |
+
# train_epoch_iterator = range(ModelArgs.epochs)
|
1025 |
+
# if device == 0: # Ensure tqdm only runs on rank 0
|
1026 |
+
# train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training Progress", position=0, leave=True)
|
1027 |
+
|
1028 |
+
# lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters)
|
1029 |
+
world_size = torch.cuda.device_count()
|
1030 |
+
@torch.inference_mode()
|
1031 |
+
def estimate_loss(val_loader, train_loader=None):
|
1032 |
+
out = {}
|
1033 |
+
# train_loader = prepare_dataset('train', ModelArgs.batch_size)
|
1034 |
+
model.eval()
|
1035 |
+
loader = None
|
1036 |
+
epoch_loss = None
|
1037 |
+
epoch_losses = []
|
1038 |
+
# print("Starting the eval...")
|
1039 |
+
for split in ['train', 'val']:
|
1040 |
+
print(f"Starting with {split} evaluation...")
|
1041 |
+
# losses = torch.zeros(ModelArgs.val_epochs)
|
1042 |
+
if(split == 'train'):
|
1043 |
+
loader = train_loader
|
1044 |
+
if(split == 'val'):
|
1045 |
+
loader = val_loader
|
1046 |
+
for step in range(eval_check):
|
1047 |
+
total_loss = 0
|
1048 |
+
# loader.sampler.set_epoch(step)
|
1049 |
+
total_batches = 0
|
1050 |
+
batch = next(iter(loader))
|
1051 |
+
# for batch in loader: # Loop through DataLoader batches
|
1052 |
+
idx = batch['input_ids']
|
1053 |
+
targets = batch['labels']['input_ids']
|
1054 |
+
idx = idx.to(device)
|
1055 |
+
targets = targets.to(device)
|
1056 |
+
|
1057 |
+
logits = model(idx)
|
1058 |
+
batch_size, block_size, embeddings_dims = logits.shape
|
1059 |
+
logits = logits.view(batch_size * block_size, embeddings_dims) # Flatten tokens
|
1060 |
+
targets = targets.view(batch_size * block_size)
|
1061 |
+
|
1062 |
+
loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
|
1063 |
+
|
1064 |
+
total_loss += loss.item()
|
1065 |
+
total_batches += 1
|
1066 |
+
|
1067 |
+
# Compute mean loss for this epoch
|
1068 |
+
epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0
|
1069 |
+
epoch_losses.append(epoch_loss)
|
1070 |
+
|
1071 |
+
# print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}")
|
1072 |
+
|
1073 |
+
# Compute mean loss across all evaluation epochs
|
1074 |
+
out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0
|
1075 |
+
epoch_loss = None
|
1076 |
+
epoch_losses = []
|
1077 |
+
|
1078 |
+
model.train()
|
1079 |
+
return out
|
1080 |
+
|
1081 |
+
# model = model.to(rank)
|
1082 |
+
model.train()
|
1083 |
+
train_dataloader = prepare_dataset('train', ModelArgs.batch_size)
|
1084 |
+
val_loader= prepare_dataset('val', ModelArgs.batch_size)
|
1085 |
+
# for step in tqdm(range(total_iters)):
|
1086 |
+
for epoch in range(ModelArgs.epochs):
|
1087 |
+
# torch.cuda.synchronize()
|
1088 |
+
|
1089 |
+
train_dataloader.sampler.set_epoch(epoch)
|
1090 |
+
|
1091 |
+
val_loader.sampler.set_epoch(epoch)
|
1092 |
+
print("Loaders ready both")
|
1093 |
+
epochs = ModelArgs.epochs
|
1094 |
+
|
1095 |
+
# train_step_iterator = range(len(train_dataloader))
|
1096 |
+
# if device == 0: # Only create progress bar on rank 0
|
1097 |
+
# train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True)
|
1098 |
+
|
1099 |
+
# Print progress on rank 0
|
1100 |
+
train_loader_length = 0
|
1101 |
+
if(device == 0):
|
1102 |
+
train_loader_length = len(train_dataloader)
|
1103 |
+
print("Total batches: ", train_loader_length)
|
1104 |
+
# print("Length of : ", len(train_dataloader))
|
1105 |
+
# print("Length of val: ", len(val_loader))
|
1106 |
+
for step, batch in enumerate(train_dataloader):
|
1107 |
+
# print("Dataloader things: ", batch)
|
1108 |
+
# print("Total batches: ", len(train_dataloader))
|
1109 |
+
if(device == 0):
|
1110 |
+
if(step % 100 == 0):
|
1111 |
+
# if(step == train_loader_length):
|
1112 |
+
# break
|
1113 |
+
print("Batch : ", step, "/", len(train_dataloader))
|
1114 |
+
# all_gpus_avg_train_loss = None
|
1115 |
+
# all_gpus_avg_val_loss = None
|
1116 |
+
# every once in a while evaluate the loss on train and val sets
|
1117 |
+
if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
|
1118 |
+
losses = estimate_loss( val_loader, train_dataloader)
|
1119 |
+
avg_train_loss = losses['train']
|
1120 |
+
avg_val_loss = losses['val']
|
1121 |
+
# print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
1122 |
+
# if device == 0: # Only print on main process
|
1123 |
+
print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f} | Val Loss: {losses['val']:.4f}")
|
1124 |
+
# print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}")
|
1125 |
+
# print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
1126 |
+
# Log training loss more frequently
|
1127 |
+
# Aggregate average loss across all GPUs
|
1128 |
+
avg_train_loss = torch.Tensor([losses['train']]).to(device)
|
1129 |
+
avg_val_loss = torch.Tensor([losses['val']]).to(device)
|
1130 |
+
torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
1131 |
+
torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
1132 |
+
|
1133 |
+
if device == 0:
|
1134 |
+
all_gpus_avg_train_loss = avg_train_loss / world_size
|
1135 |
+
print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}")
|
1136 |
+
all_gpus_avg_val_loss = avg_val_loss / world_size
|
1137 |
+
print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}")
|
1138 |
+
|
1139 |
+
# if device == 0:
|
1140 |
+
|
1141 |
+
# writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step)
|
1142 |
+
# writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step)
|
1143 |
+
# writer.add_scalar("training_step_loss", losses['train'], global_step=step)
|
1144 |
+
# writer.add_scalar("val_step_loss", losses['val'], global_step=step)
|
1145 |
+
# writer.add_scalar("GPU", device, global_step=step)
|
1146 |
+
# writer.add_scalar("Epoch", epoch, global_step=step)
|
1147 |
+
|
1148 |
+
wandb.log({
|
1149 |
+
"Learning Rate": new_scheduler.get_last_lr()[0] ,
|
1150 |
+
"All_GPUs_Train_losses": all_gpus_avg_train_loss,
|
1151 |
+
"All_GPUs_Val_losses": all_gpus_avg_val_loss,
|
1152 |
+
"training_step_loss": losses['train'],
|
1153 |
+
"val_step_loss": losses['val'],
|
1154 |
+
"Step": step,
|
1155 |
+
"Epoch": epoch
|
1156 |
+
})
|
1157 |
+
|
1158 |
+
|
1159 |
+
|
1160 |
+
#Loading a checkpoint
|
1161 |
+
# if(os.path.exists('snapshot.pt')):
|
1162 |
+
# model, optimizer = _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt')
|
1163 |
+
|
1164 |
+
# if(step % save_chechpoint_iter == 0 and device == 0 and step != 0):
|
1165 |
+
|
1166 |
+
# _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step)
|
1167 |
+
|
1168 |
+
if step % save_chechpoint_iter == 0 and device == 0 and step != 0:
|
1169 |
+
print(f"Saving the model checkpoint for step: {step}")
|
1170 |
+
_save_snapshot(model, optimizer, scheduler, epoch, step)
|
1171 |
+
|
1172 |
+
# batch = {k: v.to(self.local_rank) for k, v in batch.items()}
|
1173 |
+
idx = batch['input_ids'].to(device)
|
1174 |
+
# idx, targets = get_batch(split='train')
|
1175 |
+
# print(f"Starting the train step: {step}...")
|
1176 |
+
# for idx, targets in train_loader:
|
1177 |
+
# idx, targets = next(iter(train_loader))
|
1178 |
+
|
1179 |
+
# print("Idx: ", idx)
|
1180 |
+
# print("Targets: ", targets)
|
1181 |
+
|
1182 |
+
# idx = idx.to(device)
|
1183 |
+
# print("Idx: ", idx)
|
1184 |
+
# print("Targets: ", targets)
|
1185 |
+
targets = batch['labels']['input_ids'].to(device)
|
1186 |
+
# with torch.autocast(device_type=device, dtype=torch.bfloat16()):
|
1187 |
+
logits = model(idx)
|
1188 |
+
batch_size, block_size, embeddings_dims = logits.shape
|
1189 |
+
# print(logits.shape)
|
1190 |
+
# print(targets)
|
1191 |
+
logits = logits.view(batch_size*block_size, embeddings_dims)
|
1192 |
+
# print("OK")
|
1193 |
+
targets = targets.view(batch_size * block_size)
|
1194 |
+
# print("OK2")
|
1195 |
+
loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
|
1196 |
+
|
1197 |
+
optimizer.zero_grad(set_to_none=True)
|
1198 |
+
loss.backward()
|
1199 |
+
# Compute gradient norms before clipping
|
1200 |
+
total_norm_before = torch.norm(
|
1201 |
+
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
|
1202 |
+
)
|
1203 |
+
|
1204 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip)
|
1205 |
+
|
1206 |
+
# Compute gradient norms after clipping
|
1207 |
+
total_norm_after = torch.norm(
|
1208 |
+
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
|
1209 |
+
)
|
1210 |
+
|
1211 |
+
if(device == 0 and step !=0 and step % 100 == 0):
|
1212 |
+
print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
|
1213 |
+
print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")
|
1214 |
+
|
1215 |
+
optimizer.step()
|
1216 |
+
new_scheduler.step()
|
1217 |
+
# torch.cuda.synchronize()
|
1218 |
+
# print(loss.item())
|
1219 |
+
# if(step % 100 == 0):
|
1220 |
+
# print(f'Step : {step} | GPU: {device} Loss: {loss.item()}')
|
1221 |
+
# if device == 0:
|
1222 |
+
# print("loss: ", loss.item())
|
1223 |
+
# train_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
|
1224 |
+
# print(loss.item())
|
1225 |
+
# break
|
1226 |
+
|
1227 |
+
# if step != 0 and (step % eval_iters == 0 or step == total_steps -1) :
|
1228 |
+
# loss_values = estimate_loss()
|
1229 |
+
# print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val']))
|
1230 |
+
|
1231 |
+
# Add after a training step:
|
1232 |
+
# unused_params = find_unused_parameters(model)
|
1233 |
+
# print("Unused parameters:", unused_params)
|
1234 |
+
# break
|
1235 |
+
# if device == 0 and step % 200 == 0 and step != 0:
|
1236 |
+
# count = 5
|
1237 |
+
# while(count): # Only generate text on the main process
|
1238 |
+
# print("Generating text...")
|
1239 |
+
|
1240 |
+
# alpaca_prompt = '''
|
1241 |
+
|
1242 |
+
# ### Instruction:
|
1243 |
+
# {}
|
1244 |
+
|
1245 |
+
# ### Input:
|
1246 |
+
# {}
|
1247 |
+
|
1248 |
+
# ### Response:
|
1249 |
+
|
1250 |
+
# '''
|
1251 |
+
|
1252 |
+
# prompt = alpaca_prompt.format("You are a helpful assistant.", "Say a joke.", "")
|
1253 |
+
# generated_text = greedy_decode(
|
1254 |
+
# model,
|
1255 |
+
# tokenizer,
|
1256 |
+
# prompt,
|
1257 |
+
# max_length=60,
|
1258 |
+
# repetition_penalty=1.2,
|
1259 |
+
# context_window=10,
|
1260 |
+
# temperature=0.7 # Lower temperature for more deterministic output
|
1261 |
+
# )
|
1262 |
+
# # generated_text = beam_search(model, tokenizer, prompt, beam_width=5, max_length=50, temperature=1.0)
|
1263 |
+
# print(f" Step: {step} | Generated Text: {generated_text}")
|
1264 |
+
# save_to_file(generated_text)
|
1265 |
+
# count -= 1
|
1266 |
+
|
1267 |
+
# if step != 0:
|
1268 |
+
# train_step_iterator.set_postfix({"Train loss": f"{all_gpus_avg_train_loss.item():.4f} | Val Loss : {all_gpus_avg_val_loss.item():.4f}"})
|
1269 |
+
|
1270 |
+
|
1271 |
+
# break
|
1272 |
+
# Cleanup
|
1273 |
+
if device == 0:
|
1274 |
+
# writer.close()
|
1275 |
+
wandb.finish()
|
1276 |
+
cleanup()
|
1277 |
+
|
1278 |
+
|
1279 |
+
world_size = torch.cuda.device_count()
|
1280 |
+
print(f"World size: {world_size}")
|
1281 |
+
train()
|
1282 |
+
|
inference.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from config import ModelArgs
|
2 |
+
from model import Llama
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from tokenizer import Tokenizer
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
|
9 |
+
tokenizer = Tokenizer()
|
10 |
+
tokenizer = tokenizer.ready_tokenizer()
|
11 |
+
|
12 |
+
|
13 |
+
def remove_prefix(state_dict, prefix):
|
14 |
+
new_state_dict = {}
|
15 |
+
for key, value in state_dict.items():
|
16 |
+
if key.startswith(prefix):
|
17 |
+
new_key = key[len(prefix):] # Remove the prefix
|
18 |
+
new_state_dict[new_key] = value
|
19 |
+
else:
|
20 |
+
new_state_dict[key] = value
|
21 |
+
return new_state_dict
|
22 |
+
|
23 |
+
|
24 |
+
def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0):
|
25 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
26 |
+
generated_tokens = []
|
27 |
+
ModelArgs.inference=True
|
28 |
+
for _ in range(max_length):
|
29 |
+
with torch.no_grad():
|
30 |
+
outputs = model(input_ids)
|
31 |
+
logits = outputs[:, -1, :]
|
32 |
+
|
33 |
+
probs = F.softmax(logits, dim=-1)
|
34 |
+
|
35 |
+
# Top-k filtering
|
36 |
+
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
|
37 |
+
|
38 |
+
|
39 |
+
# Apply temperature scaling
|
40 |
+
# probs = probs / temperature
|
41 |
+
|
42 |
+
# Sample from top-k
|
43 |
+
next_token = torch.multinomial(top_k_probs, num_samples=1)
|
44 |
+
|
45 |
+
# generated_tokens.append(next_token.item())
|
46 |
+
|
47 |
+
xcol = torch.gather(top_k_indices, -1, next_token)
|
48 |
+
input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence
|
49 |
+
|
50 |
+
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
51 |
+
|
52 |
+
def main():
|
53 |
+
|
54 |
+
torch.set_float32_matmul_precision('high')
|
55 |
+
|
56 |
+
parser = argparse.ArgumentParser()
|
57 |
+
parser.add_argument("--prompt", type=str, default="Once upon a time")
|
58 |
+
parser.add_argument("--max_length", type=int, default=128)
|
59 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
60 |
+
parser.add_argument("--top_k", type=int, default=50)
|
61 |
+
|
62 |
+
# parser.add_argument("--repetition_penalty", type=float, default=1.2)
|
63 |
+
args = parser.parse_args()
|
64 |
+
|
65 |
+
model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
|
66 |
+
# model = torch.compile(model)
|
67 |
+
model = model.to(ModelArgs.device)
|
68 |
+
|
69 |
+
dict_model = torch.load('weights/pretrained/snapshot_4650.pt')
|
70 |
+
dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
|
71 |
+
model.load_state_dict(dict_model['MODEL_STATE'])
|
72 |
+
model.eval()
|
73 |
+
print("Model ready")
|
74 |
+
# prompt = 'Its a secret'
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
generated_text = topk_sampling(model, args.prompt, max_length=args.max_length, top_k=50, temperature=args.temperature, device=ModelArgs.device)
|
78 |
+
print("Gnerated: ", generated_text)
|
79 |
+
# generated_text = beam_search(model, tokenizer, args.prompt, beam_width=5, max_length=50, temperature=1.0)
|
80 |
+
print(args.prompt + generated_text)
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
main()
|
llama_torchrun.py
ADDED
@@ -0,0 +1,1435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Based on Llama from Meta (https://github.com/meta-llama/llama/blob/main/llama/model.py)
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from tokenizers import Tokenizer
|
9 |
+
from pathlib import Path
|
10 |
+
import torch.multiprocessing as mp
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
13 |
+
from torch.distributed import init_process_group, destroy_process_group
|
14 |
+
import torch
|
15 |
+
from datasets import Dataset
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from transformers.models.prophetnet.modeling_prophetnet import ProphetNetDecoderModelOutput
|
18 |
+
import wandb
|
19 |
+
from tqdm import tqdm
|
20 |
+
from functools import partial
|
21 |
+
import tiktoken
|
22 |
+
import torch.optim as optim
|
23 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
24 |
+
|
25 |
+
# Load model directly
|
26 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
27 |
+
import os
|
28 |
+
|
29 |
+
torch.manual_seed(1337)
|
30 |
+
torch.cuda.manual_seed(1337)
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
# import wandb
|
36 |
+
# wandb.login()
|
37 |
+
|
38 |
+
|
39 |
+
# from torch.utils.tensorboard import SummaryWriter
|
40 |
+
|
41 |
+
|
42 |
+
from datasets import load_dataset, concatenate_datasets
|
43 |
+
|
44 |
+
# data = {}
|
45 |
+
# texts = []
|
46 |
+
# with open('data/input.txt', 'r') as f:
|
47 |
+
# texts.append(f.readlines())
|
48 |
+
|
49 |
+
# # print(texts)
|
50 |
+
# # print(len(texts[0]))
|
51 |
+
# data = {
|
52 |
+
# "text": texts[0]
|
53 |
+
# }
|
54 |
+
# fw_train = Dataset.from_dict(data)
|
55 |
+
# print(fw_train)
|
56 |
+
# fw_train = load_dataset("karpathy/tiny_shakespeare", split="train", trust_remote_code=True)
|
57 |
+
# print(fw_train['text'])
|
58 |
+
# text = fw_train['text'][0].split("\n")
|
59 |
+
# print(text)
|
60 |
+
# filtered_lines = [line for line in text if line != '']
|
61 |
+
# print(len(filtered_lines))
|
62 |
+
# use name="sample-10BT" to use the 10BT sample
|
63 |
+
|
64 |
+
tinystories = True
|
65 |
+
fw = False
|
66 |
+
fw_train = None
|
67 |
+
fw_test = None
|
68 |
+
if(tinystories):
|
69 |
+
fw_train = load_dataset("roneneldan/TinyStories", split="train")
|
70 |
+
fw_test = load_dataset("roneneldan/TinyStories", split="validation")
|
71 |
+
print(fw_train)
|
72 |
+
print(fw_test)
|
73 |
+
if(fw):
|
74 |
+
fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
|
75 |
+
fw_train = fw_train.train_test_split(test_size=0.01)
|
76 |
+
print(fw_train)
|
77 |
+
print(fw_train)
|
78 |
+
# Select only 1000 rows from the dataset
|
79 |
+
# fw_train = fw_train.select(range(1000000))
|
80 |
+
# alpaca = load_dataset("yahma/alpaca-cleaned", split='train')
|
81 |
+
# dolly = load_dataset("llm-wizard/dolly-15k-instruction-alpaca-format", split='train')
|
82 |
+
# merged_dataset = concatenate_datasets([alpaca, dolly])
|
83 |
+
# dataset = load_dataset("swype/instruct", split='train', trust_remote_code=True)
|
84 |
+
# print(fw_train)
|
85 |
+
# Split the dataset into training and validation sets
|
86 |
+
# Split the dataset into training and validation sets
|
87 |
+
# fw_train = fw_train.train_test_split(test_size=0.01)
|
88 |
+
# print(fw_train)
|
89 |
+
|
90 |
+
|
91 |
+
# Access the splits
|
92 |
+
# train_dataset = train_val_split['train']
|
93 |
+
# val_dataset = train_val_split['test']
|
94 |
+
|
95 |
+
# train_dataset = fw_train.train_test_split(test_size=0.2)
|
96 |
+
|
97 |
+
|
98 |
+
def setup(rank=None, world_size=None):
|
99 |
+
# os.environ['MASTER_ADDR'] = 'localhost'
|
100 |
+
# os.environ['MASTER_PORT'] = '12355'
|
101 |
+
init_process_group("nccl")
|
102 |
+
# torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
|
103 |
+
|
104 |
+
def cleanup():
|
105 |
+
destroy_process_group()
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class ModelArgs:
|
111 |
+
#Hyperparameters
|
112 |
+
|
113 |
+
epochs = 4
|
114 |
+
block_size = 512
|
115 |
+
batch_size = 64
|
116 |
+
embeddings_dims = 512
|
117 |
+
attn_dropout = 0.1
|
118 |
+
no_of_heads = 8
|
119 |
+
dropout = 0.1
|
120 |
+
# epochs = 100
|
121 |
+
val_epochs = 2
|
122 |
+
max_lr = 6e-4
|
123 |
+
no_of_decoder_layers = 8 #IMP needs to be thoroughly calculated
|
124 |
+
weight_decay_optim = 0.1
|
125 |
+
beta_1 = 0.9
|
126 |
+
beta_2 = 0.95
|
127 |
+
clip = 1.0
|
128 |
+
device = 'cuda'
|
129 |
+
no_kv_heads = 2
|
130 |
+
vocab_size = 50304 #powers of 2 so nice!
|
131 |
+
eps = 1e-5
|
132 |
+
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
|
133 |
+
# dtype = 'bfloat16'
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
def _save_snapshot(model, optimizer, scheduler, epoch, step):
|
140 |
+
snapshot = {
|
141 |
+
"MODEL_STATE": model.module.state_dict(),
|
142 |
+
"OPTIMIZER_STATE": optimizer.state_dict(),
|
143 |
+
# "SCHEDULER_STATE": scheduler.state_dict(),
|
144 |
+
"EPOCHS_RUN": epoch,
|
145 |
+
"STEP_RUN": step
|
146 |
+
}
|
147 |
+
torch.save(snapshot, f"snapshot_{step}.pt")
|
148 |
+
print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.")
|
149 |
+
|
150 |
+
def _load_snapshot(snapshot_path, model, optimizer, scheduler):
|
151 |
+
snapshot = torch.load(snapshot_path)
|
152 |
+
model.load_state_dict(snapshot["MODEL_STATE"])
|
153 |
+
optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"])
|
154 |
+
# scheduler.load_state_dict(snapshot["SCHEDULER_STATE"]) # Load scheduler state
|
155 |
+
epoch = snapshot["EPOCHS_RUN"]
|
156 |
+
step = snapshot["STEP_RUN"]
|
157 |
+
print(f"Resuming from Epoch {epoch}, Step {step}")
|
158 |
+
return epoch, step
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '...')
|
166 |
+
|
167 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
168 |
+
# if tokenizer.pad_token is None:
|
169 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
170 |
+
# print("ADDED THE TOKENS: ", tokenizer.pad_token_id)
|
171 |
+
# tokenizer.bos_token = "[INST]"
|
172 |
+
# tokenizer.eos_token = "[/INST]"
|
173 |
+
# model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
174 |
+
|
175 |
+
def tokenize_function(examples):
|
176 |
+
return tokenizer(
|
177 |
+
examples['text'],
|
178 |
+
max_length=ModelArgs.block_size,
|
179 |
+
padding='max_length',
|
180 |
+
truncation=True,
|
181 |
+
return_tensors='pt'
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
def prepare_dataset(split, device, batch_size):
|
188 |
+
print("Device is: ", device)
|
189 |
+
# alpaca_prompt = '''
|
190 |
+
|
191 |
+
|
192 |
+
# ### Instruction:
|
193 |
+
# {}
|
194 |
+
|
195 |
+
# ### Response:
|
196 |
+
# {}
|
197 |
+
# '''
|
198 |
+
# Load a subset of the C4 dataset with a glob pattern for specific training files
|
199 |
+
# dataset = load_dataset("allenai/c4", data_files=["en/c4-train.00001-of-01024.json.gz"], trust_remote_code=True)
|
200 |
+
|
201 |
+
# Initialize tokenizer
|
202 |
+
# tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
203 |
+
# generator = torch.Generator(device=device)
|
204 |
+
def collate_fn(batch):
|
205 |
+
# Extract text data
|
206 |
+
texts = [item ["text"] for item in batch]
|
207 |
+
|
208 |
+
# Set the pad token if it isn't set already
|
209 |
+
# if tokenizer.pad_token is None:
|
210 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
211 |
+
# outputs = []
|
212 |
+
# texts = []
|
213 |
+
# for item in batch:
|
214 |
+
# instruction = item['prompt']
|
215 |
+
# # input = item['input']
|
216 |
+
# output = item['completion']
|
217 |
+
# # out = alpaca_prompt.format(instruction, output)
|
218 |
+
# texts.append(instruction)
|
219 |
+
# outputs.append(output)
|
220 |
+
# Tokenize text data
|
221 |
+
input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
|
222 |
+
# output_encodings = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
|
223 |
+
# input_encodings["labels"] = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
|
224 |
+
# out = {"input": input_encodings}
|
225 |
+
# input_encodings['input_ids'][: , input_encodings["attention_mask"] == 0] = -100
|
226 |
+
input_encodings["labels"] = input_encodings["input_ids"].clone() # Use `input_ids` as labels
|
227 |
+
|
228 |
+
input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] # Shift right
|
229 |
+
input_encodings["labels"][:, -1] = tokenizer.eos_token_id # Let the last token be end
|
230 |
+
# Return tokenized input tensors
|
231 |
+
# return out
|
232 |
+
return input_encodings
|
233 |
+
|
234 |
+
# Create DistributedSampler for proper shuffling and partitioning across processes
|
235 |
+
# dist_sampler = DistributedSampler(fw_train["text"], shuffle=True)
|
236 |
+
|
237 |
+
# Create DataLoader with custom collate_fn
|
238 |
+
# print(fw_dataset)
|
239 |
+
dataloader = None
|
240 |
+
if(tinystories):
|
241 |
+
if(split == 'train'):
|
242 |
+
data_loader = DataLoader(
|
243 |
+
fw_train,
|
244 |
+
# generator=generator,
|
245 |
+
batch_size=batch_size,
|
246 |
+
|
247 |
+
sampler=DistributedSampler(fw_train, shuffle=True),
|
248 |
+
collate_fn=collate_fn,
|
249 |
+
drop_last=True,
|
250 |
+
shuffle=False
|
251 |
+
)
|
252 |
+
elif(split == 'val'):
|
253 |
+
data_loader = DataLoader(
|
254 |
+
fw_test,
|
255 |
+
|
256 |
+
|
257 |
+
batch_size=batch_size,
|
258 |
+
sampler=DistributedSampler(fw_test, shuffle=True),
|
259 |
+
collate_fn=collate_fn,
|
260 |
+
drop_last=True,
|
261 |
+
shuffle=False
|
262 |
+
)
|
263 |
+
elif(fw):
|
264 |
+
if(split == 'train'):
|
265 |
+
data_loader = DataLoader(
|
266 |
+
fw_train['train'],
|
267 |
+
batch_size=batch_size,
|
268 |
+
|
269 |
+
|
270 |
+
sampler=DistributedSampler(fw_train['train'], shuffle=True),
|
271 |
+
collate_fn=collate_fn,
|
272 |
+
drop_last=True,
|
273 |
+
shuffle=False
|
274 |
+
)
|
275 |
+
elif(split == 'val'):
|
276 |
+
data_loader = DataLoader(
|
277 |
+
fw_train['test'],
|
278 |
+
batch_size=batch_size,
|
279 |
+
# generator=generator,
|
280 |
+
sampler=DistributedSampler(fw_train["test"]),
|
281 |
+
collate_fn=collate_fn,
|
282 |
+
|
283 |
+
drop_last=True,
|
284 |
+
shuffle=False
|
285 |
+
)
|
286 |
+
return data_loader
|
287 |
+
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
class Normalization(nn.Module):
|
294 |
+
def __init__(
|
295 |
+
self,
|
296 |
+
|
297 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
298 |
+
):
|
299 |
+
super().__init__()
|
300 |
+
self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims)
|
301 |
+
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
|
305 |
+
x = self.rmsnorm_layer(x)
|
306 |
+
return x
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
# import numpy as np
|
313 |
+
class RotaryEmbeddings(nn.Module):
|
314 |
+
def __init__(
|
315 |
+
self,
|
316 |
+
device,
|
317 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
318 |
+
block_size: int = ModelArgs.block_size,
|
319 |
+
batch_size: int = ModelArgs.batch_size
|
320 |
+
):
|
321 |
+
super().__init__()
|
322 |
+
|
323 |
+
self.embeddings_dims = embeddings_dims
|
324 |
+
self.block_size = block_size
|
325 |
+
self.batch_size = batch_size
|
326 |
+
self.theta = 0
|
327 |
+
self.device=device
|
328 |
+
|
329 |
+
# self.d_model = embeddings_dims
|
330 |
+
# self.i = torch.arange(0, embeddings_dims, dtype=torch.float32)
|
331 |
+
# # self.pos = torch.arange(0, block_size, dtype=torch.float32)
|
332 |
+
# self.exp = ((2 * self.i)) / self.d_model
|
333 |
+
# self.theta = 10000 ** self.exp
|
334 |
+
# # print(self.theta.shape)
|
335 |
+
# self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims,dtype=torch.float32, device=device)
|
336 |
+
|
337 |
+
# self.cos = torch.cos((self.i / self.theta))
|
338 |
+
# self.sin = torch.sin((self.i / self.theta))
|
339 |
+
|
340 |
+
# self.even = self.sin[::2]
|
341 |
+
# self.odd = self.cos[1::2]
|
342 |
+
|
343 |
+
# # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype)
|
344 |
+
# self.x_reshaped[..., : , ::2] = self.even
|
345 |
+
# self.x_reshaped[..., : , 1::2] = self.odd
|
346 |
+
|
347 |
+
|
348 |
+
def apply_rope(self, seq):
|
349 |
+
batch_size, seq_len, embeds_dims = seq.shape
|
350 |
+
# print(seq.shape)
|
351 |
+
# print(self.embeddings_dims)
|
352 |
+
# self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
|
353 |
+
|
354 |
+
positions = torch.arange(0 , embeds_dims, 2, dtype=torch.float32, device = self.device).unsqueeze(0)
|
355 |
+
# dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
|
356 |
+
theta = 10000 ** (-2 * (positions) / embeds_dims)
|
357 |
+
angles = positions * theta
|
358 |
+
angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
|
359 |
+
x_reshaped = seq.view(batch_size, seq_len, embeds_dims // 2, 2)
|
360 |
+
|
361 |
+
cos_angles = torch.cos(angles)
|
362 |
+
sin_angles = torch.sin(angles)
|
363 |
+
# print(cos_angles.shape)
|
364 |
+
# print(sin_angles.shape)
|
365 |
+
# print(x_reshaped.shape)
|
366 |
+
# indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device)
|
367 |
+
|
368 |
+
out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=-1)
|
369 |
+
out = out.view(batch_size, seq_len, embeds_dims)
|
370 |
+
return out
|
371 |
+
|
372 |
+
def forward(self, x):
|
373 |
+
# print("X shape: ", x.shape)
|
374 |
+
# print("X is: ", x)
|
375 |
+
# B,T,C = x.shape
|
376 |
+
# print("MATRIX:",x)
|
377 |
+
# if(x > self.block_size or x < self.block_size):
|
378 |
+
# matrix = self.init_matrix(x)
|
379 |
+
# return matrix
|
380 |
+
# else:
|
381 |
+
# matrix = self.init_matrix(self.block_size)
|
382 |
+
|
383 |
+
# return matrix
|
384 |
+
# if(ModelArgs.inference):
|
385 |
+
res = self.apply_rope(x)
|
386 |
+
return res
|
387 |
+
# else:
|
388 |
+
# return self.x_reshaped
|
389 |
+
|
390 |
+
class RotaryAttentionHead(nn.Module):
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
device,
|
394 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
395 |
+
no_of_heads: int = ModelArgs.no_of_heads,
|
396 |
+
attn_dropout: int = ModelArgs.attn_dropout
|
397 |
+
):
|
398 |
+
super().__init__()
|
399 |
+
self.head_size = embeddings_dims // no_of_heads
|
400 |
+
self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
401 |
+
self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
402 |
+
self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
403 |
+
self.rope = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
404 |
+
self.dropout = nn.Dropout(p = attn_dropout)
|
405 |
+
self.device = device
|
406 |
+
def forward(self,x):
|
407 |
+
# print(x.shape)
|
408 |
+
# print("X is: ", x)
|
409 |
+
batch, block_size, embeddings_dims = x.shape
|
410 |
+
query = self.query(x)
|
411 |
+
# print(query)
|
412 |
+
key = self.key(x)
|
413 |
+
values = self.value(x)
|
414 |
+
# matrix = self.rotary_matrix(block_size)
|
415 |
+
rotary_q = self.rope(query)
|
416 |
+
rotary_k = self.rope(key)
|
417 |
+
|
418 |
+
# print(matrix.shape)
|
419 |
+
# print(query.shape)
|
420 |
+
masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
421 |
+
# rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
422 |
+
# rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
423 |
+
weights = rotary_q.permute(2,0,1) @ rotary_k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
424 |
+
weights_masked = weights.masked_fill(masked == 0, float('-inf'))
|
425 |
+
scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
|
426 |
+
scaled_weights = F.softmax(scaled_weights, dim=-1)
|
427 |
+
value = scaled_weights @ values
|
428 |
+
out = self.dropout(value)
|
429 |
+
return out
|
430 |
+
|
431 |
+
|
432 |
+
# # import numpy as np
|
433 |
+
# class RotaryEmbeddings(nn.Module):
|
434 |
+
# def __init__(
|
435 |
+
# self,
|
436 |
+
# device,
|
437 |
+
# embeddings_dims: int = ModelArgs.embeddings_dims,
|
438 |
+
# block_size: int = ModelArgs.block_size,
|
439 |
+
# batch_size: int = ModelArgs.batch_size
|
440 |
+
# ):
|
441 |
+
# super().__init__()
|
442 |
+
|
443 |
+
# self.embeddings_dims = embeddings_dims
|
444 |
+
# self.block_size = block_size
|
445 |
+
# self.batch_size = batch_size
|
446 |
+
# self.theta = 0
|
447 |
+
|
448 |
+
|
449 |
+
# # def init_matrix(self, seq_len):
|
450 |
+
# # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False)
|
451 |
+
# # for pos in range(seq_len):
|
452 |
+
# # for j in range(1, self.embeddings_dims // 2):
|
453 |
+
# # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
|
454 |
+
# # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
|
455 |
+
# # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
|
456 |
+
# # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
|
457 |
+
# # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
|
458 |
+
# # return self.matrix
|
459 |
+
# self.device=device
|
460 |
+
|
461 |
+
# def init_matrix(self, seq_len):
|
462 |
+
# self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
|
463 |
+
|
464 |
+
# positions = torch.arange(0 , seq_len, 2, dtype=torch.float32, device = self.device).unsqueeze(1)
|
465 |
+
# # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
|
466 |
+
# theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
|
467 |
+
# angles = positions * theta
|
468 |
+
|
469 |
+
# cos_angles = torch.cos(angles)
|
470 |
+
# sin_angles = torch.sin(angles)
|
471 |
+
|
472 |
+
# indices = torch.arange(seq_len, dtype=torch.int64, device = self.device)
|
473 |
+
# # print(indices)
|
474 |
+
# # print(indices.shape)
|
475 |
+
# # print(indices[::2])
|
476 |
+
# even_indices = indices[::2]
|
477 |
+
# odd_indices = indices[1::2]
|
478 |
+
|
479 |
+
# self.matrix[:, even_indices, even_indices] = cos_angles
|
480 |
+
# self.matrix[:, odd_indices, odd_indices] = sin_angles
|
481 |
+
# self.matrix[:, odd_indices, even_indices] = -sin_angles
|
482 |
+
# self.matrix[:, even_indices, odd_indices] = cos_angles
|
483 |
+
|
484 |
+
# return self.matrix
|
485 |
+
|
486 |
+
# def forward(self, x):
|
487 |
+
# # B,T,C = x.shape
|
488 |
+
# # print("MATRIX:",x)
|
489 |
+
# if(x > self.block_size or x < self.block_size):
|
490 |
+
# matrix = self.init_matrix(x)
|
491 |
+
# return matrix
|
492 |
+
# else:
|
493 |
+
# matrix = self.init_matrix(self.block_size)
|
494 |
+
|
495 |
+
# return matrix
|
496 |
+
|
497 |
+
|
498 |
+
# class RotaryAttentionHead(nn.Module):
|
499 |
+
# def __init__(
|
500 |
+
# self,
|
501 |
+
# device,
|
502 |
+
# embeddings_dims: int = ModelArgs.embeddings_dims,
|
503 |
+
# no_of_heads: int = ModelArgs.no_of_heads,
|
504 |
+
# attn_dropout: int = ModelArgs.attn_dropout
|
505 |
+
# ):
|
506 |
+
# super().__init__()
|
507 |
+
# self.head_size = embeddings_dims // no_of_heads
|
508 |
+
# self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
509 |
+
# self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
510 |
+
# self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
511 |
+
# self.rotary_matrix = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
512 |
+
# self.dropout = nn.Dropout(p = attn_dropout)
|
513 |
+
# self.device = device
|
514 |
+
# def forward(self,x):
|
515 |
+
# # print(x.shape)
|
516 |
+
# batch, block_size, embeddings_dims = x.shape
|
517 |
+
# query = self.query(x)
|
518 |
+
# # print(query)
|
519 |
+
# key = self.key(x)
|
520 |
+
# values = self.value(x)
|
521 |
+
# matrix = self.rotary_matrix(block_size)
|
522 |
+
|
523 |
+
# # print(matrix.shape)
|
524 |
+
# # print(query.shape)
|
525 |
+
# masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
526 |
+
# rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
527 |
+
# rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
528 |
+
# weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
529 |
+
# weights_masked = weights.masked_fill(masked == 0, float('-inf'))
|
530 |
+
# scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
|
531 |
+
# scaled_weights = F.softmax(scaled_weights, dim=-1)
|
532 |
+
# value = scaled_weights @ values
|
533 |
+
# out = self.dropout(value)
|
534 |
+
# return out
|
535 |
+
|
536 |
+
|
537 |
+
class MQA(nn.Module):
|
538 |
+
def __init__(
|
539 |
+
self,
|
540 |
+
device,
|
541 |
+
no_of_q_heads: int,
|
542 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
543 |
+
block_size: int = ModelArgs.block_size,
|
544 |
+
|
545 |
+
|
546 |
+
):
|
547 |
+
super().__init__()
|
548 |
+
|
549 |
+
|
550 |
+
# self.no_of_q_heads = no_of_heads // no_of_kv_heads
|
551 |
+
# self.no_of_q_heads = no_of_q_heads
|
552 |
+
self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads
|
553 |
+
self.head_size = embeddings_dims // no_of_q_heads
|
554 |
+
# self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2
|
555 |
+
self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
556 |
+
# self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size, device = device)
|
557 |
+
# self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False)
|
558 |
+
self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
|
559 |
+
self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
|
560 |
+
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
|
561 |
+
self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
|
562 |
+
self.device = device
|
563 |
+
self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, device = self.device) for _ in range(self.no_of_kv_heads)])
|
564 |
+
|
565 |
+
def scaled_dot_product(self, q, k, v, block_size):
|
566 |
+
|
567 |
+
# masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
568 |
+
q = self.rotary(q)
|
569 |
+
masked_table = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
570 |
+
# rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
571 |
+
# rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
572 |
+
# print("Query: ", q.shape)
|
573 |
+
# print("Keys: ", k.shape)
|
574 |
+
# print(q.permute(2,0,1).shape)
|
575 |
+
# print(k.permute(2,0,1).transpose(-2, -1).shape)
|
576 |
+
# weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
577 |
+
# weights = q @ k.permute(2,1,0)
|
578 |
+
# print(weights.shape)
|
579 |
+
# print(masked.shape)
|
580 |
+
weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
|
581 |
+
masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
|
582 |
+
weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
|
583 |
+
weights_normalized = self.dropout(weights_normalized)
|
584 |
+
out = weights_normalized @ v
|
585 |
+
return out
|
586 |
+
|
587 |
+
def forward(self,x):
|
588 |
+
# print("MQA: ", x.shape)
|
589 |
+
batch, block_size, embeddings_dims = x.shape
|
590 |
+
|
591 |
+
# query = self.query(x)
|
592 |
+
# matrix = self.rotary_matrix(block_size)
|
593 |
+
|
594 |
+
|
595 |
+
key = self.key(x)
|
596 |
+
values = self.value(x)
|
597 |
+
# print("Keys: ", key.shape)
|
598 |
+
# print("Values: ", values.shape)
|
599 |
+
# rotary_value = self.rotary(values)
|
600 |
+
rotary_key = self.rotary(key)
|
601 |
+
multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size) for query in self.multi_query], dim=-1)
|
602 |
+
# print("Multi query: ", multi_query_concat.shape)
|
603 |
+
|
604 |
+
linear_layer= self.linear_layer(multi_query_concat)
|
605 |
+
# out = self.dropout(linear_layer)
|
606 |
+
return linear_layer
|
607 |
+
|
608 |
+
|
609 |
+
class GQA(nn.Module):
|
610 |
+
def __init__(
|
611 |
+
self,
|
612 |
+
device,
|
613 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
614 |
+
block_size: int = ModelArgs.block_size,
|
615 |
+
# no_of_q_heads: int = ModelArgs.no_of_heads,
|
616 |
+
mqa_heads: int = ModelArgs.no_kv_heads
|
617 |
+
):
|
618 |
+
super().__init__()
|
619 |
+
|
620 |
+
# self.no_of_kv_heads = no_of_kv_heads
|
621 |
+
self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads
|
622 |
+
# self.head_dim = embeddings_dims // self.no_kv_heads
|
623 |
+
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
|
624 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device)
|
625 |
+
self.device = device
|
626 |
+
self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)])
|
627 |
+
# self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size)
|
628 |
+
def forward(self,x):
|
629 |
+
|
630 |
+
batch, block_size, embeddings_dims = x.shape
|
631 |
+
|
632 |
+
# res = self.mqa(x)
|
633 |
+
grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1)
|
634 |
+
|
635 |
+
linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one
|
636 |
+
out = self.dropout(linear_layer)
|
637 |
+
return out
|
638 |
+
|
639 |
+
|
640 |
+
class Swish(nn.Module):
|
641 |
+
def __init__(
|
642 |
+
self,
|
643 |
+
device,
|
644 |
+
block_size: int = ModelArgs.block_size,
|
645 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
646 |
+
):
|
647 |
+
super().__init__()
|
648 |
+
|
649 |
+
self.sig = torch.nn.Sigmoid()
|
650 |
+
|
651 |
+
|
652 |
+
def forward(self, x):
|
653 |
+
swish = x * self.sig(x)
|
654 |
+
|
655 |
+
return swish
|
656 |
+
|
657 |
+
|
658 |
+
|
659 |
+
class SWiGLU(nn.Module):
|
660 |
+
def __init__(
|
661 |
+
self,
|
662 |
+
device,
|
663 |
+
block_size: int = ModelArgs.block_size,
|
664 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
665 |
+
):
|
666 |
+
super().__init__()
|
667 |
+
self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3)
|
668 |
+
self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
|
669 |
+
self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
|
670 |
+
self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
|
671 |
+
self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
672 |
+
|
673 |
+
|
674 |
+
|
675 |
+
|
676 |
+
def forward(self, x):
|
677 |
+
swish_res = self.swish(self.linear_layer1(x))
|
678 |
+
x_V = self.linear_layer2(x)
|
679 |
+
res = torch.mul(swish_res, x_V)
|
680 |
+
out = self.linear_layer3(res)
|
681 |
+
return out
|
682 |
+
|
683 |
+
|
684 |
+
|
685 |
+
class FFN(nn.Module):
|
686 |
+
def __init__(self,
|
687 |
+
device,
|
688 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
689 |
+
block_size: int = ModelArgs.block_size,
|
690 |
+
vocab_size: int = ModelArgs.vocab_size,
|
691 |
+
dropout = ModelArgs.dropout
|
692 |
+
|
693 |
+
):
|
694 |
+
super().__init__()
|
695 |
+
|
696 |
+
# self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
|
697 |
+
self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device)
|
698 |
+
self.dropout = nn.Dropout(p = dropout)
|
699 |
+
def forward(self, x):
|
700 |
+
|
701 |
+
x = self.swiglue(x)
|
702 |
+
# x = self.linear_layer(x)
|
703 |
+
x = self.dropout(x)
|
704 |
+
return x
|
705 |
+
|
706 |
+
|
707 |
+
class DecoderLayer(nn.Module):
|
708 |
+
def __init__(self,
|
709 |
+
device,
|
710 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
711 |
+
dropout = ModelArgs.dropout,
|
712 |
+
block_size: int = ModelArgs.block_size,
|
713 |
+
vocab_size: int = ModelArgs.vocab_size,
|
714 |
+
|
715 |
+
) :
|
716 |
+
super().__init__()
|
717 |
+
|
718 |
+
|
719 |
+
self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
|
720 |
+
self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2, device = device)
|
721 |
+
# self.norm = Normalization(embeddings_dims=embeddings_dims)
|
722 |
+
self.norm1 = Normalization(embeddings_dims=embeddings_dims)
|
723 |
+
self.norm2 = Normalization(embeddings_dims=embeddings_dims)
|
724 |
+
self.dropout = nn.Dropout(p = dropout)
|
725 |
+
def forward(self, x):
|
726 |
+
|
727 |
+
x = x + self.gqa(self.norm1(x))
|
728 |
+
x = x + self.feedforward_network(self.norm2(x))
|
729 |
+
return x
|
730 |
+
|
731 |
+
|
732 |
+
class Llama(nn.Module):
|
733 |
+
def __init__(self,
|
734 |
+
device,
|
735 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
736 |
+
no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
|
737 |
+
block_size: int = ModelArgs.block_size,
|
738 |
+
vocab_size: int = ModelArgs.vocab_size,
|
739 |
+
dropout = ModelArgs.dropout
|
740 |
+
|
741 |
+
) :
|
742 |
+
super().__init__()
|
743 |
+
|
744 |
+
self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
|
745 |
+
self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)])
|
746 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
|
747 |
+
self.dropout = nn.Dropout(p = dropout)
|
748 |
+
# self.norm = Normalization(embeddings_dims)
|
749 |
+
|
750 |
+
|
751 |
+
#weight tying
|
752 |
+
self.embeddings.weight = self.linear_layer.weight
|
753 |
+
|
754 |
+
self.apply(self._init_weights)
|
755 |
+
|
756 |
+
def _init_weights(self, module):
|
757 |
+
if isinstance(module, nn.Linear):
|
758 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
759 |
+
|
760 |
+
if module.bias is not None:
|
761 |
+
nn.init.zeros_(module.bias)
|
762 |
+
elif isinstance(module, nn.Embedding):
|
763 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
764 |
+
|
765 |
+
|
766 |
+
|
767 |
+
def forward(self, x):
|
768 |
+
x = self.embeddings(x)
|
769 |
+
x = self.dropout(x)
|
770 |
+
x = self.decoder(x)
|
771 |
+
# x = self.norm(x)
|
772 |
+
x = self.linear_layer(x)
|
773 |
+
# out = self.norm(x)
|
774 |
+
return x
|
775 |
+
|
776 |
+
|
777 |
+
# from andrej karapathy github
|
778 |
+
def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0):
|
779 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
|
780 |
+
generated_tokens = []
|
781 |
+
ModelArgs.inference=True
|
782 |
+
for _ in range(max_length):
|
783 |
+
with torch.no_grad():
|
784 |
+
outputs = model.module(input_ids)
|
785 |
+
logits = outputs[:, -1, :]
|
786 |
+
|
787 |
+
probs = F.softmax(logits, dim=-1)
|
788 |
+
|
789 |
+
# Top-k filtering
|
790 |
+
top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
|
791 |
+
|
792 |
+
|
793 |
+
# Apply temperature scaling
|
794 |
+
# probs = probs / temperature
|
795 |
+
|
796 |
+
# Sample from top-k
|
797 |
+
next_token = torch.multinomial(top_k_probs, num_samples=1)
|
798 |
+
|
799 |
+
# generated_tokens.append(next_token.item())
|
800 |
+
|
801 |
+
xcol = torch.gather(top_k_indices, -1, next_token)
|
802 |
+
input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence
|
803 |
+
|
804 |
+
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
805 |
+
|
806 |
+
def beam_search(model, tokenizer, prompt, beam_width=5, max_length=50, temperature=1.0):
|
807 |
+
device = next(model.module.parameters()).device
|
808 |
+
input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids']
|
809 |
+
beam_scores = torch.zeros(beam_width, device=device)
|
810 |
+
beam_sequences = input_ids.repeat(beam_width, 1)
|
811 |
+
|
812 |
+
for _ in range(max_length):
|
813 |
+
outputs = model(beam_sequences)
|
814 |
+
logits = outputs[:, -1, :] / temperature
|
815 |
+
probs = F.softmax(logits, dim=-1)
|
816 |
+
top_probs, top_indices = torch.topk(probs, beam_width, dim=-1)
|
817 |
+
|
818 |
+
# Expand beams
|
819 |
+
beam_scores = beam_scores.unsqueeze(-1) + torch.log(top_probs)
|
820 |
+
beam_scores = beam_scores.view(-1)
|
821 |
+
top_indices = top_indices.view(-1)
|
822 |
+
|
823 |
+
# Select top beams
|
824 |
+
beam_scores, top_beams = torch.topk(beam_scores, beam_width)
|
825 |
+
beam_sequences = torch.cat([beam_sequences[top_beams // beam_width], top_indices[top_beams].unsqueeze(-1)], dim=-1)
|
826 |
+
|
827 |
+
# Return the best sequence
|
828 |
+
best_sequence = beam_sequences[0]
|
829 |
+
return tokenizer.decode(best_sequence, skip_special_tokens=True)
|
830 |
+
|
831 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
832 |
+
# device = "cpu"
|
833 |
+
# ModelArgs.device = device
|
834 |
+
model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
|
835 |
+
model = model.to(ModelArgs.device)
|
836 |
+
|
837 |
+
# Printing a summary of the architecture
|
838 |
+
# !pip install torchinfo
|
839 |
+
from torchinfo import summary
|
840 |
+
# idx, targets = get_batch('test')
|
841 |
+
idx = torch.randint(
|
842 |
+
low=0,
|
843 |
+
high=ModelArgs.vocab_size,
|
844 |
+
size=(ModelArgs.batch_size, ModelArgs.block_size),
|
845 |
+
dtype=torch.long
|
846 |
+
)
|
847 |
+
# sample_idx = random.randint(range(len(train_dataset)))
|
848 |
+
# idx, targets = train_dataset[0]
|
849 |
+
idx = idx.to(ModelArgs.device)
|
850 |
+
# targets = targets.to(ModelArgs.device)
|
851 |
+
summary(model=model,
|
852 |
+
input_data=idx,
|
853 |
+
# input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
|
854 |
+
col_names=["input_size", "output_size", "num_params", "trainable"],
|
855 |
+
col_width=20,
|
856 |
+
row_settings=["var_names"])
|
857 |
+
|
858 |
+
|
859 |
+
def find_unused_parameters(model):
|
860 |
+
unused = []
|
861 |
+
for name, param in model.named_parameters():
|
862 |
+
if param.grad is None:
|
863 |
+
unused.append(name)
|
864 |
+
return unused
|
865 |
+
|
866 |
+
def greedy_decode(
|
867 |
+
model,
|
868 |
+
tokenizer,
|
869 |
+
prompt,
|
870 |
+
device,
|
871 |
+
max_length=50,
|
872 |
+
repetition_penalty=1.2,
|
873 |
+
context_window=10,
|
874 |
+
temperature=1.0,
|
875 |
+
eos_token_id=None,
|
876 |
+
|
877 |
+
):
|
878 |
+
# model.eval()
|
879 |
+
# device = next(model.parameters()).device
|
880 |
+
input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids']
|
881 |
+
generated_tokens = []
|
882 |
+
eos_token_id = eos_token_id or tokenizer.eos_token_id # Use EOS token if provided
|
883 |
+
|
884 |
+
for _ in range(max_length):
|
885 |
+
with torch.no_grad():
|
886 |
+
outputs = model.module(input_ids)
|
887 |
+
logits = outputs[:, -1, :] # Get logits for the last token
|
888 |
+
|
889 |
+
# Apply temperature scaling
|
890 |
+
# if temperature != 1.0:
|
891 |
+
# logits = logits / temperature
|
892 |
+
|
893 |
+
# Apply repetition penalty
|
894 |
+
# if repetition_penalty != 1.0 and len(generated_tokens) > 0:
|
895 |
+
# for token in set(generated_tokens[-context_window:]): # Penalize recent tokens
|
896 |
+
# logits[0, token] /= repetition_penalty
|
897 |
+
|
898 |
+
# Greedy selection
|
899 |
+
next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
|
900 |
+
generated_tokens.append(next_token.item())
|
901 |
+
|
902 |
+
# Stop if EOS token is generated
|
903 |
+
# if next_token.item() == eos_token_id:
|
904 |
+
# break
|
905 |
+
|
906 |
+
# Append the new token to the input
|
907 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
908 |
+
|
909 |
+
# Decode the generated tokens
|
910 |
+
return tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
911 |
+
|
912 |
+
|
913 |
+
|
914 |
+
def save_to_file(text):
|
915 |
+
|
916 |
+
with open('generations.txt', 'a') as f:
|
917 |
+
f.writelines(text + "\n\n")
|
918 |
+
|
919 |
+
|
920 |
+
#Train the model
|
921 |
+
|
922 |
+
|
923 |
+
# writer = SummaryWriter(log_dir="runs/experiment")
|
924 |
+
|
925 |
+
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
|
926 |
+
|
927 |
+
# Warmup phase for 2000 steps
|
928 |
+
def warmup_fn(step):
|
929 |
+
if step < 2000:
|
930 |
+
return step / 2000 # LR gradually increases
|
931 |
+
return 1.0
|
932 |
+
|
933 |
+
|
934 |
+
from torch.optim.lr_scheduler import LambdaLR
|
935 |
+
|
936 |
+
def trapezoidal_lr_scheduler(optimizer, max_lr, total_steps, warmup_steps, plateau_steps, decay_steps):
|
937 |
+
"""
|
938 |
+
Trapezoidal learning rate scheduler:
|
939 |
+
- Increases linearly for `warmup_steps` steps.
|
940 |
+
- Remains constant for `plateau_steps` steps.
|
941 |
+
- Decreases linearly for `decay_steps` steps.
|
942 |
+
"""
|
943 |
+
def lr_lambda(step):
|
944 |
+
if step < warmup_steps:
|
945 |
+
# Linear warmup
|
946 |
+
return float(step) / float(max(1, warmup_steps))
|
947 |
+
elif step < warmup_steps + plateau_steps:
|
948 |
+
# Constant plateau
|
949 |
+
return 1.0
|
950 |
+
else:
|
951 |
+
# Linear decay
|
952 |
+
decay_step = step - (warmup_steps + plateau_steps)
|
953 |
+
return max(0.0, float(decay_steps - decay_step) / float(max(1, decay_steps)))
|
954 |
+
|
955 |
+
return LambdaLR(optimizer, lr_lambda)
|
956 |
+
|
957 |
+
|
958 |
+
torch.set_float32_matmul_precision('high')
|
959 |
+
|
960 |
+
scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16'))
|
961 |
+
|
962 |
+
save_chechpoint_iter = 50
|
963 |
+
total_iters = 10000
|
964 |
+
eval_iters = 50
|
965 |
+
eval_check = 100
|
966 |
+
warmup_iters = 700
|
967 |
+
min_lr = 0.1 * ModelArgs.max_lr
|
968 |
+
lr_decay_iters = 10000
|
969 |
+
total_batch_size = 524288
|
970 |
+
micro_batch_size = ModelArgs.batch_size
|
971 |
+
gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count()))
|
972 |
+
|
973 |
+
# learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py
|
974 |
+
|
975 |
+
def get_lr(it):
|
976 |
+
# 1) linear warmup for warmup_iters steps
|
977 |
+
if it < warmup_iters:
|
978 |
+
return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1)
|
979 |
+
# 2) if it > lr_decay_iters, return min learning rate
|
980 |
+
if it > lr_decay_iters:
|
981 |
+
return min_lr
|
982 |
+
# 3) in between, use cosine decay down to min learning rate
|
983 |
+
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
984 |
+
assert 0 <= decay_ratio <= 1
|
985 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
986 |
+
return min_lr + coeff * (ModelArgs.max_lr - min_lr)
|
987 |
+
|
988 |
+
|
989 |
+
def train():
|
990 |
+
setup()
|
991 |
+
device = int(os.environ["LOCAL_RANK"])
|
992 |
+
|
993 |
+
torch.cuda.set_device(int(device))
|
994 |
+
# torch.set_default_device('cuda')
|
995 |
+
# train_dataloader = prepare_dataset(ModelArgs.batch_size)
|
996 |
+
# rank = torch.distributed.get_rank()
|
997 |
+
print(f"Start running DDP on rank {device}.")
|
998 |
+
# # create model and move it to GPU with id rank
|
999 |
+
# device_id = rank % torch.cuda.device_count()
|
1000 |
+
# CFG = ModelArgs()
|
1001 |
+
|
1002 |
+
if(device == 0):
|
1003 |
+
|
1004 |
+
|
1005 |
+
|
1006 |
+
# # Initialise run
|
1007 |
+
wandb.init(
|
1008 |
+
# entity = 'rajceo2031',
|
1009 |
+
project = 'Llama-DDP-Pretrain-10-billion-tokens',
|
1010 |
+
# config = CFG,
|
1011 |
+
# save_code = True,
|
1012 |
+
#group = 'ANN',
|
1013 |
+
#job_type = 'train'
|
1014 |
+
)
|
1015 |
+
print("wand initialized")
|
1016 |
+
|
1017 |
+
model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
|
1018 |
+
|
1019 |
+
# print(f"Model on device {device} is ready")
|
1020 |
+
print(f"Model on device {device} is ready")
|
1021 |
+
|
1022 |
+
# Wrap model with DDP after moving to GPU
|
1023 |
+
# model = DDP(model, device_ids=[device])
|
1024 |
+
# optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=1e-8)
|
1025 |
+
# # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=4000, T_mult=1, eta_min=1e-5)
|
1026 |
+
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(None, T_max=30000, eta_min=1e-6)
|
1027 |
+
# _load_snapshot('/kaggle/input/models/snapshot2.pt', model.module, None, None)
|
1028 |
+
optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps)
|
1029 |
+
|
1030 |
+
# model = torch.compile(model)
|
1031 |
+
model = model.to(device)
|
1032 |
+
|
1033 |
+
model = DDP(model, device_ids=[device])
|
1034 |
+
|
1035 |
+
|
1036 |
+
# new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25000, eta_min=1e-6) #with the prev optim snapshot
|
1037 |
+
# new_scheduler = trapezoidal_lr_scheduler(optimizer, ModelArgs.max_lr, total_steps, warmup_steps, plateau_steps, decay_steps)
|
1038 |
+
|
1039 |
+
# warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_fn)
|
1040 |
+
# new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)
|
1041 |
+
# Cosine decay after warmup
|
1042 |
+
# new_scheduler = CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)
|
1043 |
+
|
1044 |
+
# Combine both schedulers
|
1045 |
+
# scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, new_scheduler], milestones=[2000])
|
1046 |
+
|
1047 |
+
# Reset learning rate to 1e-4
|
1048 |
+
# for param_group in optimizer.param_groups:
|
1049 |
+
# param_group['lr'] = ModelArgs.max_lr
|
1050 |
+
# scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1, eta_min=1e-6)
|
1051 |
+
# print("Old optimizer with new lr ready")
|
1052 |
+
|
1053 |
+
|
1054 |
+
|
1055 |
+
# optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr)
|
1056 |
+
# Create DataLoader with collate_fn
|
1057 |
+
# train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device))
|
1058 |
+
# val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device))
|
1059 |
+
# print("Loader is ready")
|
1060 |
+
# print(train_loader)
|
1061 |
+
# print(next(iter(train_loader)))
|
1062 |
+
|
1063 |
+
|
1064 |
+
# for X,y in train_loader:
|
1065 |
+
# print(X.shape)
|
1066 |
+
# print(y.shape)
|
1067 |
+
|
1068 |
+
# alpaca_prompt = '''
|
1069 |
+
|
1070 |
+
# ### Instruction:
|
1071 |
+
# {instruction}
|
1072 |
+
|
1073 |
+
# ### Input:
|
1074 |
+
# {input}
|
1075 |
+
|
1076 |
+
# ### Response:
|
1077 |
+
|
1078 |
+
# '''
|
1079 |
+
# Only create progress bar for rank 0
|
1080 |
+
# eval_epoch_iterator = range(eval_iters)
|
1081 |
+
# train_epoch_iterator = range(total_iters)
|
1082 |
+
# if device == 0:
|
1083 |
+
# train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training")
|
1084 |
+
|
1085 |
+
# train_epoch_iterator = range(ModelArgs.epochs)
|
1086 |
+
# if device == 0: # Ensure tqdm only runs on rank 0
|
1087 |
+
# train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training Progress", position=0, leave=True)
|
1088 |
+
|
1089 |
+
# lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters)
|
1090 |
+
|
1091 |
+
|
1092 |
+
|
1093 |
+
model.eval()
|
1094 |
+
world_size = torch.cuda.device_count()
|
1095 |
+
@torch.inference_mode()
|
1096 |
+
def estimate_loss(val_loader, val_iterator, device):
|
1097 |
+
out = {}
|
1098 |
+
# train_loader = prepare_dataset('train', ModelArgs.batch_size)
|
1099 |
+
|
1100 |
+
# val_loader_iterator = iter(val_loader)
|
1101 |
+
loader = None
|
1102 |
+
epoch_loss = None
|
1103 |
+
epoch_losses = []
|
1104 |
+
# print("Starting the eval...")
|
1105 |
+
for split in ['val']:
|
1106 |
+
print(f"Starting with {split} evaluation...")
|
1107 |
+
# losses = torch.zeros(ModelArgs.val_epochs)
|
1108 |
+
# if(split == 'train'):
|
1109 |
+
# loader = train_loader
|
1110 |
+
# if(split == 'val'):
|
1111 |
+
# loader = val_loader
|
1112 |
+
for step in range(eval_check):
|
1113 |
+
try:
|
1114 |
+
batch = next(val_iterator)
|
1115 |
+
except StopIteration:
|
1116 |
+
val_loader_iterator = iter(val_loader)
|
1117 |
+
batch = next(val_loader_iterator)
|
1118 |
+
|
1119 |
+
total_loss = 0
|
1120 |
+
# loader.sampler.set_epoch(step)
|
1121 |
+
total_batches = 0
|
1122 |
+
# batch = next(val_loader_iterator)
|
1123 |
+
# for batch in loader: # Loop through DataLoader batches
|
1124 |
+
idx = batch['input_ids']
|
1125 |
+
targets = batch['labels']
|
1126 |
+
idx = idx.to(device)
|
1127 |
+
targets = targets.to(device)
|
1128 |
+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
1129 |
+
|
1130 |
+
logits = model(idx)
|
1131 |
+
batch_size, block_size, embeddings_dims = logits.shape
|
1132 |
+
logits = logits.view(batch_size * block_size, embeddings_dims) # Flatten tokens
|
1133 |
+
targets = targets.view(batch_size * block_size)
|
1134 |
+
|
1135 |
+
loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
|
1136 |
+
|
1137 |
+
total_loss += loss.item()
|
1138 |
+
total_batches += 1
|
1139 |
+
|
1140 |
+
# Compute mean loss for this epoch
|
1141 |
+
epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0
|
1142 |
+
epoch_losses.append(epoch_loss)
|
1143 |
+
|
1144 |
+
# print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}")
|
1145 |
+
|
1146 |
+
# Compute mean loss across all evaluation epochs
|
1147 |
+
out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0
|
1148 |
+
epoch_loss = None
|
1149 |
+
epoch_losses = []
|
1150 |
+
|
1151 |
+
model.train()
|
1152 |
+
return out
|
1153 |
+
|
1154 |
+
# model = model.to(rank)
|
1155 |
+
model.train()
|
1156 |
+
count = 0
|
1157 |
+
|
1158 |
+
train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size)
|
1159 |
+
val_loader= prepare_dataset('val', device, ModelArgs.batch_size)
|
1160 |
+
# for step in tqdm(range(total_iters)):
|
1161 |
+
# for epoch in range(ModelArgs.epochs):
|
1162 |
+
# torch.cuda.synchronize()
|
1163 |
+
|
1164 |
+
# train_dataloader.sampler.set_epoch(epoch)
|
1165 |
+
|
1166 |
+
# val_loader.sampler.set_epoch(epoch)
|
1167 |
+
print("Loaders ready both")
|
1168 |
+
epochs = ModelArgs.epochs
|
1169 |
+
|
1170 |
+
# train_step_iterator = range(len(train_dataloader))
|
1171 |
+
# if device == 0: # Only create progress bar on rank 0
|
1172 |
+
# train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True)
|
1173 |
+
|
1174 |
+
# Print progress on rank 0
|
1175 |
+
train_loader_length = 0
|
1176 |
+
train_data_iterator = iter(train_dataloader)
|
1177 |
+
val_data_iterator = iter(val_loader)
|
1178 |
+
token_count = 0
|
1179 |
+
if(device == 0):
|
1180 |
+
train_loader_length = len(train_dataloader)
|
1181 |
+
# print("Total batches: ", train_loader_length)
|
1182 |
+
# print("Length of : ", len(train_dataloader))
|
1183 |
+
# print("Length of val: ", len(val_loader))
|
1184 |
+
# for step, batch in enumerate(train_dataloader):
|
1185 |
+
for step in tqdm(range(total_iters)):
|
1186 |
+
# print("Dataloader things: ", batch)
|
1187 |
+
# print("Total batches: ", len(train_dataloader))
|
1188 |
+
|
1189 |
+
|
1190 |
+
if(device == 0):
|
1191 |
+
# if(step % 100 == 0):
|
1192 |
+
# if(step == train_loader_length):
|
1193 |
+
# break
|
1194 |
+
print("Step : ", step, "/", total_iters)
|
1195 |
+
print('Total batches: ', len(train_dataloader))
|
1196 |
+
print("Total gradient accumulation steps: ", gradient_accumulation_steps)
|
1197 |
+
print("Total tokens processed: ", token_count)
|
1198 |
+
|
1199 |
+
# all_gpus_avg_train_loss = None
|
1200 |
+
# all_gpus_avg_val_loss = None
|
1201 |
+
# every once in a while evaluate the loss on train and val sets
|
1202 |
+
if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
|
1203 |
+
losses = estimate_loss( val_loader, val_data_iterator, 'cuda')
|
1204 |
+
# avg_train_loss = losses['train']
|
1205 |
+
avg_val_loss = losses['val']
|
1206 |
+
# print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
1207 |
+
# if device == 0: # Only print on main process
|
1208 |
+
print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}")
|
1209 |
+
# print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}")
|
1210 |
+
# print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
1211 |
+
# Log training loss more frequently
|
1212 |
+
# Aggregate average loss across all GPUs
|
1213 |
+
# avg_train_loss = torch.Tensor([losses['train']]).to(device)
|
1214 |
+
avg_val_loss = torch.Tensor([losses['val']]).to(device)
|
1215 |
+
# torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
1216 |
+
torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
1217 |
+
|
1218 |
+
if device == 0:
|
1219 |
+
# all_gpus_avg_train_loss = avg_train_loss / world_size
|
1220 |
+
# print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}")
|
1221 |
+
all_gpus_avg_val_loss = avg_val_loss / world_size
|
1222 |
+
print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}")
|
1223 |
+
|
1224 |
+
# if device == 0:
|
1225 |
+
|
1226 |
+
# writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step)
|
1227 |
+
# writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step)
|
1228 |
+
# writer.add_scalar("training_step_loss", losses['train'], global_step=step)
|
1229 |
+
# writer.add_scalar("val_step_loss", losses['val'], global_step=step)
|
1230 |
+
# writer.add_scalar("GPU", device, global_step=step)
|
1231 |
+
# writer.add_scalar("Epoch", epoch, global_step=step)
|
1232 |
+
|
1233 |
+
wandb.log({
|
1234 |
+
# "Learning Rate": optimizer.param_groups[0]['lr'],
|
1235 |
+
# "All_GPUs_Train_losses": all_gpus_avg_train_loss,
|
1236 |
+
"All_GPUs_Val_losses": all_gpus_avg_val_loss,
|
1237 |
+
# "training_step_loss": losses['train'],
|
1238 |
+
"val_step_loss": losses['val'],
|
1239 |
+
# "Step": step,
|
1240 |
+
# "Epoch": epoch
|
1241 |
+
})
|
1242 |
+
|
1243 |
+
|
1244 |
+
|
1245 |
+
#Loading a checkpoint
|
1246 |
+
# if(os.path.exists('snapshot.pt')):
|
1247 |
+
# model, optimizer = _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt')
|
1248 |
+
|
1249 |
+
# if(step % save_chechpoint_iter == 0 and device == 0 and step != 0):
|
1250 |
+
|
1251 |
+
# _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step)
|
1252 |
+
|
1253 |
+
if step % save_chechpoint_iter == 0 and device == 0 and step != 0:
|
1254 |
+
print(f"Saving the model checkpoint for step: {step}")
|
1255 |
+
_save_snapshot(model, optimizer, None, None, step)
|
1256 |
+
|
1257 |
+
accumulated_loss = 0.0
|
1258 |
+
|
1259 |
+
|
1260 |
+
optimizer.zero_grad(set_to_none=True)
|
1261 |
+
for micro_step in range(gradient_accumulation_steps):
|
1262 |
+
try:
|
1263 |
+
batch = next(train_data_iterator)
|
1264 |
+
except StopIteration:
|
1265 |
+
train_data_iterator = iter(train_dataloader)
|
1266 |
+
batch = next(train_data_iterator)
|
1267 |
+
# print(batch)
|
1268 |
+
# batch = next(train_data_iterator)
|
1269 |
+
# print(batch)
|
1270 |
+
# batch = {k: v.to(self.local_rank) for k, v in batch.items()}
|
1271 |
+
idx = batch['input_ids'].to(device)
|
1272 |
+
# idx, targets = get_batch(split='train')
|
1273 |
+
# print(f"Starting the train step: {step}...")
|
1274 |
+
# for idx, targets in train_loader:
|
1275 |
+
# idx, targets = next(iter(train_loader))
|
1276 |
+
|
1277 |
+
# print("Idx: ", idx)
|
1278 |
+
# print("Targets: ", targets)
|
1279 |
+
|
1280 |
+
# idx = idx.to(device)
|
1281 |
+
# print("Idx: ", idx)
|
1282 |
+
# print("Targets: ", targets)
|
1283 |
+
targets = batch['labels'].to(device)
|
1284 |
+
token_count += len(idx)
|
1285 |
+
with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16):
|
1286 |
+
logits = model(idx)
|
1287 |
+
batch_size, block_size, embeddings_dims = logits.shape
|
1288 |
+
# print(logits.shape)
|
1289 |
+
# print(targets)
|
1290 |
+
logits = logits.view(batch_size*block_size, embeddings_dims)
|
1291 |
+
# print("OK")
|
1292 |
+
targets = targets.view(batch_size * block_size)
|
1293 |
+
# print("OK2")
|
1294 |
+
loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
|
1295 |
+
|
1296 |
+
loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch
|
1297 |
+
accumulated_loss += loss.detach()
|
1298 |
+
|
1299 |
+
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices
|
1300 |
+
scaler.scale(loss).backward()
|
1301 |
+
# Check for unused parameters
|
1302 |
+
unused_params = find_unused_parameters(model)
|
1303 |
+
if unused_params:
|
1304 |
+
print(f"Unused parameters: {unused_params}")
|
1305 |
+
# break
|
1306 |
+
|
1307 |
+
if(device == 0):
|
1308 |
+
if(micro_step % 10 == 0):
|
1309 |
+
# if(step == train_loader_length):
|
1310 |
+
# break
|
1311 |
+
|
1312 |
+
print("Micro Batch : ", micro_step)
|
1313 |
+
print("Step : ", step, "/", total_iters)
|
1314 |
+
print('Total batches: ', len(train_dataloader))
|
1315 |
+
print("Total gradient accumulation steps: ", gradient_accumulation_steps)
|
1316 |
+
print("Total tokens processed: ", token_count)
|
1317 |
+
# count += 1
|
1318 |
+
|
1319 |
+
lr = get_lr(step)
|
1320 |
+
for params in optimizer.param_groups:
|
1321 |
+
params['lr'] = lr
|
1322 |
+
|
1323 |
+
|
1324 |
+
|
1325 |
+
# Compute gradient norms before clipping
|
1326 |
+
if(ModelArgs.clip != 0.0):
|
1327 |
+
scaler.unscale_(optimizer) #To avoid underflow
|
1328 |
+
total_norm_before = torch.norm(
|
1329 |
+
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
|
1330 |
+
)
|
1331 |
+
|
1332 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip)
|
1333 |
+
|
1334 |
+
# Compute gradient norms after clipping
|
1335 |
+
total_norm_after = torch.norm(
|
1336 |
+
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
if(device == 0 and step !=0):
|
1340 |
+
print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
|
1341 |
+
print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")
|
1342 |
+
|
1343 |
+
scaler.step(optimizer)
|
1344 |
+
scaler.update()
|
1345 |
+
|
1346 |
+
# optimizer.step()
|
1347 |
+
# new_scheduler.step()
|
1348 |
+
|
1349 |
+
torch.cuda.synchronize()
|
1350 |
+
torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
1351 |
+
if(device == 0):
|
1352 |
+
wandb.log({
|
1353 |
+
"Learning Rate": lr,
|
1354 |
+
"All_GPUs_Train_losses": accumulated_loss.item(),
|
1355 |
+
# "All_GPUs_Val_losses": all_gpus_avg_val_loss,
|
1356 |
+
# "training_step_loss": losses['train'],
|
1357 |
+
# "val_step_loss": losses['val'],
|
1358 |
+
"Step": step,
|
1359 |
+
# "Epoch": epoch
|
1360 |
+
|
1361 |
+
})
|
1362 |
+
# print(loss.item())
|
1363 |
+
# if(step % 100 == 0):
|
1364 |
+
# print(f'Step : {step} | GPU: {device} Loss: {loss.item()}')
|
1365 |
+
# if device == 0:
|
1366 |
+
# print("loss: ", loss.item())
|
1367 |
+
# train_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
|
1368 |
+
# print(loss.item())
|
1369 |
+
# break
|
1370 |
+
|
1371 |
+
# if step != 0 and (step % eval_iters == 0 or step == total_steps -1) :
|
1372 |
+
# loss_values = estimate_loss()
|
1373 |
+
# print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val']))
|
1374 |
+
|
1375 |
+
# Add after a training step:
|
1376 |
+
# unused_params = find_unused_parameters(model)
|
1377 |
+
# print("Unused parameters:", unused_params)
|
1378 |
+
# break
|
1379 |
+
if device == 0 and step % 5 == 0:
|
1380 |
+
count = 3
|
1381 |
+
while(count): # Only generate text on the main process
|
1382 |
+
# print("Generating text...")
|
1383 |
+
|
1384 |
+
# alpaca_prompt = '''
|
1385 |
+
|
1386 |
+
# ### Instruction:
|
1387 |
+
# {}
|
1388 |
+
|
1389 |
+
# ### Input:
|
1390 |
+
# {}
|
1391 |
+
|
1392 |
+
# ### Response:
|
1393 |
+
|
1394 |
+
# '''
|
1395 |
+
|
1396 |
+
# prompt = alpaca_prompt.format("You are a helpful assistant.", "Say a joke.", "")
|
1397 |
+
# print("Generating text")
|
1398 |
+
prompt = "Once upon a time"
|
1399 |
+
generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device)
|
1400 |
+
|
1401 |
+
# generated_text = greedy_decode(
|
1402 |
+
# model,
|
1403 |
+
# tokenizer,
|
1404 |
+
# "Once upon a time",
|
1405 |
+
# max_length=40,
|
1406 |
+
# repetition_penalty=1.2,
|
1407 |
+
# context_window=10,
|
1408 |
+
# temperature=0.7, # Lower temperature for more deterministic output
|
1409 |
+
# device=device
|
1410 |
+
# )
|
1411 |
+
# generated_text = beam_search(model, tokenizer, "Once upon a time ", beam_width=5, max_length=50, temperature=0.6)
|
1412 |
+
print(f" Step: {step} | Generated Text: {generated_text}")
|
1413 |
+
# model.train()
|
1414 |
+
# save_to_file(generated_text)
|
1415 |
+
count -= 1
|
1416 |
+
|
1417 |
+
# if step != 0:
|
1418 |
+
# train_step_iterator.set_postfix({"Train loss": f"{all_gpus_avg_train_loss.item():.4f} | Val Loss : {all_gpus_avg_val_loss.item():.4f}"})
|
1419 |
+
|
1420 |
+
|
1421 |
+
# break
|
1422 |
+
# Cleanup
|
1423 |
+
if device == 0:
|
1424 |
+
# writer.close()
|
1425 |
+
wandb.finish()
|
1426 |
+
cleanup()
|
1427 |
+
|
1428 |
+
|
1429 |
+
world_size = torch.cuda.device_count()
|
1430 |
+
print(f"World size: {world_size}")
|
1431 |
+
train()
|
1432 |
+
|
1433 |
+
|
1434 |
+
|
1435 |
+
|
metric.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import evaluate
|
3 |
+
|
4 |
+
from config import ModelArgs
|
5 |
+
from model import Llama
|
6 |
+
|
7 |
+
import evaluate
|
8 |
+
|
9 |
+
# Load the perplexity metric
|
10 |
+
perplexity = evaluate.load("perplexity")
|
11 |
+
|
12 |
+
|
13 |
+
def compute_perplexity(model_name, text):
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
results = perplexity.compute(predictions=[text], model_id=model_name)
|
18 |
+
|
19 |
+
return results["perplexities"][0]
|
20 |
+
|
21 |
+
# Example Usage
|
22 |
+
llama = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
|
23 |
+
llama = llama.to(ModelArgs.device)
|
24 |
+
|
25 |
+
text = "This is an example sentence for perplexity calculation."
|
26 |
+
|
27 |
+
ppl = compute_perplexity(llama, text)
|
28 |
+
print(f"Perplexity: {ppl}")
|
model.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from config import ModelArgs
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Normalization(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
|
12 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims)
|
16 |
+
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
|
20 |
+
x = self.rmsnorm_layer(x)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
# import numpy as np
|
28 |
+
class RotaryEmbeddings(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
device,
|
32 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
33 |
+
block_size: int = ModelArgs.block_size,
|
34 |
+
batch_size: int = ModelArgs.batch_size
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.embeddings_dims = embeddings_dims
|
39 |
+
self.block_size = block_size
|
40 |
+
self.batch_size = batch_size
|
41 |
+
self.theta = 0
|
42 |
+
self.device=device
|
43 |
+
|
44 |
+
# self.d_model = embeddings_dims
|
45 |
+
# self.i = torch.arange(0, embeddings_dims, dtype=torch.float32)
|
46 |
+
# # self.pos = torch.arange(0, block_size, dtype=torch.float32)
|
47 |
+
# self.exp = ((2 * self.i)) / self.d_model
|
48 |
+
# self.theta = 10000 ** self.exp
|
49 |
+
# # print(self.theta.shape)
|
50 |
+
# self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims,dtype=torch.float32, device=device)
|
51 |
+
|
52 |
+
# self.cos = torch.cos((self.i / self.theta))
|
53 |
+
# self.sin = torch.sin((self.i / self.theta))
|
54 |
+
|
55 |
+
# self.even = self.sin[::2]
|
56 |
+
# self.odd = self.cos[1::2]
|
57 |
+
|
58 |
+
# # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype)
|
59 |
+
# self.x_reshaped[..., : , ::2] = self.even
|
60 |
+
# self.x_reshaped[..., : , 1::2] = self.odd
|
61 |
+
|
62 |
+
|
63 |
+
def apply_rope(self, seq):
|
64 |
+
batch_size, seq_len, embeds_dims = seq.shape
|
65 |
+
# print(seq.shape)
|
66 |
+
# print(self.embeddings_dims)
|
67 |
+
# self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
|
68 |
+
|
69 |
+
positions = torch.arange(0 , embeds_dims, 2, dtype=torch.float32, device = self.device).unsqueeze(0)
|
70 |
+
# dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
|
71 |
+
theta = 10000 ** (-2 * (positions) / embeds_dims)
|
72 |
+
angles = positions * theta
|
73 |
+
angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
|
74 |
+
x_reshaped = seq.view(batch_size, seq_len, embeds_dims // 2, 2)
|
75 |
+
|
76 |
+
cos_angles = torch.cos(angles)
|
77 |
+
sin_angles = torch.sin(angles)
|
78 |
+
# print(cos_angles.shape)
|
79 |
+
# print(sin_angles.shape)
|
80 |
+
# print(x_reshaped.shape)
|
81 |
+
# indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device)
|
82 |
+
|
83 |
+
out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=-1)
|
84 |
+
out = out.view(batch_size, seq_len, embeds_dims)
|
85 |
+
return out
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
# print("X shape: ", x.shape)
|
89 |
+
# print("X is: ", x)
|
90 |
+
# B,T,C = x.shape
|
91 |
+
# print("MATRIX:",x)
|
92 |
+
# if(x > self.block_size or x < self.block_size):
|
93 |
+
# matrix = self.init_matrix(x)
|
94 |
+
# return matrix
|
95 |
+
# else:
|
96 |
+
# matrix = self.init_matrix(self.block_size)
|
97 |
+
|
98 |
+
# return matrix
|
99 |
+
# if(ModelArgs.inference):
|
100 |
+
res = self.apply_rope(x)
|
101 |
+
return res
|
102 |
+
# else:
|
103 |
+
# return self.x_reshaped
|
104 |
+
|
105 |
+
class RotaryAttentionHead(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
device,
|
109 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
110 |
+
no_of_heads: int = ModelArgs.no_of_heads,
|
111 |
+
attn_dropout: int = ModelArgs.attn_dropout
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
self.head_size = embeddings_dims // no_of_heads
|
115 |
+
self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
116 |
+
self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
117 |
+
self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
118 |
+
self.rope = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
119 |
+
self.dropout = nn.Dropout(p = attn_dropout)
|
120 |
+
self.device = device
|
121 |
+
def forward(self,x):
|
122 |
+
# print(x.shape)
|
123 |
+
# print("X is: ", x)
|
124 |
+
batch, block_size, embeddings_dims = x.shape
|
125 |
+
query = self.query(x)
|
126 |
+
# print(query)
|
127 |
+
key = self.key(x)
|
128 |
+
values = self.value(x)
|
129 |
+
# matrix = self.rotary_matrix(block_size)
|
130 |
+
rotary_q = self.rope(query)
|
131 |
+
rotary_k = self.rope(key)
|
132 |
+
|
133 |
+
# print(matrix.shape)
|
134 |
+
# print(query.shape)
|
135 |
+
masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
136 |
+
# rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
137 |
+
# rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
138 |
+
weights = rotary_q.permute(2,0,1) @ rotary_k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
139 |
+
weights_masked = weights.masked_fill(masked == 0, float('-inf'))
|
140 |
+
scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
|
141 |
+
scaled_weights = F.softmax(scaled_weights, dim=-1)
|
142 |
+
value = scaled_weights @ values
|
143 |
+
out = self.dropout(value)
|
144 |
+
return out
|
145 |
+
|
146 |
+
|
147 |
+
# # import numpy as np
|
148 |
+
# class RotaryEmbeddings(nn.Module):
|
149 |
+
# def __init__(
|
150 |
+
# self,
|
151 |
+
# device,
|
152 |
+
# embeddings_dims: int = ModelArgs.embeddings_dims,
|
153 |
+
# block_size: int = ModelArgs.block_size,
|
154 |
+
# batch_size: int = ModelArgs.batch_size
|
155 |
+
# ):
|
156 |
+
# super().__init__()
|
157 |
+
|
158 |
+
# self.embeddings_dims = embeddings_dims
|
159 |
+
# self.block_size = block_size
|
160 |
+
# self.batch_size = batch_size
|
161 |
+
# self.theta = 0
|
162 |
+
|
163 |
+
|
164 |
+
# # def init_matrix(self, seq_len):
|
165 |
+
# # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False)
|
166 |
+
# # for pos in range(seq_len):
|
167 |
+
# # for j in range(1, self.embeddings_dims // 2):
|
168 |
+
# # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
|
169 |
+
# # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
|
170 |
+
# # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
|
171 |
+
# # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
|
172 |
+
# # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
|
173 |
+
# # return self.matrix
|
174 |
+
# self.device=device
|
175 |
+
|
176 |
+
# def init_matrix(self, seq_len):
|
177 |
+
# self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
|
178 |
+
|
179 |
+
# positions = torch.arange(0 , seq_len, 2, dtype=torch.float32, device = self.device).unsqueeze(1)
|
180 |
+
# # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
|
181 |
+
# theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
|
182 |
+
# angles = positions * theta
|
183 |
+
|
184 |
+
# cos_angles = torch.cos(angles)
|
185 |
+
# sin_angles = torch.sin(angles)
|
186 |
+
|
187 |
+
# indices = torch.arange(seq_len, dtype=torch.int64, device = self.device)
|
188 |
+
# # print(indices)
|
189 |
+
# # print(indices.shape)
|
190 |
+
# # print(indices[::2])
|
191 |
+
# even_indices = indices[::2]
|
192 |
+
# odd_indices = indices[1::2]
|
193 |
+
|
194 |
+
# self.matrix[:, even_indices, even_indices] = cos_angles
|
195 |
+
# self.matrix[:, odd_indices, odd_indices] = sin_angles
|
196 |
+
# self.matrix[:, odd_indices, even_indices] = -sin_angles
|
197 |
+
# self.matrix[:, even_indices, odd_indices] = cos_angles
|
198 |
+
|
199 |
+
# return self.matrix
|
200 |
+
|
201 |
+
# def forward(self, x):
|
202 |
+
# # B,T,C = x.shape
|
203 |
+
# # print("MATRIX:",x)
|
204 |
+
# if(x > self.block_size or x < self.block_size):
|
205 |
+
# matrix = self.init_matrix(x)
|
206 |
+
# return matrix
|
207 |
+
# else:
|
208 |
+
# matrix = self.init_matrix(self.block_size)
|
209 |
+
|
210 |
+
# return matrix
|
211 |
+
|
212 |
+
|
213 |
+
# class RotaryAttentionHead(nn.Module):
|
214 |
+
# def __init__(
|
215 |
+
# self,
|
216 |
+
# device,
|
217 |
+
# embeddings_dims: int = ModelArgs.embeddings_dims,
|
218 |
+
# no_of_heads: int = ModelArgs.no_of_heads,
|
219 |
+
# attn_dropout: int = ModelArgs.attn_dropout
|
220 |
+
# ):
|
221 |
+
# super().__init__()
|
222 |
+
# self.head_size = embeddings_dims // no_of_heads
|
223 |
+
# self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
224 |
+
# self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
225 |
+
# self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
|
226 |
+
# self.rotary_matrix = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
227 |
+
# self.dropout = nn.Dropout(p = attn_dropout)
|
228 |
+
# self.device = device
|
229 |
+
# def forward(self,x):
|
230 |
+
# # print(x.shape)
|
231 |
+
# batch, block_size, embeddings_dims = x.shape
|
232 |
+
# query = self.query(x)
|
233 |
+
# # print(query)
|
234 |
+
# key = self.key(x)
|
235 |
+
# values = self.value(x)
|
236 |
+
# matrix = self.rotary_matrix(block_size)
|
237 |
+
|
238 |
+
# # print(matrix.shape)
|
239 |
+
# # print(query.shape)
|
240 |
+
# masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
241 |
+
# rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
242 |
+
# rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
243 |
+
# weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
244 |
+
# weights_masked = weights.masked_fill(masked == 0, float('-inf'))
|
245 |
+
# scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
|
246 |
+
# scaled_weights = F.softmax(scaled_weights, dim=-1)
|
247 |
+
# value = scaled_weights @ values
|
248 |
+
# out = self.dropout(value)
|
249 |
+
# return out
|
250 |
+
|
251 |
+
|
252 |
+
class MQA(nn.Module):
|
253 |
+
def __init__(
|
254 |
+
self,
|
255 |
+
device,
|
256 |
+
no_of_q_heads: int,
|
257 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
258 |
+
block_size: int = ModelArgs.block_size,
|
259 |
+
|
260 |
+
|
261 |
+
):
|
262 |
+
super().__init__()
|
263 |
+
|
264 |
+
|
265 |
+
# self.no_of_q_heads = no_of_heads // no_of_kv_heads
|
266 |
+
# self.no_of_q_heads = no_of_q_heads
|
267 |
+
self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads
|
268 |
+
self.head_size = embeddings_dims // no_of_q_heads
|
269 |
+
# self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2
|
270 |
+
self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
|
271 |
+
# self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size, device = device)
|
272 |
+
# self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False)
|
273 |
+
self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
|
274 |
+
self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
|
275 |
+
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
|
276 |
+
self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
|
277 |
+
self.device = device
|
278 |
+
self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, device = self.device) for _ in range(self.no_of_kv_heads)])
|
279 |
+
|
280 |
+
def scaled_dot_product(self, q, k, v, block_size):
|
281 |
+
|
282 |
+
# masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
283 |
+
q = self.rotary(q)
|
284 |
+
masked_table = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
|
285 |
+
# rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
286 |
+
# rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
|
287 |
+
# print("Query: ", q.shape)
|
288 |
+
# print("Keys: ", k.shape)
|
289 |
+
# print(q.permute(2,0,1).shape)
|
290 |
+
# print(k.permute(2,0,1).transpose(-2, -1).shape)
|
291 |
+
# weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
|
292 |
+
# weights = q @ k.permute(2,1,0)
|
293 |
+
# print(weights.shape)
|
294 |
+
# print(masked.shape)
|
295 |
+
weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
|
296 |
+
masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
|
297 |
+
weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
|
298 |
+
weights_normalized = self.dropout(weights_normalized)
|
299 |
+
out = weights_normalized @ v
|
300 |
+
return out
|
301 |
+
|
302 |
+
def forward(self,x):
|
303 |
+
# print("MQA: ", x.shape)
|
304 |
+
batch, block_size, embeddings_dims = x.shape
|
305 |
+
|
306 |
+
# query = self.query(x)
|
307 |
+
# matrix = self.rotary_matrix(block_size)
|
308 |
+
|
309 |
+
|
310 |
+
key = self.key(x)
|
311 |
+
values = self.value(x)
|
312 |
+
# print("Keys: ", key.shape)
|
313 |
+
# print("Values: ", values.shape)
|
314 |
+
# rotary_value = self.rotary(values)
|
315 |
+
rotary_key = self.rotary(key)
|
316 |
+
multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size) for query in self.multi_query], dim=-1)
|
317 |
+
# print("Multi query: ", multi_query_concat.shape)
|
318 |
+
|
319 |
+
linear_layer= self.linear_layer(multi_query_concat)
|
320 |
+
# out = self.dropout(linear_layer)
|
321 |
+
return linear_layer
|
322 |
+
|
323 |
+
|
324 |
+
class GQA(nn.Module):
|
325 |
+
def __init__(
|
326 |
+
self,
|
327 |
+
device,
|
328 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
329 |
+
block_size: int = ModelArgs.block_size,
|
330 |
+
# no_of_q_heads: int = ModelArgs.no_of_heads,
|
331 |
+
mqa_heads: int = ModelArgs.no_kv_heads
|
332 |
+
):
|
333 |
+
super().__init__()
|
334 |
+
|
335 |
+
# self.no_of_kv_heads = no_of_kv_heads
|
336 |
+
self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads
|
337 |
+
# self.head_dim = embeddings_dims // self.no_kv_heads
|
338 |
+
self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
|
339 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device)
|
340 |
+
self.device = device
|
341 |
+
self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)])
|
342 |
+
# self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size)
|
343 |
+
def forward(self,x):
|
344 |
+
|
345 |
+
batch, block_size, embeddings_dims = x.shape
|
346 |
+
|
347 |
+
# res = self.mqa(x)
|
348 |
+
grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1)
|
349 |
+
|
350 |
+
linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one
|
351 |
+
out = self.dropout(linear_layer)
|
352 |
+
return out
|
353 |
+
|
354 |
+
|
355 |
+
class Swish(nn.Module):
|
356 |
+
def __init__(
|
357 |
+
self,
|
358 |
+
device,
|
359 |
+
block_size: int = ModelArgs.block_size,
|
360 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
361 |
+
):
|
362 |
+
super().__init__()
|
363 |
+
|
364 |
+
self.sig = torch.nn.Sigmoid()
|
365 |
+
|
366 |
+
|
367 |
+
def forward(self, x):
|
368 |
+
swish = x * self.sig(x)
|
369 |
+
|
370 |
+
return swish
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
class SWiGLU(nn.Module):
|
375 |
+
def __init__(
|
376 |
+
self,
|
377 |
+
device,
|
378 |
+
block_size: int = ModelArgs.block_size,
|
379 |
+
embeddings_dims: int = ModelArgs.embeddings_dims
|
380 |
+
):
|
381 |
+
super().__init__()
|
382 |
+
self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3)
|
383 |
+
self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
|
384 |
+
self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
|
385 |
+
self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
|
386 |
+
self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
def forward(self, x):
|
392 |
+
swish_res = self.swish(self.linear_layer1(x))
|
393 |
+
x_V = self.linear_layer2(x)
|
394 |
+
res = torch.mul(swish_res, x_V)
|
395 |
+
out = self.linear_layer3(res)
|
396 |
+
return out
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
+
class FFN(nn.Module):
|
401 |
+
def __init__(self,
|
402 |
+
device,
|
403 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
404 |
+
block_size: int = ModelArgs.block_size,
|
405 |
+
vocab_size: int = ModelArgs.vocab_size,
|
406 |
+
dropout = ModelArgs.dropout
|
407 |
+
|
408 |
+
):
|
409 |
+
super().__init__()
|
410 |
+
|
411 |
+
# self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
|
412 |
+
self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device)
|
413 |
+
self.dropout = nn.Dropout(p = dropout)
|
414 |
+
def forward(self, x):
|
415 |
+
|
416 |
+
x = self.swiglue(x)
|
417 |
+
# x = self.linear_layer(x)
|
418 |
+
x = self.dropout(x)
|
419 |
+
return x
|
420 |
+
|
421 |
+
|
422 |
+
class DecoderLayer(nn.Module):
|
423 |
+
def __init__(self,
|
424 |
+
device,
|
425 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
426 |
+
dropout = ModelArgs.dropout,
|
427 |
+
block_size: int = ModelArgs.block_size,
|
428 |
+
vocab_size: int = ModelArgs.vocab_size,
|
429 |
+
|
430 |
+
) :
|
431 |
+
super().__init__()
|
432 |
+
|
433 |
+
|
434 |
+
self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
|
435 |
+
self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2, device = device)
|
436 |
+
# self.norm = Normalization(embeddings_dims=embeddings_dims)
|
437 |
+
self.norm1 = Normalization(embeddings_dims=embeddings_dims)
|
438 |
+
self.norm2 = Normalization(embeddings_dims=embeddings_dims)
|
439 |
+
self.dropout = nn.Dropout(p = dropout)
|
440 |
+
def forward(self, x):
|
441 |
+
|
442 |
+
x = x + self.gqa(self.norm1(x))
|
443 |
+
x = x + self.feedforward_network(self.norm2(x))
|
444 |
+
return x
|
445 |
+
|
446 |
+
|
447 |
+
class Llama(nn.Module):
|
448 |
+
def __init__(self,
|
449 |
+
device,
|
450 |
+
embeddings_dims: int = ModelArgs.embeddings_dims,
|
451 |
+
no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
|
452 |
+
block_size: int = ModelArgs.block_size,
|
453 |
+
vocab_size: int = ModelArgs.vocab_size,
|
454 |
+
dropout = ModelArgs.dropout
|
455 |
+
|
456 |
+
) :
|
457 |
+
super().__init__()
|
458 |
+
|
459 |
+
self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
|
460 |
+
self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)])
|
461 |
+
self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
|
462 |
+
self.dropout = nn.Dropout(p = dropout)
|
463 |
+
# self.norm = Normalization(embeddings_dims)
|
464 |
+
|
465 |
+
|
466 |
+
#weight tying
|
467 |
+
self.embeddings.weight = self.linear_layer.weight
|
468 |
+
|
469 |
+
self.apply(self._init_weights)
|
470 |
+
|
471 |
+
def _init_weights(self, module):
|
472 |
+
if isinstance(module, nn.Linear):
|
473 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
474 |
+
|
475 |
+
if module.bias is not None:
|
476 |
+
nn.init.zeros_(module.bias)
|
477 |
+
elif isinstance(module, nn.Embedding):
|
478 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
479 |
+
|
480 |
+
|
481 |
+
|
482 |
+
def forward(self, x):
|
483 |
+
x = self.embeddings(x)
|
484 |
+
x = self.dropout(x)
|
485 |
+
x = self.decoder(x)
|
486 |
+
# x = self.norm(x)
|
487 |
+
x = self.linear_layer(x)
|
488 |
+
# out = self.norm(x)
|
489 |
+
return x
|
tokenizer.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class Tokenizer:
|
7 |
+
|
8 |
+
def __init__(self) -> None:
|
9 |
+
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '...')
|
11 |
+
|
12 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
13 |
+
|
14 |
+
def ready_tokenizer(self):
|
15 |
+
|
16 |
+
return self.tokenizer
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
trainer.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
7 |
+
from torch.distributed import init_process_group, destroy_process_group
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import wandb
|
11 |
+
|
12 |
+
|
13 |
+
import torch.optim as optim
|
14 |
+
|
15 |
+
|
16 |
+
import os
|
17 |
+
from config import ModelArgs
|
18 |
+
from model import Llama
|
19 |
+
|
20 |
+
from inference import greedy_decode
|
21 |
+
from data import prepare_dataset
|
22 |
+
from tokenizer import Tokenizer
|
23 |
+
|
24 |
+
|
25 |
+
torch.set_float32_matmul_precision('high')
|
26 |
+
|
27 |
+
scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16'))
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
save_chechpoint_iter = 50
|
32 |
+
total_iters = 10000
|
33 |
+
eval_iters = 50
|
34 |
+
eval_check = 100
|
35 |
+
warmup_iters = 700
|
36 |
+
min_lr = 0.1 * ModelArgs.max_lr
|
37 |
+
lr_decay_iters = 10000
|
38 |
+
total_batch_size = 524288
|
39 |
+
micro_batch_size = ModelArgs.batch_size
|
40 |
+
gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count()))
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class Trainer:
|
45 |
+
|
46 |
+
def __init__(self, model_args):
|
47 |
+
|
48 |
+
|
49 |
+
def setup(rank=None, world_size=None):
|
50 |
+
# os.environ['MASTER_ADDR'] = 'localhost'
|
51 |
+
# os.environ['MASTER_PORT'] = '12355'
|
52 |
+
init_process_group("nccl")
|
53 |
+
# torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
|
54 |
+
|
55 |
+
self.model_args = model_args
|
56 |
+
self.tokenizer = Tokenizer().ready_tokenizer()
|
57 |
+
setup()
|
58 |
+
|
59 |
+
def cleanup(self):
|
60 |
+
destroy_process_group()
|
61 |
+
|
62 |
+
def _save_snapshot(self, model, optimizer, epoch, step, save_dir):
|
63 |
+
snapshot = {}
|
64 |
+
snapshot["MODEL_STATE"] = model.module.state_dict()
|
65 |
+
snapshot["OPTIMIZER_STATE"]= optimizer.state_dict()
|
66 |
+
snapshot["EPOCHS_RUN"] = epoch
|
67 |
+
snapshot["STEP_RUN"] = step
|
68 |
+
torch.save(snapshot, os.path.join(save_dir, "snapshot.pt"))
|
69 |
+
print(f"Epoch: {epoch} | step {step} | Training snapshot saved at snapshot.pt")
|
70 |
+
|
71 |
+
# Warmup phase for 2000 steps
|
72 |
+
def warmup_fn(step):
|
73 |
+
if step < 2000:
|
74 |
+
return step / 2000 # LR gradually increases
|
75 |
+
return 1.0
|
76 |
+
|
77 |
+
|
78 |
+
# learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py
|
79 |
+
|
80 |
+
def get_lr(it):
|
81 |
+
# 1) linear warmup for warmup_iters steps
|
82 |
+
if it < warmup_iters:
|
83 |
+
return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1)
|
84 |
+
# 2) if it > lr_decay_iters, return min learning rate
|
85 |
+
if it > lr_decay_iters:
|
86 |
+
return min_lr
|
87 |
+
# 3) in between, use cosine decay down to min learning rate
|
88 |
+
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
89 |
+
assert 0 <= decay_ratio <= 1
|
90 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
91 |
+
return min_lr + coeff * (ModelArgs.max_lr - min_lr)
|
92 |
+
|
93 |
+
|
94 |
+
def train():
|
95 |
+
|
96 |
+
setup()
|
97 |
+
device = int(os.environ["LOCAL_RANK"])
|
98 |
+
|
99 |
+
torch.cuda.set_device(int(device))
|
100 |
+
|
101 |
+
print(f"Start running DDP on rank {device}.")
|
102 |
+
|
103 |
+
if(device == 0):
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
# # Initialise run
|
108 |
+
wandb.init(
|
109 |
+
# entity = 'rajceo2031',
|
110 |
+
project = 'Llama-DDP-Pretrain-10-billion-tokens',
|
111 |
+
# config = CFG,
|
112 |
+
# save_code = True,
|
113 |
+
#group = 'ANN',
|
114 |
+
#job_type = 'train'
|
115 |
+
)
|
116 |
+
print("wand initialized")
|
117 |
+
|
118 |
+
model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
|
119 |
+
|
120 |
+
# print(f"Model on device {device} is ready")
|
121 |
+
print(f"Model on device {device} is ready")
|
122 |
+
|
123 |
+
|
124 |
+
optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps)
|
125 |
+
|
126 |
+
# model = torch.compile(model)
|
127 |
+
model = model.to(device)
|
128 |
+
|
129 |
+
model = DDP(model, device_ids=[device])
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
model.eval()
|
136 |
+
world_size = torch.cuda.device_count()
|
137 |
+
@torch.inference_mode()
|
138 |
+
def estimate_loss(val_loader, val_iterator, device):
|
139 |
+
out = {}
|
140 |
+
|
141 |
+
loader = None
|
142 |
+
epoch_loss = None
|
143 |
+
epoch_losses = []
|
144 |
+
|
145 |
+
for split in ['val']:
|
146 |
+
print(f"Starting with {split} evaluation...")
|
147 |
+
|
148 |
+
for step in range(eval_check):
|
149 |
+
try:
|
150 |
+
batch = next(val_iterator)
|
151 |
+
except StopIteration:
|
152 |
+
val_loader_iterator = iter(val_loader)
|
153 |
+
batch = next(val_loader_iterator)
|
154 |
+
|
155 |
+
total_loss = 0
|
156 |
+
|
157 |
+
total_batches = 0
|
158 |
+
|
159 |
+
idx = batch['input_ids']
|
160 |
+
targets = batch['labels']
|
161 |
+
idx = idx.to(device)
|
162 |
+
targets = targets.to(device)
|
163 |
+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
164 |
+
|
165 |
+
logits = model(idx)
|
166 |
+
batch_size, block_size, embeddings_dims = logits.shape
|
167 |
+
logits = logits.view(batch_size * block_size, embeddings_dims)
|
168 |
+
targets = targets.view(batch_size * block_size)
|
169 |
+
|
170 |
+
loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
|
171 |
+
|
172 |
+
total_loss += loss.item()
|
173 |
+
total_batches += 1
|
174 |
+
|
175 |
+
|
176 |
+
epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0
|
177 |
+
epoch_losses.append(epoch_loss)
|
178 |
+
|
179 |
+
|
180 |
+
out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0
|
181 |
+
epoch_loss = None
|
182 |
+
epoch_losses = []
|
183 |
+
|
184 |
+
model.train()
|
185 |
+
return out
|
186 |
+
|
187 |
+
|
188 |
+
model.train()
|
189 |
+
count = 0
|
190 |
+
|
191 |
+
train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size)
|
192 |
+
val_loader= prepare_dataset('val', device, ModelArgs.batch_size)
|
193 |
+
|
194 |
+
print("Loaders ready both")
|
195 |
+
epochs = ModelArgs.epochs
|
196 |
+
|
197 |
+
train_loader_length = 0
|
198 |
+
train_data_iterator = iter(train_dataloader)
|
199 |
+
val_data_iterator = iter(val_loader)
|
200 |
+
token_count = 0
|
201 |
+
if(device == 0):
|
202 |
+
train_loader_length = len(train_dataloader)
|
203 |
+
|
204 |
+
for step in tqdm(range(total_iters)):
|
205 |
+
|
206 |
+
|
207 |
+
if(device == 0):
|
208 |
+
|
209 |
+
print("Step : ", step, "/", total_iters)
|
210 |
+
print('Total batches: ', len(train_dataloader))
|
211 |
+
print("Total gradient accumulation steps: ", gradient_accumulation_steps)
|
212 |
+
print("Total tokens processed: ", token_count)
|
213 |
+
|
214 |
+
|
215 |
+
if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
|
216 |
+
losses = estimate_loss( val_loader, val_data_iterator, 'cuda')
|
217 |
+
# avg_train_loss = losses['train']
|
218 |
+
avg_val_loss = losses['val']
|
219 |
+
|
220 |
+
print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}")
|
221 |
+
|
222 |
+
avg_val_loss = torch.Tensor([losses['val']]).to(device)
|
223 |
+
# torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
224 |
+
torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
225 |
+
|
226 |
+
if device == 0:
|
227 |
+
|
228 |
+
all_gpus_avg_val_loss = avg_val_loss / world_size
|
229 |
+
print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}")
|
230 |
+
|
231 |
+
wandb.log({
|
232 |
+
# "Learning Rate": optimizer.param_groups[0]['lr'],
|
233 |
+
# "All_GPUs_Train_losses": all_gpus_avg_train_loss,
|
234 |
+
"All_GPUs_Val_losses": all_gpus_avg_val_loss,
|
235 |
+
# "training_step_loss": losses['train'],
|
236 |
+
"val_step_loss": losses['val'],
|
237 |
+
# "Step": step,
|
238 |
+
# "Epoch": epoch
|
239 |
+
})
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
if step % save_chechpoint_iter == 0 and device == 0 and step != 0:
|
246 |
+
print(f"Saving the model checkpoint for step: {step}")
|
247 |
+
_save_snapshot(model, optimizer, None, None, step)
|
248 |
+
|
249 |
+
accumulated_loss = 0.0
|
250 |
+
|
251 |
+
|
252 |
+
optimizer.zero_grad(set_to_none=True)
|
253 |
+
for micro_step in range(gradient_accumulation_steps):
|
254 |
+
try:
|
255 |
+
batch = next(train_data_iterator)
|
256 |
+
except StopIteration:
|
257 |
+
train_data_iterator = iter(train_dataloader)
|
258 |
+
batch = next(train_data_iterator)
|
259 |
+
# print(batch)
|
260 |
+
# batch = next(train_data_iterator)
|
261 |
+
# print(batch)
|
262 |
+
# batch = {k: v.to(self.local_rank) for k, v in batch.items()}
|
263 |
+
idx = batch['input_ids'].to(device)
|
264 |
+
# idx, targets = get_batch(split='train')
|
265 |
+
# print(f"Starting the train step: {step}...")
|
266 |
+
# for idx, targets in train_loader:
|
267 |
+
# idx, targets = next(iter(train_loader))
|
268 |
+
|
269 |
+
# print("Idx: ", idx)
|
270 |
+
# print("Targets: ", targets)
|
271 |
+
|
272 |
+
# idx = idx.to(device)
|
273 |
+
# print("Idx: ", idx)
|
274 |
+
# print("Targets: ", targets)
|
275 |
+
targets = batch['labels'].to(device)
|
276 |
+
token_count += len(idx)
|
277 |
+
with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16):
|
278 |
+
logits = model(idx)
|
279 |
+
batch_size, block_size, embeddings_dims = logits.shape
|
280 |
+
# print(logits.shape)
|
281 |
+
# print(targets)
|
282 |
+
logits = logits.view(batch_size*block_size, embeddings_dims)
|
283 |
+
# print("OK")
|
284 |
+
targets = targets.view(batch_size * block_size)
|
285 |
+
# print("OK2")
|
286 |
+
loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
|
287 |
+
|
288 |
+
loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch
|
289 |
+
accumulated_loss += loss.detach()
|
290 |
+
|
291 |
+
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices
|
292 |
+
scaler.scale(loss).backward()
|
293 |
+
# Check for unused parameters
|
294 |
+
unused_params = find_unused_parameters(model)
|
295 |
+
if unused_params:
|
296 |
+
print(f"Unused parameters: {unused_params}")
|
297 |
+
# break
|
298 |
+
|
299 |
+
if(device == 0):
|
300 |
+
if(micro_step % 10 == 0):
|
301 |
+
# if(step == train_loader_length):
|
302 |
+
# break
|
303 |
+
|
304 |
+
print("Micro Batch : ", micro_step)
|
305 |
+
print("Step : ", step, "/", total_iters)
|
306 |
+
print('Total batches: ', len(train_dataloader))
|
307 |
+
print("Total gradient accumulation steps: ", gradient_accumulation_steps)
|
308 |
+
print("Total tokens processed: ", token_count)
|
309 |
+
# count += 1
|
310 |
+
|
311 |
+
lr = get_lr(step)
|
312 |
+
for params in optimizer.param_groups:
|
313 |
+
params['lr'] = lr
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
# Compute gradient norms before clipping
|
318 |
+
if(ModelArgs.clip != 0.0):
|
319 |
+
scaler.unscale_(optimizer) #To avoid underflow
|
320 |
+
total_norm_before = torch.norm(
|
321 |
+
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
|
322 |
+
)
|
323 |
+
|
324 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip)
|
325 |
+
|
326 |
+
# Compute gradient norms after clipping
|
327 |
+
total_norm_after = torch.norm(
|
328 |
+
torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
|
329 |
+
)
|
330 |
+
|
331 |
+
if(device == 0 and step !=0):
|
332 |
+
print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
|
333 |
+
print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")
|
334 |
+
|
335 |
+
scaler.step(optimizer)
|
336 |
+
scaler.update()
|
337 |
+
|
338 |
+
# optimizer.step()
|
339 |
+
# new_scheduler.step()
|
340 |
+
|
341 |
+
torch.cuda.synchronize()
|
342 |
+
torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
343 |
+
if(device == 0):
|
344 |
+
wandb.log({
|
345 |
+
"Learning Rate": lr,
|
346 |
+
"All_GPUs_Train_losses": accumulated_loss.item(),
|
347 |
+
# "All_GPUs_Val_losses": all_gpus_avg_val_loss,
|
348 |
+
# "training_step_loss": losses['train'],
|
349 |
+
# "val_step_loss": losses['val'],
|
350 |
+
"Step": step,
|
351 |
+
# "Epoch": epoch
|
352 |
+
|
353 |
+
})
|
354 |
+
# print(loss.item())
|
355 |
+
|
356 |
+
# break
|
357 |
+
if device == 0 and step % 5 == 0:
|
358 |
+
count = 3
|
359 |
+
while(count): # Only generate text on the main process
|
360 |
+
|
361 |
+
prompt = "Once upon a time"
|
362 |
+
generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device)
|
363 |
+
|
364 |
+
|
365 |
+
print(f" Step: {step} | Generated Text: {generated_text}")
|
366 |
+
|
367 |
+
count -= 1
|
368 |
+
|
369 |
+
|
370 |
+
if device == 0:
|
371 |
+
|
372 |
+
wandb.finish()
|
373 |
+
cleanup()
|
374 |
+
|
375 |
+
|
376 |
+
world_size = torch.cuda.device_count()
|
377 |
+
print(f"World size: {world_size}")
|
378 |
+
|
379 |
+
|
380 |
+
|
381 |
+
|
382 |
+
def parse_args():
|
383 |
+
parser = argparse.ArgumentParser(description="Model Training Arguments")
|
384 |
+
|
385 |
+
# Add arguments for each field in ModelArgs
|
386 |
+
parser.add_argument("--epochs", type=int, default=ModelArgs.epochs, help="Number of training epochs.")
|
387 |
+
parser.add_argument("--block_size", type=int, default=ModelArgs.block_size, help="Block size for the model.")
|
388 |
+
parser.add_argument("--batch_size", type=int, default=ModelArgs.batch_size, help="Batch size for training.")
|
389 |
+
# parser.add_argument("--inference", type=lambda x: (str(x).lower() == 'true'), default=ModelArgs.inference, help="Whether to run in inference mode.")
|
390 |
+
parser.add_argument("--embeddings_dims", type=int, default=ModelArgs.embeddings_dims, help="Embedding dimensions.")
|
391 |
+
parser.add_argument("--attn_dropout", type=float, default=ModelArgs.attn_dropout, help="Attention dropout rate.")
|
392 |
+
parser.add_argument("--no_of_heads", type=int, default=ModelArgs.no_of_heads, help="Number of attention heads.")
|
393 |
+
parser.add_argument("--dropout", type=float, default=ModelArgs.dropout, help="Dropout rate.")
|
394 |
+
parser.add_argument("--val_epochs", type=int, default=ModelArgs.val_epochs, help="Number of validation epochs.")
|
395 |
+
parser.add_argument("--max_lr", type=float, default=ModelArgs.max_lr, help="Learning rate.")
|
396 |
+
parser.add_argument("--no_of_decoder_layers", type=int, default=ModelArgs.no_of_decoder_layers, help="Number of decoder layers.")
|
397 |
+
parser.add_argument("--weight_decay_optim", type=float, default=ModelArgs.weight_decay_optim, help="Weight decay for optimizer.")
|
398 |
+
parser.add_argument("--beta_1", type=float, default=ModelArgs.beta_1, help="Beta1 for Adam optimizer.")
|
399 |
+
parser.add_argument("--beta_2", type=float, default=ModelArgs.beta_2, help="Beta2 for Adam optimizer.")
|
400 |
+
parser.add_argument("--clip", type=float, default=ModelArgs.clip, help="Gradient clipping value.")
|
401 |
+
parser.add_argument("--device", type=str, default=ModelArgs.device, help="Device to run the model on (e.g., 'cuda' or 'cpu').")
|
402 |
+
parser.add_argument("--no_kv_heads", type=int, default=ModelArgs.no_kv_heads, help="Number of key/value heads.")
|
403 |
+
parser.add_argument("--vocab_size", type=int, default=ModelArgs.vocab_size, help="Vocabulary size.")
|
404 |
+
parser.add_argument("--eps", type=float, default=ModelArgs.eps, help="Epsilon value for numerical stability.")
|
405 |
+
parser.add_argument("--dtype", type=str, default=ModelArgs.dtype, help="Data type for tensors (e.g., 'float16' or 'bfloat16').")
|
406 |
+
parser.add_argument("--save_checkpoint_dir", type=str, default=ModelArgs.save_checkpoint_dir, help="Directory to save model checkpoints.")
|
407 |
+
parser.add_argument("--prompt", type=str, default=ModelArgs.prompt, help="Prompt for testing during training.")
|
408 |
+
|
409 |
+
# Additional arguments
|
410 |
+
parser.add_argument("--save_checkpoint_iter", type=int, default=ModelArgs.save_checkpoint_iter, help="Save checkpoint every N iterations.")
|
411 |
+
parser.add_argument("--total_iters", type=int, default=ModelArgs.total_iters, help="Total number of training iterations.")
|
412 |
+
parser.add_argument("--eval_iters", type=int, default=ModelArgs.eval_iters, help="Number of iterations for evaluation.")
|
413 |
+
parser.add_argument("--eval_check", type=int, default=ModelArgs.eval_check, help="Evaluate model every N iterations.")
|
414 |
+
parser.add_argument("--warmup_iters", type=int, default=ModelArgs.warmup_iters, help="Number of warmup iterations for learning rate scheduling.")
|
415 |
+
parser.add_argument("--min_lr", type=float, default=ModelArgs.min_lr, help="Minimum learning rate.")
|
416 |
+
parser.add_argument("--lr_decay_iters", type=int, default=ModelArgs.lr_decay_iters, help="Number of iterations for learning rate decay.")
|
417 |
+
parser.add_argument("--total_batch_size", type=int, default=ModelArgs.total_batch_size, help="Total batch size across all devices.")
|
418 |
+
parser.add_argument("--micro_batch_size", type=int, default=ModelArgs.micro_batch_size, help="Micro batch size per device.")
|
419 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=ModelArgs.gradient_accumulation_steps, help="Number of gradient accumulation steps.")
|
420 |
+
|
421 |
+
args = parser.parse_args()
|
422 |
+
return args
|
423 |
+
|
424 |
+
|
425 |
+
def initialize_model_args(args):
|
426 |
+
# Create a ModelArgs instance from the parsed arguments
|
427 |
+
model_args = ModelArgs(
|
428 |
+
epochs=args.epochs,
|
429 |
+
block_size=args.block_size,
|
430 |
+
batch_size=args.batch_size,
|
431 |
+
# inference=args.inference,
|
432 |
+
embeddings_dims=args.embeddings_dims,
|
433 |
+
attn_dropout=args.attn_dropout,
|
434 |
+
no_of_heads=args.no_of_heads,
|
435 |
+
dropout=args.dropout,
|
436 |
+
val_epochs=args.val_epochs,
|
437 |
+
max_lr=args.max_lr,
|
438 |
+
no_of_decoder_layers=args.no_of_decoder_layers,
|
439 |
+
weight_decay_optim=args.weight_decay_optim,
|
440 |
+
beta_1=args.beta_1,
|
441 |
+
beta_2=args.beta_2,
|
442 |
+
clip=args.clip,
|
443 |
+
device=args.device,
|
444 |
+
no_kv_heads=args.no_kv_heads,
|
445 |
+
vocab_size=args.vocab_size,
|
446 |
+
eps=args.eps,
|
447 |
+
dtype=args.dtype,
|
448 |
+
save_checkpoint_dir=args.save_checkpoint_dir,
|
449 |
+
prompt=args.prompt,
|
450 |
+
save_checkpoint_iter=args.save_checkpoint_iter,
|
451 |
+
total_iters=args.total_iters,
|
452 |
+
eval_iters=args.eval_iters,
|
453 |
+
eval_check=args.eval_check,
|
454 |
+
warmup_iters=args.warmup_iters,
|
455 |
+
min_lr=args.min_lr,
|
456 |
+
lr_decay_iters=args.lr_decay_iters,
|
457 |
+
total_batch_size=args.total_batch_size,
|
458 |
+
micro_batch_size=args.micro_batch_size,
|
459 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps
|
460 |
+
)
|
461 |
+
return model_args
|
462 |
+
|
463 |
+
|
464 |
+
if __name__ == "__main__":
|
465 |
+
args = parse_args()
|
466 |
+
|
467 |
+
|
468 |
+
model_args = initialize_model_args(args)
|
469 |
+
|