par-meta commited on
Commit
96d51b5
·
unverified ·
1 Parent(s): e299427

Open source weights! (#97)

Browse files

Summary:

Add code to download weights and demo code for running model.

Weights at:
- https://huggingface.co./collections/facebook/blt-6801263d4ac1704702a192a6
- https://huggingface.co./facebook/blt
- https://huggingface.co./facebook/blt-1b
- https://huggingface.co./facebook/blt-7b

Test Plan:

.gitignore CHANGED
@@ -169,4 +169,4 @@ internal/
169
  jobs_parallel-copy/
170
  wandb/
171
  *.ipynb
172
-
 
169
  jobs_parallel-copy/
170
  wandb/
171
  *.ipynb
172
+ hf-weights/
README.md CHANGED
@@ -46,7 +46,25 @@ Once that is done you can activate the environment
46
  conda activate blt_<date>
47
  ```
48
 
49
- use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
51
 
52
  ```bash
 
46
  conda activate blt_<date>
47
  ```
48
 
49
+ ## Downloading HF Model Weights and Generating Text
50
+
51
+ We have released weights on HF for the [BLT 1B Model](https://huggingface.co/facebook/blt-1b) and [BLT 7B Model](https://huggingface.co/facebook/blt-7b).
52
+ We are actively working with HF to make BLT available in [Transformers](https://huggingface.co/docs/transformers/en/index) and will update this when it is.
53
+ In the meantime, you can follow these instructions to load model weights, initialize a model, and generate text.
54
+ These instructions have been tested on H100 GPUs, but we can only offer suggestions on running on other hardware.
55
+
56
+ 1. On the model weights HF page, create a HuggingFace account, request access to weights, and wait for approval.
57
+ 2. On the huggingface cli, login: `huggingface-cli login`
58
+ 3. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights`
59
+ 4. Run the generate demo: `python demo.py "A BLT has"`.
60
+
61
+ The demo generates text, but is also a good starting point for loading BLT in your own code.
62
+
63
+ ## Downloading Training Data
64
+
65
+ Note: The following instructions are not well tested in the BLT code as it is based on the lingua code, which we have diverged from.
66
+
67
+ Use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
68
  This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
69
 
70
  ```bash
bytelatent/data/patcher.py CHANGED
@@ -476,12 +476,19 @@ class Patcher:
476
  assert (
477
  patcher_args.entropy_model_checkpoint_dir is not None
478
  ), "Cannot require realtime patching without an entropy model checkpoint"
 
 
 
 
 
 
 
 
 
 
479
  entropy_model = load_entropy_model(
480
  patcher_args.entropy_model_checkpoint_dir,
481
- os.path.join(
482
- patcher_args.entropy_model_checkpoint_dir,
483
- "consolidated/consolidated.pth",
484
- ),
485
  )
486
  entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
487
  self.entropy_model = entropy_model
 
476
  assert (
477
  patcher_args.entropy_model_checkpoint_dir is not None
478
  ), "Cannot require realtime patching without an entropy model checkpoint"
479
+ maybe_consolidated = os.path.join(
480
+ patcher_args.entropy_model_checkpoint_dir,
481
+ "consolidated/consolidated.pth",
482
+ )
483
+ if os.path.exists(maybe_consolidated):
484
+ state_path = maybe_consolidated
485
+ else:
486
+ state_path = os.path.join(
487
+ patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
488
+ )
489
  entropy_model = load_entropy_model(
490
  patcher_args.entropy_model_checkpoint_dir,
491
+ state_path,
 
 
 
492
  )
493
  entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
494
  self.entropy_model = entropy_model
bytelatent/eval.py CHANGED
@@ -206,7 +206,9 @@ def eval_ppl_on_path(
206
  pred = model(x, patch_lengths=patch_lengths)
207
  else:
208
  pred = model(x)
209
- loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0)
 
 
210
  total_loss += loss.item()
211
  else:
212
  raise NotImplementedError()
 
206
  pred = model(x, patch_lengths=patch_lengths)
207
  else:
208
  pred = model(x)
209
+ loss = F.cross_entropy(
210
+ pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0
211
+ )
212
  total_loss += loss.item()
213
  else:
214
  raise NotImplementedError()
bytelatent/generate.py CHANGED
@@ -25,7 +25,11 @@ from bytelatent.checkpoint import (
25
  )
26
  from bytelatent.config_parser import parse_args_to_pydantic_model
27
  from bytelatent.data.file_util import get_fs
28
- from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs
 
 
 
 
29
  from bytelatent.model.blt import ByteLatentTransformer
30
  from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
31
  from bytelatent.transformer import LMTransformer
@@ -388,10 +392,7 @@ class PackedCausalTransformerGenerator:
388
  return generation, loglikelihood, greedy
389
 
390
 
391
- def load_consolidated_model_and_tokenizer(
392
- consolidated_path,
393
- init_distributed=False
394
- ):
395
  if init_distributed:
396
  distributed_args = DistributedArgs()
397
  distributed_args.configure_world()
 
25
  )
26
  from bytelatent.config_parser import parse_args_to_pydantic_model
27
  from bytelatent.data.file_util import get_fs
28
+ from bytelatent.distributed import (
29
+ DistributedArgs,
30
+ get_global_rank,
31
+ setup_torch_distributed,
32
+ )
33
  from bytelatent.model.blt import ByteLatentTransformer
34
  from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
35
  from bytelatent.transformer import LMTransformer
 
392
  return generation, loglikelihood, greedy
393
 
394
 
395
+ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=False):
 
 
 
396
  if init_distributed:
397
  distributed_args = DistributedArgs()
398
  distributed_args.configure_world()
demo.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import typer
5
+
6
+ from bytelatent.distributed import DistributedArgs, setup_torch_distributed
7
+ from bytelatent.generate import load_consolidated_model_and_tokenizer
8
+ from bytelatent.generate_blt import generate_nocache
9
+ from bytelatent.model.blt import ByteLatentTransformer
10
+ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
11
+
12
+
13
+ def main(prompt: str, model_name: str = "blt-1b"):
14
+ distributed_args = DistributedArgs()
15
+ distributed_args.configure_world()
16
+ if not torch.distributed.is_initialized():
17
+ setup_torch_distributed(distributed_args)
18
+ checkpoint_path = os.path.join("hf-weights", model_name)
19
+ print(f"Loading BLT model: {model_name}")
20
+ model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
21
+ checkpoint_path,
22
+ )
23
+ assert isinstance(model, ByteLatentTransformer)
24
+ assert isinstance(tokenizer, BltTokenizer)
25
+ patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
26
+ patcher_args.realtime_patching = True
27
+ print("Loading entropy model and patcher")
28
+ patcher_args.entropy_model_checkpoint_dir = os.path.join(
29
+ checkpoint_path, "entropy_model"
30
+ )
31
+ patcher = patcher_args.build()
32
+ prompts = [prompt]
33
+ outputs = generate_nocache(
34
+ prompts, model=model, tokenizer=tokenizer, patcher=patcher
35
+ )
36
+ text_outputs = [tokenizer.decode(t) for t in outputs]
37
+ for p, t in zip(prompts, text_outputs):
38
+ print(f'Prompt: "{p}" Completion: "{t}"')
39
+ print()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ typer.run(main)
download_blt_weights.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import typer
4
+ from huggingface_hub import snapshot_download
5
+
6
+
7
+ def main(models: list[str] = ["blt-1b", "blt-7b"]):
8
+ if not os.path.exists("hf-weights"):
9
+ os.makedirs("hf-weights")
10
+ for model in models:
11
+ snapshot_download(f"facebook/{model}", local_dir=f"hf-weights/{model}")
12
+
13
+
14
+ if __name__ == "__main__":
15
+ typer.run(main)