Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
sys.path.append(".") | |
from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens | |
from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer | |
from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref | |
import torch | |
import os | |
from glob import glob | |
import base64 | |
import io | |
import torchaudio | |
from transformers import AutoModelForCausalLM, GenerationConfig | |
import librosa | |
from tqdm import tqdm | |
class Model(object): | |
def __init__(self): | |
self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens() | |
self.speech_token_offset = 163840 | |
print(self.extra_tokens) | |
self.assistant_ids = self.tokenizer.encode("assistant") # [110866] | |
self.user_ids = self.tokenizer.encode("user") # [1495] | |
self.audio_ids = self.tokenizer.encode("audio") # [26229] | |
self.spk_0_ids = self.tokenizer.encode("0") # [501] | |
self.spk_1_ids = self.tokenizer.encode("1") # [503] | |
self.msg_end = self.extra_tokens.msg_end # 260 | |
self.user_msg_start = self.extra_tokens.user_msg_start # 261 | |
self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262 | |
self.name_end = self.extra_tokens.name_end # 272 | |
self.media_begin = self.extra_tokens.media_begin # 273 | |
self.media_content = self.extra_tokens.media_content # 274 | |
self.media_end = self.extra_tokens.media_end # 275 | |
self.audio_tokenizer = get_audio_tokenizer() | |
self.audio_detokenizer = get_audio_detokenizer() | |
model_path = "resources/text2semantic" | |
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device()) | |
self.generate_config = GenerationConfig( | |
max_new_tokens=200 * 50, # no more than 200s per turn | |
do_sample=True, | |
top_k=30, | |
top_p=0.8, | |
temperature=0.8, | |
eos_token_id=self.media_end, | |
) | |
def _clean_text(self, text): | |
# you can add front-end processing here | |
text = text.replace("“", "") | |
text = text.replace("”", "") | |
text = text.replace("...", " ") | |
text = text.replace("…", " ") | |
text = text.replace("*", "") | |
text = text.replace(":", ",") | |
text = text.replace("‘", "'") | |
text = text.replace("’", "'") | |
text = text.strip() | |
return text | |
def _process_text(self, js): | |
if "role_mapping" in js: | |
for role in js["role_mapping"].keys(): | |
js["role_mapping"][role]["ref_bpe_ids"] = self.tokenizer.encode(self._clean_text(js["role_mapping"][role]["ref_text"])) | |
for turn in js["dialogue"]: | |
turn["bpe_ids"] = self.tokenizer.encode(self._clean_text(turn["text"])) | |
return js | |
def inference(self, js, streaming=False): | |
js = self._process_text(js) | |
if "role_mapping" not in js: | |
return self.infer_without_prompt(js, streaming) | |
else: | |
return self.infer_with_prompt(js, streaming) | |
def infer_with_prompt(self, js, streaming=False): | |
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end] | |
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end] | |
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end] | |
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end] | |
media_start = [self.media_begin] + self.audio_ids + [self.media_content] | |
media_end = [self.media_end] + [self.msg_end] | |
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device()) | |
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device()) | |
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device()) | |
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device()) | |
prompt = [] | |
cur_role_dict = dict() | |
for role, role_item in js["role_mapping"].items(): | |
waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0] | |
waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device()) | |
waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0] | |
waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device()) | |
semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k) | |
semantic_tokens = semantic_tokens.to(torch.cuda.current_device()) | |
prompt_ids = semantic_tokens + self.speech_token_offset | |
cur_role_dict[role] = { | |
"ref_bpe_ids": role_item["ref_bpe_ids"], | |
"wav_24k": waveform_24k, | |
"semantic_tokens": semantic_tokens, | |
"prompt_ids": prompt_ids | |
} | |
prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end] | |
prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end] | |
for seg_id, turn in enumerate(js["dialogue"]): | |
role_id = turn["role"] | |
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids | |
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end] | |
prompt = prompt + cur_start_ids | |
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device()) | |
prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1) | |
prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1) | |
generation_config = self.generate_config | |
# you can modify sampling strategy here | |
wav_list = [] | |
for seg_id, turn in tqdm(enumerate(js["dialogue"])): | |
role_id = turn["role"] | |
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids | |
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1) | |
len_prompt = prompt.shape[1] | |
generation_config.min_length = len_prompt + 2 | |
# print(generation_config) | |
# todo: add streaming support for generate function | |
outputs = self.model.generate(prompt, | |
generation_config=generation_config) | |
if outputs[0, -1] == self.media_end: | |
outputs = outputs[:, :-1] | |
output_token = outputs[:, len_prompt:] | |
prompt = torch.cat([outputs, media_end], dim=-1) | |
torch_token = output_token - self.speech_token_offset | |
if streaming: | |
# gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"]) | |
# yield detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"]) | |
for cur_chunk in detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"], streaming=True): | |
cur_chunk = cur_chunk.cpu() | |
cur_chunk = cur_chunk / cur_chunk.abs().max() | |
cur_buffer = io.BytesIO() | |
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3") | |
audio_bytes = cur_buffer.getvalue() | |
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") | |
yield audio_b64 | |
else: | |
gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"]) | |
gen_speech_fm = gen_speech_fm.cpu() | |
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max() | |
wav_list.append(gen_speech_fm) | |
del torch_token | |
if not streaming: | |
concat_wav = torch.cat(wav_list, dim=-1).cpu() | |
# print(concat_wav.shape) | |
buffer = io.BytesIO() | |
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3") | |
audio_bytes = buffer.getvalue() | |
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") | |
return audio_b64 | |
def infer_without_prompt(self, js, streaming=False): | |
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end] | |
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end] | |
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end] | |
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end] | |
media_start = [self.media_begin] + self.audio_ids + [self.media_content] | |
media_end = [self.media_end] + [self.msg_end] | |
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device()) | |
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device()) | |
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device()) | |
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device()) | |
prompt = [] | |
for seg_id, turn in enumerate(js["dialogue"]): | |
role_id = turn["role"] | |
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids | |
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end] | |
prompt = prompt + cur_start_ids | |
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device()) | |
generation_config = self.generate_config | |
# you can modify sampling strategy here | |
wav_list = [] | |
for seg_id, turn in tqdm(enumerate(js["dialogue"])): | |
role_id = turn["role"] | |
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids | |
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1) | |
len_prompt = prompt.shape[1] | |
generation_config.min_length = len_prompt + 2 | |
# print(generation_config) | |
# todo: add streaming support for generate function | |
outputs = self.model.generate(prompt, | |
generation_config=generation_config) | |
if outputs[0, -1] == self.media_end: | |
outputs = outputs[:, :-1] | |
output_token = outputs[:, len_prompt:] | |
prompt = torch.cat([outputs, media_end], dim=-1) | |
torch_token = output_token - self.speech_token_offset | |
if streaming: | |
for cur_chunk in detokenize_noref(self.audio_detokenizer, torch_token, streaming=True): | |
cur_chunk = cur_chunk.cpu() | |
cur_chunk = cur_chunk / cur_chunk.abs().max() | |
cur_buffer = io.BytesIO() | |
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3") | |
audio_bytes = cur_buffer.getvalue() | |
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") | |
yield audio_b64 | |
else: | |
gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token) | |
gen_speech_fm = gen_speech_fm.cpu() | |
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max() | |
wav_list.append(gen_speech_fm) | |
del torch_token | |
if not streaming: | |
concat_wav = torch.cat(wav_list, dim=-1).cpu() | |
# print(concat_wav.shape) | |
buffer = io.BytesIO() | |
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3") | |
audio_bytes = buffer.getvalue() | |
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") | |
return audio_b64 | |
if __name__ == "__main__": | |
model = Model() | |
# speaker should be interleaved | |
zh_test_json = { | |
"role_mapping": { | |
"0": { | |
"ref_audio": "./zh_prompt0.wav", | |
"ref_text": "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。", #asr output | |
}, | |
"1": { | |
"ref_audio": "./zh_prompt1.wav", | |
"ref_text": "他最后就能让同样食材炒出来的菜味道大大提升。" #asr output | |
} | |
}, | |
"dialogue": [ | |
{ | |
"role": "0", | |
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用," | |
}, | |
{ | |
"role": "1", | |
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练," | |
}, | |
{ | |
"role": "0", | |
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。" | |
} | |
] | |
} | |
audio_bytes = model.inference(zh_test_json) | |
file_to_save = open(f"tmp_generated_zh.mp3", "wb") | |
file_to_save.write(base64.b64decode(audio_bytes)) | |
print("zh done") | |
# speaker should be interleaved | |
en_test_json = { | |
"role_mapping": { | |
"0": { | |
"ref_audio": "./en_prompt0.wav", | |
"ref_text": "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.", #asr output | |
}, | |
"1": { | |
"ref_audio": "./en_prompt1.wav", | |
"ref_text": "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." #asr output | |
} | |
}, | |
"dialogue": [ | |
{ | |
"role": "0", | |
"text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.", | |
}, | |
{ | |
"role": "1", | |
"text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah." | |
}, | |
{ | |
"role": "0", | |
"text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes." | |
}, | |
{ | |
"role": "1", | |
"text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part." | |
} | |
] | |
} | |
audio_bytes = model.inference(en_test_json) | |
file_to_save = open(f"tmp_generated_en.mp3", "wb") | |
file_to_save.write(base64.b64decode(audio_bytes)) | |
print("en done") | |
# also support inference without prompt | |
# speaker should be interleaved | |
without_prompt_test_json = { | |
"dialogue": [ | |
{ | |
"role": "0", | |
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用," | |
}, | |
{ | |
"role": "1", | |
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练," | |
}, | |
{ | |
"role": "0", | |
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。" | |
} | |
] | |
} | |
audio_bytes = model.inference(without_prompt_test_json) | |
file_to_save = open(f"tmp_generated_woprompt.mp3", "wb") | |
file_to_save.write(base64.b64decode(audio_bytes)) | |
print("without prompt done") |