Spaces:
Running
Running
Open source weights! (#97)
Browse filesSummary:
Add code to download weights and demo code for running model.
Weights at:
- https://huggingface.co./collections/facebook/blt-6801263d4ac1704702a192a6
- https://huggingface.co./facebook/blt
- https://huggingface.co./facebook/blt-1b
- https://huggingface.co./facebook/blt-7b
Test Plan:
- .gitignore +1 -1
- README.md +19 -1
- bytelatent/data/patcher.py +11 -4
- bytelatent/eval.py +3 -1
- bytelatent/generate.py +6 -5
- demo.py +43 -0
- download_blt_weights.py +15 -0
.gitignore
CHANGED
@@ -169,4 +169,4 @@ internal/
|
|
169 |
jobs_parallel-copy/
|
170 |
wandb/
|
171 |
*.ipynb
|
172 |
-
|
|
|
169 |
jobs_parallel-copy/
|
170 |
wandb/
|
171 |
*.ipynb
|
172 |
+
hf-weights/
|
README.md
CHANGED
@@ -46,7 +46,25 @@ Once that is done you can activate the environment
|
|
46 |
conda activate blt_<date>
|
47 |
```
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
|
51 |
|
52 |
```bash
|
|
|
46 |
conda activate blt_<date>
|
47 |
```
|
48 |
|
49 |
+
## Downloading HF Model Weights and Generating Text
|
50 |
+
|
51 |
+
We have released weights on HF for the [BLT 1B Model](https://huggingface.co/facebook/blt-1b) and [BLT 7B Model](https://huggingface.co/facebook/blt-7b).
|
52 |
+
We are actively working with HF to make BLT available in [Transformers](https://huggingface.co/docs/transformers/en/index) and will update this when it is.
|
53 |
+
In the meantime, you can follow these instructions to load model weights, initialize a model, and generate text.
|
54 |
+
These instructions have been tested on H100 GPUs, but we can only offer suggestions on running on other hardware.
|
55 |
+
|
56 |
+
1. On the model weights HF page, create a HuggingFace account, request access to weights, and wait for approval.
|
57 |
+
2. On the huggingface cli, login: `huggingface-cli login`
|
58 |
+
3. Download the model weights with: `python download_blt_weights.py`, which will load to `hf-weights`
|
59 |
+
4. Run the generate demo: `python demo.py "A BLT has"`.
|
60 |
+
|
61 |
+
The demo generates text, but is also a good starting point for loading BLT in your own code.
|
62 |
+
|
63 |
+
## Downloading Training Data
|
64 |
+
|
65 |
+
Note: The following instructions are not well tested in the BLT code as it is based on the lingua code, which we have diverged from.
|
66 |
+
|
67 |
+
Use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
|
68 |
This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
|
69 |
|
70 |
```bash
|
bytelatent/data/patcher.py
CHANGED
@@ -476,12 +476,19 @@ class Patcher:
|
|
476 |
assert (
|
477 |
patcher_args.entropy_model_checkpoint_dir is not None
|
478 |
), "Cannot require realtime patching without an entropy model checkpoint"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
entropy_model = load_entropy_model(
|
480 |
patcher_args.entropy_model_checkpoint_dir,
|
481 |
-
|
482 |
-
patcher_args.entropy_model_checkpoint_dir,
|
483 |
-
"consolidated/consolidated.pth",
|
484 |
-
),
|
485 |
)
|
486 |
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
487 |
self.entropy_model = entropy_model
|
|
|
476 |
assert (
|
477 |
patcher_args.entropy_model_checkpoint_dir is not None
|
478 |
), "Cannot require realtime patching without an entropy model checkpoint"
|
479 |
+
maybe_consolidated = os.path.join(
|
480 |
+
patcher_args.entropy_model_checkpoint_dir,
|
481 |
+
"consolidated/consolidated.pth",
|
482 |
+
)
|
483 |
+
if os.path.exists(maybe_consolidated):
|
484 |
+
state_path = maybe_consolidated
|
485 |
+
else:
|
486 |
+
state_path = os.path.join(
|
487 |
+
patcher_args.entropy_model_checkpoint_dir, "consolidated.pth"
|
488 |
+
)
|
489 |
entropy_model = load_entropy_model(
|
490 |
patcher_args.entropy_model_checkpoint_dir,
|
491 |
+
state_path,
|
|
|
|
|
|
|
492 |
)
|
493 |
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
494 |
self.entropy_model = entropy_model
|
bytelatent/eval.py
CHANGED
@@ -206,7 +206,9 @@ def eval_ppl_on_path(
|
|
206 |
pred = model(x, patch_lengths=patch_lengths)
|
207 |
else:
|
208 |
pred = model(x)
|
209 |
-
loss = F.cross_entropy(
|
|
|
|
|
210 |
total_loss += loss.item()
|
211 |
else:
|
212 |
raise NotImplementedError()
|
|
|
206 |
pred = model(x, patch_lengths=patch_lengths)
|
207 |
else:
|
208 |
pred = model(x)
|
209 |
+
loss = F.cross_entropy(
|
210 |
+
pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0
|
211 |
+
)
|
212 |
total_loss += loss.item()
|
213 |
else:
|
214 |
raise NotImplementedError()
|
bytelatent/generate.py
CHANGED
@@ -25,7 +25,11 @@ from bytelatent.checkpoint import (
|
|
25 |
)
|
26 |
from bytelatent.config_parser import parse_args_to_pydantic_model
|
27 |
from bytelatent.data.file_util import get_fs
|
28 |
-
from bytelatent.distributed import
|
|
|
|
|
|
|
|
|
29 |
from bytelatent.model.blt import ByteLatentTransformer
|
30 |
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
31 |
from bytelatent.transformer import LMTransformer
|
@@ -388,10 +392,7 @@ class PackedCausalTransformerGenerator:
|
|
388 |
return generation, loglikelihood, greedy
|
389 |
|
390 |
|
391 |
-
def load_consolidated_model_and_tokenizer(
|
392 |
-
consolidated_path,
|
393 |
-
init_distributed=False
|
394 |
-
):
|
395 |
if init_distributed:
|
396 |
distributed_args = DistributedArgs()
|
397 |
distributed_args.configure_world()
|
|
|
25 |
)
|
26 |
from bytelatent.config_parser import parse_args_to_pydantic_model
|
27 |
from bytelatent.data.file_util import get_fs
|
28 |
+
from bytelatent.distributed import (
|
29 |
+
DistributedArgs,
|
30 |
+
get_global_rank,
|
31 |
+
setup_torch_distributed,
|
32 |
+
)
|
33 |
from bytelatent.model.blt import ByteLatentTransformer
|
34 |
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
35 |
from bytelatent.transformer import LMTransformer
|
|
|
392 |
return generation, loglikelihood, greedy
|
393 |
|
394 |
|
395 |
+
def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=False):
|
|
|
|
|
|
|
396 |
if init_distributed:
|
397 |
distributed_args = DistributedArgs()
|
398 |
distributed_args.configure_world()
|
demo.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import typer
|
5 |
+
|
6 |
+
from bytelatent.distributed import DistributedArgs, setup_torch_distributed
|
7 |
+
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
8 |
+
from bytelatent.generate_blt import generate_nocache
|
9 |
+
from bytelatent.model.blt import ByteLatentTransformer
|
10 |
+
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
11 |
+
|
12 |
+
|
13 |
+
def main(prompt: str, model_name: str = "blt-1b"):
|
14 |
+
distributed_args = DistributedArgs()
|
15 |
+
distributed_args.configure_world()
|
16 |
+
if not torch.distributed.is_initialized():
|
17 |
+
setup_torch_distributed(distributed_args)
|
18 |
+
checkpoint_path = os.path.join("hf-weights", model_name)
|
19 |
+
print(f"Loading BLT model: {model_name}")
|
20 |
+
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
21 |
+
checkpoint_path,
|
22 |
+
)
|
23 |
+
assert isinstance(model, ByteLatentTransformer)
|
24 |
+
assert isinstance(tokenizer, BltTokenizer)
|
25 |
+
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
26 |
+
patcher_args.realtime_patching = True
|
27 |
+
print("Loading entropy model and patcher")
|
28 |
+
patcher_args.entropy_model_checkpoint_dir = os.path.join(
|
29 |
+
checkpoint_path, "entropy_model"
|
30 |
+
)
|
31 |
+
patcher = patcher_args.build()
|
32 |
+
prompts = [prompt]
|
33 |
+
outputs = generate_nocache(
|
34 |
+
prompts, model=model, tokenizer=tokenizer, patcher=patcher
|
35 |
+
)
|
36 |
+
text_outputs = [tokenizer.decode(t) for t in outputs]
|
37 |
+
for p, t in zip(prompts, text_outputs):
|
38 |
+
print(f'Prompt: "{p}" Completion: "{t}"')
|
39 |
+
print()
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
typer.run(main)
|
download_blt_weights.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import typer
|
4 |
+
from huggingface_hub import snapshot_download
|
5 |
+
|
6 |
+
|
7 |
+
def main(models: list[str] = ["blt-1b", "blt-7b"]):
|
8 |
+
if not os.path.exists("hf-weights"):
|
9 |
+
os.makedirs("hf-weights")
|
10 |
+
for model in models:
|
11 |
+
snapshot_download(f"facebook/{model}", local_dir=f"hf-weights/{model}")
|
12 |
+
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
typer.run(main)
|