Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import argparse | |
import os | |
from typing import Optional | |
from requests.exceptions import HTTPError | |
TOKENIZER = { | |
"llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"), | |
"llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"), | |
"gemma": ("google/gemma-2-9b", "tokenizer.model"), | |
} | |
def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None): | |
if tokenizer_name in TOKENIZER: | |
repo_id, filename = TOKENIZER[tokenizer_name] | |
from huggingface_hub import hf_hub_download | |
try: | |
hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
local_dir=path_to_save, | |
local_dir_use_symlinks=False, | |
token=api_key if api_key else None, | |
) | |
except HTTPError as e: | |
if e.response.status_code == 401: | |
print( | |
"You need to pass a valid `--hf_token=...` to download private checkpoints." | |
) | |
else: | |
raise e | |
else: | |
from tiktoken import get_encoding | |
if "TIKTOKEN_CACHE_DIR" not in os.environ: | |
os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save | |
try: | |
get_encoding(tokenizer_name) | |
except ValueError: | |
print( | |
f"Tokenizer {tokenizer_name} not found. Please check the name and try again." | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("tokenizer_name", type=str) | |
parser.add_argument("tokenizer_dir", type=str, default=8) | |
parser.add_argument("--api_key", type=str, default="") | |
args = parser.parse_args() | |
main( | |
tokenizer_name=args.tokenizer_name, | |
path_to_save=args.tokenizer_dir, | |
api_key=args.api_key, | |
) | |