Spaces:
Running
on
Zero
Running
on
Zero
from abc import ABC | |
from abc import abstractmethod | |
import sentencepiece as spm | |
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model | |
from typing import Any, Union | |
import numpy as np | |
from dataclasses import dataclass | |
def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sample=False): | |
"""Encode text into sentence pieces. Only supports py3.""" | |
if not sample: | |
pieces = sp_model.EncodeAsPieces(text) | |
else: | |
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) | |
return pieces | |
class AbstractTokenizer(ABC): | |
"""Abstract class for tokenizer.""" | |
def __init__(self, name): | |
self.name = name | |
super().__init__() | |
def vocab_size(self): | |
pass | |
def vocab(self): | |
"""Dictionary from vocab text token to id token.""" | |
pass | |
def inv_vocab(self): | |
"""Dictionary from vocab id token to text token.""" | |
pass | |
def tokenize(self, text): | |
pass | |
def detokenize(self, token_ids): | |
raise NotImplementedError('detokenizer is not implemented for {} ' | |
'tokenizer'.format(self.name)) | |
def cls(self): | |
raise NotImplementedError('CLS is not provided for {} ' | |
'tokenizer'.format(self.name)) | |
def sep(self): | |
raise NotImplementedError('SEP is not provided for {} ' | |
'tokenizer'.format(self.name)) | |
def pad(self): | |
raise NotImplementedError('PAD is not provided for {} ' | |
'tokenizer'.format(self.name)) | |
def eod(self): | |
raise NotImplementedError('EOD is not provided for {} ' | |
'tokenizer'.format(self.name)) | |
def mask(self): | |
raise NotImplementedError('MASK is not provided for {} ' | |
'tokenizer'.format(self.name)) | |
class SPieceTokenizer(AbstractTokenizer): | |
def __init__(self, spm_file: str): | |
super().__init__('Sentence Piece') | |
self.sp_model = spm.SentencePieceProcessor() | |
self.sp_model.Load(spm_file) | |
self.eod_id = self.get_token_id('</s>') | |
self.special_ids = set([ | |
self.sp_model.pad_id(), | |
self.sp_model.eos_id(), | |
self.sp_model.bos_id(), | |
self.sp_model.unk_id(), | |
self.eod_id, | |
]) | |
# initialize index_2_bytes | |
self._initialize_index_2_bytes() | |
def encode_pieces(self, text: str, sample=False): | |
if not sample: | |
pieces = self.sp_model.EncodeAsPieces(text) | |
else: | |
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) | |
return pieces | |
def _initialize_index_2_bytes(self): | |
proto = sp_pb2_model.ModelProto() | |
proto.ParseFromString(self.sp_model.serialized_model_proto()) | |
self.index_2_numbytes = [0] * len(proto.pieces) | |
for i, p in enumerate(proto.pieces): | |
clean_piece = p.piece.replace('β', '') | |
self.index_2_numbytes[i] = len(clean_piece.encode('utf-8')) | |
def set_add_dummy_prefix(self, add_dummy_prefix: bool = False): | |
proto = sp_pb2_model.ModelProto() | |
proto.ParseFromString(self.sp_model.serialized_model_proto()) | |
if proto.normalizer_spec.add_dummy_prefix != add_dummy_prefix: | |
proto.normalizer_spec.add_dummy_prefix = add_dummy_prefix | |
self.sp_model.LoadFromSerializedProto(proto.SerializeToString()) | |
print(f"> set add_dummy_prefix to {add_dummy_prefix} ...", flush=True) | |
def add_special_id(self, token_id): | |
self.special_ids.add(token_id) | |
def has_dummy_prefix(self): | |
pieces = self.sp_model.EncodeAsPieces("hello") | |
return pieces[0].startswith('β') | |
def vocab_size(self): | |
return self.sp_model.GetPieceSize() | |
def vocab(self): | |
"""Dictionary from vocab text token to id token.""" | |
return self.sp_model | |
def get_array_bytes(self, array): | |
return sum(self.index_2_numbytes[i] if i < self.vocab_size else 2 for i in array) | |
def tokenize(self, text): | |
tokens = encode_pieces(self.sp_model, text) | |
return self.convert_tokens_to_ids(tokens) | |
def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs: Any) -> list[int]: | |
tokens = self.encode_pieces(text) | |
t = self.convert_tokens_to_ids(tokens) | |
if bos: | |
t.insert(0, self.bos_id) | |
if eos: | |
t.append(self.eos_id) | |
return t | |
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: | |
if isinstance(tokens, str): | |
return self.sp_model.PieceToId(tokens) | |
return [self.sp_model.PieceToId(token) for token in tokens] | |
def detokenize(self, token_ids): | |
if isinstance(token_ids, list): | |
pieces = [self.sp_model.IdToPiece(id) for id in token_ids] | |
else: | |
pieces = [self.sp_model.IdToPiece(id) for id in token_ids.tolist()] | |
return pieces | |
def decode(self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False) -> str: | |
assert not skip_special_tokens, "skip_special_tokens is not supported" | |
if isinstance(token_ids, (int, np.integer)): | |
return self.detokenize([int(token_ids)])[0] | |
return ''.join(self.detokenize(token_ids)) | |
def get_token_id(self, token): | |
return self.sp_model.PieceToId(token) | |
def inv_vocab(self): | |
# TODO: to be implemented | |
return {} | |
def decode_pieces(self, pieces): | |
return self.sp_model.DecodePieces(pieces) | |
def eod(self): | |
return self.eod_id | |
def pad_id(self): | |
return self.sp_model.pad_id() | |
def eos_id(self): | |
return self.sp_model.eos_id() | |
def bos_id(self): | |
return self.sp_model.bos_id() | |
def unk_id(self): | |
return self.sp_model.unk_id() | |
def pad_token_id(self): | |
return self.pad_id | |
def eos_token_id(self): | |
return self.eos_id | |
class ExtraTokens: | |
msg_end: int | |
user_msg_start: int | |
assistant_msg_start: int | |
name_end: int | |
media_begin: int | |
media_content: int | |
media_end: int | |
pad: int | |
def instantiate_extra_tokens(tokenizer: AbstractTokenizer): | |
if isinstance(tokenizer, SPieceTokenizer): | |
map_fn = lambda x: tokenizer.convert_tokens_to_ids(x) | |
else: | |
raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}") | |
return ExtraTokens( | |
msg_end=map_fn('[extra_id_0]'), | |
user_msg_start=map_fn('[extra_id_1]'), | |
assistant_msg_start=map_fn('[extra_id_2]'), | |
name_end=map_fn('[extra_id_12]'), | |
media_begin=map_fn('[extra_id_13]'), | |
media_content=map_fn('[extra_id_14]'), | |
media_end=map_fn('[extra_id_15]'), | |
pad=tokenizer.pad_id | |
) | |
def get_tokenizer_and_extra_tokens(): | |
sp_model_path = "resources/tokenizer/160k.model" | |
tokenizer = SPieceTokenizer(sp_model_path) | |
tokenizer.set_add_dummy_prefix(False) | |
extra_tokens = instantiate_extra_tokens(tokenizer) | |
return tokenizer, extra_tokens | |