YuvrajSingh9886 commited on
Commit
5bb6ad4
·
verified ·
1 Parent(s): 65b3f00

Upload 12 files

Browse files
Files changed (12) hide show
  1. .gitignore +14 -0
  2. README.md +187 -10
  3. config.py +41 -0
  4. data.py +117 -0
  5. download_model_weight.py +131 -0
  6. fine_tune.py +1282 -0
  7. inference.py +84 -0
  8. llama_torchrun.py +1435 -0
  9. metric.py +28 -0
  10. model.py +489 -0
  11. tokenizer.py +21 -0
  12. trainer.py +469 -0
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ snapshot.pt
2
+ snapshot2.pt
3
+ llama.py
4
+ snapshot_3.pt
5
+ metric.py
6
+ weights/
7
+ gpt4all.json
8
+ fine_tune.py
9
+ old_files/
10
+ snapshot_4650.pt
11
+ snapshot (1).pt
12
+
13
+
14
+
README.md CHANGED
@@ -1,13 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: StoryLlama
3
- emoji: 🐠
4
- colorFrom: indigo
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.21.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Introducing StoryLlama - A Smaller Language Model for Bedtime Stories!
3
+
4
+ - So, I trained a Llama a 88M architecture I coded from ground up to build a small instruct model, going through the below-mentioned stages from scratch.
5
+ - Trained on TiyStories dataset form HuggingFace consisting of 4B tokens for a total of 5000 steps
6
+
7
+
8
+
9
+ ### Pretraining
10
+
11
+ #### Dataset
12
+
13
+ - I used the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset from HuggingFace.
14
+
15
+ 1) Train dataset - 2 M records approx
16
+ 2) Val dataset - 26K records approx
17
+
18
+
19
+
20
+ ---
21
+
22
+ #### ModelArgs (Hyperparameters)
23
+
24
+
25
+ Below is a table summarizing the configuration parameters for the model:
26
+
27
+ | Parameter | Description | Default Value | Type |
28
+ |--------------------------------|-----------------------------------------------------------------------------|-----------------------------------|-----------|
29
+ | `epochs` | Number of training epochs | `4` | `int` |
30
+ | `block_size` | Size of each block (context length) | `512` | `int` |
31
+ | `batch_size` | Batch size for training | `64` | `int` |
32
+ | `inference` | Inference mode (not specified) | `None` | `None` |
33
+ | `embeddings_dims` | Dimensionality of embeddings | `512` | `int` |
34
+ | `attn_dropout` | Dropout rate for attention layers | `0.1` | `float` |
35
+ | `no_of_heads` | Number of attention heads | `8` | `int` |
36
+ | `dropout` | Dropout rate for the model | `0.1` | `float` |
37
+ | `val_epochs` | Number of validation epochs | `2` | `int` |
38
+ | `max_lr` | Maximum learning rate | `6e-4` | `float` |
39
+ | `no_of_decoder_layers` | Number of decoder layers | `8` | `int` |
40
+ | `weight_decay_optim` | Weight decay for the optimizer | `0.1` | `float` |
41
+ | `beta_1` | Beta 1 for Adam optimizer | `0.9` | `float` |
42
+ | `beta_2` | Beta 2 for Adam optimizer | `0.95` | `float` |
43
+ | `clip` | Gradient clipping value | `1.0` | `float` |
44
+ | `device` | Device to run the model (`cuda` or `cpu`) | `'cuda'` | `str` |
45
+ | `no_kv_heads` | Number of key-value heads | `2` | `int` |
46
+ | `vocab_size` | Size of the vocabulary | `50304` | `int` |
47
+ | `eps` | Epsilon value for numerical stability | `1e-5` | `float` |
48
+ | `dtype` | Data type for tensors (`bfloat16` if supported, else `float16`) | `'bfloat16'` or `'float16'` | `str` |
49
+ | `save_checkpoint_dir` | Directory to save model checkpoints | `"checkpoints"` | `str` |
50
+ | `prompt` | Default prompt for inference | `"Once upon a time"` | `str` |
51
+ | `save_checkpoint_iter` | Save checkpoint every N iterations | `50` | `int` |
52
+ | `total_iters` | Total number of training iterations | `10000` | `int` |
53
+ | `eval_iters` | Evaluate model every N iterations | `50` | `int` |
54
+ | `eval_check` | Check evaluation metrics every N iterations | `100` | `int` |
55
+ | `warmup_iters` | Number of warmup iterations for learning rate scheduling | `700` | `int` |
56
+ | `min_lr` | Minimum learning rate (10% of `max_lr`) | `0.1 * max_lr` | `float` |
57
+ | `lr_decay_iters` | Number of iterations for learning rate decay | `10000` | `int` |
58
+ | `total_batch_size` | Total batch size across all devices | `524288` | `int` |
59
+ | `micro_batch_size` | Micro batch size per device | `batch_size` | `int` |
60
+ | `gradient_accumulation_steps` | Gradient accumulation steps | 524288 | `int` |
61
+ ---
62
+ #### Hardware Setup
63
+
64
+ - Used DPP using Pytorch torchrun consisting of 2x GeForce RTX A100 AXM (80gb VRAM each) rented on runpod.io
65
+ - The model is a 0.768GB in size but needs around 4 GB of VRAM when loaded in fp32 precision
66
+ ---
67
+
68
+ #### Frameworks:
69
+ **Pytorch**
70
+
71
+
72
+ ---
73
+
74
+ #### Epochs/Steps
75
+ - Iterations (train) = 5k
76
+
77
+ - Val iterations = every 50 steps
78
+ ---
79
+
80
+ #### Losses
81
+ - Train loss - 1.43
82
+
83
+ - Val loss - 1.45
84
+
85
+ ---
86
+
87
+ #### Screenshots of the loss curves
88
+
89
+ - Loss Curves (Train and Val)
90
+
91
+ ![Loss Curves (Train and Val)](images/loss_curves.jpg)
92
+
93
+ ---
94
+ #### Output
95
+
96
+ - Prompt: Once upon a time
97
+
98
+ ![Prompt: Once upon a time](images/sample.jpg)
99
+
100
  ---
101
+
102
+ ### Local setup
103
+
104
+
105
+ ### Requirements
106
+
107
+
108
+
109
+ ```python
110
+ git [clone the repo](https://github.com/YuvrajSingh-mist/StoryLlama.git)
111
+ cd StoryLlama
112
+ bash ./install.sh
113
+
114
+ ```
115
+ - A wandb.ai account for plotting graphs for your loss curves
116
+
117
+ - On your terminal run
118
+ ```python
119
+ wandb login
120
+ ```
121
+
122
+ - Enter the api key and follow the instructions and once you are succesfully logged in follow the given steps
123
+
124
+
125
+ - Download the model
126
+
127
+ ```python
128
+ python download_model_weight.py
129
+ ```
130
+
131
+
132
  ---
133
 
