mooncast / inference.py
jzq11111's picture
Upload folder using huggingface_hub
a3e05e8 verified
import sys
sys.path.append(".")
from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens
from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref
import torch
import os
from glob import glob
import base64
import io
import torchaudio
from transformers import AutoModelForCausalLM, GenerationConfig
import librosa
from tqdm import tqdm
class Model(object):
def __init__(self):
self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens()
self.speech_token_offset = 163840
print(self.extra_tokens)
self.assistant_ids = self.tokenizer.encode("assistant") # [110866]
self.user_ids = self.tokenizer.encode("user") # [1495]
self.audio_ids = self.tokenizer.encode("audio") # [26229]
self.spk_0_ids = self.tokenizer.encode("0") # [501]
self.spk_1_ids = self.tokenizer.encode("1") # [503]
self.msg_end = self.extra_tokens.msg_end # 260
self.user_msg_start = self.extra_tokens.user_msg_start # 261
self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262
self.name_end = self.extra_tokens.name_end # 272
self.media_begin = self.extra_tokens.media_begin # 273
self.media_content = self.extra_tokens.media_content # 274
self.media_end = self.extra_tokens.media_end # 275
self.audio_tokenizer = get_audio_tokenizer()
self.audio_detokenizer = get_audio_detokenizer()
model_path = "resources/text2semantic"
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device())
self.generate_config = GenerationConfig(
max_new_tokens=200 * 50, # no more than 200s per turn
do_sample=True,
top_k=30,
top_p=0.8,
temperature=0.8,
eos_token_id=self.media_end,
)
def _clean_text(self, text):
# you can add front-end processing here
text = text.replace("“", "")
text = text.replace("”", "")
text = text.replace("...", " ")
text = text.replace("…", " ")
text = text.replace("*", "")
text = text.replace(":", ",")
text = text.replace("‘", "'")
text = text.replace("’", "'")
text = text.strip()
return text
@torch.inference_mode()
def _process_text(self, js):
if "role_mapping" in js:
for role in js["role_mapping"].keys():
js["role_mapping"][role]["ref_bpe_ids"] = self.tokenizer.encode(self._clean_text(js["role_mapping"][role]["ref_text"]))
for turn in js["dialogue"]:
turn["bpe_ids"] = self.tokenizer.encode(self._clean_text(turn["text"]))
return js
def inference(self, js, streaming=False):
js = self._process_text(js)
if "role_mapping" not in js:
return self.infer_without_prompt(js, streaming)
else:
return self.infer_with_prompt(js, streaming)
@torch.inference_mode()
def infer_with_prompt(self, js, streaming=False):
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
media_end = [self.media_end] + [self.msg_end]
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
prompt = []
cur_role_dict = dict()
for role, role_item in js["role_mapping"].items():
waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0]
waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())
waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0]
waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())
semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)
semantic_tokens = semantic_tokens.to(torch.cuda.current_device())
prompt_ids = semantic_tokens + self.speech_token_offset
cur_role_dict[role] = {
"ref_bpe_ids": role_item["ref_bpe_ids"],
"wav_24k": waveform_24k,
"semantic_tokens": semantic_tokens,
"prompt_ids": prompt_ids
}
prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end]
prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end]
for seg_id, turn in enumerate(js["dialogue"]):
role_id = turn["role"]
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
prompt = prompt + cur_start_ids
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1)
prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1)
generation_config = self.generate_config
# you can modify sampling strategy here
wav_list = []
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
role_id = turn["role"]
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
len_prompt = prompt.shape[1]
generation_config.min_length = len_prompt + 2
# print(generation_config)
# todo: add streaming support for generate function
outputs = self.model.generate(prompt,
generation_config=generation_config)
if outputs[0, -1] == self.media_end:
outputs = outputs[:, :-1]
output_token = outputs[:, len_prompt:]
prompt = torch.cat([outputs, media_end], dim=-1)
torch_token = output_token - self.speech_token_offset
if streaming:
# gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
# yield detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
for cur_chunk in detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"], streaming=True):
cur_chunk = cur_chunk.cpu()
cur_chunk = cur_chunk / cur_chunk.abs().max()
cur_buffer = io.BytesIO()
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
audio_bytes = cur_buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
yield audio_b64
else:
gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
gen_speech_fm = gen_speech_fm.cpu()
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
wav_list.append(gen_speech_fm)
del torch_token
if not streaming:
concat_wav = torch.cat(wav_list, dim=-1).cpu()
# print(concat_wav.shape)
buffer = io.BytesIO()
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
audio_bytes = buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return audio_b64
@torch.inference_mode()
def infer_without_prompt(self, js, streaming=False):
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
media_end = [self.media_end] + [self.msg_end]
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
prompt = []
for seg_id, turn in enumerate(js["dialogue"]):
role_id = turn["role"]
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
prompt = prompt + cur_start_ids
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
generation_config = self.generate_config
# you can modify sampling strategy here
wav_list = []
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
role_id = turn["role"]
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
len_prompt = prompt.shape[1]
generation_config.min_length = len_prompt + 2
# print(generation_config)
# todo: add streaming support for generate function
outputs = self.model.generate(prompt,
generation_config=generation_config)
if outputs[0, -1] == self.media_end:
outputs = outputs[:, :-1]
output_token = outputs[:, len_prompt:]
prompt = torch.cat([outputs, media_end], dim=-1)
torch_token = output_token - self.speech_token_offset
if streaming:
for cur_chunk in detokenize_noref(self.audio_detokenizer, torch_token, streaming=True):
cur_chunk = cur_chunk.cpu()
cur_chunk = cur_chunk / cur_chunk.abs().max()
cur_buffer = io.BytesIO()
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
audio_bytes = cur_buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
yield audio_b64
else:
gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token)
gen_speech_fm = gen_speech_fm.cpu()
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
wav_list.append(gen_speech_fm)
del torch_token
if not streaming:
concat_wav = torch.cat(wav_list, dim=-1).cpu()
# print(concat_wav.shape)
buffer = io.BytesIO()
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
audio_bytes = buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return audio_b64
if __name__ == "__main__":
model = Model()
# speaker should be interleaved
zh_test_json = {
"role_mapping": {
"0": {
"ref_audio": "./zh_prompt0.wav",
"ref_text": "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。", #asr output
},
"1": {
"ref_audio": "./zh_prompt1.wav",
"ref_text": "他最后就能让同样食材炒出来的菜味道大大提升。" #asr output
}
},
"dialogue": [
{
"role": "0",
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
},
{
"role": "1",
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
},
{
"role": "0",
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
}
]
}
audio_bytes = model.inference(zh_test_json)
file_to_save = open(f"tmp_generated_zh.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("zh done")
# speaker should be interleaved
en_test_json = {
"role_mapping": {
"0": {
"ref_audio": "./en_prompt0.wav",
"ref_text": "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.", #asr output
},
"1": {
"ref_audio": "./en_prompt1.wav",
"ref_text": "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." #asr output
}
},
"dialogue": [
{
"role": "0",
"text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
},
{
"role": "1",
"text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah."
},
{
"role": "0",
"text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes."
},
{
"role": "1",
"text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
}
]
}
audio_bytes = model.inference(en_test_json)
file_to_save = open(f"tmp_generated_en.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("en done")
# also support inference without prompt
# speaker should be interleaved
without_prompt_test_json = {
"dialogue": [
{
"role": "0",
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
},
{
"role": "1",
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
},
{
"role": "0",
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
}
]
}
audio_bytes = model.inference(without_prompt_test_json)
file_to_save = open(f"tmp_generated_woprompt.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("without prompt done")