|
class Vocab(): |
|
def __init__(self, chars): |
|
self.pad = 0 |
|
self.go = 1 |
|
self.eos = 2 |
|
self.mask_token = 3 |
|
|
|
self.chars = chars |
|
|
|
self.c2i = {c:i+4 for i, c in enumerate(chars)} |
|
|
|
self.i2c = {i+4:c for i, c in enumerate(chars)} |
|
|
|
self.i2c[0] = '<pad>' |
|
self.i2c[1] = '<sos>' |
|
self.i2c[2] = '<eos>' |
|
self.i2c[3] = '*' |
|
|
|
def encode(self, chars): |
|
return [self.go] + [self.c2i[c] for c in chars] + [self.eos] |
|
|
|
def decode(self, ids): |
|
first = 1 if self.go in ids else 0 |
|
last = ids.index(self.eos) if self.eos in ids else None |
|
sent = ''.join([self.i2c[i] for i in ids[first:last]]) |
|
return sent |
|
|
|
def __len__(self): |
|
return len(self.c2i) + 4 |
|
|
|
def batch_decode(self, arr): |
|
texts = [self.decode(ids) for ids in arr] |
|
return texts |
|
|
|
def __str__(self): |
|
return self.chars |
|
|