134
+ ### Running
135
+
136
+
137
+ #### Training a model
138
+
139
+ - Kindly change 'device' to any of your available cuda gpus.
140
+
141
+ To run:
142
+
143
+ ```python
144
+ bash ./install.sh
145
+ ```
146
+
147
+ ```python
148
+ torchrun --standalone --nproc_per_node=gpu trainer.py \
149
+ --epochs 10 \
150
+ --block_size 256 \
151
+ --batch_size 128 \
152
+ --embeddings_dims 768 \
153
+ --attn_dropout 0.2 \
154
+ --no_of_heads 12 \
155
+ --dropout 0.2 \
156
+ --val_epochs 3 \
157
+ --max_lr 5e-4 \
158
+ --no_of_decoder_layers 6 \
159
+ --weight_decay_optim 0.01 \
160
+ --beta_1 0.85 \
161
+ --beta_2 0.99 \
162
+ --clip 0.5 \
163
+ --device "cuda" \
164
+ --no_kv_heads 4 \
165
+ --vocab_size 50257 \
166
+ --eps 1e-6 \
167
+ --dtype "float16" \
168
+ --save_checkpoint_dir "model_checkpoints" \
169
+ --prompt "Once upon a time" \
170
+ --save_checkpoint_iter 100 \
171
+ --total_iters 5000 \
172
+ --eval_iters 200 \
173
+ --eval_check 500 \
174
+ --warmup_iters 1000 \
175
+ --min_lr 1e-5 \
176
+ --lr_decay_iters 2000 \
177
+ --total_batch_size 262144 \
178
+ --micro_batch_size 128 \
179
+ --gradient_accumulation_steps 4
180
+
181
+ ```
182
+ --standalone - if all the gpu are on one server
183
+ --npro_per_node - number of gpus available and use the keyword gpu to use all
184
+
185
+ #### Inference on a model
186
+
187
+ ```python
188
+ python inference.py --prompt "Once upon a time" --max_length 100 --temperature 0.8 --topk 50
189
+ ```
190
+
config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import dataclass
3
+ import torch
4
+
5
+ @dataclass
6
+ class ModelArgs:
7
+
8
+ epochs: int = 4
9
+ block_size: int = 512
10
+ batch_size: int = 64
11
+ inference = None
12
+ embeddings_dims: int = 512
13
+ attn_dropout: float = 0.1
14
+ no_of_heads: int = 8
15
+ dropout: float = 0.1
16
+ val_epochs: int = 2
17
+ max_lr: float = 6e-4
18
+ no_of_decoder_layers: int = 8
19
+ weight_decay_optim: float = 0.1
20
+ beta_1: float = 0.9
21
+ beta_2: float = 0.95
22
+ clip: float = 1.0
23
+ device: str = 'cuda'
24
+ no_kv_heads: int = 2
25
+ vocab_size: int = 50304
26
+ eps: float = 1e-5
27
+ dtype: str = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
28
+ save_checkpoint_dir: str = "checkpoints"
29
+ prompt: str = "Once upon a time"
30
+
31
+
32
+ save_checkpoint_iter: int = 50
33
+ total_iters: int = 10000
34
+ eval_iters: int = 50
35
+ eval_check: int = 100
36
+ warmup_iters: int = 700
37
+ min_lr: float = 0.1 * max_lr
38
+ lr_decay_iters: int = 10000
39
+ total_batch_size: int = 524288
40
+ micro_batch_size: int = batch_size
41
+ gradient_accumulation_steps: int = total_batch_size // (micro_batch_size * (block_size * torch.cuda.device_count()))
data.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn.functional as F
3
+
4
+ import torch.multiprocessing as mp
5
+ from torch.utils.data.distributed import DistributedSampler
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+
8
+
9
+ from datasets import load_dataset
10
+ from torch.utils.data import DataLoader
11
+ from tokenizer import Tokenizer
12
+ from config import ModelArgs
13
+
14
+
15
+
16
+ tokenizer = Tokenizer().ready_tokenizer()
17
+
18
+
19
+ tinystories = True
20
+ fw = False
21
+ fw_train = None
22
+ fw_test = None
23
+ if(tinystories):
24
+ fw_train = load_dataset("roneneldan/TinyStories", split="train")
25
+ fw_test = load_dataset("roneneldan/TinyStories", split="validation")
26
+ print(fw_train)
27
+ print(fw_test)
28
+ if(fw):
29
+ fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
30
+ fw_train = fw_train.train_test_split(test_size=0.01)
31
+ print(fw_train)
32
+ print(fw_train)
33
+
34
+
35
+
36
+
37
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
38
+
39
+ def tokenize_function(examples):
40
+ return tokenizer(
41
+ examples['text'],
42
+ max_length=ModelArgs.block_size,
43
+ padding='max_length',
44
+ truncation=True,
45
+ return_tensors='pt'
46
+ )
47
+
48
+
49
+
50
+
51
+ def prepare_dataset(split, device, batch_size):
52
+ print("Device is: ", device)
53
+
54
+ def collate_fn(batch):
55
+ # Extract text data
56
+ texts = [item ["text"] for item in batch]
57
+
58
+
59
+ input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
60
+
61
+ input_encodings["labels"] = input_encodings["input_ids"].clone()
62
+
63
+ input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:]
64
+ input_encodings["labels"][:, -1] = tokenizer.eos_token_id
65
+
66
+ return input_encodings
67
+
68
+
69
+ dataloader = None
70
+ if(tinystories):
71
+ if(split == 'train'):
72
+ data_loader = DataLoader(
73
+ fw_train,
74
+ # generator=generator,
75
+ batch_size=batch_size,
76
+
77
+ sampler=DistributedSampler(fw_train, shuffle=True),
78
+ collate_fn=collate_fn,
79
+ drop_last=True,
80
+ shuffle=False
81
+ )
82
+ elif(split == 'val'):
83
+ data_loader = DataLoader(
84
+ fw_test,
85
+
86
+
87
+ batch_size=batch_size,
88
+ sampler=DistributedSampler(fw_test, shuffle=True),
89
+ collate_fn=collate_fn,
90
+ drop_last=True,
91
+ shuffle=False
92
+ )
93
+ elif(fw):
94
+ if(split == 'train'):
95
+ data_loader = DataLoader(
96
+ fw_train['train'],
97
+ batch_size=batch_size,
98
+
99
+
100
+ sampler=DistributedSampler(fw_train['train'], shuffle=True),
101
+ collate_fn=collate_fn,
102
+ drop_last=True,
103
+ shuffle=False
104
+ )
105
+ elif(split == 'val'):
106
+ data_loader = DataLoader(
107
+ fw_train['test'],
108
+ batch_size=batch_size,
109
+ # generator=generator,
110
+ sampler=DistributedSampler(fw_train["test"]),
111
+ collate_fn=collate_fn,
112
+
113
+ drop_last=True,
114
+ shuffle=False
115
+ )
116
+ return data_loader
117
+
download_model_weight.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gdown
2
+ # import os
3
+ # import argparse
4
+
5
+ # def download_model(model_id, folder, filename):
6
+ # os.makedirs(folder, exist_ok=True)
7
+ # url = f"https://drive.google.com/uc?id={model_id}"
8
+ # output_path = os.path.join(folder, filename)
9
+ # print(f"Downloading model to {output_path}...")
10
+ # gdown.download(url, output_path, quiet=False)
11
+ # print("Download complete!")
12
+
13
+ # def main():
14
+ # parser = argparse.ArgumentParser(description="Download models using gdown and organize them into appropriate folders.")
15
+ # parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model")
16
+ # parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model")
17
+ # parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model")
18
+
19
+ # args = parser.parse_args()
20
+
21
+ # pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98"
22
+ # fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf"
23
+ # dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh"
24
+
25
+ # if args.pretrained:
26
+ # download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt")
27
+ # if args.sft:
28
+ # download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt")
29
+ # if args.dpo:
30
+ # download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt")
31
+
32
+ # if __name__ == "__main__":
33
+ # main()
34
+
35
+
36
+ # import os
37
+ # import argparse
38
+
39
+ # def download_model(model_id, folder, filename, access_token):
40
+ # os.makedirs(folder, exist_ok=True)
41
+ # output_path = os.path.join(folder, filename)
42
+
43
+ # url = f"https://www.googleapis.com/drive/v3/files/{model_id}?alt=media"
44
+ # command = f"curl -H \"Authorization: Bearer {access_token}\" {url} -o {output_path}"
45
+
46
+ # print(f"Downloading model to {output_path}...")
47
+ # os.system(command)
48
+ # print("Download complete!")
49
+
50
+ # def main():
51
+ # parser = argparse.ArgumentParser(description="Download models using Google Drive API and organize them into appropriate folders.")
52
+ # parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model")
53
+ # parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model")
54
+ # parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model")
55
+ # parser.add_argument("--token", type=str, required=True, help="Google Drive API Access Token")
56
+
57
+ # args = parser.parse_args()
58
+
59
+ # pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98"
60
+ # fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf"
61
+ # dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh"
62
+
63
+ # if args.pretrained:
64
+ # download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt", args.token)
65
+ # if args.sft:
66
+ # download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt", args.token)
67
+ # if args.dpo:
68
+ # download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt", args.token)
69
+
70
+ # if __name__ == "__main__":
71
+ # main()
72
+
73
+
74
+
75
+ # download_model_weight.py
76
+ import os
77
+ import argparse
78
+ from huggingface_hub import hf_hub_download, login
79
+
80
+ def download_model(repo_id, filename, cache_dir):
81
+
82
+ try:
83
+ model_path = hf_hub_download(
84
+ repo_id=repo_id,
85
+ filename=filename,
86
+ cache_dir=cache_dir,
87
+ resume_download=True,
88
+ force_download=False,
89
+ token=os.getenv("HF_TOKEN")
90
+ )
91
+
92
+ if os.path.exists(model_path) and os.path.getsize(model_path) > 1024*1024:
93
+ return model_path
94
+ raise ValueError("Downloaded file is too small or invalid")
95
+ except Exception as e:
96
+ print(f"Download failed: {str(e)}")
97
+ raise
98
+
99
+ def main():
100
+ parser = argparse.ArgumentParser(description="Download models from Hugging Face Hub")
101
+ parser.add_argument("--model_type",
102
+ choices=["pretrained"],
103
+ required=True,
104
+ help="Type of model to download")
105
+
106
+ args = parser.parse_args()
107
+
108
+ model_config = {
109
+
110
+ "pretrained": {
111
+ "repo_id": "YuvrajSingh9886/StoryLlama",
112
+ "filename": "snapshot_4650.pt",
113
+ "cache_dir": "weights/pretrained"
114
+ }
115
+ }
116
+
117
+ config = model_config[args.model_type]
118
+ os.makedirs(config["cache_dir"], exist_ok=True)
119
+
120
+ print(f"Downloading {args.model_type} model...")
121
+ model_path = download_model(
122
+ config["repo_id"],
123
+ config["filename"],
124
+ config["cache_dir"]
125
+ )
126
+ print(f"Successfully downloaded to: {model_path}")
127
+
128
+ if __name__ == "__main__":
129
+
130
+ login(token=os.getenv("HF_TOKEN"))
131
+ main()
fine_tune.py ADDED
@@ -0,0 +1,1282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 185860
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from dataclasses import dataclass
8
+ # from torchtune.modules import RMSNorm
9
+ from tokenizers import Tokenizer
10
+ from pathlib import Path
11
+ import torch.multiprocessing as mp
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ from torch.distributed import init_process_group, destroy_process_group
15
+ import torch
16
+ from datasets import Dataset
17
+ from torch.utils.data import DataLoader
18
+ from transformers.models.prophetnet.modeling_prophetnet import ProphetNetDecoderModelOutput
19
+ import wandb
20
+ from tqdm import tqdm
21
+ from functools import partial
22
+
23
+ import torch.optim as optim
24
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
25
+
26
+
27
+ # Load model directly
28
+ from transformers import AutoTokenizer, AutoModelForCausalLM
29
+ import os
30
+
31
+
32
+ # import wandb
33
+ # wandb.login()
34
+
35
+
36
+ # from torch.utils.tensorboard import SummaryWriter
37
+
38
+
39
+ from datasets import load_dataset, concatenate_datasets
40
+ # use name="sample-10BT" to use the 10BT sample
41
+ # fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
42
+ # print(fw_train)
43
+ # Select only 1000 rows from the dataset
44
+ # fw_train = fw_train.select(range(1000000))
45
+ # alpaca = load_dataset("yahma/alpaca-cleaned", split='train')
46
+ # dolly = load_dataset("llm-wizard/dolly-15k-instruction-alpaca-format", split='train')
47
+ # merged_dataset = concatenate_datasets([alpaca, dolly])
48
+ dataset = load_dataset("swype/instruct", split='train', trust_remote_code=True)
49
+ # print(fw_train)
50
+ # Split the dataset into training and validation sets
51
+ merged_dataset = dataset.train_test_split(test_size=0.1)
52
+ print(merged_dataset)
53
+ # fw_train = fw_train.train_test_split(test_size=0.2)
54
+ # print(fw_train)
55
+
56
+ # Access the splits
57
+ # train_dataset = train_val_split['train']
58
+ # val_dataset = train_val_split['test']
59
+
60
+ # train_dataset = fw_train.train_test_split(test_size=0.2)
61
+
62
+
63
+ def setup(rank=None, world_size=None):
64
+ # os.environ['MASTER_ADDR'] = 'localhost'
65
+ # os.environ['MASTER_PORT'] = '12355'
66
+ init_process_group("nccl")
67
+ # torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
68
+
69
+ def cleanup():
70
+ destroy_process_group()
71
+
72
+
73
+
74
+ @dataclass
75
+ class ModelArgs:
76
+ #Hyperparameters
77
+
78
+ epochs = 5
79
+ block_size = 128
80
+ batch_size = 64
81
+ embeddings_dims = 786
82
+ attn_dropout = 0.1
83
+ no_of_heads = 6 #IMP needs to be thoroughly calculated
84
+ dropout = 0.1
85
+ # epochs = 100
86
+ val_epochs = 2
87
+ max_lr = 2e-4
88
+ no_of_decoder_layers = 6 #IMP needs to be thoroughly calculated
89
+ weight_decay_optim = 0.1
90
+ beta_1 = 0.9
91
+ beta_2 = 0.95
92
+ clip = 1.0
93
+ device = 'cuda'
94
+ no_kv_heads = 2
95
+ vocab_size = 50258
96
+
97
+
98
+
99
+ from pathlib import Path
100
+ data_path = Path('data')
101
+ data_path.mkdir(exist_ok=True)
102
+ # !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
103
+ # !cp input.txt data/input.txt
104
+
105
+
106
+
107
+ #Datasets
108
+
109
+ # Using tinyshakespeare
110
+
111
+ # with open('data/input.txt', 'r', encoding='utf-8') as f:
112
+ # text = f.read()
113
+
114
+
115
+ # Load the tokenizer
116
+ # tokenizer = Tokenizer.from_file("bpe_tokenizer_30k.json")
117
+
118
+ # Encode and decode functions
119
+ # encode = lambda s: tokenizer.encode(s).ids
120
+ # decode = lambda l: tokenizer.decode(l)
121
+
122
+
123
+
124
+ def _save_snapshot(model, optimizer, scheduler, epoch, step):
125
+ snapshot = {
126
+ "MODEL_STATE": model.module.state_dict(),
127
+ "OPTIMIZER_STATE": optimizer.state_dict(),
128
+ "SCHEDULER_STATE": scheduler.state_dict(), # NEW: Save scheduler state
129
+ "EPOCHS_RUN": epoch,
130
+ "STEP_RUN": step
131
+ }
132
+ torch.save(snapshot, "/kaggle/working/snapshot_fine_tuned_model_with_gradient_clipping_3.pt")
133
+ print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.")
134
+
135
+ def _load_snapshot(snapshot_path, model, optimizer, scheduler):
136
+ snapshot = torch.load(snapshot_path)
137
+ model.load_state_dict(snapshot["MODEL_STATE"])
138
+ # optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"])
139
+ # scheduler.load_state_dict(snapshot["SCHEDULER_STATE"]) # Load scheduler state
140
+ epoch = snapshot["EPOCHS_RUN"]
141
+ step = snapshot["STEP_RUN"]
142
+ print(f"Resuming from Epoch {epoch}, Step {step}")
143
+ return epoch, step
144
+
145
+ #Subword level tokenization
146
+
147
+ #Loading custom trained BPE
148
+ # Load the tokenizer
149
+ # tokenizer = Tokenizer.from_file("data/bpe_tokenizer_tinyshakespeare_1k.json")
150
+ # vocab_size = tokenizer.get_vocab_size()
151
+ # Encode and decode functions
152
+ # encode = lambda s: tokenizer.encode(s).ids
153
+ # decode = lambda l: tokenizer.decode(l)
154
+
155
+
156
+
157
+
158
+
159
+ ###############################################################################
160
+ #Character level tokenization
161
+
162
+ # # here are all the unique characters that occur in this text
163
+ # chars = sorted(list(set(text)))
164
+ # vocab_size = len(chars)
165
+
166
+
167
+ # # create a mapping from characters to integers
168
+ # stoi = { ch: i for i,ch in enumerate(chars) }
169
+ # itos = { i:ch for i,ch in enumerate(chars) }
170
+ # encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
171
+ # decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
172
+
173
+
174
+ # Convert the dataset to Hugging Face Dataset format
175
+ # train_hf_dataset = Dataset.from_dict({"text": train_dataset['train']['text']})
176
+ # val_hf_dataset = Dataset.from_dict({"text": train_dataset['test']['text']})
177
+
178
+ # Tokenize the dataset using the `map` function
179
+
180
+
181
+ # from google.colab import userdata
182
+ # HF_TOKEN = userdata.get('HF_TOKEN')
183
+
184
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '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')
185
+ # tokenizer.pad_token = tokenizer.eos_token
186
+ # if tokenizer.pad_token is None:
187
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
188
+ # print("ADDED THE TOKENS: ", tokenizer.pad_token_id)
189
+ # tokenizer.bos_token = "[INST]"
190
+ # tokenizer.eos_token = "[/INST]"
191
+ # model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
192
+
193
+ def tokenize_function(examples):
194
+ return tokenizer(
195
+ examples['text'],
196
+ max_length=ModelArgs.block_size,
197
+ padding='max_length',
198
+ truncation=True,
199
+ return_tensors='pt'
200
+ )
201
+
202
+
203
+ ## Load the tokenizer
204
+ # tokenizer = Tokenizer.from_file("bpe_tokenizer_30k.json")
205
+
206
+ # # Tokenization functions
207
+ # def encode_train(examples):
208
+ # tokens = []
209
+ # for example in examples['text']:
210
+ # out = tokenizer.encode(example).ids
211
+ # tokens.append(out) # Append the tokenized sequence (do not flatten)
212
+ # return {"tokens": tokens}
213
+
214
+ # def encode_val(examples):
215
+ # tokens = []
216
+ # for example in examples['text']:
217
+ # out = tokenizer.encode(example).ids
218
+ # tokens.append(out) # Append the tokenized sequence (do not flatten)
219
+ # return {"tokens": tokens}
220
+
221
+ # Apply tokenization with batching
222
+ # train_data = train_dataset['train'].map(tokenize_function, batched=True, batch_size=8000, remove_columns=['id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'], num_proc=8)
223
+ # val_data = train_dataset['test'].map(tokenize_function, batched=True, batch_size=8000, remove_columns=['id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'], num_proc=8)
224
+
225
+ # # # Extract tokens from the processed datasets
226
+ # # train_tokens = train_data['tokens']
227
+ # # val_tokens = val_data['tokens']
228
+
229
+ # # Flatten the tokenized data
230
+ # # train_tokens = [token_id for seq in train_data['input_ids'] for token_id in seq]
231
+ # # val_tokens = [token_id for seq in val_data['input_ids'] for token_id in seq]
232
+
233
+ # try:
234
+ # train_tensors = [torch.tensor(seq) for seq in tqdm(train_data['input_ids'], desc="Converting train_data to tensors")]
235
+ # train_data_tensor = torch.cat(train_tensors)
236
+ # except Exception as e:
237
+ # print(f"Error during tensor conversion: {e}")
238
+
239
+ # try:
240
+ # train_tensors = [torch.tensor(seq) for seq in tqdm(val_data['input_ids'], desc="Converting train_data to tensors")]
241
+ # val_data_tensor = torch.cat(train_tensors)
242
+ # except Exception as e:
243
+ # print(f"Error during tensor conversion: {e}")
244
+ # print("Train tokens count: ", train_data_tensor)
245
+ # print("Val tokens count: ", val_data_tensor)
246
+
247
+
248
+ def prepare_dataset(split, batch_size):
249
+
250
+ # alpaca_prompt = '''
251
+
252
+
253
+ # ### Instruction:
254
+ # {}
255
+
256
+ # ### Response:
257
+ # {}
258
+ # '''
259
+ # Load a subset of the C4 dataset with a glob pattern for specific training files
260
+ # dataset = load_dataset("allenai/c4", data_files=["en/c4-train.00001-of-01024.json.gz"], trust_remote_code=True)
261
+
262
+ # Initialize tokenizer
263
+ # tokenizer = AutoTokenizer.from_pretrained("gpt2")
264
+
265
+ def collate_fn(batch):
266
+ # Extract text data
267
+ # texts = [item ["text"] for item in batch]
268
+
269
+ # Set the pad token if it isn't set already
270
+ # if tokenizer.pad_token is None:
271
+ # tokenizer.pad_token = tokenizer.eos_token
272
+ outputs = []
273
+ texts = []
274
+ for item in batch:
275
+ instruction = item['prompt']
276
+ # input = item['input']
277
+ output = item['completion']
278
+ # out = alpaca_prompt.format(instruction, output)
279
+ texts.append(instruction)
280
+ outputs.append(output)
281
+ # Tokenize text data
282
+ input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
283
+ # output_encodings = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
284
+ input_encodings["labels"] = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
285
+ # out = {"input": input_encodings}
286
+ # input_encodings["labels"] = input_encodings["input_ids"].clone() # Use `input_ids` as labels
287
+ # input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] # Shift right
288
+ # input_encodings["labels"][:, -1] = tokenizer.pad_token_id # Ignore the last token (no target for it)
289
+ # Return tokenized input tensors
290
+ # return out
291
+ return input_encodings
292
+
293
+ # Create DistributedSampler for proper shuffling and partitioning across processes
294
+ # dist_sampler = DistributedSampler(fw_train["text"], shuffle=True)
295
+
296
+ # Create DataLoader with custom collate_fn
297
+ # print(fw_dataset)
298
+ dataloader = None
299
+ if(split == 'train'):
300
+ data_loader = DataLoader(
301
+ merged_dataset['train'],
302
+ batch_size=batch_size,
303
+ sampler=DistributedSampler(merged_dataset['train'], shuffle=True),
304
+ collate_fn=collate_fn,
305
+ drop_last=True,
306
+ shuffle=False
307
+ )
308
+ elif(split == 'val'):
309
+ data_loader = DataLoader(
310
+ merged_dataset['test'],
311
+ batch_size=batch_size,
312
+ sampler=DistributedSampler(merged_dataset["test"], shuffle=True),
313
+ collate_fn=collate_fn,
314
+ drop_last=True,
315
+ shuffle=False
316
+ )
317
+
318
+ return data_loader
319
+ # Convert to tensors
320
+ # train_data_tensor = torch.tensor(train_tokens, dtype=torch.long)
321
+ # val_data_tensor = torch.tensor(val_tokens, dtype=torch.long)
322
+
323
+ # # Debug output
324
+ # print("Number of train tokens:", len(train_data_tensor))
325
+ # print("Number of validation tokens:", len(val_data_tensor))
326
+
327
+
328
+ # def create_sequences(data, block_size):
329
+ # sequences = []
330
+
331
+ # for seq in data:
332
+ # if len(seq) < block_size:
333
+ # # while(len(sequence) < block_size):
334
+ # # sequence = data[i:i + block_size + 1]
335
+
336
+ # # Pad the sequence if it's shorter than block_size
337
+ # padding_length = block_size - len(seq)
338
+ # seq = torch.cat([seq, torch.full((padding_length,), tokenizer.pad_token_id, dtype=torch.long)])
339
+ # sequences.append(seq)
340
+ # out = torch.tensor(sequences, dtype=torch.long)
341
+ # return out
342
+
343
+ # train_data = create_sequences(train_data['input_ids'], ModelArgs.block_size)
344
+ # val_data = create_sequences(val_data['input_ids'], ModelArgs.block_size)
345
+
346
+
347
+ def get_batch(split):
348
+ # generate a small batch of data of inputs x and targets y
349
+ data = train_data if split == 'train' else val_data
350
+ ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,))
351
+ x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix])
352
+ y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix])
353
+ x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
354
+ return x, y
355
+
356
+ from torch.utils.data import Dataset
357
+
358
+ class TokenDataset(Dataset):
359
+ def __init__(self, data, block_size):
360
+ self.data = data
361
+ self.block_size = block_size
362
+
363
+ def __len__(self):
364
+ return len(self.data) - self.block_size # Ensure valid indexing
365
+
366
+ def __getitem__(self, idx):
367
+ x = self.data[idx:idx + self.block_size]
368
+ y = self.data[idx + 1:idx + self.block_size + 1]
369
+ return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
370
+
371
+
372
+ # train_rows = 11895089
373
+ # encoded_data = torch.tensor(encode(fw_train['text']), dtype=torch.long)
374
+ # train_data = train_data[:train_rows]
375
+ # val_data = val_data[train_rows:]
376
+
377
+ # train_dataset = TokenDataset(train_data_tensor, ModelArgs.block_size)
378
+ # val_dataset = TokenDataset(val_data_tensor, ModelArgs.block_size)
379
+ # encoded_data = torch.tensor(encode(text), dtype=torch.long)
380
+
381
+ # print(train_data)
382
+ # print(val_data)
383
+ # train_dataset = TextDataset(train_data, ModelArgs.block_size)
384
+ # val_dataset = TextDataset(val_data, ModelArgs.block_size)
385
+
386
+ # print(train_dataset)
387
+ # print(val_dataset)
388
+
389
+
390
+ # # Convert the tokenized data into a list of sequences
391
+ # train_sequences = [train_data[i:i + ModelArgs.block_size] for i in range(0, len(train_data) - ModelArgs.block_size)]
392
+ # val_sequences = [val_data[i:i + ModelArgs.block_size] for i in range(0, len(val_data) - ModelArgs.block_size)]
393
+
394
+ # Define collate_fn
395
+ # def collate_fn(batch):
396
+ # block_size = ModelArgs.block_size
397
+ # batch_size = len(batch)
398
+ # x = torch.zeros((batch_size, block_size), dtype=torch.long)
399
+ # y = torch.zeros((batch_size, block_size), dtype=torch.long)
400
+ # for i, sequence in enumerate(batch):
401
+ # print("Shape x: ", sequence[:-1].shape)
402
+ # print("Shape of y: ", len(sequence[1:]))
403
+ # x[i] = sequence[:-1] # Input is all tokens except the last one
404
+ # y[i] = sequence[1:] # Target is all tokens except the first one
405
+ # return x, y
406
+
407
+
408
+
409
+ def create_sequences(data, block_size):
410
+ sequences = []
411
+
412
+ for seq in data:
413
+ len(seq)
414
+ if len(seq) < block_size:
415
+ # while(len(sequence) < block_size):
416
+ # sequence = data[i:i + block_size + 1]
417
+
418
+ # Pad the sequence if it's shorter than block_size
419
+ padding_length = block_size - len(seq)
420
+ seq = torch.cat([seq, torch.full((padding_length,), tokenizer.encode('[PAD]').ids[0], dtype=torch.long)])
421
+
422
+ else:
423
+ if len(seq) > block_size:
424
+ seq = seq[:block_size]
425
+ # while(len(sequence) < block_size):
426
+ # sequence = data[i:i + block_size + 1]
427
+
428
+ # Pad the sequence if it's shorter than block_size
429
+ # padding_length = block_size - len(seq)
430
+ # seq = torch.cat([seq, torch.full((padding_length,), tokenizer.encode('[PAD]').ids[0], dtype=torch.long)])
431
+ sequences.append(seq)
432
+ out = torch.tensor(sequences, dtype=torch.long)
433
+ return out
434
+
435
+ # train_data = create_sequences(train_data_flat['input_ids'], ModelArgs.block_size)
436
+ # val_data = create_sequences(val_data['input_ids'], ModelArgs.block_size)
437
+
438
+
439
+ # Define collate_fn
440
+ def collate_fn(split , batch):
441
+ block_size = ModelArgs.block_size
442
+ batch_size = len(batch)
443
+ if(split == 'train'):
444
+ data = train_data_tensor
445
+ elif(split == 'test'):
446
+ data = val_data_tensor
447
+ ix = torch.randint(len(data) - ModelArgs.block_size, (ModelArgs.batch_size,))
448
+ x = torch.stack([data[i:i+ModelArgs.block_size] for i in ix])
449
+ y = torch.stack([data[i+1:i+ModelArgs.block_size+1] for i in ix])
450
+
451
+ # print("Shape of x: ", len(x))
452
+ # print("Length of y: ", len(y))
453
+ # x, y = x.to(ModelArgs.device), y.to(ModelArgs.device)
454
+ # x = torch.zeros((batch_size, block_size), dtype=torch.long)
455
+ # y = torch.zeros((batch_size, block_size), dtype=torch.long)
456
+ # for i, sequence in enumerate(batch):
457
+ # print("Seq: ", sequence)
458
+ # print("Shape x: ", sequence[:-1].shape)
459
+ # print("Shape of y: ", len(sequence[1:]))
460
+ # x[i] = sequence[:-1] # Input is all tokens except the last one
461
+ # y[i] = sequence[1:] # Target is all tokens except the first one
462
+ return x, y
463
+
464
+
465
+ class Normalization(nn.Module):
466
+ def __init__(
467
+ self,
468
+
469
+ embeddings_dims: int = ModelArgs.embeddings_dims
470
+ ):
471
+ super().__init__()
472
+ self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims)
473
+
474
+
475
+ def forward(self, x):
476
+
477
+ x = self.rmsnorm_layer(x)
478
+ return x
479
+
480
+
481
+
482
+ # import numpy as np
483
+ class RotaryEmbeddings(nn.Module):
484
+ def __init__(
485
+ self,
486
+ device,
487
+ embeddings_dims: int = ModelArgs.embeddings_dims,
488
+ block_size: int = ModelArgs.block_size,
489
+ batch_size: int = ModelArgs.batch_size
490
+ ):
491
+ super().__init__()
492
+
493
+ self.embeddings_dims = embeddings_dims
494
+ self.block_size = block_size
495
+ self.batch_size = batch_size
496
+ self.theta = 0
497
+
498
+
499
+ # def init_matrix(self, seq_len):
500
+ # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False)
501
+ # for pos in range(seq_len):
502
+ # for j in range(1, self.embeddings_dims // 2):
503
+ # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
504
+ # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
505
+ # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
506
+ # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
507
+ # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
508
+ # return self.matrix
509
+ self.device=device
510
+
511
+ def init_matrix(self, seq_len):
512
+ self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
513
+
514
+ positions = torch.arange(seq_len, dtype=torch.float32, device = self.device).unsqueeze(1)
515
+ # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
516
+ theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
517
+ angles = positions * theta
518
+
519
+ cos_angles = torch.cos(angles)
520
+ sin_angles = torch.sin(angles)
521
+
522
+ indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device)
523
+ # print(indices)
524
+ # print(indices.shape)
525
+ # print(indices[::2])
526
+ even_indices = indices[::2]
527
+ odd_indices = indices[1::2]
528
+
529
+ self.matrix[:, even_indices, even_indices] = cos_angles
530
+ self.matrix[:, odd_indices, odd_indices] = sin_angles
531
+ self.matrix[:, odd_indices, even_indices] = -sin_angles
532
+ self.matrix[:, even_indices, odd_indices] = cos_angles
533
+
534
+ return self.matrix
535
+
536
+ def forward(self, x):
537
+ # B,T,C = x.shape
538
+ # print("MATRIX:",x)
539
+ if(x > self.block_size or x < self.block_size):
540
+ matrix = self.init_matrix(x)
541
+ return matrix
542
+ else:
543
+ matrix = self.init_matrix(self.block_size)
544
+
545
+ return matrix
546
+
547
+
548
+ class RotaryAttentionHead(nn.Module):
549
+ def __init__(
550
+ self,
551
+ device,
552
+ embeddings_dims: int = ModelArgs.embeddings_dims,
553
+ no_of_heads: int = ModelArgs.no_of_heads,
554
+ attn_dropout: int = ModelArgs.attn_dropout
555
+ ):
556
+ super().__init__()
557
+ self.head_size = embeddings_dims // no_of_heads
558
+ self.query = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
559
+ self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
560
+ self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
561
+ self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims, device = device)
562
+ self.dropout = nn.Dropout(p = attn_dropout)
563
+ self.device = device
564
+ def forward(self,x):
565
+ # print(x.shape)
566
+ batch, block_size, embeddings_dims = x.shape
567
+ query = self.query(x)
568
+ # print(query)
569
+ key = self.key(x)
570
+ values = self.value(x)
571
+ matrix = self.rotary_matrix(block_size)
572
+
573
+ # print(matrix.shape)
574
+ # print(query.shape)
575
+ masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
576
+ rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
577
+ rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
578
+ weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
579
+ weights_masked = weights.masked_fill(masked == 0, float('-inf'))
580
+ scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
581
+ scaled_weights = F.softmax(scaled_weights, dim=-1)
582
+ value = scaled_weights @ values
583
+ out = self.dropout(value)
584
+ return out
585
+
586
+
587
+ class MQA(nn.Module):
588
+ def __init__(
589
+ self,
590
+ device,
591
+ embeddings_dims: int = ModelArgs.embeddings_dims,
592
+ block_size: int = ModelArgs.block_size,
593
+ no_of_kv_heads: int = ModelArgs.no_of_heads,
594
+ no_of_heads: int = ModelArgs.no_of_heads,
595
+
596
+ ):
597
+ super().__init__()
598
+
599
+ self.no_of_kv_heads = no_of_kv_heads
600
+ self.no_of_q_heads = no_of_heads // no_of_kv_heads
601
+ self.head_size = embeddings_dims // self.no_of_q_heads
602
+ self.rotary_matrix = RotaryEmbeddings(embeddings_dims=embeddings_dims, device = device)
603
+ # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False)
604
+ self.key = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
605
+ self.value = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
606
+ self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
607
+ self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
608
+ self.device = device
609
+ self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, device = self.device) for _ in range(self.no_of_q_heads)])
610
+
611
+ def scaled_dot_product(self, q, k, v, block_size, matrix):
612
+
613
+ # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
614
+
615
+ masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
616
+ rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
617
+ rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
618
+ weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
619
+ weights_masked = weights.masked_fill(masked == 0, float('-inf'))
620
+ scaled_weights = weights_masked / (torch.sqrt(torch.tensor(k.shape[-1])))
621
+ scaled_weights = F.softmax(scaled_weights, dim=-1)
622
+ value = scaled_weights @ v
623
+ out = self.dropout(value)
624
+ return value
625
+
626
+ def forward(self,x):
627
+ # print("MQA: ", x.shape)
628
+ batch, block_size, embeddings_dims = x.shape
629
+
630
+ # query = self.query(x)
631
+ matrix = self.rotary_matrix(block_size)
632
+
633
+
634
+ key = self.key(x)
635
+ values = self.value(x)
636
+
637
+ multi_query_concat = torch.cat([self.scaled_dot_product(query(x), key, values, block_size, matrix) for query in self.multi_query], dim=-1)
638
+
639
+
640
+ linear_layer= self.linear_layer(multi_query_concat)
641
+ out = self.dropout(linear_layer)
642
+ return out
643
+
644
+
645
+ class GQA(nn.Module):
646
+ def __init__(
647
+ self,
648
+ device,
649
+ embeddings_dims: int = ModelArgs.embeddings_dims,
650
+ block_size: int = ModelArgs.block_size,
651
+ no_of_q_heads: int = ModelArgs.no_of_heads,
652
+ no_of_kv_heads: int = ModelArgs.no_kv_heads
653
+ ):
654
+ super().__init__()
655
+
656
+ self.no_of_kv_heads = no_of_kv_heads
657
+ self.no_of_q_heads = no_of_q_heads
658
+ self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
659
+ self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_kv_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device)
660
+ self.device = device
661
+ self.mqa = nn.ModuleList([MQA(embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_kv_heads)])
662
+
663
+ def forward(self,x):
664
+
665
+ batch, block_size, embeddings_dims = x.shape
666
+
667
+
668
+ grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1)
669
+
670
+ linear_layer= self.linear_layer(grouped_query_concat)
671
+ out = self.dropout(linear_layer)
672
+ return out
673
+
674
+
675
+ class Swish(nn.Module):
676
+ def __init__(
677
+ self,
678
+ device,
679
+ block_size: int = ModelArgs.block_size,
680
+ embeddings_dims: int = ModelArgs.embeddings_dims
681
+ ):
682
+ super().__init__()
683
+
684
+ self.sig = torch.nn.Sigmoid()
685
+
686
+
687
+ def forward(self, x):
688
+ swish = x * self.sig(x)
689
+
690
+ return swish
691
+
692
+
693
+
694
+ class SWiGLU(nn.Module):
695
+ def __init__(
696
+ self,
697
+ device,
698
+ block_size: int = ModelArgs.block_size,
699
+ embeddings_dims: int = ModelArgs.embeddings_dims
700
+ ):
701
+ super().__init__()
702
+
703
+ self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
704
+ self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
705
+ self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
706
+ self.linear_layer3 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
707
+
708
+
709
+
710
+
711
+ def forward(self, x):
712
+ swish_res = self.swish(self.linear_layer1(x))
713
+ x_V = self.linear_layer2(x)
714
+ res = torch.mul(swish_res, x_V)
715
+ out = self.linear_layer3(res)
716
+ return out
717
+
718
+
719
+
720
+ class FFN(nn.Module):
721
+ def __init__(self,
722
+ device,
723
+ embeddings_dims: int = ModelArgs.embeddings_dims,
724
+ block_size: int = ModelArgs.block_size,
725
+ vocab_size: int = ModelArgs.vocab_size,
726
+ dropout = ModelArgs.dropout
727
+
728
+ ):
729
+ super().__init__()
730
+
731
+ self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
732
+ self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device)
733
+ self.dropout = nn.Dropout(p = dropout)
734
+ def forward(self, x):
735
+
736
+ x = self.swiglue(x)
737
+ x = self.linear_layer(x)
738
+ x = self.dropout(x)
739
+ return x
740
+
741
+
742
+ class DecoderLayer(nn.Module):
743
+ def __init__(self,
744
+ device,
745
+ embeddings_dims: int = ModelArgs.embeddings_dims,
746
+ dropout = ModelArgs.dropout,
747
+ block_size: int = ModelArgs.block_size,
748
+ vocab_size: int = ModelArgs.vocab_size,
749
+
750
+ ) :
751
+ super().__init__()
752
+
753
+
754
+ self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
755
+ self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, no_of_kv_heads=ModelArgs.no_kv_heads, no_of_q_heads=ModelArgs.no_of_heads, device = device)
756
+ # self.norm = Normalization(embeddings_dims=embeddings_dims)
757
+ self.norm1 = Normalization(embeddings_dims=embeddings_dims)
758
+ self.norm2 = Normalization(embeddings_dims=embeddings_dims)
759
+ self.dropout = nn.Dropout(p = dropout)
760
+ def forward(self, x):
761
+
762
+ x = self.norm1(x + self.gqa(x))
763
+ x = self.norm2(x + self.feedforward_network(x))
764
+ return x
765
+
766
+
767
+ class Llama(nn.Module):
768
+ def __init__(self,
769
+ device,
770
+ embeddings_dims: int = ModelArgs.embeddings_dims,
771
+ no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
772
+ block_size: int = ModelArgs.block_size,
773
+ vocab_size: int = ModelArgs.vocab_size,
774
+ dropout = ModelArgs.dropout
775
+
776
+ ) :
777
+ super().__init__()
778
+
779
+ self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
780
+ self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)])
781
+ self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
782
+ self.dropout = nn.Dropout(p = dropout)
783
+ # self.norm = Normalization(embeddings_dims)
784
+ def forward(self, x):
785
+ x = self.embeddings(x)
786
+ x = self.dropout(x)
787
+ x = self.decoder(x)
788
+ # x = self.norm(x)
789
+ x = self.linear_layer(x)
790
+ # out = self.norm(x)
791
+ return x
792
+
793
+
794
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
795
+ # # device = "cpu"
796
+ # ModelArgs.device = device
797
+ # model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
798
+ # model = model.to(ModelArgs.device)
799
+
800
+ #Printing a summary of the architecture
801
+ # !pip install torchinfo
802
+ # from torchinfo import summary
803
+ # # idx, targets = get_batch('test')
804
+ # idx = torch.randint(
805
+ # low=0,
806
+ # high=ModelArgs.vocab_size,
807
+ # size=(ModelArgs.batch_size, ModelArgs.block_size),
808
+ # dtype=torch.long
809
+ # )
810
+ # # sample_idx = random.randint(range(len(train_dataset)))
811
+ # # idx, targets = train_dataset[0]
812
+ # idx = idx.to(ModelArgs.device)
813
+ # # targets = targets.to(ModelArgs.device)
814
+ # summary(model=model,
815
+ # input_data=idx,
816
+ # # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
817
+ # col_names=["input_size", "output_size", "num_params", "trainable"],
818
+ # col_width=20,
819
+ # row_settings=["var_names"])
820
+
821
+
822
+ def find_unused_parameters(model):
823
+ unused = []
824
+ for name, param in model.named_parameters():
825
+ if param.grad is None:
826
+ unused.append(name)
827
+ return unused
828
+
829
+ def greedy_decode(
830
+ model,
831
+ tokenizer,
832
+ prompt,
833
+ max_length=50,
834
+ repetition_penalty=1.2,
835
+ context_window=10,
836
+ temperature=1.0,
837
+ eos_token_id=None
838
+ ):
839
+
840
+ device = next(model.parameters()).device
841
+ input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids']
842
+ generated_tokens = []
843
+ eos_token_id = eos_token_id or tokenizer.eos_token_id # Use EOS token if provided
844
+
845
+ for _ in range(max_length):
846
+ outputs = model(input_ids)
847
+ logits = outputs[:, -1, :] # Get logits for the last token
848
+
849
+ # Apply temperature scaling
850
+ if temperature != 1.0:
851
+ logits = logits / temperature
852
+
853
+ # Apply repetition penalty
854
+ if repetition_penalty != 1.0 and len(generated_tokens) > 0:
855
+ for token in set(generated_tokens[-context_window:]): # Penalize recent tokens
856
+ logits[0, token] /= repetition_penalty
857
+
858
+ # Greedy selection
859
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
860
+ generated_tokens.append(next_token.item())
861
+
862
+ # Stop if EOS token is generated
863
+ if next_token.item() == eos_token_id:
864
+ break
865
+
866
+ # Append the new token to the input
867
+ input_ids = torch.cat([input_ids, next_token], dim=1)
868
+
869
+ # Decode the generated tokens
870
+ return tokenizer.decode(generated_tokens, skip_special_tokens=True)
871
+
872
+
873
+
874
+ def save_to_file(text):
875
+
876
+ with open('generations.txt', 'a') as f:
877
+ f.writelines(text + "\n\n")
878
+
879
+
880
+ #Train the model
881
+
882
+
883
+ # writer = SummaryWriter(log_dir="runs/experiment")
884
+
885
+ from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
886
+
887
+ # Warmup phase for 2000 steps
888
+ def warmup_fn(step):
889
+ if step < 2000:
890
+ return step / 2000 # LR gradually increases
891
+ return 1.0
892
+
893
+
894
+ from torch.optim.lr_scheduler import LambdaLR
895
+
896
+ def trapezoidal_lr_scheduler(optimizer, max_lr, total_steps, warmup_steps, plateau_steps, decay_steps):
897
+ """
898
+ Trapezoidal learning rate scheduler:
899
+ - Increases linearly for `warmup_steps` steps.
900
+ - Remains constant for `plateau_steps` steps.
901
+ - Decreases linearly for `decay_steps` steps.
902
+ """
903
+ def lr_lambda(step):
904
+ if step < warmup_steps:
905
+ # Linear warmup
906
+ return float(step) / float(max(1, warmup_steps))
907
+ elif step < warmup_steps + plateau_steps:
908
+ # Constant plateau
909
+ return 1.0
910
+ else:
911
+ # Linear decay
912
+ decay_step = step - (warmup_steps + plateau_steps)
913
+ return max(0.0, float(decay_steps - decay_step) / float(max(1, decay_steps)))
914
+
915
+ return LambdaLR(optimizer, lr_lambda)
916
+
917
+
918
+ torch.set_float32_matmul_precision('high')
919
+
920
+
921
+ def train():
922
+ setup()
923
+ device = int(os.environ["LOCAL_RANK"])
924
+
925
+ torch.cuda.set_device(int(device))
926
+
927
+ # train_dataloader = prepare_dataset(ModelArgs.batch_size)
928
+ # rank = torch.distributed.get_rank()
929
+ print(f"Start running DDP on rank {device}.")
930
+ # # create model and move it to GPU with id rank
931
+ # device_id = rank % torch.cuda.device_count()
932
+ # CFG = ModelArgs()
933
+
934
+ if(device == 0):
935
+
936
+
937
+
938
+ # # Initialise run
939
+ wandb.init(
940
+ # entity = 'rajceo2031',
941
+ project = 'Llama-DDP-Pretrain-10-billion-tokens',
942
+ # config = CFG,
943
+ # save_code = True,
944
+ #group = 'ANN',
945
+ #job_type = 'train'
946
+ )
947
+
948
+ model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
949
+ # Optimizer setup and scheduler steup
950
+
951
+ model = model.to(device)
952
+
953
+ print(f"Model on device {device} is ready")
954
+ # Wrap model with DDP after moving to GPU
955
+ # model = DDP(model, device_ids=[device])
956
+ optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim)
957
+ # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=4000, T_mult=1, eta_min=1e-5)
958
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30000, eta_min=1e-6)
959
+ _load_snapshot('/kaggle/input/models/snapshot2.pt', model, optimizer, scheduler)
960
+ optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim)
961
+
962
+ # model = torch.compile(model)
963
+
964
+ # Define the trapezoidal learning rate scheduler
965
+ total_steps = 100000 # Total steps (40k + 20k + 40k)
966
+ warmup_steps = 40000 # Steps for warmup (increase)
967
+ plateau_steps = 20000 # Steps for plateau (constant)
968
+ decay_steps = 40000 # Steps for decay (decrease)
969
+
970
+
971
+ # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25000, eta_min=1e-6) #with the prev optim snapshot
972
+ new_scheduler = trapezoidal_lr_scheduler(optimizer, ModelArgs.max_lr, total_steps, warmup_steps, plateau_steps, decay_steps)
973
+
974
+ # warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_fn)
975
+ # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)
976
+ # Cosine decay after warmup
977
+ # new_scheduler = CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)
978
+
979
+ # Combine both schedulers
980
+ # scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, new_scheduler], milestones=[2000])
981
+
982
+ # Reset learning rate to 1e-4
983
+ # for param_group in optimizer.param_groups:
984
+ # param_group['lr'] = ModelArgs.max_lr
985
+ # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1, eta_min=1e-6)
986
+ # print("Old optimizer with new lr ready")
987
+ model = DDP(model, device_ids=[device])
988
+ print(f"Model on device {device} is ready")
989
+
990
+
991
+ # optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr)
992
+ # Create DataLoader with collate_fn
993
+ # train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device))
994
+ # val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device))
995
+ # print("Loader is ready")
996
+ # print(train_loader)
997
+ # print(next(iter(train_loader)))
998
+
999
+ save_chechpoint_iter = 1000
1000
+ total_iters = 20000
1001
+ eval_iters = 200
1002
+ eval_check = 100
1003
+ # for X,y in train_loader:
1004
+ # print(X.shape)
1005
+ # print(y.shape)
1006
+
1007
+ # alpaca_prompt = '''
1008
+
1009
+ # ### Instruction:
1010
+ # {instruction}
1011
+
1012
+ # ### Input:
1013
+ # {input}
1014
+
1015
+ # ### Response:
1016
+
1017
+ # '''
1018
+ # Only create progress bar for rank 0
1019
+ # eval_epoch_iterator = range(eval_iters)
1020
+ # train_epoch_iterator = range(total_iters)
1021
+ # if device == 0:
1022
+ # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training")
1023
+
1024
+ # train_epoch_iterator = range(ModelArgs.epochs)
1025
+ # if device == 0: # Ensure tqdm only runs on rank 0
1026
+ # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training Progress", position=0, leave=True)
1027
+
1028
+ # lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters)
1029
+ world_size = torch.cuda.device_count()
1030
+ @torch.inference_mode()
1031
+ def estimate_loss(val_loader, train_loader=None):
1032
+ out = {}
1033
+ # train_loader = prepare_dataset('train', ModelArgs.batch_size)
1034
+ model.eval()
1035
+ loader = None
1036
+ epoch_loss = None
1037
+ epoch_losses = []
1038
+ # print("Starting the eval...")
1039
+ for split in ['train', 'val']:
1040
+ print(f"Starting with {split} evaluation...")
1041
+ # losses = torch.zeros(ModelArgs.val_epochs)
1042
+ if(split == 'train'):
1043
+ loader = train_loader
1044
+ if(split == 'val'):
1045
+ loader = val_loader
1046
+ for step in range(eval_check):
1047
+ total_loss = 0
1048
+ # loader.sampler.set_epoch(step)
1049
+ total_batches = 0
1050
+ batch = next(iter(loader))
1051
+ # for batch in loader: # Loop through DataLoader batches
1052
+ idx = batch['input_ids']
1053
+ targets = batch['labels']['input_ids']
1054
+ idx = idx.to(device)
1055
+ targets = targets.to(device)
1056
+
1057
+ logits = model(idx)
1058
+ batch_size, block_size, embeddings_dims = logits.shape
1059
+ logits = logits.view(batch_size * block_size, embeddings_dims) # Flatten tokens
1060
+ targets = targets.view(batch_size * block_size)
1061
+
1062
+ loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
1063
+
1064
+ total_loss += loss.item()
1065
+ total_batches += 1
1066
+
1067
+ # Compute mean loss for this epoch
1068
+ epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0
1069
+ epoch_losses.append(epoch_loss)
1070
+
1071
+ # print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}")
1072
+
1073
+ # Compute mean loss across all evaluation epochs
1074
+ out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0
1075
+ epoch_loss = None
1076
+ epoch_losses = []
1077
+
1078
+ model.train()
1079
+ return out
1080
+
1081
+ # model = model.to(rank)
1082
+ model.train()
1083
+ train_dataloader = prepare_dataset('train', ModelArgs.batch_size)
1084
+ val_loader= prepare_dataset('val', ModelArgs.batch_size)
1085
+ # for step in tqdm(range(total_iters)):
1086
+ for epoch in range(ModelArgs.epochs):
1087
+ # torch.cuda.synchronize()
1088
+
1089
+ train_dataloader.sampler.set_epoch(epoch)
1090
+
1091
+ val_loader.sampler.set_epoch(epoch)
1092
+ print("Loaders ready both")
1093
+ epochs = ModelArgs.epochs
1094
+
1095
+ # train_step_iterator = range(len(train_dataloader))
1096
+ # if device == 0: # Only create progress bar on rank 0
1097
+ # train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True)
1098
+
1099
+ # Print progress on rank 0
1100
+ train_loader_length = 0
1101
+ if(device == 0):
1102
+ train_loader_length = len(train_dataloader)
1103
+ print("Total batches: ", train_loader_length)
1104
+ # print("Length of : ", len(train_dataloader))
1105
+ # print("Length of val: ", len(val_loader))
1106
+ for step, batch in enumerate(train_dataloader):
1107
+ # print("Dataloader things: ", batch)
1108
+ # print("Total batches: ", len(train_dataloader))
1109
+ if(device == 0):
1110
+ if(step % 100 == 0):
1111
+ # if(step == train_loader_length):
1112
+ # break
1113
+ print("Batch : ", step, "/", len(train_dataloader))
1114
+ # all_gpus_avg_train_loss = None
1115
+ # all_gpus_avg_val_loss = None
1116
+ # every once in a while evaluate the loss on train and val sets
1117
+ if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
1118
+ losses = estimate_loss( val_loader, train_dataloader)
1119
+ avg_train_loss = losses['train']
1120
+ avg_val_loss = losses['val']
1121
+ # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
1122
+ # if device == 0: # Only print on main process
1123
+ print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f} | Val Loss: {losses['val']:.4f}")
1124
+ # print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}")
1125
+ # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
1126
+ # Log training loss more frequently
1127
+ # Aggregate average loss across all GPUs
1128
+ avg_train_loss = torch.Tensor([losses['train']]).to(device)
1129
+ avg_val_loss = torch.Tensor([losses['val']]).to(device)
1130
+ torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
1131
+ torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
1132
+
1133
+ if device == 0:
1134
+ all_gpus_avg_train_loss = avg_train_loss / world_size
1135
+ print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}")
1136
+ all_gpus_avg_val_loss = avg_val_loss / world_size
1137
+ print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}")
1138
+
1139
+ # if device == 0:
1140
+
1141
+ # writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step)
1142
+ # writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step)
1143
+ # writer.add_scalar("training_step_loss", losses['train'], global_step=step)
1144
+ # writer.add_scalar("val_step_loss", losses['val'], global_step=step)
1145
+ # writer.add_scalar("GPU", device, global_step=step)
1146
+ # writer.add_scalar("Epoch", epoch, global_step=step)
1147
+
1148
+ wandb.log({
1149
+ "Learning Rate": new_scheduler.get_last_lr()[0] ,
1150
+ "All_GPUs_Train_losses": all_gpus_avg_train_loss,
1151
+ "All_GPUs_Val_losses": all_gpus_avg_val_loss,
1152
+ "training_step_loss": losses['train'],
1153
+ "val_step_loss": losses['val'],
1154
+ "Step": step,
1155
+ "Epoch": epoch
1156
+ })
1157
+
1158
+
1159
+
1160
+ #Loading a checkpoint
1161
+ # if(os.path.exists('snapshot.pt')):
1162
+ # model, optimizer = _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt')
1163
+
1164
+ # if(step % save_chechpoint_iter == 0 and device == 0 and step != 0):
1165
+
1166
+ # _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step)
1167
+
1168
+ if step % save_chechpoint_iter == 0 and device == 0 and step != 0:
1169
+ print(f"Saving the model checkpoint for step: {step}")
1170
+ _save_snapshot(model, optimizer, scheduler, epoch, step)
1171
+
1172
+ # batch = {k: v.to(self.local_rank) for k, v in batch.items()}
1173
+ idx = batch['input_ids'].to(device)
1174
+ # idx, targets = get_batch(split='train')
1175
+ # print(f"Starting the train step: {step}...")
1176
+ # for idx, targets in train_loader:
1177
+ # idx, targets = next(iter(train_loader))
1178
+
1179
+ # print("Idx: ", idx)
1180
+ # print("Targets: ", targets)
1181
+
1182
+ # idx = idx.to(device)
1183
+ # print("Idx: ", idx)
1184
+ # print("Targets: ", targets)
1185
+ targets = batch['labels']['input_ids'].to(device)
1186
+ # with torch.autocast(device_type=device, dtype=torch.bfloat16()):
1187
+ logits = model(idx)
1188
+ batch_size, block_size, embeddings_dims = logits.shape
1189
+ # print(logits.shape)
1190
+ # print(targets)
1191
+ logits = logits.view(batch_size*block_size, embeddings_dims)
1192
+ # print("OK")
1193
+ targets = targets.view(batch_size * block_size)
1194
+ # print("OK2")
1195
+ loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
1196
+
1197
+ optimizer.zero_grad(set_to_none=True)
1198
+ loss.backward()
1199
+ # Compute gradient norms before clipping
1200
+ total_norm_before = torch.norm(
1201
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
1202
+ )
1203
+
1204
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip)
1205
+
1206
+ # Compute gradient norms after clipping
1207
+ total_norm_after = torch.norm(
1208
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
1209
+ )
1210
+
1211
+ if(device == 0 and step !=0 and step % 100 == 0):
1212
+ print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
1213
+ print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")
1214
+
1215
+ optimizer.step()
1216
+ new_scheduler.step()
1217
+ # torch.cuda.synchronize()
1218
+ # print(loss.item())
1219
+ # if(step % 100 == 0):
1220
+ # print(f'Step : {step} | GPU: {device} Loss: {loss.item()}')
1221
+ # if device == 0:
1222
+ # print("loss: ", loss.item())
1223
+ # train_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
1224
+ # print(loss.item())
1225
+ # break
1226
+
1227
+ # if step != 0 and (step % eval_iters == 0 or step == total_steps -1) :
1228
+ # loss_values = estimate_loss()
1229
+ # print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val']))
1230
+
1231
+ # Add after a training step:
1232
+ # unused_params = find_unused_parameters(model)
1233
+ # print("Unused parameters:", unused_params)
1234
+ # break
1235
+ # if device == 0 and step % 200 == 0 and step != 0:
1236
+ # count = 5
1237
+ # while(count): # Only generate text on the main process
1238
+ # print("Generating text...")
1239
+
1240
+ # alpaca_prompt = '''
1241
+
1242
+ # ### Instruction:
1243
+ # {}
1244
+
1245
+ # ### Input:
1246
+ # {}
1247
+
1248
+ # ### Response:
1249
+
1250
+ # '''
1251
+
1252
+ # prompt = alpaca_prompt.format("You are a helpful assistant.", "Say a joke.", "")
1253
+ # generated_text = greedy_decode(
1254
+ # model,
1255
+ # tokenizer,
1256
+ # prompt,
1257
+ # max_length=60,
1258
+ # repetition_penalty=1.2,
1259
+ # context_window=10,
1260
+ # temperature=0.7 # Lower temperature for more deterministic output
1261
+ # )
1262
+ # # generated_text = beam_search(model, tokenizer, prompt, beam_width=5, max_length=50, temperature=1.0)
1263
+ # print(f" Step: {step} | Generated Text: {generated_text}")
1264
+ # save_to_file(generated_text)
1265
+ # count -= 1
1266
+
1267
+ # if step != 0:
1268
+ # train_step_iterator.set_postfix({"Train loss": f"{all_gpus_avg_train_loss.item():.4f} | Val Loss : {all_gpus_avg_val_loss.item():.4f}"})
1269
+
1270
+
1271
+ # break
1272
+ # Cleanup
1273
+ if device == 0:
1274
+ # writer.close()
1275
+ wandb.finish()
1276
+ cleanup()
1277
+
1278
+
1279
+ world_size = torch.cuda.device_count()
1280
+ print(f"World size: {world_size}")
1281
+ train()
1282
+
inference.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import ModelArgs
2
+ from model import Llama
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from tokenizer import Tokenizer
6
+ import argparse
7
+
8
+
9
+ tokenizer = Tokenizer()
10
+ tokenizer = tokenizer.ready_tokenizer()
11
+
12
+
13
+ def remove_prefix(state_dict, prefix):
14
+ new_state_dict = {}
15
+ for key, value in state_dict.items():
16
+ if key.startswith(prefix):
17
+ new_key = key[len(prefix):] # Remove the prefix
18
+ new_state_dict[new_key] = value
19
+ else:
20
+ new_state_dict[key] = value
21
+ return new_state_dict
22
+
23
+
24
+ def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0):
25
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
26
+ generated_tokens = []
27
+ ModelArgs.inference=True
28
+ for _ in range(max_length):
29
+ with torch.no_grad():
30
+ outputs = model(input_ids)
31
+ logits = outputs[:, -1, :]
32
+
33
+ probs = F.softmax(logits, dim=-1)
34
+
35
+ # Top-k filtering
36
+ top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
37
+
38
+
39
+ # Apply temperature scaling
40
+ # probs = probs / temperature
41
+
42
+ # Sample from top-k
43
+ next_token = torch.multinomial(top_k_probs, num_samples=1)
44
+
45
+ # generated_tokens.append(next_token.item())
46
+
47
+ xcol = torch.gather(top_k_indices, -1, next_token)
48
+ input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence
49
+
50
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)
51
+
52
+ def main():
53
+
54
+ torch.set_float32_matmul_precision('high')
55
+
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument("--prompt", type=str, default="Once upon a time")
58
+ parser.add_argument("--max_length", type=int, default=128)
59
+ parser.add_argument("--temperature", type=float, default=1.0)
60
+ parser.add_argument("--top_k", type=int, default=50)
61
+
62
+ # parser.add_argument("--repetition_penalty", type=float, default=1.2)
63
+ args = parser.parse_args()
64
+
65
+ model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
66
+ # model = torch.compile(model)
67
+ model = model.to(ModelArgs.device)
68
+
69
+ dict_model = torch.load('weights/pretrained/snapshot_4650.pt')
70
+ dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
71
+ model.load_state_dict(dict_model['MODEL_STATE'])
72
+ model.eval()
73
+ print("Model ready")
74
+ # prompt = 'Its a secret'
75
+
76
+ with torch.no_grad():
77
+ generated_text = topk_sampling(model, args.prompt, max_length=args.max_length, top_k=50, temperature=args.temperature, device=ModelArgs.device)
78
+ print("Gnerated: ", generated_text)
79
+ # generated_text = beam_search(model, tokenizer, args.prompt, beam_width=5, max_length=50, temperature=1.0)
80
+ print(args.prompt + generated_text)
81
+
82
+
83
+ if __name__ == '__main__':
84
+ main()
llama_torchrun.py ADDED
@@ -0,0 +1,1435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Based on Llama from Meta (https://github.com/meta-llama/llama/blob/main/llama/model.py)
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from dataclasses import dataclass
8
+ from tokenizers import Tokenizer
9
+ from pathlib import Path
10
+ import torch.multiprocessing as mp
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ from torch.distributed import init_process_group, destroy_process_group
14
+ import torch
15
+ from datasets import Dataset
16
+ from torch.utils.data import DataLoader
17
+ from transformers.models.prophetnet.modeling_prophetnet import ProphetNetDecoderModelOutput
18
+ import wandb
19
+ from tqdm import tqdm
20
+ from functools import partial
21
+ import tiktoken
22
+ import torch.optim as optim
23
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
24
+
25
+ # Load model directly
26
+ from transformers import AutoTokenizer, AutoModelForCausalLM
27
+ import os
28
+
29
+ torch.manual_seed(1337)
30
+ torch.cuda.manual_seed(1337)
31
+
32
+
33
+
34
+
35
+ # import wandb
36
+ # wandb.login()
37
+
38
+
39
+ # from torch.utils.tensorboard import SummaryWriter
40
+
41
+
42
+ from datasets import load_dataset, concatenate_datasets
43
+
44
+ # data = {}
45
+ # texts = []
46
+ # with open('data/input.txt', 'r') as f:
47
+ # texts.append(f.readlines())
48
+
49
+ # # print(texts)
50
+ # # print(len(texts[0]))
51
+ # data = {
52
+ # "text": texts[0]
53
+ # }
54
+ # fw_train = Dataset.from_dict(data)
55
+ # print(fw_train)
56
+ # fw_train = load_dataset("karpathy/tiny_shakespeare", split="train", trust_remote_code=True)
57
+ # print(fw_train['text'])
58
+ # text = fw_train['text'][0].split("\n")
59
+ # print(text)
60
+ # filtered_lines = [line for line in text if line != '']
61
+ # print(len(filtered_lines))
62
+ # use name="sample-10BT" to use the 10BT sample
63
+
64
+ tinystories = True
65
+ fw = False
66
+ fw_train = None
67
+ fw_test = None
68
+ if(tinystories):
69
+ fw_train = load_dataset("roneneldan/TinyStories", split="train")
70
+ fw_test = load_dataset("roneneldan/TinyStories", split="validation")
71
+ print(fw_train)
72
+ print(fw_test)
73
+ if(fw):
74
+ fw_train = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False)
75
+ fw_train = fw_train.train_test_split(test_size=0.01)
76
+ print(fw_train)
77
+ print(fw_train)
78
+ # Select only 1000 rows from the dataset
79
+ # fw_train = fw_train.select(range(1000000))
80
+ # alpaca = load_dataset("yahma/alpaca-cleaned", split='train')
81
+ # dolly = load_dataset("llm-wizard/dolly-15k-instruction-alpaca-format", split='train')
82
+ # merged_dataset = concatenate_datasets([alpaca, dolly])
83
+ # dataset = load_dataset("swype/instruct", split='train', trust_remote_code=True)
84
+ # print(fw_train)
85
+ # Split the dataset into training and validation sets
86
+ # Split the dataset into training and validation sets
87
+ # fw_train = fw_train.train_test_split(test_size=0.01)
88
+ # print(fw_train)
89
+
90
+
91
+ # Access the splits
92
+ # train_dataset = train_val_split['train']
93
+ # val_dataset = train_val_split['test']
94
+
95
+ # train_dataset = fw_train.train_test_split(test_size=0.2)
96
+
97
+
98
+ def setup(rank=None, world_size=None):
99
+ # os.environ['MASTER_ADDR'] = 'localhost'
100
+ # os.environ['MASTER_PORT'] = '12355'
101
+ init_process_group("nccl")
102
+ # torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
103
+
104
+ def cleanup():
105
+ destroy_process_group()
106
+
107
+
108
+
109
+ @dataclass
110
+ class ModelArgs:
111
+ #Hyperparameters
112
+
113
+ epochs = 4
114
+ block_size = 512
115
+ batch_size = 64
116
+ embeddings_dims = 512
117
+ attn_dropout = 0.1
118
+ no_of_heads = 8
119
+ dropout = 0.1
120
+ # epochs = 100
121
+ val_epochs = 2
122
+ max_lr = 6e-4
123
+ no_of_decoder_layers = 8 #IMP needs to be thoroughly calculated
124
+ weight_decay_optim = 0.1
125
+ beta_1 = 0.9
126
+ beta_2 = 0.95
127
+ clip = 1.0
128
+ device = 'cuda'
129
+ no_kv_heads = 2
130
+ vocab_size = 50304 #powers of 2 so nice!
131
+ eps = 1e-5
132
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
133
+ # dtype = 'bfloat16'
134
+
135
+
136
+
137
+
138
+
139
+ def _save_snapshot(model, optimizer, scheduler, epoch, step):
140
+ snapshot = {
141
+ "MODEL_STATE": model.module.state_dict(),
142
+ "OPTIMIZER_STATE": optimizer.state_dict(),
143
+ # "SCHEDULER_STATE": scheduler.state_dict(),
144
+ "EPOCHS_RUN": epoch,
145
+ "STEP_RUN": step
146
+ }
147
+ torch.save(snapshot, f"snapshot_{step}.pt")
148
+ print(f"Epoch: {epoch} | Step: {step} | Snapshot saved.")
149
+
150
+ def _load_snapshot(snapshot_path, model, optimizer, scheduler):
151
+ snapshot = torch.load(snapshot_path)
152
+ model.load_state_dict(snapshot["MODEL_STATE"])
153
+ optimizer.load_state_dict(snapshot["OPTIMIZER_STATE"])
154
+ # scheduler.load_state_dict(snapshot["SCHEDULER_STATE"]) # Load scheduler state
155
+ epoch = snapshot["EPOCHS_RUN"]
156
+ step = snapshot["STEP_RUN"]
157
+ print(f"Resuming from Epoch {epoch}, Step {step}")
158
+ return epoch, step
159
+
160
+
161
+
162
+
163
+
164
+
165
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '...')
166
+
167
+ # tokenizer.pad_token = tokenizer.eos_token
168
+ # if tokenizer.pad_token is None:
169
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
170
+ # print("ADDED THE TOKENS: ", tokenizer.pad_token_id)
171
+ # tokenizer.bos_token = "[INST]"
172
+ # tokenizer.eos_token = "[/INST]"
173
+ # model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
174
+
175
+ def tokenize_function(examples):
176
+ return tokenizer(
177
+ examples['text'],
178
+ max_length=ModelArgs.block_size,
179
+ padding='max_length',
180
+ truncation=True,
181
+ return_tensors='pt'
182
+ )
183
+
184
+
185
+
186
+
187
+ def prepare_dataset(split, device, batch_size):
188
+ print("Device is: ", device)
189
+ # alpaca_prompt = '''
190
+
191
+
192
+ # ### Instruction:
193
+ # {}
194
+
195
+ # ### Response:
196
+ # {}
197
+ # '''
198
+ # Load a subset of the C4 dataset with a glob pattern for specific training files
199
+ # dataset = load_dataset("allenai/c4", data_files=["en/c4-train.00001-of-01024.json.gz"], trust_remote_code=True)
200
+
201
+ # Initialize tokenizer
202
+ # tokenizer = AutoTokenizer.from_pretrained("gpt2")
203
+ # generator = torch.Generator(device=device)
204
+ def collate_fn(batch):
205
+ # Extract text data
206
+ texts = [item ["text"] for item in batch]
207
+
208
+ # Set the pad token if it isn't set already
209
+ # if tokenizer.pad_token is None:
210
+ # tokenizer.pad_token = tokenizer.eos_token
211
+ # outputs = []
212
+ # texts = []
213
+ # for item in batch:
214
+ # instruction = item['prompt']
215
+ # # input = item['input']
216
+ # output = item['completion']
217
+ # # out = alpaca_prompt.format(instruction, output)
218
+ # texts.append(instruction)
219
+ # outputs.append(output)
220
+ # Tokenize text data
221
+ input_encodings = tokenizer(texts, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
222
+ # output_encodings = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
223
+ # input_encodings["labels"] = tokenizer(outputs, max_length = ModelArgs.block_size, padding='max_length', truncation=True, return_tensors="pt")
224
+ # out = {"input": input_encodings}
225
+ # input_encodings['input_ids'][: , input_encodings["attention_mask"] == 0] = -100
226
+ input_encodings["labels"] = input_encodings["input_ids"].clone() # Use `input_ids` as labels
227
+
228
+ input_encodings["labels"][:, :-1] = input_encodings["input_ids"][:, 1:] # Shift right
229
+ input_encodings["labels"][:, -1] = tokenizer.eos_token_id # Let the last token be end
230
+ # Return tokenized input tensors
231
+ # return out
232
+ return input_encodings
233
+
234
+ # Create DistributedSampler for proper shuffling and partitioning across processes
235
+ # dist_sampler = DistributedSampler(fw_train["text"], shuffle=True)
236
+
237
+ # Create DataLoader with custom collate_fn
238
+ # print(fw_dataset)
239
+ dataloader = None
240
+ if(tinystories):
241
+ if(split == 'train'):
242
+ data_loader = DataLoader(
243
+ fw_train,
244
+ # generator=generator,
245
+ batch_size=batch_size,
246
+
247
+ sampler=DistributedSampler(fw_train, shuffle=True),
248
+ collate_fn=collate_fn,
249
+ drop_last=True,
250
+ shuffle=False
251
+ )
252
+ elif(split == 'val'):
253
+ data_loader = DataLoader(
254
+ fw_test,
255
+
256
+
257
+ batch_size=batch_size,
258
+ sampler=DistributedSampler(fw_test, shuffle=True),
259
+ collate_fn=collate_fn,
260
+ drop_last=True,
261
+ shuffle=False
262
+ )
263
+ elif(fw):
264
+ if(split == 'train'):
265
+ data_loader = DataLoader(
266
+ fw_train['train'],
267
+ batch_size=batch_size,
268
+
269
+
270
+ sampler=DistributedSampler(fw_train['train'], shuffle=True),
271
+ collate_fn=collate_fn,
272
+ drop_last=True,
273
+ shuffle=False
274
+ )
275
+ elif(split == 'val'):
276
+ data_loader = DataLoader(
277
+ fw_train['test'],
278
+ batch_size=batch_size,
279
+ # generator=generator,
280
+ sampler=DistributedSampler(fw_train["test"]),
281
+ collate_fn=collate_fn,
282
+
283
+ drop_last=True,
284
+ shuffle=False
285
+ )
286
+ return data_loader
287
+
288
+
289
+
290
+
291
+
292
+
293
+ class Normalization(nn.Module):
294
+ def __init__(
295
+ self,
296
+
297
+ embeddings_dims: int = ModelArgs.embeddings_dims
298
+ ):
299
+ super().__init__()
300
+ self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims)
301
+
302
+
303
+ def forward(self, x):
304
+
305
+ x = self.rmsnorm_layer(x)
306
+ return x
307
+
308
+
309
+
310
+
311
+
312
+ # import numpy as np
313
+ class RotaryEmbeddings(nn.Module):
314
+ def __init__(
315
+ self,
316
+ device,
317
+ embeddings_dims: int = ModelArgs.embeddings_dims,
318
+ block_size: int = ModelArgs.block_size,
319
+ batch_size: int = ModelArgs.batch_size
320
+ ):
321
+ super().__init__()
322
+
323
+ self.embeddings_dims = embeddings_dims
324
+ self.block_size = block_size
325
+ self.batch_size = batch_size
326
+ self.theta = 0
327
+ self.device=device
328
+
329
+ # self.d_model = embeddings_dims
330
+ # self.i = torch.arange(0, embeddings_dims, dtype=torch.float32)
331
+ # # self.pos = torch.arange(0, block_size, dtype=torch.float32)
332
+ # self.exp = ((2 * self.i)) / self.d_model
333
+ # self.theta = 10000 ** self.exp
334
+ # # print(self.theta.shape)
335
+ # self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims,dtype=torch.float32, device=device)
336
+
337
+ # self.cos = torch.cos((self.i / self.theta))
338
+ # self.sin = torch.sin((self.i / self.theta))
339
+
340
+ # self.even = self.sin[::2]
341
+ # self.odd = self.cos[1::2]
342
+
343
+ # # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype)
344
+ # self.x_reshaped[..., : , ::2] = self.even
345
+ # self.x_reshaped[..., : , 1::2] = self.odd
346
+
347
+
348
+ def apply_rope(self, seq):
349
+ batch_size, seq_len, embeds_dims = seq.shape
350
+ # print(seq.shape)
351
+ # print(self.embeddings_dims)
352
+ # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
353
+
354
+ positions = torch.arange(0 , embeds_dims, 2, dtype=torch.float32, device = self.device).unsqueeze(0)
355
+ # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
356
+ theta = 10000 ** (-2 * (positions) / embeds_dims)
357
+ angles = positions * theta
358
+ angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
359
+ x_reshaped = seq.view(batch_size, seq_len, embeds_dims // 2, 2)
360
+
361
+ cos_angles = torch.cos(angles)
362
+ sin_angles = torch.sin(angles)
363
+ # print(cos_angles.shape)
364
+ # print(sin_angles.shape)
365
+ # print(x_reshaped.shape)
366
+ # indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device)
367
+
368
+ out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=-1)
369
+ out = out.view(batch_size, seq_len, embeds_dims)
370
+ return out
371
+
372
+ def forward(self, x):
373
+ # print("X shape: ", x.shape)
374
+ # print("X is: ", x)
375
+ # B,T,C = x.shape
376
+ # print("MATRIX:",x)
377
+ # if(x > self.block_size or x < self.block_size):
378
+ # matrix = self.init_matrix(x)
379
+ # return matrix
380
+ # else:
381
+ # matrix = self.init_matrix(self.block_size)
382
+
383
+ # return matrix
384
+ # if(ModelArgs.inference):
385
+ res = self.apply_rope(x)
386
+ return res
387
+ # else:
388
+ # return self.x_reshaped
389
+
390
+ class RotaryAttentionHead(nn.Module):
391
+ def __init__(
392
+ self,
393
+ device,
394
+ embeddings_dims: int = ModelArgs.embeddings_dims,
395
+ no_of_heads: int = ModelArgs.no_of_heads,
396
+ attn_dropout: int = ModelArgs.attn_dropout
397
+ ):
398
+ super().__init__()
399
+ self.head_size = embeddings_dims // no_of_heads
400
+ self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
401
+ self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
402
+ self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
403
+ self.rope = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
404
+ self.dropout = nn.Dropout(p = attn_dropout)
405
+ self.device = device
406
+ def forward(self,x):
407
+ # print(x.shape)
408
+ # print("X is: ", x)
409
+ batch, block_size, embeddings_dims = x.shape
410
+ query = self.query(x)
411
+ # print(query)
412
+ key = self.key(x)
413
+ values = self.value(x)
414
+ # matrix = self.rotary_matrix(block_size)
415
+ rotary_q = self.rope(query)
416
+ rotary_k = self.rope(key)
417
+
418
+ # print(matrix.shape)
419
+ # print(query.shape)
420
+ masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
421
+ # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
422
+ # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
423
+ weights = rotary_q.permute(2,0,1) @ rotary_k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
424
+ weights_masked = weights.masked_fill(masked == 0, float('-inf'))
425
+ scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
426
+ scaled_weights = F.softmax(scaled_weights, dim=-1)
427
+ value = scaled_weights @ values
428
+ out = self.dropout(value)
429
+ return out
430
+
431
+
432
+ # # import numpy as np
433
+ # class RotaryEmbeddings(nn.Module):
434
+ # def __init__(
435
+ # self,
436
+ # device,
437
+ # embeddings_dims: int = ModelArgs.embeddings_dims,
438
+ # block_size: int = ModelArgs.block_size,
439
+ # batch_size: int = ModelArgs.batch_size
440
+ # ):
441
+ # super().__init__()
442
+
443
+ # self.embeddings_dims = embeddings_dims
444
+ # self.block_size = block_size
445
+ # self.batch_size = batch_size
446
+ # self.theta = 0
447
+
448
+
449
+ # # def init_matrix(self, seq_len):
450
+ # # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False)
451
+ # # for pos in range(seq_len):
452
+ # # for j in range(1, self.embeddings_dims // 2):
453
+ # # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
454
+ # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
455
+ # # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
456
+ # # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
457
+ # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
458
+ # # return self.matrix
459
+ # self.device=device
460
+
461
+ # def init_matrix(self, seq_len):
462
+ # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
463
+
464
+ # positions = torch.arange(0 , seq_len, 2, dtype=torch.float32, device = self.device).unsqueeze(1)
465
+ # # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
466
+ # theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
467
+ # angles = positions * theta
468
+
469
+ # cos_angles = torch.cos(angles)
470
+ # sin_angles = torch.sin(angles)
471
+
472
+ # indices = torch.arange(seq_len, dtype=torch.int64, device = self.device)
473
+ # # print(indices)
474
+ # # print(indices.shape)
475
+ # # print(indices[::2])
476
+ # even_indices = indices[::2]
477
+ # odd_indices = indices[1::2]
478
+
479
+ # self.matrix[:, even_indices, even_indices] = cos_angles
480
+ # self.matrix[:, odd_indices, odd_indices] = sin_angles
481
+ # self.matrix[:, odd_indices, even_indices] = -sin_angles
482
+ # self.matrix[:, even_indices, odd_indices] = cos_angles
483
+
484
+ # return self.matrix
485
+
486
+ # def forward(self, x):
487
+ # # B,T,C = x.shape
488
+ # # print("MATRIX:",x)
489
+ # if(x > self.block_size or x < self.block_size):
490
+ # matrix = self.init_matrix(x)
491
+ # return matrix
492
+ # else:
493
+ # matrix = self.init_matrix(self.block_size)
494
+
495
+ # return matrix
496
+
497
+
498
+ # class RotaryAttentionHead(nn.Module):
499
+ # def __init__(
500
+ # self,
501
+ # device,
502
+ # embeddings_dims: int = ModelArgs.embeddings_dims,
503
+ # no_of_heads: int = ModelArgs.no_of_heads,
504
+ # attn_dropout: int = ModelArgs.attn_dropout
505
+ # ):
506
+ # super().__init__()
507
+ # self.head_size = embeddings_dims // no_of_heads
508
+ # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
509
+ # self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
510
+ # self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
511
+ # self.rotary_matrix = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
512
+ # self.dropout = nn.Dropout(p = attn_dropout)
513
+ # self.device = device
514
+ # def forward(self,x):
515
+ # # print(x.shape)
516
+ # batch, block_size, embeddings_dims = x.shape
517
+ # query = self.query(x)
518
+ # # print(query)
519
+ # key = self.key(x)
520
+ # values = self.value(x)
521
+ # matrix = self.rotary_matrix(block_size)
522
+
523
+ # # print(matrix.shape)
524
+ # # print(query.shape)
525
+ # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
526
+ # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
527
+ # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
528
+ # weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
529
+ # weights_masked = weights.masked_fill(masked == 0, float('-inf'))
530
+ # scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
531
+ # scaled_weights = F.softmax(scaled_weights, dim=-1)
532
+ # value = scaled_weights @ values
533
+ # out = self.dropout(value)
534
+ # return out
535
+
536
+
537
+ class MQA(nn.Module):
538
+ def __init__(
539
+ self,
540
+ device,
541
+ no_of_q_heads: int,
542
+ embeddings_dims: int = ModelArgs.embeddings_dims,
543
+ block_size: int = ModelArgs.block_size,
544
+
545
+
546
+ ):
547
+ super().__init__()
548
+
549
+
550
+ # self.no_of_q_heads = no_of_heads // no_of_kv_heads
551
+ # self.no_of_q_heads = no_of_q_heads
552
+ self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads
553
+ self.head_size = embeddings_dims // no_of_q_heads
554
+ # self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2
555
+ self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
556
+ # self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size, device = device)
557
+ # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False)
558
+ self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
559
+ self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
560
+ self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
561
+ self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
562
+ self.device = device
563
+ self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, device = self.device) for _ in range(self.no_of_kv_heads)])
564
+
565
+ def scaled_dot_product(self, q, k, v, block_size):
566
+
567
+ # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
568
+ q = self.rotary(q)
569
+ masked_table = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
570
+ # rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
571
+ # rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
572
+ # print("Query: ", q.shape)
573
+ # print("Keys: ", k.shape)
574
+ # print(q.permute(2,0,1).shape)
575
+ # print(k.permute(2,0,1).transpose(-2, -1).shape)
576
+ # weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
577
+ # weights = q @ k.permute(2,1,0)
578
+ # print(weights.shape)
579
+ # print(masked.shape)
580
+ weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
581
+ masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
582
+ weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
583
+ weights_normalized = self.dropout(weights_normalized)
584
+ out = weights_normalized @ v
585
+ return out
586
+
587
+ def forward(self,x):
588
+ # print("MQA: ", x.shape)
589
+ batch, block_size, embeddings_dims = x.shape
590
+
591
+ # query = self.query(x)
592
+ # matrix = self.rotary_matrix(block_size)
593
+
594
+
595
+ key = self.key(x)
596
+ values = self.value(x)
597
+ # print("Keys: ", key.shape)
598
+ # print("Values: ", values.shape)
599
+ # rotary_value = self.rotary(values)
600
+ rotary_key = self.rotary(key)
601
+ multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size) for query in self.multi_query], dim=-1)
602
+ # print("Multi query: ", multi_query_concat.shape)
603
+
604
+ linear_layer= self.linear_layer(multi_query_concat)
605
+ # out = self.dropout(linear_layer)
606
+ return linear_layer
607
+
608
+
609
+ class GQA(nn.Module):
610
+ def __init__(
611
+ self,
612
+ device,
613
+ embeddings_dims: int = ModelArgs.embeddings_dims,
614
+ block_size: int = ModelArgs.block_size,
615
+ # no_of_q_heads: int = ModelArgs.no_of_heads,
616
+ mqa_heads: int = ModelArgs.no_kv_heads
617
+ ):
618
+ super().__init__()
619
+
620
+ # self.no_of_kv_heads = no_of_kv_heads
621
+ self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads
622
+ # self.head_dim = embeddings_dims // self.no_kv_heads
623
+ self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
624
+ self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device)
625
+ self.device = device
626
+ self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)])
627
+ # self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size)
628
+ def forward(self,x):
629
+
630
+ batch, block_size, embeddings_dims = x.shape
631
+
632
+ # res = self.mqa(x)
633
+ grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1)
634
+
635
+ linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one
636
+ out = self.dropout(linear_layer)
637
+ return out
638
+
639
+
640
+ class Swish(nn.Module):
641
+ def __init__(
642
+ self,
643
+ device,
644
+ block_size: int = ModelArgs.block_size,
645
+ embeddings_dims: int = ModelArgs.embeddings_dims
646
+ ):
647
+ super().__init__()
648
+
649
+ self.sig = torch.nn.Sigmoid()
650
+
651
+
652
+ def forward(self, x):
653
+ swish = x * self.sig(x)
654
+
655
+ return swish
656
+
657
+
658
+
659
+ class SWiGLU(nn.Module):
660
+ def __init__(
661
+ self,
662
+ device,
663
+ block_size: int = ModelArgs.block_size,
664
+ embeddings_dims: int = ModelArgs.embeddings_dims
665
+ ):
666
+ super().__init__()
667
+ self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3)
668
+ self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
669
+ self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
670
+ self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
671
+ self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
672
+
673
+
674
+
675
+
676
+ def forward(self, x):
677
+ swish_res = self.swish(self.linear_layer1(x))
678
+ x_V = self.linear_layer2(x)
679
+ res = torch.mul(swish_res, x_V)
680
+ out = self.linear_layer3(res)
681
+ return out
682
+
683
+
684
+
685
+ class FFN(nn.Module):
686
+ def __init__(self,
687
+ device,
688
+ embeddings_dims: int = ModelArgs.embeddings_dims,
689
+ block_size: int = ModelArgs.block_size,
690
+ vocab_size: int = ModelArgs.vocab_size,
691
+ dropout = ModelArgs.dropout
692
+
693
+ ):
694
+ super().__init__()
695
+
696
+ # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
697
+ self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device)
698
+ self.dropout = nn.Dropout(p = dropout)
699
+ def forward(self, x):
700
+
701
+ x = self.swiglue(x)
702
+ # x = self.linear_layer(x)
703
+ x = self.dropout(x)
704
+ return x
705
+
706
+
707
+ class DecoderLayer(nn.Module):
708
+ def __init__(self,
709
+ device,
710
+ embeddings_dims: int = ModelArgs.embeddings_dims,
711
+ dropout = ModelArgs.dropout,
712
+ block_size: int = ModelArgs.block_size,
713
+ vocab_size: int = ModelArgs.vocab_size,
714
+
715
+ ) :
716
+ super().__init__()
717
+
718
+
719
+ self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
720
+ self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2, device = device)
721
+ # self.norm = Normalization(embeddings_dims=embeddings_dims)
722
+ self.norm1 = Normalization(embeddings_dims=embeddings_dims)
723
+ self.norm2 = Normalization(embeddings_dims=embeddings_dims)
724
+ self.dropout = nn.Dropout(p = dropout)
725
+ def forward(self, x):
726
+
727
+ x = x + self.gqa(self.norm1(x))
728
+ x = x + self.feedforward_network(self.norm2(x))
729
+ return x
730
+
731
+
732
+ class Llama(nn.Module):
733
+ def __init__(self,
734
+ device,
735
+ embeddings_dims: int = ModelArgs.embeddings_dims,
736
+ no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
737
+ block_size: int = ModelArgs.block_size,
738
+ vocab_size: int = ModelArgs.vocab_size,
739
+ dropout = ModelArgs.dropout
740
+
741
+ ) :
742
+ super().__init__()
743
+
744
+ self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
745
+ self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)])
746
+ self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
747
+ self.dropout = nn.Dropout(p = dropout)
748
+ # self.norm = Normalization(embeddings_dims)
749
+
750
+
751
+ #weight tying
752
+ self.embeddings.weight = self.linear_layer.weight
753
+
754
+ self.apply(self._init_weights)
755
+
756
+ def _init_weights(self, module):
757
+ if isinstance(module, nn.Linear):
758
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
759
+
760
+ if module.bias is not None:
761
+ nn.init.zeros_(module.bias)
762
+ elif isinstance(module, nn.Embedding):
763
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
764
+
765
+
766
+
767
+ def forward(self, x):
768
+ x = self.embeddings(x)
769
+ x = self.dropout(x)
770
+ x = self.decoder(x)
771
+ # x = self.norm(x)
772
+ x = self.linear_layer(x)
773
+ # out = self.norm(x)
774
+ return x
775
+
776
+
777
+ # from andrej karapathy github
778
+ def topk_sampling(model, prompt, device, max_length=50, top_k=50, temperature=1.0):
779
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
780
+ generated_tokens = []
781
+ ModelArgs.inference=True
782
+ for _ in range(max_length):
783
+ with torch.no_grad():
784
+ outputs = model.module(input_ids)
785
+ logits = outputs[:, -1, :]
786
+
787
+ probs = F.softmax(logits, dim=-1)
788
+
789
+ # Top-k filtering
790
+ top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
791
+
792
+
793
+ # Apply temperature scaling
794
+ # probs = probs / temperature
795
+
796
+ # Sample from top-k
797
+ next_token = torch.multinomial(top_k_probs, num_samples=1)
798
+
799
+ # generated_tokens.append(next_token.item())
800
+
801
+ xcol = torch.gather(top_k_indices, -1, next_token)
802
+ input_ids = torch.cat([input_ids, xcol], dim=1) #1 because is it the dimension of the sequence
803
+
804
+ return tokenizer.decode(input_ids[0], skip_special_tokens=True)
805
+
806
+ def beam_search(model, tokenizer, prompt, beam_width=5, max_length=50, temperature=1.0):
807
+ device = next(model.module.parameters()).device
808
+ input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids']
809
+ beam_scores = torch.zeros(beam_width, device=device)
810
+ beam_sequences = input_ids.repeat(beam_width, 1)
811
+
812
+ for _ in range(max_length):
813
+ outputs = model(beam_sequences)
814
+ logits = outputs[:, -1, :] / temperature
815
+ probs = F.softmax(logits, dim=-1)
816
+ top_probs, top_indices = torch.topk(probs, beam_width, dim=-1)
817
+
818
+ # Expand beams
819
+ beam_scores = beam_scores.unsqueeze(-1) + torch.log(top_probs)
820
+ beam_scores = beam_scores.view(-1)
821
+ top_indices = top_indices.view(-1)
822
+
823
+ # Select top beams
824
+ beam_scores, top_beams = torch.topk(beam_scores, beam_width)
825
+ beam_sequences = torch.cat([beam_sequences[top_beams // beam_width], top_indices[top_beams].unsqueeze(-1)], dim=-1)
826
+
827
+ # Return the best sequence
828
+ best_sequence = beam_sequences[0]
829
+ return tokenizer.decode(best_sequence, skip_special_tokens=True)
830
+
831
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
832
+ # device = "cpu"
833
+ # ModelArgs.device = device
834
+ model = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
835
+ model = model.to(ModelArgs.device)
836
+
837
+ # Printing a summary of the architecture
838
+ # !pip install torchinfo
839
+ from torchinfo import summary
840
+ # idx, targets = get_batch('test')
841
+ idx = torch.randint(
842
+ low=0,
843
+ high=ModelArgs.vocab_size,
844
+ size=(ModelArgs.batch_size, ModelArgs.block_size),
845
+ dtype=torch.long
846
+ )
847
+ # sample_idx = random.randint(range(len(train_dataset)))
848
+ # idx, targets = train_dataset[0]
849
+ idx = idx.to(ModelArgs.device)
850
+ # targets = targets.to(ModelArgs.device)
851
+ summary(model=model,
852
+ input_data=idx,
853
+ # input_size=(ModelArgs.batch_size, ModelArgs.block_size, ModelArgs.embeddings_dims),
854
+ col_names=["input_size", "output_size", "num_params", "trainable"],
855
+ col_width=20,
856
+ row_settings=["var_names"])
857
+
858
+
859
+ def find_unused_parameters(model):
860
+ unused = []
861
+ for name, param in model.named_parameters():
862
+ if param.grad is None:
863
+ unused.append(name)
864
+ return unused
865
+
866
+ def greedy_decode(
867
+ model,
868
+ tokenizer,
869
+ prompt,
870
+ device,
871
+ max_length=50,
872
+ repetition_penalty=1.2,
873
+ context_window=10,
874
+ temperature=1.0,
875
+ eos_token_id=None,
876
+
877
+ ):
878
+ # model.eval()
879
+ # device = next(model.parameters()).device
880
+ input_ids = tokenizer(prompt, return_tensors="pt").to(device)['input_ids']
881
+ generated_tokens = []
882
+ eos_token_id = eos_token_id or tokenizer.eos_token_id # Use EOS token if provided
883
+
884
+ for _ in range(max_length):
885
+ with torch.no_grad():
886
+ outputs = model.module(input_ids)
887
+ logits = outputs[:, -1, :] # Get logits for the last token
888
+
889
+ # Apply temperature scaling
890
+ # if temperature != 1.0:
891
+ # logits = logits / temperature
892
+
893
+ # Apply repetition penalty
894
+ # if repetition_penalty != 1.0 and len(generated_tokens) > 0:
895
+ # for token in set(generated_tokens[-context_window:]): # Penalize recent tokens
896
+ # logits[0, token] /= repetition_penalty
897
+
898
+ # Greedy selection
899
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
900
+ generated_tokens.append(next_token.item())
901
+
902
+ # Stop if EOS token is generated
903
+ # if next_token.item() == eos_token_id:
904
+ # break
905
+
906
+ # Append the new token to the input
907
+ input_ids = torch.cat([input_ids, next_token], dim=1)
908
+
909
+ # Decode the generated tokens
910
+ return tokenizer.decode(generated_tokens, skip_special_tokens=True)
911
+
912
+
913
+
914
+ def save_to_file(text):
915
+
916
+ with open('generations.txt', 'a') as f:
917
+ f.writelines(text + "\n\n")
918
+
919
+
920
+ #Train the model
921
+
922
+
923
+ # writer = SummaryWriter(log_dir="runs/experiment")
924
+
925
+ from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
926
+
927
+ # Warmup phase for 2000 steps
928
+ def warmup_fn(step):
929
+ if step < 2000:
930
+ return step / 2000 # LR gradually increases
931
+ return 1.0
932
+
933
+
934
+ from torch.optim.lr_scheduler import LambdaLR
935
+
936
+ def trapezoidal_lr_scheduler(optimizer, max_lr, total_steps, warmup_steps, plateau_steps, decay_steps):
937
+ """
938
+ Trapezoidal learning rate scheduler:
939
+ - Increases linearly for `warmup_steps` steps.
940
+ - Remains constant for `plateau_steps` steps.
941
+ - Decreases linearly for `decay_steps` steps.
942
+ """
943
+ def lr_lambda(step):
944
+ if step < warmup_steps:
945
+ # Linear warmup
946
+ return float(step) / float(max(1, warmup_steps))
947
+ elif step < warmup_steps + plateau_steps:
948
+ # Constant plateau
949
+ return 1.0
950
+ else:
951
+ # Linear decay
952
+ decay_step = step - (warmup_steps + plateau_steps)
953
+ return max(0.0, float(decay_steps - decay_step) / float(max(1, decay_steps)))
954
+
955
+ return LambdaLR(optimizer, lr_lambda)
956
+
957
+
958
+ torch.set_float32_matmul_precision('high')
959
+
960
+ scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16'))
961
+
962
+ save_chechpoint_iter = 50
963
+ total_iters = 10000
964
+ eval_iters = 50
965
+ eval_check = 100
966
+ warmup_iters = 700
967
+ min_lr = 0.1 * ModelArgs.max_lr
968
+ lr_decay_iters = 10000
969
+ total_batch_size = 524288
970
+ micro_batch_size = ModelArgs.batch_size
971
+ gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count()))
972
+
973
+ # learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py
974
+
975
+ def get_lr(it):
976
+ # 1) linear warmup for warmup_iters steps
977
+ if it < warmup_iters:
978
+ return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1)
979
+ # 2) if it > lr_decay_iters, return min learning rate
980
+ if it > lr_decay_iters:
981
+ return min_lr
982
+ # 3) in between, use cosine decay down to min learning rate
983
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
984
+ assert 0 <= decay_ratio <= 1
985
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
986
+ return min_lr + coeff * (ModelArgs.max_lr - min_lr)
987
+
988
+
989
+ def train():
990
+ setup()
991
+ device = int(os.environ["LOCAL_RANK"])
992
+
993
+ torch.cuda.set_device(int(device))
994
+ # torch.set_default_device('cuda')
995
+ # train_dataloader = prepare_dataset(ModelArgs.batch_size)
996
+ # rank = torch.distributed.get_rank()
997
+ print(f"Start running DDP on rank {device}.")
998
+ # # create model and move it to GPU with id rank
999
+ # device_id = rank % torch.cuda.device_count()
1000
+ # CFG = ModelArgs()
1001
+
1002
+ if(device == 0):
1003
+
1004
+
1005
+
1006
+ # # Initialise run
1007
+ wandb.init(
1008
+ # entity = 'rajceo2031',
1009
+ project = 'Llama-DDP-Pretrain-10-billion-tokens',
1010
+ # config = CFG,
1011
+ # save_code = True,
1012
+ #group = 'ANN',
1013
+ #job_type = 'train'
1014
+ )
1015
+ print("wand initialized")
1016
+
1017
+ model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
1018
+
1019
+ # print(f"Model on device {device} is ready")
1020
+ print(f"Model on device {device} is ready")
1021
+
1022
+ # Wrap model with DDP after moving to GPU
1023
+ # model = DDP(model, device_ids=[device])
1024
+ # optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=1e-8)
1025
+ # # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=4000, T_mult=1, eta_min=1e-5)
1026
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(None, T_max=30000, eta_min=1e-6)
1027
+ # _load_snapshot('/kaggle/input/models/snapshot2.pt', model.module, None, None)
1028
+ optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps)
1029
+
1030
+ # model = torch.compile(model)
1031
+ model = model.to(device)
1032
+
1033
+ model = DDP(model, device_ids=[device])
1034
+
1035
+
1036
+ # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25000, eta_min=1e-6) #with the prev optim snapshot
1037
+ # new_scheduler = trapezoidal_lr_scheduler(optimizer, ModelArgs.max_lr, total_steps, warmup_steps, plateau_steps, decay_steps)
1038
+
1039
+ # warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_fn)
1040
+ # new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)
1041
+ # Cosine decay after warmup
1042
+ # new_scheduler = CosineAnnealingLR(optimizer, T_max=20000, eta_min=1e-6)
1043
+
1044
+ # Combine both schedulers
1045
+ # scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, new_scheduler], milestones=[2000])
1046
+
1047
+ # Reset learning rate to 1e-4
1048
+ # for param_group in optimizer.param_groups:
1049
+ # param_group['lr'] = ModelArgs.max_lr
1050
+ # scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2000, T_mult=1, eta_min=1e-6)
1051
+ # print("Old optimizer with new lr ready")
1052
+
1053
+
1054
+
1055
+ # optimizer = torch.optim.AdamW(params=model.parameters(), lr=ModelArgs.max_lr)
1056
+ # Create DataLoader with collate_fn
1057
+ # train_loader = DataLoader(train_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device))
1058
+ # val_loader = DataLoader(val_dataset, batch_size=ModelArgs.batch_size, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True, num_replicas=int(os.environ["WORLD_SIZE"]), rank=device))
1059
+ # print("Loader is ready")
1060
+ # print(train_loader)
1061
+ # print(next(iter(train_loader)))
1062
+
1063
+
1064
+ # for X,y in train_loader:
1065
+ # print(X.shape)
1066
+ # print(y.shape)
1067
+
1068
+ # alpaca_prompt = '''
1069
+
1070
+ # ### Instruction:
1071
+ # {instruction}
1072
+
1073
+ # ### Input:
1074
+ # {input}
1075
+
1076
+ # ### Response:
1077
+
1078
+ # '''
1079
+ # Only create progress bar for rank 0
1080
+ # eval_epoch_iterator = range(eval_iters)
1081
+ # train_epoch_iterator = range(total_iters)
1082
+ # if device == 0:
1083
+ # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training")
1084
+
1085
+ # train_epoch_iterator = range(ModelArgs.epochs)
1086
+ # if device == 0: # Ensure tqdm only runs on rank 0
1087
+ # train_epoch_iterator = tqdm(train_epoch_iterator, desc="Training Progress", position=0, leave=True)
1088
+
1089
+ # lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters)
1090
+
1091
+
1092
+
1093
+ model.eval()
1094
+ world_size = torch.cuda.device_count()
1095
+ @torch.inference_mode()
1096
+ def estimate_loss(val_loader, val_iterator, device):
1097
+ out = {}
1098
+ # train_loader = prepare_dataset('train', ModelArgs.batch_size)
1099
+
1100
+ # val_loader_iterator = iter(val_loader)
1101
+ loader = None
1102
+ epoch_loss = None
1103
+ epoch_losses = []
1104
+ # print("Starting the eval...")
1105
+ for split in ['val']:
1106
+ print(f"Starting with {split} evaluation...")
1107
+ # losses = torch.zeros(ModelArgs.val_epochs)
1108
+ # if(split == 'train'):
1109
+ # loader = train_loader
1110
+ # if(split == 'val'):
1111
+ # loader = val_loader
1112
+ for step in range(eval_check):
1113
+ try:
1114
+ batch = next(val_iterator)
1115
+ except StopIteration:
1116
+ val_loader_iterator = iter(val_loader)
1117
+ batch = next(val_loader_iterator)
1118
+
1119
+ total_loss = 0
1120
+ # loader.sampler.set_epoch(step)
1121
+ total_batches = 0
1122
+ # batch = next(val_loader_iterator)
1123
+ # for batch in loader: # Loop through DataLoader batches
1124
+ idx = batch['input_ids']
1125
+ targets = batch['labels']
1126
+ idx = idx.to(device)
1127
+ targets = targets.to(device)
1128
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
1129
+
1130
+ logits = model(idx)
1131
+ batch_size, block_size, embeddings_dims = logits.shape
1132
+ logits = logits.view(batch_size * block_size, embeddings_dims) # Flatten tokens
1133
+ targets = targets.view(batch_size * block_size)
1134
+
1135
+ loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
1136
+
1137
+ total_loss += loss.item()
1138
+ total_batches += 1
1139
+
1140
+ # Compute mean loss for this epoch
1141
+ epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0
1142
+ epoch_losses.append(epoch_loss)
1143
+
1144
+ # print(f"Epoch {epoch + 1}/{ModelArgs.val_epochs}: Loss = {epoch_loss:.4f}")
1145
+
1146
+ # Compute mean loss across all evaluation epochs
1147
+ out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0
1148
+ epoch_loss = None
1149
+ epoch_losses = []
1150
+
1151
+ model.train()
1152
+ return out
1153
+
1154
+ # model = model.to(rank)
1155
+ model.train()
1156
+ count = 0
1157
+
1158
+ train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size)
1159
+ val_loader= prepare_dataset('val', device, ModelArgs.batch_size)
1160
+ # for step in tqdm(range(total_iters)):
1161
+ # for epoch in range(ModelArgs.epochs):
1162
+ # torch.cuda.synchronize()
1163
+
1164
+ # train_dataloader.sampler.set_epoch(epoch)
1165
+
1166
+ # val_loader.sampler.set_epoch(epoch)
1167
+ print("Loaders ready both")
1168
+ epochs = ModelArgs.epochs
1169
+
1170
+ # train_step_iterator = range(len(train_dataloader))
1171
+ # if device == 0: # Only create progress bar on rank 0
1172
+ # train_step_iterator = tqdm(train_step_iterator, desc="Training Progress", position=0, leave=True)
1173
+
1174
+ # Print progress on rank 0
1175
+ train_loader_length = 0
1176
+ train_data_iterator = iter(train_dataloader)
1177
+ val_data_iterator = iter(val_loader)
1178
+ token_count = 0
1179
+ if(device == 0):
1180
+ train_loader_length = len(train_dataloader)
1181
+ # print("Total batches: ", train_loader_length)
1182
+ # print("Length of : ", len(train_dataloader))
1183
+ # print("Length of val: ", len(val_loader))
1184
+ # for step, batch in enumerate(train_dataloader):
1185
+ for step in tqdm(range(total_iters)):
1186
+ # print("Dataloader things: ", batch)
1187
+ # print("Total batches: ", len(train_dataloader))
1188
+
1189
+
1190
+ if(device == 0):
1191
+ # if(step % 100 == 0):
1192
+ # if(step == train_loader_length):
1193
+ # break
1194
+ print("Step : ", step, "/", total_iters)
1195
+ print('Total batches: ', len(train_dataloader))
1196
+ print("Total gradient accumulation steps: ", gradient_accumulation_steps)
1197
+ print("Total tokens processed: ", token_count)
1198
+
1199
+ # all_gpus_avg_train_loss = None
1200
+ # all_gpus_avg_val_loss = None
1201
+ # every once in a while evaluate the loss on train and val sets
1202
+ if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
1203
+ losses = estimate_loss( val_loader, val_data_iterator, 'cuda')
1204
+ # avg_train_loss = losses['train']
1205
+ avg_val_loss = losses['val']
1206
+ # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
1207
+ # if device == 0: # Only print on main process
1208
+ print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}")
1209
+ # print(f"[GPU {device}] | Epoch {epoch}/{ModelArgs.epochs}| |Step: {step} | Train Loss: {losses['train']:.4f}")
1210
+ # print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
1211
+ # Log training loss more frequently
1212
+ # Aggregate average loss across all GPUs
1213
+ # avg_train_loss = torch.Tensor([losses['train']]).to(device)
1214
+ avg_val_loss = torch.Tensor([losses['val']]).to(device)
1215
+ # torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
1216
+ torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
1217
+
1218
+ if device == 0:
1219
+ # all_gpus_avg_train_loss = avg_train_loss / world_size
1220
+ # print(f"All_GPUs_Train_losses: {all_gpus_avg_train_loss.item():.4f}")
1221
+ all_gpus_avg_val_loss = avg_val_loss / world_size
1222
+ print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}")
1223
+
1224
+ # if device == 0:
1225
+
1226
+ # writer.add_scalar("All_GPUs_Train_losses", all_gpus_avg_train_loss.item(), global_step=step)
1227
+ # writer.add_scalar("All_GPUs_Val_losses", all_gpus_avg_val_loss.item(), global_step=step)
1228
+ # writer.add_scalar("training_step_loss", losses['train'], global_step=step)
1229
+ # writer.add_scalar("val_step_loss", losses['val'], global_step=step)
1230
+ # writer.add_scalar("GPU", device, global_step=step)
1231
+ # writer.add_scalar("Epoch", epoch, global_step=step)
1232
+
1233
+ wandb.log({
1234
+ # "Learning Rate": optimizer.param_groups[0]['lr'],
1235
+ # "All_GPUs_Train_losses": all_gpus_avg_train_loss,
1236
+ "All_GPUs_Val_losses": all_gpus_avg_val_loss,
1237
+ # "training_step_loss": losses['train'],
1238
+ "val_step_loss": losses['val'],
1239
+ # "Step": step,
1240
+ # "Epoch": epoch
1241
+ })
1242
+
1243
+
1244
+
1245
+ #Loading a checkpoint
1246
+ # if(os.path.exists('snapshot.pt')):
1247
+ # model, optimizer = _load_snapshot(model=model, optimizer=optimizer, epoch=epoch, step=step, snapshot_path='snapshot.pt')
1248
+
1249
+ # if(step % save_chechpoint_iter == 0 and device == 0 and step != 0):
1250
+
1251
+ # _save_snapshot(epoch=epoch, model=model, optimizer=optimizer, step=step)
1252
+
1253
+ if step % save_chechpoint_iter == 0 and device == 0 and step != 0:
1254
+ print(f"Saving the model checkpoint for step: {step}")
1255
+ _save_snapshot(model, optimizer, None, None, step)
1256
+
1257
+ accumulated_loss = 0.0
1258
+
1259
+
1260
+ optimizer.zero_grad(set_to_none=True)
1261
+ for micro_step in range(gradient_accumulation_steps):
1262
+ try:
1263
+ batch = next(train_data_iterator)
1264
+ except StopIteration:
1265
+ train_data_iterator = iter(train_dataloader)
1266
+ batch = next(train_data_iterator)
1267
+ # print(batch)
1268
+ # batch = next(train_data_iterator)
1269
+ # print(batch)
1270
+ # batch = {k: v.to(self.local_rank) for k, v in batch.items()}
1271
+ idx = batch['input_ids'].to(device)
1272
+ # idx, targets = get_batch(split='train')
1273
+ # print(f"Starting the train step: {step}...")
1274
+ # for idx, targets in train_loader:
1275
+ # idx, targets = next(iter(train_loader))
1276
+
1277
+ # print("Idx: ", idx)
1278
+ # print("Targets: ", targets)
1279
+
1280
+ # idx = idx.to(device)
1281
+ # print("Idx: ", idx)
1282
+ # print("Targets: ", targets)
1283
+ targets = batch['labels'].to(device)
1284
+ token_count += len(idx)
1285
+ with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16):
1286
+ logits = model(idx)
1287
+ batch_size, block_size, embeddings_dims = logits.shape
1288
+ # print(logits.shape)
1289
+ # print(targets)
1290
+ logits = logits.view(batch_size*block_size, embeddings_dims)
1291
+ # print("OK")
1292
+ targets = targets.view(batch_size * block_size)
1293
+ # print("OK2")
1294
+ loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
1295
+
1296
+ loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch
1297
+ accumulated_loss += loss.detach()
1298
+
1299
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices
1300
+ scaler.scale(loss).backward()
1301
+ # Check for unused parameters
1302
+ unused_params = find_unused_parameters(model)
1303
+ if unused_params:
1304
+ print(f"Unused parameters: {unused_params}")
1305
+ # break
1306
+
1307
+ if(device == 0):
1308
+ if(micro_step % 10 == 0):
1309
+ # if(step == train_loader_length):
1310
+ # break
1311
+
1312
+ print("Micro Batch : ", micro_step)
1313
+ print("Step : ", step, "/", total_iters)
1314
+ print('Total batches: ', len(train_dataloader))
1315
+ print("Total gradient accumulation steps: ", gradient_accumulation_steps)
1316
+ print("Total tokens processed: ", token_count)
1317
+ # count += 1
1318
+
1319
+ lr = get_lr(step)
1320
+ for params in optimizer.param_groups:
1321
+ params['lr'] = lr
1322
+
1323
+
1324
+
1325
+ # Compute gradient norms before clipping
1326
+ if(ModelArgs.clip != 0.0):
1327
+ scaler.unscale_(optimizer) #To avoid underflow
1328
+ total_norm_before = torch.norm(
1329
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
1330
+ )
1331
+
1332
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip)
1333
+
1334
+ # Compute gradient norms after clipping
1335
+ total_norm_after = torch.norm(
1336
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
1337
+ )
1338
+
1339
+ if(device == 0 and step !=0):
1340
+ print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
1341
+ print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")
1342
+
1343
+ scaler.step(optimizer)
1344
+ scaler.update()
1345
+
1346
+ # optimizer.step()
1347
+ # new_scheduler.step()
1348
+
1349
+ torch.cuda.synchronize()
1350
+ torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM)
1351
+ if(device == 0):
1352
+ wandb.log({
1353
+ "Learning Rate": lr,
1354
+ "All_GPUs_Train_losses": accumulated_loss.item(),
1355
+ # "All_GPUs_Val_losses": all_gpus_avg_val_loss,
1356
+ # "training_step_loss": losses['train'],
1357
+ # "val_step_loss": losses['val'],
1358
+ "Step": step,
1359
+ # "Epoch": epoch
1360
+
1361
+ })
1362
+ # print(loss.item())
1363
+ # if(step % 100 == 0):
1364
+ # print(f'Step : {step} | GPU: {device} Loss: {loss.item()}')
1365
+ # if device == 0:
1366
+ # print("loss: ", loss.item())
1367
+ # train_epoch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
1368
+ # print(loss.item())
1369
+ # break
1370
+
1371
+ # if step != 0 and (step % eval_iters == 0 or step == total_steps -1) :
1372
+ # loss_values = estimate_loss()
1373
+ # print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val']))
1374
+
1375
+ # Add after a training step:
1376
+ # unused_params = find_unused_parameters(model)
1377
+ # print("Unused parameters:", unused_params)
1378
+ # break
1379
+ if device == 0 and step % 5 == 0:
1380
+ count = 3
1381
+ while(count): # Only generate text on the main process
1382
+ # print("Generating text...")
1383
+
1384
+ # alpaca_prompt = '''
1385
+
1386
+ # ### Instruction:
1387
+ # {}
1388
+
1389
+ # ### Input:
1390
+ # {}
1391
+
1392
+ # ### Response:
1393
+
1394
+ # '''
1395
+
1396
+ # prompt = alpaca_prompt.format("You are a helpful assistant.", "Say a joke.", "")
1397
+ # print("Generating text")
1398
+ prompt = "Once upon a time"
1399
+ generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device)
1400
+
1401
+ # generated_text = greedy_decode(
1402
+ # model,
1403
+ # tokenizer,
1404
+ # "Once upon a time",
1405
+ # max_length=40,
1406
+ # repetition_penalty=1.2,
1407
+ # context_window=10,
1408
+ # temperature=0.7, # Lower temperature for more deterministic output
1409
+ # device=device
1410
+ # )
1411
+ # generated_text = beam_search(model, tokenizer, "Once upon a time ", beam_width=5, max_length=50, temperature=0.6)
1412
+ print(f" Step: {step} | Generated Text: {generated_text}")
1413
+ # model.train()
1414
+ # save_to_file(generated_text)
1415
+ count -= 1
1416
+
1417
+ # if step != 0:
1418
+ # train_step_iterator.set_postfix({"Train loss": f"{all_gpus_avg_train_loss.item():.4f} | Val Loss : {all_gpus_avg_val_loss.item():.4f}"})
1419
+
1420
+
1421
+ # break
1422
+ # Cleanup
1423
+ if device == 0:
1424
+ # writer.close()
1425
+ wandb.finish()
1426
+ cleanup()
1427
+
1428
+
1429
+ world_size = torch.cuda.device_count()
1430
+ print(f"World size: {world_size}")
1431
+ train()
1432
+
1433
+
1434
+
1435
+
metric.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import evaluate
3
+
4
+ from config import ModelArgs
5
+ from model import Llama
6
+
7
+ import evaluate
8
+
9
+ # Load the perplexity metric
10
+ perplexity = evaluate.load("perplexity")
11
+
12
+
13
+ def compute_perplexity(model_name, text):
14
+
15
+
16
+
17
+ results = perplexity.compute(predictions=[text], model_id=model_name)
18
+
19
+ return results["perplexities"][0]
20
+
21
+ # Example Usage
22
+ llama = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
23
+ llama = llama.to(ModelArgs.device)
24
+
25
+ text = "This is an example sentence for perplexity calculation."
26
+
27
+ ppl = compute_perplexity(llama, text)
28
+ print(f"Perplexity: {ppl}")
model.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from config import ModelArgs
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Normalization(nn.Module):
9
+ def __init__(
10
+ self,
11
+
12
+ embeddings_dims: int = ModelArgs.embeddings_dims
13
+ ):
14
+ super().__init__()
15
+ self.rmsnorm_layer = torch.nn.RMSNorm(normalized_shape=embeddings_dims)
16
+
17
+
18
+ def forward(self, x):
19
+
20
+ x = self.rmsnorm_layer(x)
21
+ return x
22
+
23
+
24
+
25
+
26
+
27
+ # import numpy as np
28
+ class RotaryEmbeddings(nn.Module):
29
+ def __init__(
30
+ self,
31
+ device,
32
+ embeddings_dims: int = ModelArgs.embeddings_dims,
33
+ block_size: int = ModelArgs.block_size,
34
+ batch_size: int = ModelArgs.batch_size
35
+ ):
36
+ super().__init__()
37
+
38
+ self.embeddings_dims = embeddings_dims
39
+ self.block_size = block_size
40
+ self.batch_size = batch_size
41
+ self.theta = 0
42
+ self.device=device
43
+
44
+ # self.d_model = embeddings_dims
45
+ # self.i = torch.arange(0, embeddings_dims, dtype=torch.float32)
46
+ # # self.pos = torch.arange(0, block_size, dtype=torch.float32)
47
+ # self.exp = ((2 * self.i)) / self.d_model
48
+ # self.theta = 10000 ** self.exp
49
+ # # print(self.theta.shape)
50
+ # self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims,dtype=torch.float32, device=device)
51
+
52
+ # self.cos = torch.cos((self.i / self.theta))
53
+ # self.sin = torch.sin((self.i / self.theta))
54
+
55
+ # self.even = self.sin[::2]
56
+ # self.odd = self.cos[1::2]
57
+
58
+ # # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype)
59
+ # self.x_reshaped[..., : , ::2] = self.even
60
+ # self.x_reshaped[..., : , 1::2] = self.odd
61
+
62
+
63
+ def apply_rope(self, seq):
64
+ batch_size, seq_len, embeds_dims = seq.shape
65
+ # print(seq.shape)
66
+ # print(self.embeddings_dims)
67
+ # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
68
+
69
+ positions = torch.arange(0 , embeds_dims, 2, dtype=torch.float32, device = self.device).unsqueeze(0)
70
+ # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
71
+ theta = 10000 ** (-2 * (positions) / embeds_dims)
72
+ angles = positions * theta
73
+ angles = angles.expand(seq_len, -1) # because this thing needs to be applied to every sequence in the batch but with embeds dims halved
74
+ x_reshaped = seq.view(batch_size, seq_len, embeds_dims // 2, 2)
75
+
76
+ cos_angles = torch.cos(angles)
77
+ sin_angles = torch.sin(angles)
78
+ # print(cos_angles.shape)
79
+ # print(sin_angles.shape)
80
+ # print(x_reshaped.shape)
81
+ # indices = torch.arange(self.embeddings_dims, dtype=torch.int64, device = self.device)
82
+
83
+ out = torch.stack([x_reshaped[..., 0]*cos_angles - (x_reshaped[...,1] * sin_angles), x_reshaped[...,1] * cos_angles + x_reshaped[..., 0] * sin_angles], dim=-1)
84
+ out = out.view(batch_size, seq_len, embeds_dims)
85
+ return out
86
+
87
+ def forward(self, x):
88
+ # print("X shape: ", x.shape)
89
+ # print("X is: ", x)
90
+ # B,T,C = x.shape
91
+ # print("MATRIX:",x)
92
+ # if(x > self.block_size or x < self.block_size):
93
+ # matrix = self.init_matrix(x)
94
+ # return matrix
95
+ # else:
96
+ # matrix = self.init_matrix(self.block_size)
97
+
98
+ # return matrix
99
+ # if(ModelArgs.inference):
100
+ res = self.apply_rope(x)
101
+ return res
102
+ # else:
103
+ # return self.x_reshaped
104
+
105
+ class RotaryAttentionHead(nn.Module):
106
+ def __init__(
107
+ self,
108
+ device,
109
+ embeddings_dims: int = ModelArgs.embeddings_dims,
110
+ no_of_heads: int = ModelArgs.no_of_heads,
111
+ attn_dropout: int = ModelArgs.attn_dropout
112
+ ):
113
+ super().__init__()
114
+ self.head_size = embeddings_dims // no_of_heads
115
+ self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
116
+ self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
117
+ self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
118
+ self.rope = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
119
+ self.dropout = nn.Dropout(p = attn_dropout)
120
+ self.device = device
121
+ def forward(self,x):
122
+ # print(x.shape)
123
+ # print("X is: ", x)
124
+ batch, block_size, embeddings_dims = x.shape
125
+ query = self.query(x)
126
+ # print(query)
127
+ key = self.key(x)
128
+ values = self.value(x)
129
+ # matrix = self.rotary_matrix(block_size)
130
+ rotary_q = self.rope(query)
131
+ rotary_k = self.rope(key)
132
+
133
+ # print(matrix.shape)
134
+ # print(query.shape)
135
+ masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
136
+ # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
137
+ # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
138
+ weights = rotary_q.permute(2,0,1) @ rotary_k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
139
+ weights_masked = weights.masked_fill(masked == 0, float('-inf'))
140
+ scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
141
+ scaled_weights = F.softmax(scaled_weights, dim=-1)
142
+ value = scaled_weights @ values
143
+ out = self.dropout(value)
144
+ return out
145
+
146
+
147
+ # # import numpy as np
148
+ # class RotaryEmbeddings(nn.Module):
149
+ # def __init__(
150
+ # self,
151
+ # device,
152
+ # embeddings_dims: int = ModelArgs.embeddings_dims,
153
+ # block_size: int = ModelArgs.block_size,
154
+ # batch_size: int = ModelArgs.batch_size
155
+ # ):
156
+ # super().__init__()
157
+
158
+ # self.embeddings_dims = embeddings_dims
159
+ # self.block_size = block_size
160
+ # self.batch_size = batch_size
161
+ # self.theta = 0
162
+
163
+
164
+ # # def init_matrix(self, seq_len):
165
+ # # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False)
166
+ # # for pos in range(seq_len):
167
+ # # for j in range(1, self.embeddings_dims // 2):
168
+ # # self.theta = 10000 ** (-2*(pos-1) / self.embeddings_dims)
169
+ # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.cos((pos*self.theta))
170
+ # # self.matrix[pos, 2*j + 1, j + 1] = -np.sin((pos* self.theta))
171
+ # # self.matrix[pos, 2*j , 2*j ] = -np.cos((pos* self.theta))
172
+ # # self.matrix[pos, 2*j + 1, 2*j + 1] = np.sin((pos* self.theta))
173
+ # # return self.matrix
174
+ # self.device=device
175
+
176
+ # def init_matrix(self, seq_len):
177
+ # self.matrix = torch.zeros((seq_len, self.embeddings_dims, self.embeddings_dims), dtype=torch.float32, requires_grad=False, device = self.device)
178
+
179
+ # positions = torch.arange(0 , seq_len, 2, dtype=torch.float32, device = self.device).unsqueeze(1)
180
+ # # dims = torch.arange(1, self.embeddings_dims // 2, dtype=torch.float32)
181
+ # theta = 10000 ** (-2 * (positions - 1) / self.embeddings_dims)
182
+ # angles = positions * theta
183
+
184
+ # cos_angles = torch.cos(angles)
185
+ # sin_angles = torch.sin(angles)
186
+
187
+ # indices = torch.arange(seq_len, dtype=torch.int64, device = self.device)
188
+ # # print(indices)
189
+ # # print(indices.shape)
190
+ # # print(indices[::2])
191
+ # even_indices = indices[::2]
192
+ # odd_indices = indices[1::2]
193
+
194
+ # self.matrix[:, even_indices, even_indices] = cos_angles
195
+ # self.matrix[:, odd_indices, odd_indices] = sin_angles
196
+ # self.matrix[:, odd_indices, even_indices] = -sin_angles
197
+ # self.matrix[:, even_indices, odd_indices] = cos_angles
198
+
199
+ # return self.matrix
200
+
201
+ # def forward(self, x):
202
+ # # B,T,C = x.shape
203
+ # # print("MATRIX:",x)
204
+ # if(x > self.block_size or x < self.block_size):
205
+ # matrix = self.init_matrix(x)
206
+ # return matrix
207
+ # else:
208
+ # matrix = self.init_matrix(self.block_size)
209
+
210
+ # return matrix
211
+
212
+
213
+ # class RotaryAttentionHead(nn.Module):
214
+ # def __init__(
215
+ # self,
216
+ # device,
217
+ # embeddings_dims: int = ModelArgs.embeddings_dims,
218
+ # no_of_heads: int = ModelArgs.no_of_heads,
219
+ # attn_dropout: int = ModelArgs.attn_dropout
220
+ # ):
221
+ # super().__init__()
222
+ # self.head_size = embeddings_dims // no_of_heads
223
+ # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
224
+ # self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
225
+ # self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, dtype=torch.float32, device = device)
226
+ # self.rotary_matrix = RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
227
+ # self.dropout = nn.Dropout(p = attn_dropout)
228
+ # self.device = device
229
+ # def forward(self,x):
230
+ # # print(x.shape)
231
+ # batch, block_size, embeddings_dims = x.shape
232
+ # query = self.query(x)
233
+ # # print(query)
234
+ # key = self.key(x)
235
+ # values = self.value(x)
236
+ # matrix = self.rotary_matrix(block_size)
237
+
238
+ # # print(matrix.shape)
239
+ # # print(query.shape)
240
+ # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
241
+ # rotary_query = matrix @ query.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
242
+ # rotary_key = matrix @ key.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
243
+ # weights = rotary_query.permute(2,0,1) @ rotary_key.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
244
+ # weights_masked = weights.masked_fill(masked == 0, float('-inf'))
245
+ # scaled_weights = weights_masked / (torch.sqrt(torch.tensor(key.shape[-1])))
246
+ # scaled_weights = F.softmax(scaled_weights, dim=-1)
247
+ # value = scaled_weights @ values
248
+ # out = self.dropout(value)
249
+ # return out
250
+
251
+
252
+ class MQA(nn.Module):
253
+ def __init__(
254
+ self,
255
+ device,
256
+ no_of_q_heads: int,
257
+ embeddings_dims: int = ModelArgs.embeddings_dims,
258
+ block_size: int = ModelArgs.block_size,
259
+
260
+
261
+ ):
262
+ super().__init__()
263
+
264
+
265
+ # self.no_of_q_heads = no_of_heads // no_of_kv_heads
266
+ # self.no_of_q_heads = no_of_q_heads
267
+ self.no_of_kv_heads = 2 # I want to have a kv for each pair of query heads
268
+ self.head_size = embeddings_dims // no_of_q_heads
269
+ # self.kv_head_size = (embeddings_dims // self.no_of_kv_heads) * 2
270
+ self.rotary= RotaryEmbeddings(embeddings_dims=self.head_size, device = device)
271
+ # self.rotary_k = RotaryEmbeddings(embeddings_dims=self.kv_head_size, device = device)
272
+ # self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False)
273
+ self.key = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
274
+ self.value = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, dtype=torch.float32, bias=False, device = device)
275
+ self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
276
+ self.linear_layer = nn.Linear(in_features=self.head_size * self.no_of_kv_heads, out_features=embeddings_dims, dtype=torch.float32, bias=False, device = device)
277
+ self.device = device
278
+ self.multi_query = nn.ModuleList([nn.Linear(in_features=embeddings_dims, out_features=self.head_size, bias=False, device = self.device) for _ in range(self.no_of_kv_heads)])
279
+
280
+ def scaled_dot_product(self, q, k, v, block_size):
281
+
282
+ # masked = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
283
+ q = self.rotary(q)
284
+ masked_table = torch.tril(torch.ones((block_size, block_size), requires_grad=False, device = self.device))
285
+ # rotary_query = matrix @ q.permute(1,2,0) # (B,T, C,C) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
286
+ # rotary_key = matrix @ k.permute(1,2,0) # (B,T, C,C ) @ (B,T,C) -> (B,C,T) = (B,T,C,T)
287
+ # print("Query: ", q.shape)
288
+ # print("Keys: ", k.shape)
289
+ # print(q.permute(2,0,1).shape)
290
+ # print(k.permute(2,0,1).transpose(-2, -1).shape)
291
+ # weights = q.permute(2,0,1) @ k.permute(2,0,1).transpose(-2, -1)#(B,T,C,T) @ (B,T,C,T) = (T,C,C,T)
292
+ # weights = q @ k.permute(2,1,0)
293
+ # print(weights.shape)
294
+ # print(masked.shape)
295
+ weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
296
+ masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
297
+ weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
298
+ weights_normalized = self.dropout(weights_normalized)
299
+ out = weights_normalized @ v
300
+ return out
301
+
302
+ def forward(self,x):
303
+ # print("MQA: ", x.shape)
304
+ batch, block_size, embeddings_dims = x.shape
305
+
306
+ # query = self.query(x)
307
+ # matrix = self.rotary_matrix(block_size)
308
+
309
+
310
+ key = self.key(x)
311
+ values = self.value(x)
312
+ # print("Keys: ", key.shape)
313
+ # print("Values: ", values.shape)
314
+ # rotary_value = self.rotary(values)
315
+ rotary_key = self.rotary(key)
316
+ multi_query_concat = torch.cat([self.scaled_dot_product(query(x), rotary_key, values, block_size) for query in self.multi_query], dim=-1)
317
+ # print("Multi query: ", multi_query_concat.shape)
318
+
319
+ linear_layer= self.linear_layer(multi_query_concat)
320
+ # out = self.dropout(linear_layer)
321
+ return linear_layer
322
+
323
+
324
+ class GQA(nn.Module):
325
+ def __init__(
326
+ self,
327
+ device,
328
+ embeddings_dims: int = ModelArgs.embeddings_dims,
329
+ block_size: int = ModelArgs.block_size,
330
+ # no_of_q_heads: int = ModelArgs.no_of_heads,
331
+ mqa_heads: int = ModelArgs.no_kv_heads
332
+ ):
333
+ super().__init__()
334
+
335
+ # self.no_of_kv_heads = no_of_kv_heads
336
+ self.no_of_q_heads = ModelArgs.no_of_heads // mqa_heads
337
+ # self.head_dim = embeddings_dims // self.no_kv_heads
338
+ self.dropout = nn.Dropout(p = ModelArgs.attn_dropout)
339
+ self.linear_layer = nn.Linear(in_features=embeddings_dims * self.no_of_q_heads, out_features=embeddings_dims , dtype=torch.float32, bias=False, device = device)
340
+ self.device = device
341
+ self.mqa = nn.ModuleList([MQA(no_of_q_heads=self.no_of_q_heads, embeddings_dims=embeddings_dims, device = self.device, block_size=block_size) for _ in range(self.no_of_q_heads)])
342
+ # self.mqa = MQA(no_of_q_heads=self.no_of_q_heads, device=self.device, embeddings_dims=embeddings_dims, block_size=block_size)
343
+ def forward(self,x):
344
+
345
+ batch, block_size, embeddings_dims = x.shape
346
+
347
+ # res = self.mqa(x)
348
+ grouped_query_concat = torch.cat([group(x) for group in self.mqa], dim=-1)
349
+
350
+ linear_layer= self.linear_layer(grouped_query_concat) #Basically MQA is made into GQA with no_of_q_heads and this class right here is just to consolidate everything into one
351
+ out = self.dropout(linear_layer)
352
+ return out
353
+
354
+
355
+ class Swish(nn.Module):
356
+ def __init__(
357
+ self,
358
+ device,
359
+ block_size: int = ModelArgs.block_size,
360
+ embeddings_dims: int = ModelArgs.embeddings_dims
361
+ ):
362
+ super().__init__()
363
+
364
+ self.sig = torch.nn.Sigmoid()
365
+
366
+
367
+ def forward(self, x):
368
+ swish = x * self.sig(x)
369
+
370
+ return swish
371
+
372
+
373
+
374
+ class SWiGLU(nn.Module):
375
+ def __init__(
376
+ self,
377
+ device,
378
+ block_size: int = ModelArgs.block_size,
379
+ embeddings_dims: int = ModelArgs.embeddings_dims
380
+ ):
381
+ super().__init__()
382
+ self.hidden_dims = int(2 * ( 4 * embeddings_dims) / 3)
383
+ self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims, device=device)
384
+ self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
385
+ self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=self.hidden_dims, bias=False, dtype=torch.float32, device = device)
386
+ self.linear_layer3 = nn.Linear(in_features=self.hidden_dims, out_features=embeddings_dims, bias=False, dtype=torch.float32, device = device)
387
+
388
+
389
+
390
+
391
+ def forward(self, x):
392
+ swish_res = self.swish(self.linear_layer1(x))
393
+ x_V = self.linear_layer2(x)
394
+ res = torch.mul(swish_res, x_V)
395
+ out = self.linear_layer3(res)
396
+ return out
397
+
398
+
399
+
400
+ class FFN(nn.Module):
401
+ def __init__(self,
402
+ device,
403
+ embeddings_dims: int = ModelArgs.embeddings_dims,
404
+ block_size: int = ModelArgs.block_size,
405
+ vocab_size: int = ModelArgs.vocab_size,
406
+ dropout = ModelArgs.dropout
407
+
408
+ ):
409
+ super().__init__()
410
+
411
+ # self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, dtype=torch.float32, device = device)
412
+ self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims, device = device)
413
+ self.dropout = nn.Dropout(p = dropout)
414
+ def forward(self, x):
415
+
416
+ x = self.swiglue(x)
417
+ # x = self.linear_layer(x)
418
+ x = self.dropout(x)
419
+ return x
420
+
421
+
422
+ class DecoderLayer(nn.Module):
423
+ def __init__(self,
424
+ device,
425
+ embeddings_dims: int = ModelArgs.embeddings_dims,
426
+ dropout = ModelArgs.dropout,
427
+ block_size: int = ModelArgs.block_size,
428
+ vocab_size: int = ModelArgs.vocab_size,
429
+
430
+ ) :
431
+ super().__init__()
432
+
433
+
434
+ self.feedforward_network = FFN(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, device = device)
435
+ self.gqa = GQA(embeddings_dims=embeddings_dims, block_size=block_size, mqa_heads=2, device = device)
436
+ # self.norm = Normalization(embeddings_dims=embeddings_dims)
437
+ self.norm1 = Normalization(embeddings_dims=embeddings_dims)
438
+ self.norm2 = Normalization(embeddings_dims=embeddings_dims)
439
+ self.dropout = nn.Dropout(p = dropout)
440
+ def forward(self, x):
441
+
442
+ x = x + self.gqa(self.norm1(x))
443
+ x = x + self.feedforward_network(self.norm2(x))
444
+ return x
445
+
446
+
447
+ class Llama(nn.Module):
448
+ def __init__(self,
449
+ device,
450
+ embeddings_dims: int = ModelArgs.embeddings_dims,
451
+ no_of_decoder_layers: int = ModelArgs.no_of_decoder_layers,
452
+ block_size: int = ModelArgs.block_size,
453
+ vocab_size: int = ModelArgs.vocab_size,
454
+ dropout = ModelArgs.dropout
455
+
456
+ ) :
457
+ super().__init__()
458
+
459
+ self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embeddings_dims, dtype=torch.float32, device = device)
460
+ self.decoder = nn.Sequential(*[DecoderLayer(embeddings_dims=embeddings_dims, block_size=block_size, vocab_size=vocab_size, dropout=dropout, device = device) for _ in range(no_of_decoder_layers)])
461
+ self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, dtype=torch.float32, device = device)
462
+ self.dropout = nn.Dropout(p = dropout)
463
+ # self.norm = Normalization(embeddings_dims)
464
+
465
+
466
+ #weight tying
467
+ self.embeddings.weight = self.linear_layer.weight
468
+
469
+ self.apply(self._init_weights)
470
+
471
+ def _init_weights(self, module):
472
+ if isinstance(module, nn.Linear):
473
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
474
+
475
+ if module.bias is not None:
476
+ nn.init.zeros_(module.bias)
477
+ elif isinstance(module, nn.Embedding):
478
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
479
+
480
+
481
+
482
+ def forward(self, x):
483
+ x = self.embeddings(x)
484
+ x = self.dropout(x)
485
+ x = self.decoder(x)
486
+ # x = self.norm(x)
487
+ x = self.linear_layer(x)
488
+ # out = self.norm(x)
489
+ return x
tokenizer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import AutoTokenizer
3
+ import os
4
+
5
+
6
+ class Tokenizer:
7
+
8
+ def __init__(self) -> None:
9
+
10
+ self.tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", hf_token = '...')
11
+
12
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
13
+
14
+ def ready_tokenizer(self):
15
+
16
+ return self.tokenizer
17
+
18
+
19
+
20
+
21
+
trainer.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+ from torch.distributed import init_process_group, destroy_process_group
8
+ import torch
9
+
10
+ import wandb
11
+
12
+
13
+ import torch.optim as optim
14
+
15
+
16
+ import os
17
+ from config import ModelArgs
18
+ from model import Llama
19
+
20
+ from inference import greedy_decode
21
+ from data import prepare_dataset
22
+ from tokenizer import Tokenizer
23
+
24
+
25
+ torch.set_float32_matmul_precision('high')
26
+
27
+ scaler = torch.amp.GradScaler(enabled=(ModelArgs.dtype == 'float16'))
28
+
29
+
30
+
31
+ save_chechpoint_iter = 50
32
+ total_iters = 10000
33
+ eval_iters = 50
34
+ eval_check = 100
35
+ warmup_iters = 700
36
+ min_lr = 0.1 * ModelArgs.max_lr
37
+ lr_decay_iters = 10000
38
+ total_batch_size = 524288
39
+ micro_batch_size = ModelArgs.batch_size
40
+ gradient_accumulation_steps = total_batch_size // (micro_batch_size * (ModelArgs.block_size * torch.cuda.device_count()))
41
+
42
+
43
+
44
+ class Trainer:
45
+
46
+ def __init__(self, model_args):
47
+
48
+
49
+ def setup(rank=None, world_size=None):
50
+ # os.environ['MASTER_ADDR'] = 'localhost'
51
+ # os.environ['MASTER_PORT'] = '12355'
52
+ init_process_group("nccl")
53
+ # torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
54
+
55
+ self.model_args = model_args
56
+ self.tokenizer = Tokenizer().ready_tokenizer()
57
+ setup()
58
+
59
+ def cleanup(self):
60
+ destroy_process_group()
61
+
62
+ def _save_snapshot(self, model, optimizer, epoch, step, save_dir):
63
+ snapshot = {}
64
+ snapshot["MODEL_STATE"] = model.module.state_dict()
65
+ snapshot["OPTIMIZER_STATE"]= optimizer.state_dict()
66
+ snapshot["EPOCHS_RUN"] = epoch
67
+ snapshot["STEP_RUN"] = step
68
+ torch.save(snapshot, os.path.join(save_dir, "snapshot.pt"))
69
+ print(f"Epoch: {epoch} | step {step} | Training snapshot saved at snapshot.pt")
70
+
71
+ # Warmup phase for 2000 steps
72
+ def warmup_fn(step):
73
+ if step < 2000:
74
+ return step / 2000 # LR gradually increases
75
+ return 1.0
76
+
77
+
78
+ # learning rate decay scheduler (cosine with warmup) from https://github.com/karpathy/nanoGPT/blob/master/train.py
79
+
80
+ def get_lr(it):
81
+ # 1) linear warmup for warmup_iters steps
82
+ if it < warmup_iters:
83
+ return ModelArgs.max_lr * (it + 1) / (warmup_iters + 1)
84
+ # 2) if it > lr_decay_iters, return min learning rate
85
+ if it > lr_decay_iters:
86
+ return min_lr
87
+ # 3) in between, use cosine decay down to min learning rate
88
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
89
+ assert 0 <= decay_ratio <= 1
90
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
91
+ return min_lr + coeff * (ModelArgs.max_lr - min_lr)
92
+
93
+
94
+ def train():
95
+
96
+ setup()
97
+ device = int(os.environ["LOCAL_RANK"])
98
+
99
+ torch.cuda.set_device(int(device))
100
+
101
+ print(f"Start running DDP on rank {device}.")
102
+
103
+ if(device == 0):
104
+
105
+
106
+
107
+ # # Initialise run
108
+ wandb.init(
109
+ # entity = 'rajceo2031',
110
+ project = 'Llama-DDP-Pretrain-10-billion-tokens',
111
+ # config = CFG,
112
+ # save_code = True,
113
+ #group = 'ANN',
114
+ #job_type = 'train'
115
+ )
116
+ print("wand initialized")
117
+
118
+ model = Llama(embeddings_dims=ModelArgs.embeddings_dims, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout, device=device)
119
+
120
+ # print(f"Model on device {device} is ready")
121
+ print(f"Model on device {device} is ready")
122
+
123
+
124
+ optimizer = optim.AdamW(model.parameters(), lr=ModelArgs.max_lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2), weight_decay=ModelArgs.weight_decay_optim, eps=ModelArgs.eps)
125
+
126
+ # model = torch.compile(model)
127
+ model = model.to(device)
128
+
129
+ model = DDP(model, device_ids=[device])
130
+
131
+
132
+
133
+
134
+
135
+ model.eval()
136
+ world_size = torch.cuda.device_count()
137
+ @torch.inference_mode()
138
+ def estimate_loss(val_loader, val_iterator, device):
139
+ out = {}
140
+
141
+ loader = None
142
+ epoch_loss = None
143
+ epoch_losses = []
144
+
145
+ for split in ['val']:
146
+ print(f"Starting with {split} evaluation...")
147
+
148
+ for step in range(eval_check):
149
+ try:
150
+ batch = next(val_iterator)
151
+ except StopIteration:
152
+ val_loader_iterator = iter(val_loader)
153
+ batch = next(val_loader_iterator)
154
+
155
+ total_loss = 0
156
+
157
+ total_batches = 0
158
+
159
+ idx = batch['input_ids']
160
+ targets = batch['labels']
161
+ idx = idx.to(device)
162
+ targets = targets.to(device)
163
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
164
+
165
+ logits = model(idx)
166
+ batch_size, block_size, embeddings_dims = logits.shape
167
+ logits = logits.view(batch_size * block_size, embeddings_dims)
168
+ targets = targets.view(batch_size * block_size)
169
+
170
+ loss = F.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
171
+
172
+ total_loss += loss.item()
173
+ total_batches += 1
174
+
175
+
176
+ epoch_loss = total_loss / total_batches if total_batches > 0 else 0.0
177
+ epoch_losses.append(epoch_loss)
178
+
179
+
180
+ out[split] = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0
181
+ epoch_loss = None
182
+ epoch_losses = []
183
+
184
+ model.train()
185
+ return out
186
+
187
+
188
+ model.train()
189
+ count = 0
190
+
191
+ train_dataloader = prepare_dataset('train', device, ModelArgs.batch_size)
192
+ val_loader= prepare_dataset('val', device, ModelArgs.batch_size)
193
+
194
+ print("Loaders ready both")
195
+ epochs = ModelArgs.epochs
196
+
197
+ train_loader_length = 0
198
+ train_data_iterator = iter(train_dataloader)
199
+ val_data_iterator = iter(val_loader)
200
+ token_count = 0
201
+ if(device == 0):
202
+ train_loader_length = len(train_dataloader)
203
+
204
+ for step in tqdm(range(total_iters)):
205
+
206
+
207
+ if(device == 0):
208
+
209
+ print("Step : ", step, "/", total_iters)
210
+ print('Total batches: ', len(train_dataloader))
211
+ print("Total gradient accumulation steps: ", gradient_accumulation_steps)
212
+ print("Total tokens processed: ", token_count)
213
+
214
+
215
+ if (step % eval_iters == 0 and step != 0) or step == total_iters - 1:
216
+ losses = estimate_loss( val_loader, val_data_iterator, 'cuda')
217
+ # avg_train_loss = losses['train']
218
+ avg_val_loss = losses['val']
219
+
220
+ print(f"[GPU {device}] | Step: {step} / {total_iters} | Val Loss: {losses['val']:.4f}")
221
+
222
+ avg_val_loss = torch.Tensor([losses['val']]).to(device)
223
+ # torch.distributed.reduce(avg_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
224
+ torch.distributed.reduce(avg_val_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
225
+
226
+ if device == 0:
227
+
228
+ all_gpus_avg_val_loss = avg_val_loss / world_size
229
+ print(f"All_GPUs_Val_losses: {all_gpus_avg_val_loss.item():.4f}")
230
+
231
+ wandb.log({
232
+ # "Learning Rate": optimizer.param_groups[0]['lr'],
233
+ # "All_GPUs_Train_losses": all_gpus_avg_train_loss,
234
+ "All_GPUs_Val_losses": all_gpus_avg_val_loss,
235
+ # "training_step_loss": losses['train'],
236
+ "val_step_loss": losses['val'],
237
+ # "Step": step,
238
+ # "Epoch": epoch
239
+ })
240
+
241
+
242
+
243
+
244
+
245
+ if step % save_chechpoint_iter == 0 and device == 0 and step != 0:
246
+ print(f"Saving the model checkpoint for step: {step}")
247
+ _save_snapshot(model, optimizer, None, None, step)
248
+
249
+ accumulated_loss = 0.0
250
+
251
+
252
+ optimizer.zero_grad(set_to_none=True)
253
+ for micro_step in range(gradient_accumulation_steps):
254
+ try:
255
+ batch = next(train_data_iterator)
256
+ except StopIteration:
257
+ train_data_iterator = iter(train_dataloader)
258
+ batch = next(train_data_iterator)
259
+ # print(batch)
260
+ # batch = next(train_data_iterator)
261
+ # print(batch)
262
+ # batch = {k: v.to(self.local_rank) for k, v in batch.items()}
263
+ idx = batch['input_ids'].to(device)
264
+ # idx, targets = get_batch(split='train')
265
+ # print(f"Starting the train step: {step}...")
266
+ # for idx, targets in train_loader:
267
+ # idx, targets = next(iter(train_loader))
268
+
269
+ # print("Idx: ", idx)
270
+ # print("Targets: ", targets)
271
+
272
+ # idx = idx.to(device)
273
+ # print("Idx: ", idx)
274
+ # print("Targets: ", targets)
275
+ targets = batch['labels'].to(device)
276
+ token_count += len(idx)
277
+ with torch.autocast(device_type=ModelArgs.device, dtype=torch.bfloat16):
278
+ logits = model(idx)
279
+ batch_size, block_size, embeddings_dims = logits.shape
280
+ # print(logits.shape)
281
+ # print(targets)
282
+ logits = logits.view(batch_size*block_size, embeddings_dims)
283
+ # print("OK")
284
+ targets = targets.view(batch_size * block_size)
285
+ # print("OK2")
286
+ loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
287
+
288
+ loss = loss / gradient_accumulation_steps #IDK why div is done here specifically? Maybe think of it in terms of a very big batch being processed and there is need for equal important of each mini batch for the overall big batch
289
+ accumulated_loss += loss.detach()
290
+
291
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) # so that we dont synchronize the gradient everytime across the GPU devices
292
+ scaler.scale(loss).backward()
293
+ # Check for unused parameters
294
+ unused_params = find_unused_parameters(model)
295
+ if unused_params:
296
+ print(f"Unused parameters: {unused_params}")
297
+ # break
298
+
299
+ if(device == 0):
300
+ if(micro_step % 10 == 0):
301
+ # if(step == train_loader_length):
302
+ # break
303
+
304
+ print("Micro Batch : ", micro_step)
305
+ print("Step : ", step, "/", total_iters)
306
+ print('Total batches: ', len(train_dataloader))
307
+ print("Total gradient accumulation steps: ", gradient_accumulation_steps)
308
+ print("Total tokens processed: ", token_count)
309
+ # count += 1
310
+
311
+ lr = get_lr(step)
312
+ for params in optimizer.param_groups:
313
+ params['lr'] = lr
314
+
315
+
316
+
317
+ # Compute gradient norms before clipping
318
+ if(ModelArgs.clip != 0.0):
319
+ scaler.unscale_(optimizer) #To avoid underflow
320
+ total_norm_before = torch.norm(
321
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
322
+ )
323
+
324
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=ModelArgs.clip)
325
+
326
+ # Compute gradient norms after clipping
327
+ total_norm_after = torch.norm(
328
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in model.parameters()]), 2
329
+ )
330
+
331
+ if(device == 0 and step !=0):
332
+ print(f"Gradient Norm Before Clipping: {total_norm_before.item():.4f}")
333
+ print(f"Gradient Norm After Clipping: {total_norm_after.item():.4f}")
334
+
335
+ scaler.step(optimizer)
336
+ scaler.update()
337
+
338
+ # optimizer.step()
339
+ # new_scheduler.step()
340
+
341
+ torch.cuda.synchronize()
342
+ torch.distributed.reduce(loss, dst=0, op=torch.distributed.ReduceOp.SUM)
343
+ if(device == 0):
344
+ wandb.log({
345
+ "Learning Rate": lr,
346
+ "All_GPUs_Train_losses": accumulated_loss.item(),
347
+ # "All_GPUs_Val_losses": all_gpus_avg_val_loss,
348
+ # "training_step_loss": losses['train'],
349
+ # "val_step_loss": losses['val'],
350
+ "Step": step,
351
+ # "Epoch": epoch
352
+
353
+ })
354
+ # print(loss.item())
355
+
356
+ # break
357
+ if device == 0 and step % 5 == 0:
358
+ count = 3
359
+ while(count): # Only generate text on the main process
360
+
361
+ prompt = "Once upon a time"
362
+ generated_text = topk_sampling(model, prompt, max_length=50, top_k=50, temperature=1.0, device=device)
363
+
364
+
365
+ print(f" Step: {step} | Generated Text: {generated_text}")
366
+
367
+ count -= 1
368
+
369
+
370
+ if device == 0:
371
+
372
+ wandb.finish()
373
+ cleanup()
374
+
375
+
376
+ world_size = torch.cuda.device_count()
377
+ print(f"World size: {world_size}")
378
+
379
+
380
+
381
+
382
+ def parse_args():
383
+ parser = argparse.ArgumentParser(description="Model Training Arguments")
384
+
385
+ # Add arguments for each field in ModelArgs
386
+ parser.add_argument("--epochs", type=int, default=ModelArgs.epochs, help="Number of training epochs.")
387
+ parser.add_argument("--block_size", type=int, default=ModelArgs.block_size, help="Block size for the model.")
388
+ parser.add_argument("--batch_size", type=int, default=ModelArgs.batch_size, help="Batch size for training.")
389
+ # parser.add_argument("--inference", type=lambda x: (str(x).lower() == 'true'), default=ModelArgs.inference, help="Whether to run in inference mode.")
390
+ parser.add_argument("--embeddings_dims", type=int, default=ModelArgs.embeddings_dims, help="Embedding dimensions.")
391
+ parser.add_argument("--attn_dropout", type=float, default=ModelArgs.attn_dropout, help="Attention dropout rate.")
392
+ parser.add_argument("--no_of_heads", type=int, default=ModelArgs.no_of_heads, help="Number of attention heads.")
393
+ parser.add_argument("--dropout", type=float, default=ModelArgs.dropout, help="Dropout rate.")
394
+ parser.add_argument("--val_epochs", type=int, default=ModelArgs.val_epochs, help="Number of validation epochs.")
395
+ parser.add_argument("--max_lr", type=float, default=ModelArgs.max_lr, help="Learning rate.")
396
+ parser.add_argument("--no_of_decoder_layers", type=int, default=ModelArgs.no_of_decoder_layers, help="Number of decoder layers.")
397
+ parser.add_argument("--weight_decay_optim", type=float, default=ModelArgs.weight_decay_optim, help="Weight decay for optimizer.")
398
+ parser.add_argument("--beta_1", type=float, default=ModelArgs.beta_1, help="Beta1 for Adam optimizer.")
399
+ parser.add_argument("--beta_2", type=float, default=ModelArgs.beta_2, help="Beta2 for Adam optimizer.")
400
+ parser.add_argument("--clip", type=float, default=ModelArgs.clip, help="Gradient clipping value.")
401
+ parser.add_argument("--device", type=str, default=ModelArgs.device, help="Device to run the model on (e.g., 'cuda' or 'cpu').")
402
+ parser.add_argument("--no_kv_heads", type=int, default=ModelArgs.no_kv_heads, help="Number of key/value heads.")
403
+ parser.add_argument("--vocab_size", type=int, default=ModelArgs.vocab_size, help="Vocabulary size.")
404
+ parser.add_argument("--eps", type=float, default=ModelArgs.eps, help="Epsilon value for numerical stability.")
405
+ parser.add_argument("--dtype", type=str, default=ModelArgs.dtype, help="Data type for tensors (e.g., 'float16' or 'bfloat16').")
406
+ parser.add_argument("--save_checkpoint_dir", type=str, default=ModelArgs.save_checkpoint_dir, help="Directory to save model checkpoints.")
407
+ parser.add_argument("--prompt", type=str, default=ModelArgs.prompt, help="Prompt for testing during training.")
408
+
409
+ # Additional arguments
410
+ parser.add_argument("--save_checkpoint_iter", type=int, default=ModelArgs.save_checkpoint_iter, help="Save checkpoint every N iterations.")
411
+ parser.add_argument("--total_iters", type=int, default=ModelArgs.total_iters, help="Total number of training iterations.")
412
+ parser.add_argument("--eval_iters", type=int, default=ModelArgs.eval_iters, help="Number of iterations for evaluation.")
413
+ parser.add_argument("--eval_check", type=int, default=ModelArgs.eval_check, help="Evaluate model every N iterations.")
414
+ parser.add_argument("--warmup_iters", type=int, default=ModelArgs.warmup_iters, help="Number of warmup iterations for learning rate scheduling.")
415
+ parser.add_argument("--min_lr", type=float, default=ModelArgs.min_lr, help="Minimum learning rate.")
416
+ parser.add_argument("--lr_decay_iters", type=int, default=ModelArgs.lr_decay_iters, help="Number of iterations for learning rate decay.")
417
+ parser.add_argument("--total_batch_size", type=int, default=ModelArgs.total_batch_size, help="Total batch size across all devices.")
418
+ parser.add_argument("--micro_batch_size", type=int, default=ModelArgs.micro_batch_size, help="Micro batch size per device.")
419
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=ModelArgs.gradient_accumulation_steps, help="Number of gradient accumulation steps.")
420
+
421
+ args = parser.parse_args()
422
+ return args
423
+
424
+
425
+ def initialize_model_args(args):
426
+ # Create a ModelArgs instance from the parsed arguments
427
+ model_args = ModelArgs(
428
+ epochs=args.epochs,
429
+ block_size=args.block_size,
430
+ batch_size=args.batch_size,
431
+ # inference=args.inference,
432
+ embeddings_dims=args.embeddings_dims,
433
+ attn_dropout=args.attn_dropout,
434
+ no_of_heads=args.no_of_heads,
435
+ dropout=args.dropout,
436
+ val_epochs=args.val_epochs,
437
+ max_lr=args.max_lr,
438
+ no_of_decoder_layers=args.no_of_decoder_layers,
439
+ weight_decay_optim=args.weight_decay_optim,
440
+ beta_1=args.beta_1,
441
+ beta_2=args.beta_2,
442
+ clip=args.clip,
443
+ device=args.device,
444
+ no_kv_heads=args.no_kv_heads,
445
+ vocab_size=args.vocab_size,
446
+ eps=args.eps,
447
+ dtype=args.dtype,
448
+ save_checkpoint_dir=args.save_checkpoint_dir,
449
+ prompt=args.prompt,
450
+ save_checkpoint_iter=args.save_checkpoint_iter,
451
+ total_iters=args.total_iters,
452
+ eval_iters=args.eval_iters,
453
+ eval_check=args.eval_check,
454
+ warmup_iters=args.warmup_iters,
455
+ min_lr=args.min_lr,
456
+ lr_decay_iters=args.lr_decay_iters,
457
+ total_batch_size=args.total_batch_size,
458
+ micro_batch_size=args.micro_batch_size,
459
+ gradient_accumulation_steps=args.gradient_accumulation_steps
460
+ )
461
+ return model_args
462
+
463
+
464
+ if __name__ == "__main__":
465
+ args = parse_args()
466
+
467
+
468
+ model_args = initialize_model_args(args)
469
+