File size: 7,446 Bytes
a3e05e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
from abc import ABC
from abc import abstractmethod
import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from typing import Any, Union
import numpy as np
from dataclasses import dataclass

def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sample=False):
    """Encode text into sentence pieces. Only supports py3."""

    if not sample:
        pieces = sp_model.EncodeAsPieces(text)
    else:
        pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)

    return pieces


class AbstractTokenizer(ABC):
    """Abstract class for tokenizer."""

    def __init__(self, name):
        self.name = name
        super().__init__()

    @property
    @abstractmethod
    def vocab_size(self):
        pass

    @property
    @abstractmethod
    def vocab(self):
        """Dictionary from vocab text token to id token."""
        pass

    @property
    @abstractmethod
    def inv_vocab(self):
        """Dictionary from vocab id token to text token."""
        pass

    @abstractmethod
    def tokenize(self, text):
        pass

    def detokenize(self, token_ids):
        raise NotImplementedError('detokenizer is not implemented for {} '
                                  'tokenizer'.format(self.name))

    @property
    def cls(self):
        raise NotImplementedError('CLS is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def sep(self):
        raise NotImplementedError('SEP is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def pad(self):
        raise NotImplementedError('PAD is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def eod(self):
        raise NotImplementedError('EOD is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def mask(self):
        raise NotImplementedError('MASK is not provided for {} '
                                  'tokenizer'.format(self.name))


class SPieceTokenizer(AbstractTokenizer):
    def __init__(self, spm_file: str):
        super().__init__('Sentence Piece')
        self.sp_model = spm.SentencePieceProcessor()
        self.sp_model.Load(spm_file)
        self.eod_id = self.get_token_id('</s>')

        self.special_ids = set([
            self.sp_model.pad_id(),
            self.sp_model.eos_id(),
            self.sp_model.bos_id(),
            self.sp_model.unk_id(),
            self.eod_id,
        ])

        # initialize index_2_bytes
        self._initialize_index_2_bytes()
    
    def encode_pieces(self, text: str, sample=False):
        if not sample:
            pieces = self.sp_model.EncodeAsPieces(text)
        else:
            pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
        return pieces

    def _initialize_index_2_bytes(self):
        proto = sp_pb2_model.ModelProto()
        proto.ParseFromString(self.sp_model.serialized_model_proto())
        self.index_2_numbytes = [0] * len(proto.pieces)
        for i, p in enumerate(proto.pieces):
            clean_piece = p.piece.replace('▁', '')
            self.index_2_numbytes[i] = len(clean_piece.encode('utf-8'))

    def set_add_dummy_prefix(self, add_dummy_prefix: bool = False):
        proto = sp_pb2_model.ModelProto()
        proto.ParseFromString(self.sp_model.serialized_model_proto())
        if proto.normalizer_spec.add_dummy_prefix != add_dummy_prefix:
            proto.normalizer_spec.add_dummy_prefix = add_dummy_prefix
            self.sp_model.LoadFromSerializedProto(proto.SerializeToString())
            print(f"> set add_dummy_prefix to {add_dummy_prefix} ...", flush=True)

    def add_special_id(self, token_id):
        self.special_ids.add(token_id)

    @property
    def has_dummy_prefix(self):
        pieces = self.sp_model.EncodeAsPieces("hello")
        return pieces[0].startswith('▁')

    @property
    def vocab_size(self):
        return self.sp_model.GetPieceSize()

    @property
    def vocab(self):
        """Dictionary from vocab text token to id token."""
        return self.sp_model

    def get_array_bytes(self, array):
        return sum(self.index_2_numbytes[i] if i < self.vocab_size else 2 for i in array)

    def tokenize(self, text):
        tokens = encode_pieces(self.sp_model, text)
        return self.convert_tokens_to_ids(tokens)
    
    def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs: Any) -> list[int]:
        tokens = self.encode_pieces(text)
        t = self.convert_tokens_to_ids(tokens)
        if bos:
            t.insert(0, self.bos_id)
        if eos:
            t.append(self.eos_id)
        return t

    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
        if isinstance(tokens, str):
            return self.sp_model.PieceToId(tokens)
        return [self.sp_model.PieceToId(token) for token in tokens]

    def detokenize(self, token_ids):
        if isinstance(token_ids, list):
            pieces = [self.sp_model.IdToPiece(id) for id in token_ids]
        else:
            pieces = [self.sp_model.IdToPiece(id) for id in token_ids.tolist()]
        return pieces
    
    def decode(self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False) -> str:
        assert not skip_special_tokens, "skip_special_tokens is not supported"
        if isinstance(token_ids, (int, np.integer)):
            return self.detokenize([int(token_ids)])[0]
        return ''.join(self.detokenize(token_ids))

    def get_token_id(self, token):
        return self.sp_model.PieceToId(token)

    def inv_vocab(self):
        # TODO: to be implemented
        return {}

    def decode_pieces(self, pieces):
        return self.sp_model.DecodePieces(pieces)

    @property
    def eod(self):
        return self.eod_id

    @property
    def pad_id(self):
        return self.sp_model.pad_id()

    @property
    def eos_id(self):
        return self.sp_model.eos_id()

    @property
    def bos_id(self):
        return self.sp_model.bos_id()

    @property
    def unk_id(self):
        return self.sp_model.unk_id()
    
    @property
    def pad_token_id(self):
        return self.pad_id

    @property
    def eos_token_id(self):
        return self.eos_id

    
@dataclass
class ExtraTokens:
    msg_end: int
    user_msg_start: int
    assistant_msg_start: int
    name_end: int
    media_begin: int
    media_content: int
    media_end: int
    pad: int


def instantiate_extra_tokens(tokenizer: AbstractTokenizer):
    if isinstance(tokenizer, SPieceTokenizer):
        map_fn = lambda x: tokenizer.convert_tokens_to_ids(x)
    else:
        raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}")

    return ExtraTokens(
        msg_end=map_fn('[extra_id_0]'),
        user_msg_start=map_fn('[extra_id_1]'),
        assistant_msg_start=map_fn('[extra_id_2]'),
        name_end=map_fn('[extra_id_12]'),
        media_begin=map_fn('[extra_id_13]'),
        media_content=map_fn('[extra_id_14]'),
        media_end=map_fn('[extra_id_15]'),
        pad=tokenizer.pad_id
    )

def get_tokenizer_and_extra_tokens():
    sp_model_path = "resources/tokenizer/160k.model"
    tokenizer = SPieceTokenizer(sp_model_path)
    tokenizer.set_add_dummy_prefix(False)
    extra_tokens = instantiate_extra_tokens(tokenizer)
    return tokenizer, extra_tokens