Spaces:
Running
on
Zero
Running
on
Zero
File size: 17,814 Bytes
a3e05e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 |
import sys
sys.path.append(".")
from modules.tokenizer.tokenizer import get_tokenizer_and_extra_tokens
from modules.audio_tokenizer.audio_tokenizer import get_audio_tokenizer
from modules.audio_detokenizer.audio_detokenizer import get_audio_detokenizer, detokenize, detokenize_noref
import torch
import os
from glob import glob
import base64
import io
import torchaudio
from transformers import AutoModelForCausalLM, GenerationConfig
import librosa
from tqdm import tqdm
class Model(object):
def __init__(self):
self.tokenizer, self.extra_tokens = get_tokenizer_and_extra_tokens()
self.speech_token_offset = 163840
print(self.extra_tokens)
self.assistant_ids = self.tokenizer.encode("assistant") # [110866]
self.user_ids = self.tokenizer.encode("user") # [1495]
self.audio_ids = self.tokenizer.encode("audio") # [26229]
self.spk_0_ids = self.tokenizer.encode("0") # [501]
self.spk_1_ids = self.tokenizer.encode("1") # [503]
self.msg_end = self.extra_tokens.msg_end # 260
self.user_msg_start = self.extra_tokens.user_msg_start # 261
self.assistant_msg_start = self.extra_tokens.assistant_msg_start # 262
self.name_end = self.extra_tokens.name_end # 272
self.media_begin = self.extra_tokens.media_begin # 273
self.media_content = self.extra_tokens.media_content # 274
self.media_end = self.extra_tokens.media_end # 275
self.audio_tokenizer = get_audio_tokenizer()
self.audio_detokenizer = get_audio_detokenizer()
model_path = "resources/text2semantic"
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", torch_dtype=torch.bfloat16, trust_remote_code=True, force_download=True).to(torch.cuda.current_device())
self.generate_config = GenerationConfig(
max_new_tokens=200 * 50, # no more than 200s per turn
do_sample=True,
top_k=30,
top_p=0.8,
temperature=0.8,
eos_token_id=self.media_end,
)
def _clean_text(self, text):
# you can add front-end processing here
text = text.replace("“", "")
text = text.replace("”", "")
text = text.replace("...", " ")
text = text.replace("…", " ")
text = text.replace("*", "")
text = text.replace(":", ",")
text = text.replace("‘", "'")
text = text.replace("’", "'")
text = text.strip()
return text
@torch.inference_mode()
def _process_text(self, js):
if "role_mapping" in js:
for role in js["role_mapping"].keys():
js["role_mapping"][role]["ref_bpe_ids"] = self.tokenizer.encode(self._clean_text(js["role_mapping"][role]["ref_text"]))
for turn in js["dialogue"]:
turn["bpe_ids"] = self.tokenizer.encode(self._clean_text(turn["text"]))
return js
def inference(self, js, streaming=False):
js = self._process_text(js)
if "role_mapping" not in js:
return self.infer_without_prompt(js, streaming)
else:
return self.infer_with_prompt(js, streaming)
@torch.inference_mode()
def infer_with_prompt(self, js, streaming=False):
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
media_end = [self.media_end] + [self.msg_end]
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
prompt = []
cur_role_dict = dict()
for role, role_item in js["role_mapping"].items():
waveform_24k = librosa.load(role_item["ref_audio"], sr=24000)[0]
waveform_24k = torch.tensor(waveform_24k).unsqueeze(0).to(torch.cuda.current_device())
waveform_16k = librosa.load(role_item["ref_audio"], sr=16000)[0]
waveform_16k = torch.tensor(waveform_16k).unsqueeze(0).to(torch.cuda.current_device())
semantic_tokens = self.audio_tokenizer.tokenize(waveform_16k)
semantic_tokens = semantic_tokens.to(torch.cuda.current_device())
prompt_ids = semantic_tokens + self.speech_token_offset
cur_role_dict[role] = {
"ref_bpe_ids": role_item["ref_bpe_ids"],
"wav_24k": waveform_24k,
"semantic_tokens": semantic_tokens,
"prompt_ids": prompt_ids
}
prompt = prompt + user_role_0_ids + cur_role_dict["0"]["ref_bpe_ids"] + [self.msg_end]
prompt = prompt + user_role_1_ids + cur_role_dict["1"]["ref_bpe_ids"] + [self.msg_end]
for seg_id, turn in enumerate(js["dialogue"]):
role_id = turn["role"]
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
prompt = prompt + cur_start_ids
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
prompt = torch.cat([prompt, assistant_role_0_ids, media_start, cur_role_dict["0"]["prompt_ids"], media_end], dim=-1)
prompt = torch.cat([prompt, assistant_role_1_ids, media_start, cur_role_dict["1"]["prompt_ids"], media_end], dim=-1)
generation_config = self.generate_config
# you can modify sampling strategy here
wav_list = []
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
role_id = turn["role"]
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
len_prompt = prompt.shape[1]
generation_config.min_length = len_prompt + 2
# print(generation_config)
# todo: add streaming support for generate function
outputs = self.model.generate(prompt,
generation_config=generation_config)
if outputs[0, -1] == self.media_end:
outputs = outputs[:, :-1]
output_token = outputs[:, len_prompt:]
prompt = torch.cat([outputs, media_end], dim=-1)
torch_token = output_token - self.speech_token_offset
if streaming:
# gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
# yield detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
for cur_chunk in detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"], streaming=True):
cur_chunk = cur_chunk.cpu()
cur_chunk = cur_chunk / cur_chunk.abs().max()
cur_buffer = io.BytesIO()
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
audio_bytes = cur_buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
yield audio_b64
else:
gen_speech_fm = detokenize(self.audio_detokenizer, torch_token, cur_role_dict[role_id]["wav_24k"], cur_role_dict[role_id]["semantic_tokens"])
gen_speech_fm = gen_speech_fm.cpu()
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
wav_list.append(gen_speech_fm)
del torch_token
if not streaming:
concat_wav = torch.cat(wav_list, dim=-1).cpu()
# print(concat_wav.shape)
buffer = io.BytesIO()
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
audio_bytes = buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return audio_b64
@torch.inference_mode()
def infer_without_prompt(self, js, streaming=False):
user_role_0_ids = [self.user_msg_start] + self.user_ids + self.spk_0_ids + [self.name_end]
user_role_1_ids = [self.user_msg_start] + self.user_ids + self.spk_1_ids + [self.name_end]
assistant_role_0_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_0_ids + [self.name_end]
assistant_role_1_ids = [self.assistant_msg_start] + self.assistant_ids + self.spk_1_ids + [self.name_end]
media_start = [self.media_begin] + self.audio_ids + [self.media_content]
media_end = [self.media_end] + [self.msg_end]
assistant_role_0_ids = torch.LongTensor(assistant_role_0_ids).unsqueeze(0).to(torch.cuda.current_device())
assistant_role_1_ids = torch.LongTensor(assistant_role_1_ids).unsqueeze(0).to(torch.cuda.current_device())
media_start = torch.LongTensor(media_start).unsqueeze(0).to(torch.cuda.current_device())
media_end = torch.LongTensor(media_end).unsqueeze(0).to(torch.cuda.current_device())
prompt = []
for seg_id, turn in enumerate(js["dialogue"]):
role_id = turn["role"]
cur_user_ids = user_role_0_ids if role_id == "0" else user_role_1_ids
cur_start_ids = cur_user_ids + turn["bpe_ids"] + [self.msg_end]
prompt = prompt + cur_start_ids
prompt = torch.LongTensor(prompt).unsqueeze(0).to(torch.cuda.current_device())
generation_config = self.generate_config
# you can modify sampling strategy here
wav_list = []
for seg_id, turn in tqdm(enumerate(js["dialogue"])):
role_id = turn["role"]
cur_assistant_ids = assistant_role_0_ids if role_id == "0" else assistant_role_1_ids
prompt = torch.cat([prompt, cur_assistant_ids, media_start], dim=-1)
len_prompt = prompt.shape[1]
generation_config.min_length = len_prompt + 2
# print(generation_config)
# todo: add streaming support for generate function
outputs = self.model.generate(prompt,
generation_config=generation_config)
if outputs[0, -1] == self.media_end:
outputs = outputs[:, :-1]
output_token = outputs[:, len_prompt:]
prompt = torch.cat([outputs, media_end], dim=-1)
torch_token = output_token - self.speech_token_offset
if streaming:
for cur_chunk in detokenize_noref(self.audio_detokenizer, torch_token, streaming=True):
cur_chunk = cur_chunk.cpu()
cur_chunk = cur_chunk / cur_chunk.abs().max()
cur_buffer = io.BytesIO()
torchaudio.save(cur_buffer, cur_chunk, sample_rate=24000, format="mp3")
audio_bytes = cur_buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
yield audio_b64
else:
gen_speech_fm = detokenize_noref(self.audio_detokenizer, torch_token)
gen_speech_fm = gen_speech_fm.cpu()
gen_speech_fm = gen_speech_fm / gen_speech_fm.abs().max()
wav_list.append(gen_speech_fm)
del torch_token
if not streaming:
concat_wav = torch.cat(wav_list, dim=-1).cpu()
# print(concat_wav.shape)
buffer = io.BytesIO()
torchaudio.save(buffer, concat_wav, sample_rate=24000, format="mp3")
audio_bytes = buffer.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return audio_b64
if __name__ == "__main__":
model = Model()
# speaker should be interleaved
zh_test_json = {
"role_mapping": {
"0": {
"ref_audio": "./zh_prompt0.wav",
"ref_text": "可以每天都骑并且可能会让你爱上骑车,然后通过爱上骑车的你省了很多很多钱。", #asr output
},
"1": {
"ref_audio": "./zh_prompt1.wav",
"ref_text": "他最后就能让同样食材炒出来的菜味道大大提升。" #asr output
}
},
"dialogue": [
{
"role": "0",
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
},
{
"role": "1",
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
},
{
"role": "0",
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
}
]
}
audio_bytes = model.inference(zh_test_json)
file_to_save = open(f"tmp_generated_zh.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("zh done")
# speaker should be interleaved
en_test_json = {
"role_mapping": {
"0": {
"ref_audio": "./en_prompt0.wav",
"ref_text": "Yeah, no, this is my backyard. It's never ending So just the way I like it. So social distancing has never been a problem.", #asr output
},
"1": {
"ref_audio": "./en_prompt1.wav",
"ref_text": "I'm doing great And. Look, it couldn't be any better than having you at your set, which is the outdoors." #asr output
}
},
"dialogue": [
{
"role": "0",
"text": "In an awesome time, And, we're even gonna do a second episode too So. This is part one part two, coming at some point in the future There. We are.",
},
{
"role": "1",
"text": "I love it. So grateful Thank you So I'm really excited. That's awesome. Yeah."
},
{
"role": "0",
"text": "All I was told, which is good because I don't want to really talk too much more is that you're really really into fitness and nutrition And overall holistic I love it Yes."
},
{
"role": "1",
"text": "Yeah So I started around thirteen Okay But my parents were fitness instructors as well. Awesome So I came from the beginning, and now it's this transition into this wholeness because I had to chart my. Own path and they weren't into nutrition at all So I had to learn that part."
}
]
}
audio_bytes = model.inference(en_test_json)
file_to_save = open(f"tmp_generated_en.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("en done")
# also support inference without prompt
# speaker should be interleaved
without_prompt_test_json = {
"dialogue": [
{
"role": "0",
"text": "我觉得啊,就是经历了这么多年的经验, 就是补剂的作用就是九分的努力, 十分之一的补剂。 嗯,选的话肯定是九分更重要,但是我觉得补剂它能够让你九分的努力更加的有效率,更加的避免徒劳无功。 嗯,就是你,你你得先得真的锻炼,真的努力,真的健康饮食,然后再考虑补剂, 那你再加十十分之一的补剂的话,他可能就是说啊, 一半是心理作用,"
},
{
"role": "1",
"text": "对,其实很多时候心理作用是非常重要的。嗯,然后我每次用补剂的时候,我就会更加努力,就比如说我在健身之前我喝了一勺蛋白粉,我就会督促自己多练,"
},
{
"role": "0",
"text": "其实心理作用只要能实现你的预期目的就可以了。 就比如说给自行车链条加油, 它其实不是必要的,但是它可以让你骑行更顺畅, 然后提高你骑行的频率。"
}
]
}
audio_bytes = model.inference(without_prompt_test_json)
file_to_save = open(f"tmp_generated_woprompt.mp3", "wb")
file_to_save.write(base64.b64decode(audio_bytes))
print("without prompt done") |