mooncast / modules /audio_detokenizer /audio_detokenizer.py
jzq11111's picture
Upload folder using huggingface_hub
a3e05e8 verified
import torch
from modules.audio_detokenizer.bigvgan_wrapper import BigVGANWrapper
from modules.audio_detokenizer.semantic_fm_prefix_streaming import StreamingSemanticFMWrapper
class PrefixStreamingFlowMatchingDetokenizer:
def __init__(self, vocoder: BigVGANWrapper, fm: StreamingSemanticFMWrapper, look_ahead_tokens: int = 0) -> None:
self.dtype = torch.bfloat16
print("Currently using bfloat16 for PrefixFlowMatchingDetokenizer")
self.vocoder = vocoder
self.vocoder.to_dtype(self.dtype)
self.semantic_fm = fm
# initialize mel_spec
self.max_pos_size = 4096
self.is_timbre_semantic_token = False
self.pre_mel = None
self.frame_size = 480 # how many samples in a frame
self.pre_wav = None
self.state_dict_backup = None
self.hamming_window_cache = {}
self.previous_chunk_left = None
self.look_ahead_tokens = look_ahead_tokens
self.clear_states()
@classmethod
def from_pretrained(cls, vocoder_config, vocoder_ckpt, fm_config, fm_ckpt, device,
look_ahead_tokens=0,
max_prompt_chunk=2, max_kv_cache_tokens=900,
use_cfg=False, use_cfg_rescale=True, cfg_init=1.5, cfg_scale=7.5, cfg_schedule="linear"):
bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device)
semantic_fm = StreamingSemanticFMWrapper.from_pretrained(fm_config, fm_ckpt, device, max_prompt_chunk=max_prompt_chunk, max_kv_cache_tokens=max_kv_cache_tokens,
use_cfg=use_cfg, cfg_scale=cfg_scale, use_cfg_rescale=use_cfg_rescale, cfg_init=cfg_init, cfg_schedule=cfg_schedule)
return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens)
@torch.inference_mode()
def prefill(self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None):
"""
Arguments:
timbre_speech: torch.Tensor, shape [B, N_speech_24k]
timbre_semantic_token: torch.Tensor, shape [B, N]
chunk_size: int, chunk size for prefilling
timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech
"""
if timbre_mel is None:
assert timbre_speech is not None, "timbre_speech should not be None if timbre_mel is not None"
assert len(timbre_semantic_token.shape) == 2 and len(timbre_speech.shape) == 2 and chunk_size > 0
assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
mel_spec = self.vocoder.extract_mel_from_wav(wav_data=timbre_speech.squeeze(0))
else:
assert len(timbre_mel.shape) == 3 and len(timbre_semantic_token.shape) == 2 and chunk_size > 0
assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
mel_spec = timbre_mel.squeeze(0)
if mel_spec.shape[0] < timbre_semantic_token.shape[1]:
# pad mel_spec
mel_spec = torch.nn.functional.pad(mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0]))
elif mel_spec.shape[0] > timbre_semantic_token.shape[1]:
# truncate mel_spec
mel_spec = mel_spec[:timbre_semantic_token.shape[1], :]
# clear all states
self.semantic_fm.clear_all_states()
self.semantic_fm.prefill(mel_spec, timbre_semantic_token.squeeze(0), chunk_size=chunk_size, verbose=False)
self.state_dict_backup = self.semantic_fm.state_dict()
@torch.inference_mode()
def detokenize_streaming(self, semantic_token, ode_step=30, verbose=False, ode_solver="neural_ode_euler", is_final=False, upsample_factor=1):
assert len(semantic_token.shape) == 2 and ode_step > 0
assert semantic_token.shape[0] == 1
semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1)
semantic_token = semantic_token.squeeze(0)
if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None:
semantic_token_previous = self.previous_chunk_left["semantic_token"]
semantic_token = torch.cat([semantic_token_previous, semantic_token], dim=-1)
x_t_chunk = torch.randn(semantic_token.shape[0], 80).to(semantic_token.device).to(self.dtype)
if self.look_ahead_tokens != 0 and self.previous_chunk_left is None:
self.previous_chunk_left = {"semantic_token": None}
speech_mel = self.semantic_fm.infer_chunk(
xt_chunk=x_t_chunk,
semantic_tokens_chunk=semantic_token,
start_position_id=self.semantic_fm.start_position_id,
ode_steps=ode_step,
verbose=verbose,
look_ahead_tokens=self.look_ahead_tokens * upsample_factor if not is_final else 0,
cache=self.previous_chunk_left,
ode_solver=ode_solver
)
chunk_size = speech_mel.shape[0]
length = speech_mel.shape[0]
self.semantic_fm.start_position_id += length
self.semantic_fm.update_incremental_state()
self.semantic_fm.reserve_kv_cache_tokens += self.semantic_fm.ode_wrapper.kv_cache_tokens
# smoothing
# I will maintain the history of seqlen wav
# For the first chunk, I will only return the half chunk wav, and save the res half chunk in history
# For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the
if self.pre_mel is None: # first chunk, related to TTFB
concat_mel = speech_mel
concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
if is_final:
self.clear_states()
self.state_dict_backup = None
ret_wav = concat_reconstructed_wav.float()
else:
reconstructed_wav = concat_reconstructed_wav[:, :int(self.frame_size * chunk_size // 2)] # return the first half chunk
self.pre_wav = concat_reconstructed_wav[:, -int(self.frame_size * chunk_size // 2):] # log the last half chunk for next generation step
self.pre_mel = speech_mel[-chunk_size//2:, :]
ret_wav = reconstructed_wav.float()
else:
concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0)
concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
if is_final:
self.clear_states()
self.state_dict_backup = None
ret_wav = concat_reconstructed_wav.float()
else:
# fetch history
prev_speech_len = self.pre_wav.shape[1]
if concat_reconstructed_wav.shape[1] > prev_speech_len * 2:
gen_speech_len = prev_speech_len * 2
else:
gen_speech_len = concat_reconstructed_wav.shape[1] // 2
reconstructed_wav = concat_reconstructed_wav[:, :gen_speech_len] # return the first half chunk
if gen_speech_len not in self.hamming_window_cache:
self.hamming_window_cache[gen_speech_len] = torch.hamming_window(gen_speech_len).to(self.dtype).to(semantic_token.device).unsqueeze(0)
hamming_window = self.hamming_window_cache[gen_speech_len]
# apply smoothing of the first half chunk
reconstructed_wav[:, :int(gen_speech_len // 2 )] = self.pre_wav[:, :int(gen_speech_len // 2 )] * hamming_window[:,-int(gen_speech_len // 2):] + \
reconstructed_wav[:, :int(gen_speech_len // 2)] * hamming_window[:, :int(gen_speech_len // 2)]
res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len
res_mel_len = res_speech_len // self.frame_size
self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:]
self.pre_mel = speech_mel[-res_mel_len:, :]
ret_wav = reconstructed_wav.float()
if not is_final and self.semantic_fm.start_position_id + 2*chunk_size > self.max_pos_size:
# out of position id,
self.semantic_fm.clear_all_states()
self.semantic_fm.load_state_dict(self.state_dict_backup)
return ret_wav
def clear_states(self):
self.semantic_fm.clear_all_states()
self.previous_chunk_left = None
self.pre_mel = None
self.pre_wav = None
def get_audio_detokenizer():
fm_model_config = "resources/audio_detokenizer/config.yaml"
fm_ckpt_path = "resources/audio_detokenizer/model.pt"
bigvgan_config_file = "resources/vocoder/config.json"
bigvgan_ckpt_path = "resources/vocoder/model.pt"
device=torch.cuda.current_device()
detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained(
vocoder_config=bigvgan_config_file,
vocoder_ckpt=bigvgan_ckpt_path,
max_prompt_chunk=10, # 10 * 3 = 30s
fm_config=fm_model_config,
fm_ckpt=fm_ckpt_path,
device=device,
use_cfg=False,
look_ahead_tokens=12)
return detokenizer
def detokenize(detokenizer, tokens, ref_wav, ref_tokens, streaming=False):
with torch.no_grad():
detokenizer.clear_states()
detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
if streaming:
yield gen_speech
else:
cache_speech_collection.append(gen_speech)
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i:i+chunk_size]
gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
if streaming:
yield gen_speech
else:
cache_speech_collection.append(gen_speech)
if not streaming:
gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
return gen_speech_all
def detokenize_noref(detokenizer, tokens, streaming=False):
with torch.no_grad():
detokenizer.clear_states()
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size)
if streaming:
yield gen_speech
else:
cache_speech_collection.append(gen_speech)
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i:i+chunk_size]
gen_speech = detokenizer.detokenize_streaming(chunk_tokens, is_final=(i+chunk_size >= res_tokens.size(1)))
if streaming:
yield gen_speech
else:
cache_speech_collection.append(gen_speech)
if not streaming:
gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
return gen_speech_all