File size: 1,540 Bytes
96d51b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os

import torch
import typer

from bytelatent.distributed import DistributedArgs, setup_torch_distributed
from bytelatent.generate import load_consolidated_model_and_tokenizer
from bytelatent.generate_blt import generate_nocache
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer


def main(prompt: str, model_name: str = "blt-1b"):
    distributed_args = DistributedArgs()
    distributed_args.configure_world()
    if not torch.distributed.is_initialized():
        setup_torch_distributed(distributed_args)
    checkpoint_path = os.path.join("hf-weights", model_name)
    print(f"Loading BLT model: {model_name}")
    model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
        checkpoint_path,
    )
    assert isinstance(model, ByteLatentTransformer)
    assert isinstance(tokenizer, BltTokenizer)
    patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
    patcher_args.realtime_patching = True
    print("Loading entropy model and patcher")
    patcher_args.entropy_model_checkpoint_dir = os.path.join(
        checkpoint_path, "entropy_model"
    )
    patcher = patcher_args.build()
    prompts = [prompt]
    outputs = generate_nocache(
        prompts, model=model, tokenizer=tokenizer, patcher=patcher
    )
    text_outputs = [tokenizer.decode(t) for t in outputs]
    for p, t in zip(prompts, text_outputs):
        print(f'Prompt: "{p}" Completion: "{t}"')
        print()


if __name__ == "__main__":
    typer.run(main)