|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
def map_phone_to_tokendict(item, pad_bos_eos=True): |
|
|
|
phone = item['txt_token'].clone() |
|
merged_phone = item['txt_token'].clone() |
|
tone_tmp = item['tone'].clone() |
|
|
|
tone_tmp[tone_tmp==4] = 1 |
|
tone_tmp[tone_tmp==11] = 2 |
|
tone_tmp[tone_tmp==12] = 3 |
|
tone_tmp[tone_tmp==13] = 4 |
|
tone_tmp[tone_tmp==14] = 5 |
|
tone_tmp[tone_tmp==15] = 6 |
|
|
|
ch_phone_idx = (phone >= 3) & (phone <= 100) |
|
merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx] |
|
|
|
if pad_bos_eos: |
|
merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798) |
|
merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799) |
|
return merged_phone |
|
|
|
def split_ph_timestamp(ph_timestamp): |
|
''' Input: ph_timestamp, shape [T] ''' |
|
|
|
|
|
ph_timestamp[ph_timestamp >= 800] -= 800 |
|
|
|
ph_list = [] |
|
tone_list = [] |
|
dur_list = [] |
|
cur_timestamp = 0 |
|
for idx, item in enumerate(ph_timestamp): |
|
if idx % 2 == 0: |
|
|
|
if (200 <= item <= 788): |
|
ph = (item - 200 - 1) // 6 + 3 |
|
tone = (item - 200 - 1) % 6 + 1 |
|
if tone == 1: |
|
tone = 4 |
|
else: |
|
tone = tone + 9 |
|
|
|
else: |
|
ph = item |
|
tone = 3 |
|
ph_list.append(ph) |
|
tone_list.append(tone) |
|
else: |
|
dur_list.append((item - cur_timestamp)) |
|
cur_timestamp = item |
|
assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}" |
|
ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list) |
|
return ph_seq, tone_seq, dur_seq, ph_timestamp[-1] |
|
|
|
def split_ph(ph_seq): |
|
''' Input: ph_timestamp, shape [T] ''' |
|
ph_list = [] |
|
tone_list = [] |
|
for idx, item in enumerate(ph_seq): |
|
|
|
if (200 <= item <= 788): |
|
ph = (item - 200 - 1) // 6 + 3 |
|
tone = (item - 200 - 1) % 6 + 1 |
|
if tone == 1: |
|
tone = 4 |
|
else: |
|
tone = tone + 9 |
|
|
|
else: |
|
ph = item |
|
tone = 3 |
|
ph_list.append(ph) |
|
tone_list.append(tone) |
|
|
|
assert len(ph_list) == len(tone_list) |
|
ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list) |
|
return ph_seq, tone_seq |