Spaces:
Paused
Paused
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import librosa | |
import torch | |
import torchaudio | |
import accelerate | |
import safetensors | |
import numpy as np | |
import yaml | |
from IPython.display import display, Audio | |
from models.vc.flow_matching_transformer.fmt_model import FlowMatchingTransformer | |
from models.vc.autoregressive_transformer.ar_model import AutoregressiveTransformer | |
from models.codec.kmeans.repcodec_model import RepCodec | |
from models.codec.vevo.vevo_repcodec import VevoRepCodec | |
from models.codec.melvqgan.melspec import MelSpectrogram | |
from models.codec.amphion_codec.vocos import Vocos | |
from utils.util import load_config | |
def g2p_(text, language): | |
from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p | |
if language in ["zh", "en"]: | |
return chn_eng_g2p(text) | |
else: | |
return g2p(text, sentence=None, language=language) | |
def transcribe_audio(audio_path, model=None): | |
if model is None: | |
import whisper | |
model = whisper.load_model("medium") | |
result = model.transcribe(audio_path) | |
return result["text"] | |
# Semantic Features Extractor | |
def build_hubert_model(device): | |
bundle = torchaudio.pipelines.HUBERT_LARGE | |
hubert = bundle.get_model() | |
hubert.eval() | |
hubert.to(device) | |
return hubert | |
# VQ-VAE Tokenizer | |
def build_vqvae_model(repcodec_cfg, device): | |
vqvae = RepCodec(cfg=repcodec_cfg) | |
vqvae.eval() | |
vqvae.to(device) | |
return vqvae | |
# Vevo VQ-VAE Tokenizer (pkl checkpoint) | |
def load_vevo_vqvae_checkpoint(repcodec_cfg, device): | |
with open(repcodec_cfg.config_path) as fp: | |
conf = yaml.load(fp, Loader=yaml.FullLoader) | |
vqvae = VevoRepCodec(**conf) | |
vqvae.quantizer.initial() | |
vqvae.eval() | |
pretrained_path = repcodec_cfg.pretrained_path | |
if ".pkl" in pretrained_path: | |
# Vevo paper | |
vqvae.load_state_dict( | |
torch.load(pretrained_path, map_location="cpu")["model"]["repcodec"] | |
) | |
elif ".safetensors" in pretrained_path: | |
# Re-trained vevovq | |
safetensors.torch.load_model(vqvae, pretrained_path) | |
vqvae.to(device) | |
return vqvae | |
# Autoregressive Transformer | |
def build_ar_model(cfg, device): | |
model = AutoregressiveTransformer(cfg=cfg.model.autoregressive_transformer) | |
model.eval() | |
model.to(device) | |
return model | |
# Flow Matching Transformer | |
def build_fmt_model(cfg, device): | |
model = FlowMatchingTransformer(cfg=cfg.model.flow_matching_transformer) | |
model.eval() | |
model.to(device) | |
return model | |
# Melspectrogram Extractor | |
def build_mel_model(cfg, device): | |
mel_model = MelSpectrogram( | |
sampling_rate=cfg.preprocess.sample_rate, | |
n_fft=cfg.preprocess.n_fft, | |
num_mels=cfg.preprocess.num_mels, | |
hop_size=cfg.preprocess.hop_size, | |
win_size=cfg.preprocess.win_size, | |
fmin=cfg.preprocess.fmin, | |
fmax=cfg.preprocess.fmax, | |
) | |
mel_model.eval() | |
mel_model.to(device) | |
return mel_model | |
# Vocoder | |
def build_vocoder_model(cfg, device): | |
vocoder_model = Vocos(cfg=cfg.model.vocos) | |
vocoder_model.eval() | |
vocoder_model.to(device) | |
return vocoder_model | |
def load_checkpoint(build_model_func, cfg, ckpt_path, device): | |
model = build_model_func(cfg, device) | |
accelerate.load_checkpoint_and_dispatch(model, ckpt_path) | |
return model | |
def count_parameters(model): | |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
if total_params < 1e6: | |
return f"{total_params} params" # Parameters | |
elif total_params < 1e9: | |
return f"{total_params / 1e6:.2f} M" # Millions | |
else: | |
return f"{total_params / 1e9:.2f} B" # Billions | |
def load_wav(wav_path, device): | |
speech = librosa.load(wav_path, sr=24000)[0] | |
speech_tensor = torch.tensor(speech).unsqueeze(0).to(device) | |
speech16k = torchaudio.functional.resample(speech_tensor, 24000, 16000) | |
return speech, speech_tensor, speech16k | |
def display_audio_in_notebook(wav, rate=24000): | |
display(Audio(wav, rate=rate)) | |
def save_audio( | |
waveform, sr=24000, output_path=None, target_sample_rate=None, target_db=-25.0 | |
): | |
""" | |
waveform: [1, T] | |
""" | |
if target_sample_rate is not None and sr != target_sample_rate: | |
resampler = torchaudio.transforms.Resample( | |
orig_freq=sr, new_freq=target_sample_rate | |
) | |
waveform = resampler(waveform) | |
else: | |
target_sample_rate = sr | |
rms = torch.sqrt(torch.mean(waveform**2)) | |
current_db = 20 * torch.log10(rms + 1e-9) | |
gain = target_db - current_db | |
normalized_waveform = waveform * (10 ** (gain / 20)) | |
torchaudio.save(output_path, normalized_waveform, target_sample_rate) | |
return output_path | |
class VevoInferencePipeline: | |
def __init__( | |
self, | |
content_tokenizer_ckpt_path=None, | |
content_style_tokenizer_ckpt_path=None, | |
ar_cfg_path=None, | |
ar_ckpt_path=None, | |
fmt_cfg_path=None, | |
fmt_ckpt_path=None, | |
vocoder_cfg_path=None, | |
vocoder_ckpt_path=None, | |
device=None, | |
): | |
self.device = device | |
if ar_cfg_path is not None and ar_ckpt_path is not None: | |
self.ar_cfg = load_config(ar_cfg_path) | |
self.ar_model = load_checkpoint( | |
build_ar_model, self.ar_cfg, ar_ckpt_path, device | |
) | |
print(f"#Params of AR model: {count_parameters(self.ar_model)}") | |
else: | |
self.ar_cfg = None | |
self.ar_model = None | |
self.fmt_cfg = load_config(fmt_cfg_path) | |
self.fmt_model = load_checkpoint( | |
build_fmt_model, self.fmt_cfg, fmt_ckpt_path, device | |
) | |
print(f"#Params of Flow Matching model: {count_parameters(self.fmt_model)}") | |
self.vocoder_cfg = load_config(vocoder_cfg_path) | |
self.mel_model = build_mel_model(self.vocoder_cfg, device) | |
self.vocoder_model = load_checkpoint( | |
build_vocoder_model, self.vocoder_cfg, vocoder_ckpt_path, device | |
) | |
print(f"#Params of Vocoder model: {count_parameters(self.vocoder_model)}") | |
self.content_tokenizer_ckpt_path = content_tokenizer_ckpt_path | |
self.content_style_tokenizer_ckpt_path = content_style_tokenizer_ckpt_path | |
self.init_vqvae_tokenizer() | |
def init_vqvae_tokenizer(self): | |
## HuBERT features extraction ## | |
self.hubert_model = build_hubert_model(self.device) | |
stat = np.load(self.fmt_cfg.model.representation_stat_mean_var_path) | |
self.hubert_feat_norm_mean = torch.tensor(stat["mean"]) | |
self.hubert_feat_norm_std = torch.tensor(stat["std"]) | |
## Content Tokenizer ## | |
if self.ar_model is not None and "input_repcodec" in self.ar_cfg.model: | |
assert self.ar_cfg.model.vc_input_token_type == "hubert_vevo_codec" | |
ckpt_path = getattr( | |
self.ar_cfg.model.input_repcodec, | |
"pretrained_path", | |
self.content_tokenizer_ckpt_path, | |
) | |
self.ar_cfg.model.input_repcodec.pretrained_path = ckpt_path | |
self.content_tokenizer = load_vevo_vqvae_checkpoint( | |
self.ar_cfg.model.input_repcodec, | |
self.device, | |
) | |
print( | |
"#Params of Content Tokenizer: {}".format( | |
count_parameters(self.content_tokenizer) | |
) | |
) | |
## Content-Style Tokenizer ## | |
ckpt_path = getattr( | |
self.fmt_cfg.model.repcodec, | |
"pretrained_path", | |
self.content_style_tokenizer_ckpt_path, | |
) | |
self.content_style_tokenizer = load_checkpoint( | |
build_vqvae_model, | |
self.fmt_cfg.model.repcodec, | |
ckpt_path, | |
self.device, | |
) | |
print( | |
"#Params of Content-Style Tokenizer: {}".format( | |
count_parameters(self.content_style_tokenizer) | |
) | |
) | |
def extract_mel_feature(self, speech): | |
mel_feature = self.mel_model(speech) # (B, d, T) | |
mel_feature = mel_feature.transpose(1, 2) | |
mel_feature = (mel_feature - self.vocoder_cfg.preprocess.mel_mean) / math.sqrt( | |
self.vocoder_cfg.preprocess.mel_var | |
) | |
return mel_feature | |
def extract_prompt_mel_feature(self, speech): | |
""" | |
This is for the global encoder of AR model | |
""" | |
if not hasattr(self, "prompt_mel_model"): | |
self.prompt_mel_model = build_mel_model(self.ar_cfg, self.device) | |
mel_feature = self.prompt_mel_model(speech) # (B, d, T) | |
mel_feature = mel_feature.transpose(1, 2) | |
mel_feature = (mel_feature - self.ar_cfg.preprocess.mel_mean) / math.sqrt( | |
self.ar_cfg.preprocess.mel_var | |
) | |
return mel_feature | |
def extract_hubert_feature(self, wavs, wav_lens=None, output_layer=18): | |
""" | |
Args: | |
wavs: [B, T] | |
wav_lens: [B,] | |
Returns: | |
feats: [B, T, D] | |
feat_lengths: [B] | |
""" | |
if wav_lens is None: | |
wav_lens = torch.tensor([wavs.shape[1]] * wavs.shape[0]).to(wavs).int() | |
feats, feat_lengths = self.hubert_model.extract_features( | |
wavs, lengths=wav_lens, num_layers=output_layer | |
) | |
feats = feats[-1] | |
return feats, feat_lengths | |
def duration_reduction_func(self, token_seq, n_gram=1): | |
""" | |
Args: | |
token_seq: (T,) | |
Returns: | |
reduced_token_seq: (T') | |
reduced_token_seq_len: T' | |
""" | |
n_gram_seq = token_seq.unfold(0, n_gram, 1) | |
mask = torch.all(n_gram_seq[1:] != n_gram_seq[:-1], dim=1) | |
reduced_token_seq = torch.cat( | |
(n_gram_seq[0, :n_gram], n_gram_seq[1:, -1][mask]) | |
) | |
return reduced_token_seq, len(reduced_token_seq) | |
def extract_hubert_codec( | |
self, | |
vqvae_model, | |
wavs, | |
wav_lens=None, | |
output_layer=18, | |
token_type="hubert_codec", | |
duration_reduction=False, | |
duration_reduction_n_gram=1, | |
): | |
""" | |
Args: | |
wavs: [B, T] | |
wav_lens: [B,] | |
Returns: | |
codecs: [B, T] | |
codec_masks: [B, T] | |
""" | |
# Extract features and normalize | |
feats, feat_lengths = self.extract_hubert_feature(wavs, wav_lens, output_layer) | |
if token_type == "hubert_codec": | |
feats = ( | |
feats - self.hubert_feat_norm_mean.to(feats) | |
) / self.hubert_feat_norm_std.to(feats) | |
codecs, _ = vqvae_model.quantize(feats) # (B, T) | |
elif token_type == "hubert_vevo_codec": | |
x = vqvae_model.encoder(feats.transpose(1, 2)) | |
z = vqvae_model.projector(x) | |
_, idx = vqvae_model.quantizer.codebook.forward_index(z.transpose(2, 1)) | |
codecs = idx[0] # (B, T) | |
else: | |
raise ValueError("Invalid token_type") | |
if not duration_reduction: | |
T = codecs.shape[1] | |
arange_tensor = torch.arange(T).expand(codecs.shape[0], T).to(codecs) | |
codec_masks = ( | |
arange_tensor < feat_lengths.unsqueeze(-1) | |
).int() # 1 means valid | |
return codecs, codec_masks | |
else: | |
reduced_codecs = [] | |
reduced_masks = [] | |
for i, token_seq_len in enumerate(feat_lengths): | |
token_seq = codecs[i, :token_seq_len] | |
reduced_token_seq, reduced_token_seq_len = self.duration_reduction_func( | |
token_seq, n_gram=duration_reduction_n_gram | |
) | |
reduced_codecs.append(reduced_token_seq) | |
reduced_masks.append( | |
torch.ones(reduced_token_seq_len, dtype=torch.int).to(codecs) | |
) | |
reduced_codecs = torch.nn.utils.rnn.pad_sequence( | |
reduced_codecs, batch_first=True, padding_value=0 | |
) | |
reduced_masks = torch.nn.utils.rnn.pad_sequence( | |
reduced_masks, batch_first=True, padding_value=0 | |
) | |
return reduced_codecs, reduced_masks | |
def random_mask_codec(self, codecs, codec_masks, ratio, mask_value): | |
""" | |
Args: | |
codecs: [B, T] | |
codec_masks: [B, T], 0 means not to mask | |
ratio: float | |
mask_value: int | |
Returns: | |
masked_codecs: [B, T] | |
""" | |
rand_mask = (torch.rand_like(codecs.float(), device=codecs.device) < ratio) & ( | |
codec_masks == 1 | |
) | |
masked_codecs = codecs.masked_fill(rand_mask, mask_value) | |
return masked_codecs | |
def inference_ar_and_fm( | |
self, | |
src_wav_path, | |
src_text, | |
style_ref_wav_path, | |
timbre_ref_wav_path, | |
style_ref_wav_text=None, | |
src_text_language=None, | |
style_ref_wav_text_language=None, | |
vc_input_mask_ratio=-1, | |
use_global_guided_inference=False, | |
flow_matching_steps=32, | |
display_audio=False, | |
): | |
assert self.ar_model is not None | |
if src_wav_path is None: | |
# TTS | |
task = "tts" | |
assert src_text is not None | |
if src_text_language is None: | |
src_text_language = "zh" | |
if style_ref_wav_text_language is None: | |
style_ref_wav_text_language = "zh" | |
if display_audio: | |
print("-" * 20) | |
print("Source Text: [{}]".format(src_text)) | |
else: | |
# VC | |
task = "vc" | |
assert src_text is None | |
src_speech, src_speech24k, src_speech16k = load_wav( | |
src_wav_path, self.device | |
) | |
if display_audio: | |
print("-" * 20) | |
print("Source audio:") | |
display_audio_in_notebook(src_speech, rate=24000) | |
style_ref_speech, style_ref_speech24k, style_ref_speech16k = load_wav( | |
style_ref_wav_path, self.device | |
) | |
timbre_ref_speech, timbre_ref_speech24k, timbre_ref_speech16k = load_wav( | |
timbre_ref_wav_path, self.device | |
) | |
if display_audio: | |
if style_ref_wav_path == timbre_ref_wav_path: | |
print("Both Style and Timbre Reference audio:") | |
display_audio_in_notebook(style_ref_speech, rate=24000) | |
else: | |
print("Style Reference audio:") | |
display_audio_in_notebook(style_ref_speech, rate=24000) | |
print("Timbre Reference audio:") | |
display_audio_in_notebook(timbre_ref_speech, rate=24000) | |
print("-" * 20) | |
## AR ## | |
if task == "tts": | |
ar_input_ids = g2p_(src_text, src_text_language)[1] | |
ar_input_ids = torch.tensor([ar_input_ids], dtype=torch.long).to( | |
self.device | |
) | |
if display_audio: | |
print("Src text input_ids:", ar_input_ids.shape) | |
if not use_global_guided_inference: | |
assert style_ref_wav_text is not None | |
style_ref_input_ids = g2p_( | |
style_ref_wav_text, style_ref_wav_text_language | |
)[1] | |
style_ref_input_ids = torch.tensor( | |
[style_ref_input_ids], dtype=torch.long | |
).to(self.device) | |
ar_input_ids = torch.cat([style_ref_input_ids, ar_input_ids], dim=1) | |
if display_audio: | |
print("AR input_ids:", ar_input_ids.shape) | |
elif task == "vc": | |
if not use_global_guided_inference: | |
src_speech16k = torch.cat([style_ref_speech16k, src_speech16k], dim=1) | |
# [1, T] | |
ar_input_ids, _ = self.extract_hubert_codec( | |
self.content_tokenizer, | |
src_speech16k, | |
token_type=self.ar_cfg.model.vc_input_token_type, | |
duration_reduction=True, | |
duration_reduction_n_gram=getattr( | |
self.ar_cfg.model, "vc_input_reduced_n_gram", 1 | |
), | |
) | |
if vc_input_mask_ratio > 0: | |
ar_input_masks = torch.ones_like( | |
ar_input_ids, dtype=torch.int, device=self.device | |
) | |
if not use_global_guided_inference: | |
total_len = ar_input_ids.shape[1] | |
style_ref_ratio = ( | |
style_ref_speech16k.shape[1] / src_speech16k.shape[1] | |
) | |
ar_input_masks[:, : int(total_len * style_ref_ratio)] = 0 | |
ar_input_ids = self.random_mask_codec( | |
codecs=ar_input_ids, | |
codec_masks=ar_input_masks, | |
ratio=vc_input_mask_ratio, | |
mask_value=self.ar_cfg.model.vc_input_vocab_size, | |
) | |
if self.ar_cfg.model.train_both_vc_and_tts: | |
# [Important] When traing both VC and TTS, the VC's input_ids should be shifted, since Llama use a unified codebook | |
ar_input_ids += self.ar_cfg.model.tts_input_vocab_size | |
if display_audio: | |
print("AR input_ids:", ar_input_ids.shape) | |
if use_global_guided_inference: | |
prompt_output_ids = None | |
else: | |
prompt_output_ids, _ = self.extract_hubert_codec( | |
self.content_style_tokenizer, | |
style_ref_speech16k, | |
duration_reduction=False, | |
) | |
if display_audio: | |
print("Prompt output_ids:", prompt_output_ids.shape) | |
# [1, T] | |
predicted_hubert_codecs = self.ar_model.generate( | |
input_ids=ar_input_ids, | |
prompt_mels=self.extract_prompt_mel_feature(style_ref_speech16k), | |
prompt_output_ids=prompt_output_ids, | |
) | |
## Diffusion ## | |
timbre_ref_hubert_codecs, _ = self.extract_hubert_codec( | |
self.content_style_tokenizer, timbre_ref_speech16k, duration_reduction=False | |
) | |
diffusion_input_codecs = torch.cat( | |
[timbre_ref_hubert_codecs, predicted_hubert_codecs], dim=1 | |
) | |
# [1, T, D] | |
predict_mel_feat = self.fmt_model.reverse_diffusion( | |
cond=self.fmt_model.cond_emb(diffusion_input_codecs), | |
prompt=self.extract_mel_feature(timbre_ref_speech24k), | |
n_timesteps=flow_matching_steps, | |
) | |
## Vocoder and Display ## | |
# [1, 1, T] -> [1, T] | |
synthesized_audio = ( | |
self.vocoder_model(predict_mel_feat.transpose(1, 2)).detach().cpu() | |
)[0] | |
if display_audio: | |
# [T] | |
audio = synthesized_audio.numpy()[0] | |
display_audio_in_notebook(audio, rate=24000) | |
return synthesized_audio | |
def inference_fm( | |
self, | |
src_wav_path, | |
timbre_ref_wav_path, | |
flow_matching_steps=32, | |
display_audio=False, | |
): | |
src_speech, src_speech24k, src_speech16k = load_wav(src_wav_path, self.device) | |
timbre_ref_speech, timbre_ref_speech24k, timbre_ref_speech16k = load_wav( | |
timbre_ref_wav_path, self.device | |
) | |
if display_audio: | |
print("-" * 20) | |
if src_wav_path == timbre_ref_wav_path: | |
print("Audio:") | |
display_audio_in_notebook(src_wav_path, rate=24000) | |
else: | |
print("Source audio:") | |
display_audio_in_notebook(src_speech, rate=24000) | |
print("Timbre Reference audio:") | |
display_audio_in_notebook(timbre_ref_speech, rate=24000) | |
print("-" * 20) | |
## Diffusion ## | |
src_hubert_codecs, _ = self.extract_hubert_codec( | |
self.content_style_tokenizer, src_speech16k, duration_reduction=False | |
) | |
timbre_ref_hubert_codecs, _ = self.extract_hubert_codec( | |
self.content_style_tokenizer, timbre_ref_speech16k, duration_reduction=False | |
) | |
diffusion_input_codecs = torch.cat( | |
[timbre_ref_hubert_codecs, src_hubert_codecs], dim=1 | |
) | |
# [1, T, D] | |
predict_mel_feat = self.fmt_model.reverse_diffusion( | |
cond=self.fmt_model.cond_emb(diffusion_input_codecs), | |
prompt=self.extract_mel_feature(timbre_ref_speech24k), | |
n_timesteps=flow_matching_steps, | |
) | |
## Vocoder and Display ## | |
# [1, 1, T] -> [1, T] | |
synthesized_audio = ( | |
self.vocoder_model(predict_mel_feat.transpose(1, 2)).detach().cpu() | |
)[0] | |
if display_audio: | |
# [T] | |
audio = synthesized_audio.numpy()[0] | |
display_audio_in_notebook(audio, rate=24000) | |
return synthesized_audio | |