mooncast / modules /tokenizer /tokenizer.py
jzq11111's picture
Upload folder using huggingface_hub
a3e05e8 verified
from abc import ABC
from abc import abstractmethod
import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from typing import Any, Union
import numpy as np
from dataclasses import dataclass
def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sample=False):
"""Encode text into sentence pieces. Only supports py3."""
if not sample:
pieces = sp_model.EncodeAsPieces(text)
else:
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
return pieces
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class SPieceTokenizer(AbstractTokenizer):
def __init__(self, spm_file: str):
super().__init__('Sentence Piece')
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(spm_file)
self.eod_id = self.get_token_id('</s>')
self.special_ids = set([
self.sp_model.pad_id(),
self.sp_model.eos_id(),
self.sp_model.bos_id(),
self.sp_model.unk_id(),
self.eod_id,
])
# initialize index_2_bytes
self._initialize_index_2_bytes()
def encode_pieces(self, text: str, sample=False):
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
else:
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
return pieces
def _initialize_index_2_bytes(self):
proto = sp_pb2_model.ModelProto()
proto.ParseFromString(self.sp_model.serialized_model_proto())
self.index_2_numbytes = [0] * len(proto.pieces)
for i, p in enumerate(proto.pieces):
clean_piece = p.piece.replace('▁', '')
self.index_2_numbytes[i] = len(clean_piece.encode('utf-8'))
def set_add_dummy_prefix(self, add_dummy_prefix: bool = False):
proto = sp_pb2_model.ModelProto()
proto.ParseFromString(self.sp_model.serialized_model_proto())
if proto.normalizer_spec.add_dummy_prefix != add_dummy_prefix:
proto.normalizer_spec.add_dummy_prefix = add_dummy_prefix
self.sp_model.LoadFromSerializedProto(proto.SerializeToString())
print(f"> set add_dummy_prefix to {add_dummy_prefix} ...", flush=True)
def add_special_id(self, token_id):
self.special_ids.add(token_id)
@property
def has_dummy_prefix(self):
pieces = self.sp_model.EncodeAsPieces("hello")
return pieces[0].startswith('▁')
@property
def vocab_size(self):
return self.sp_model.GetPieceSize()
@property
def vocab(self):
"""Dictionary from vocab text token to id token."""
return self.sp_model
def get_array_bytes(self, array):
return sum(self.index_2_numbytes[i] if i < self.vocab_size else 2 for i in array)
def tokenize(self, text):
tokens = encode_pieces(self.sp_model, text)
return self.convert_tokens_to_ids(tokens)
def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs: Any) -> list[int]:
tokens = self.encode_pieces(text)
t = self.convert_tokens_to_ids(tokens)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
if isinstance(tokens, str):
return self.sp_model.PieceToId(tokens)
return [self.sp_model.PieceToId(token) for token in tokens]
def detokenize(self, token_ids):
if isinstance(token_ids, list):
pieces = [self.sp_model.IdToPiece(id) for id in token_ids]
else:
pieces = [self.sp_model.IdToPiece(id) for id in token_ids.tolist()]
return pieces
def decode(self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False) -> str:
assert not skip_special_tokens, "skip_special_tokens is not supported"
if isinstance(token_ids, (int, np.integer)):
return self.detokenize([int(token_ids)])[0]
return ''.join(self.detokenize(token_ids))
def get_token_id(self, token):
return self.sp_model.PieceToId(token)
def inv_vocab(self):
# TODO: to be implemented
return {}
def decode_pieces(self, pieces):
return self.sp_model.DecodePieces(pieces)
@property
def eod(self):
return self.eod_id
@property
def pad_id(self):
return self.sp_model.pad_id()
@property
def eos_id(self):
return self.sp_model.eos_id()
@property
def bos_id(self):
return self.sp_model.bos_id()
@property
def unk_id(self):
return self.sp_model.unk_id()
@property
def pad_token_id(self):
return self.pad_id
@property
def eos_token_id(self):
return self.eos_id
@dataclass
class ExtraTokens:
msg_end: int
user_msg_start: int
assistant_msg_start: int
name_end: int
media_begin: int
media_content: int
media_end: int
pad: int
def instantiate_extra_tokens(tokenizer: AbstractTokenizer):
if isinstance(tokenizer, SPieceTokenizer):
map_fn = lambda x: tokenizer.convert_tokens_to_ids(x)
else:
raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}")
return ExtraTokens(
msg_end=map_fn('[extra_id_0]'),
user_msg_start=map_fn('[extra_id_1]'),
assistant_msg_start=map_fn('[extra_id_2]'),
name_end=map_fn('[extra_id_12]'),
media_begin=map_fn('[extra_id_13]'),
media_content=map_fn('[extra_id_14]'),
media_end=map_fn('[extra_id_15]'),
pad=tokenizer.pad_id
)
def get_tokenizer_and_extra_tokens():
sp_model_path = "resources/tokenizer/160k.model"
tokenizer = SPieceTokenizer(sp_model_path)
tokenizer.set_add_dummy_prefix(False)
extra_tokens = instantiate_extra_tokens(tokenizer)
return tokenizer, extra_tokens