naonauno's picture
Upload 855 files
d66c48f verified
raw
history blame contribute delete
21.7 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import librosa
import torch
import torchaudio
import accelerate
import safetensors
import numpy as np
import yaml
from IPython.display import display, Audio
from models.vc.flow_matching_transformer.fmt_model import FlowMatchingTransformer
from models.vc.autoregressive_transformer.ar_model import AutoregressiveTransformer
from models.codec.kmeans.repcodec_model import RepCodec
from models.codec.vevo.vevo_repcodec import VevoRepCodec
from models.codec.melvqgan.melspec import MelSpectrogram
from models.codec.amphion_codec.vocos import Vocos
from utils.util import load_config
def g2p_(text, language):
from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
if language in ["zh", "en"]:
return chn_eng_g2p(text)
else:
return g2p(text, sentence=None, language=language)
def transcribe_audio(audio_path, model=None):
if model is None:
import whisper
model = whisper.load_model("medium")
result = model.transcribe(audio_path)
return result["text"]
# Semantic Features Extractor
def build_hubert_model(device):
bundle = torchaudio.pipelines.HUBERT_LARGE
hubert = bundle.get_model()
hubert.eval()
hubert.to(device)
return hubert
# VQ-VAE Tokenizer
def build_vqvae_model(repcodec_cfg, device):
vqvae = RepCodec(cfg=repcodec_cfg)
vqvae.eval()
vqvae.to(device)
return vqvae
# Vevo VQ-VAE Tokenizer (pkl checkpoint)
def load_vevo_vqvae_checkpoint(repcodec_cfg, device):
with open(repcodec_cfg.config_path) as fp:
conf = yaml.load(fp, Loader=yaml.FullLoader)
vqvae = VevoRepCodec(**conf)
vqvae.quantizer.initial()
vqvae.eval()
pretrained_path = repcodec_cfg.pretrained_path
if ".pkl" in pretrained_path:
# Vevo paper
vqvae.load_state_dict(
torch.load(pretrained_path, map_location="cpu")["model"]["repcodec"]
)
elif ".safetensors" in pretrained_path:
# Re-trained vevovq
safetensors.torch.load_model(vqvae, pretrained_path)
vqvae.to(device)
return vqvae
# Autoregressive Transformer
def build_ar_model(cfg, device):
model = AutoregressiveTransformer(cfg=cfg.model.autoregressive_transformer)
model.eval()
model.to(device)
return model
# Flow Matching Transformer
def build_fmt_model(cfg, device):
model = FlowMatchingTransformer(cfg=cfg.model.flow_matching_transformer)
model.eval()
model.to(device)
return model
# Melspectrogram Extractor
def build_mel_model(cfg, device):
mel_model = MelSpectrogram(
sampling_rate=cfg.preprocess.sample_rate,
n_fft=cfg.preprocess.n_fft,
num_mels=cfg.preprocess.num_mels,
hop_size=cfg.preprocess.hop_size,
win_size=cfg.preprocess.win_size,
fmin=cfg.preprocess.fmin,
fmax=cfg.preprocess.fmax,
)
mel_model.eval()
mel_model.to(device)
return mel_model
# Vocoder
def build_vocoder_model(cfg, device):
vocoder_model = Vocos(cfg=cfg.model.vocos)
vocoder_model.eval()
vocoder_model.to(device)
return vocoder_model
def load_checkpoint(build_model_func, cfg, ckpt_path, device):
model = build_model_func(cfg, device)
accelerate.load_checkpoint_and_dispatch(model, ckpt_path)
return model
def count_parameters(model):
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if total_params < 1e6:
return f"{total_params} params" # Parameters
elif total_params < 1e9:
return f"{total_params / 1e6:.2f} M" # Millions
else:
return f"{total_params / 1e9:.2f} B" # Billions
def load_wav(wav_path, device):
speech = librosa.load(wav_path, sr=24000)[0]
speech_tensor = torch.tensor(speech).unsqueeze(0).to(device)
speech16k = torchaudio.functional.resample(speech_tensor, 24000, 16000)
return speech, speech_tensor, speech16k
def display_audio_in_notebook(wav, rate=24000):
display(Audio(wav, rate=rate))
def save_audio(
waveform, sr=24000, output_path=None, target_sample_rate=None, target_db=-25.0
):
"""
waveform: [1, T]
"""
if target_sample_rate is not None and sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sr, new_freq=target_sample_rate
)
waveform = resampler(waveform)
else:
target_sample_rate = sr
rms = torch.sqrt(torch.mean(waveform**2))
current_db = 20 * torch.log10(rms + 1e-9)
gain = target_db - current_db
normalized_waveform = waveform * (10 ** (gain / 20))
torchaudio.save(output_path, normalized_waveform, target_sample_rate)
return output_path
class VevoInferencePipeline:
def __init__(
self,
content_tokenizer_ckpt_path=None,
content_style_tokenizer_ckpt_path=None,
ar_cfg_path=None,
ar_ckpt_path=None,
fmt_cfg_path=None,
fmt_ckpt_path=None,
vocoder_cfg_path=None,
vocoder_ckpt_path=None,
device=None,
):
self.device = device
if ar_cfg_path is not None and ar_ckpt_path is not None:
self.ar_cfg = load_config(ar_cfg_path)
self.ar_model = load_checkpoint(
build_ar_model, self.ar_cfg, ar_ckpt_path, device
)
print(f"#Params of AR model: {count_parameters(self.ar_model)}")
else:
self.ar_cfg = None
self.ar_model = None
self.fmt_cfg = load_config(fmt_cfg_path)
self.fmt_model = load_checkpoint(
build_fmt_model, self.fmt_cfg, fmt_ckpt_path, device
)
print(f"#Params of Flow Matching model: {count_parameters(self.fmt_model)}")
self.vocoder_cfg = load_config(vocoder_cfg_path)
self.mel_model = build_mel_model(self.vocoder_cfg, device)
self.vocoder_model = load_checkpoint(
build_vocoder_model, self.vocoder_cfg, vocoder_ckpt_path, device
)
print(f"#Params of Vocoder model: {count_parameters(self.vocoder_model)}")
self.content_tokenizer_ckpt_path = content_tokenizer_ckpt_path
self.content_style_tokenizer_ckpt_path = content_style_tokenizer_ckpt_path
self.init_vqvae_tokenizer()
def init_vqvae_tokenizer(self):
## HuBERT features extraction ##
self.hubert_model = build_hubert_model(self.device)
stat = np.load(self.fmt_cfg.model.representation_stat_mean_var_path)
self.hubert_feat_norm_mean = torch.tensor(stat["mean"])
self.hubert_feat_norm_std = torch.tensor(stat["std"])
## Content Tokenizer ##
if self.ar_model is not None and "input_repcodec" in self.ar_cfg.model:
assert self.ar_cfg.model.vc_input_token_type == "hubert_vevo_codec"
ckpt_path = getattr(
self.ar_cfg.model.input_repcodec,
"pretrained_path",
self.content_tokenizer_ckpt_path,
)
self.ar_cfg.model.input_repcodec.pretrained_path = ckpt_path
self.content_tokenizer = load_vevo_vqvae_checkpoint(
self.ar_cfg.model.input_repcodec,
self.device,
)
print(
"#Params of Content Tokenizer: {}".format(
count_parameters(self.content_tokenizer)
)
)
## Content-Style Tokenizer ##
ckpt_path = getattr(
self.fmt_cfg.model.repcodec,
"pretrained_path",
self.content_style_tokenizer_ckpt_path,
)
self.content_style_tokenizer = load_checkpoint(
build_vqvae_model,
self.fmt_cfg.model.repcodec,
ckpt_path,
self.device,
)
print(
"#Params of Content-Style Tokenizer: {}".format(
count_parameters(self.content_style_tokenizer)
)
)
@torch.no_grad()
def extract_mel_feature(self, speech):
mel_feature = self.mel_model(speech) # (B, d, T)
mel_feature = mel_feature.transpose(1, 2)
mel_feature = (mel_feature - self.vocoder_cfg.preprocess.mel_mean) / math.sqrt(
self.vocoder_cfg.preprocess.mel_var
)
return mel_feature
@torch.no_grad()
def extract_prompt_mel_feature(self, speech):
"""
This is for the global encoder of AR model
"""
if not hasattr(self, "prompt_mel_model"):
self.prompt_mel_model = build_mel_model(self.ar_cfg, self.device)
mel_feature = self.prompt_mel_model(speech) # (B, d, T)
mel_feature = mel_feature.transpose(1, 2)
mel_feature = (mel_feature - self.ar_cfg.preprocess.mel_mean) / math.sqrt(
self.ar_cfg.preprocess.mel_var
)
return mel_feature
@torch.no_grad()
def extract_hubert_feature(self, wavs, wav_lens=None, output_layer=18):
"""
Args:
wavs: [B, T]
wav_lens: [B,]
Returns:
feats: [B, T, D]
feat_lengths: [B]
"""
if wav_lens is None:
wav_lens = torch.tensor([wavs.shape[1]] * wavs.shape[0]).to(wavs).int()
feats, feat_lengths = self.hubert_model.extract_features(
wavs, lengths=wav_lens, num_layers=output_layer
)
feats = feats[-1]
return feats, feat_lengths
def duration_reduction_func(self, token_seq, n_gram=1):
"""
Args:
token_seq: (T,)
Returns:
reduced_token_seq: (T')
reduced_token_seq_len: T'
"""
n_gram_seq = token_seq.unfold(0, n_gram, 1)
mask = torch.all(n_gram_seq[1:] != n_gram_seq[:-1], dim=1)
reduced_token_seq = torch.cat(
(n_gram_seq[0, :n_gram], n_gram_seq[1:, -1][mask])
)
return reduced_token_seq, len(reduced_token_seq)
@torch.no_grad()
def extract_hubert_codec(
self,
vqvae_model,
wavs,
wav_lens=None,
output_layer=18,
token_type="hubert_codec",
duration_reduction=False,
duration_reduction_n_gram=1,
):
"""
Args:
wavs: [B, T]
wav_lens: [B,]
Returns:
codecs: [B, T]
codec_masks: [B, T]
"""
# Extract features and normalize
feats, feat_lengths = self.extract_hubert_feature(wavs, wav_lens, output_layer)
if token_type == "hubert_codec":
feats = (
feats - self.hubert_feat_norm_mean.to(feats)
) / self.hubert_feat_norm_std.to(feats)
codecs, _ = vqvae_model.quantize(feats) # (B, T)
elif token_type == "hubert_vevo_codec":
x = vqvae_model.encoder(feats.transpose(1, 2))
z = vqvae_model.projector(x)
_, idx = vqvae_model.quantizer.codebook.forward_index(z.transpose(2, 1))
codecs = idx[0] # (B, T)
else:
raise ValueError("Invalid token_type")
if not duration_reduction:
T = codecs.shape[1]
arange_tensor = torch.arange(T).expand(codecs.shape[0], T).to(codecs)
codec_masks = (
arange_tensor < feat_lengths.unsqueeze(-1)
).int() # 1 means valid
return codecs, codec_masks
else:
reduced_codecs = []
reduced_masks = []
for i, token_seq_len in enumerate(feat_lengths):
token_seq = codecs[i, :token_seq_len]
reduced_token_seq, reduced_token_seq_len = self.duration_reduction_func(
token_seq, n_gram=duration_reduction_n_gram
)
reduced_codecs.append(reduced_token_seq)
reduced_masks.append(
torch.ones(reduced_token_seq_len, dtype=torch.int).to(codecs)
)
reduced_codecs = torch.nn.utils.rnn.pad_sequence(
reduced_codecs, batch_first=True, padding_value=0
)
reduced_masks = torch.nn.utils.rnn.pad_sequence(
reduced_masks, batch_first=True, padding_value=0
)
return reduced_codecs, reduced_masks
def random_mask_codec(self, codecs, codec_masks, ratio, mask_value):
"""
Args:
codecs: [B, T]
codec_masks: [B, T], 0 means not to mask
ratio: float
mask_value: int
Returns:
masked_codecs: [B, T]
"""
rand_mask = (torch.rand_like(codecs.float(), device=codecs.device) < ratio) & (
codec_masks == 1
)
masked_codecs = codecs.masked_fill(rand_mask, mask_value)
return masked_codecs
def inference_ar_and_fm(
self,
src_wav_path,
src_text,
style_ref_wav_path,
timbre_ref_wav_path,
style_ref_wav_text=None,
src_text_language=None,
style_ref_wav_text_language=None,
vc_input_mask_ratio=-1,
use_global_guided_inference=False,
flow_matching_steps=32,
display_audio=False,
):
assert self.ar_model is not None
if src_wav_path is None:
# TTS
task = "tts"
assert src_text is not None
if src_text_language is None:
src_text_language = "zh"
if style_ref_wav_text_language is None:
style_ref_wav_text_language = "zh"
if display_audio:
print("-" * 20)
print("Source Text: [{}]".format(src_text))
else:
# VC
task = "vc"
assert src_text is None
src_speech, src_speech24k, src_speech16k = load_wav(
src_wav_path, self.device
)
if display_audio:
print("-" * 20)
print("Source audio:")
display_audio_in_notebook(src_speech, rate=24000)
style_ref_speech, style_ref_speech24k, style_ref_speech16k = load_wav(
style_ref_wav_path, self.device
)
timbre_ref_speech, timbre_ref_speech24k, timbre_ref_speech16k = load_wav(
timbre_ref_wav_path, self.device
)
if display_audio:
if style_ref_wav_path == timbre_ref_wav_path:
print("Both Style and Timbre Reference audio:")
display_audio_in_notebook(style_ref_speech, rate=24000)
else:
print("Style Reference audio:")
display_audio_in_notebook(style_ref_speech, rate=24000)
print("Timbre Reference audio:")
display_audio_in_notebook(timbre_ref_speech, rate=24000)
print("-" * 20)
## AR ##
if task == "tts":
ar_input_ids = g2p_(src_text, src_text_language)[1]
ar_input_ids = torch.tensor([ar_input_ids], dtype=torch.long).to(
self.device
)
if display_audio:
print("Src text input_ids:", ar_input_ids.shape)
if not use_global_guided_inference:
assert style_ref_wav_text is not None
style_ref_input_ids = g2p_(
style_ref_wav_text, style_ref_wav_text_language
)[1]
style_ref_input_ids = torch.tensor(
[style_ref_input_ids], dtype=torch.long
).to(self.device)
ar_input_ids = torch.cat([style_ref_input_ids, ar_input_ids], dim=1)
if display_audio:
print("AR input_ids:", ar_input_ids.shape)
elif task == "vc":
if not use_global_guided_inference:
src_speech16k = torch.cat([style_ref_speech16k, src_speech16k], dim=1)
# [1, T]
ar_input_ids, _ = self.extract_hubert_codec(
self.content_tokenizer,
src_speech16k,
token_type=self.ar_cfg.model.vc_input_token_type,
duration_reduction=True,
duration_reduction_n_gram=getattr(
self.ar_cfg.model, "vc_input_reduced_n_gram", 1
),
)
if vc_input_mask_ratio > 0:
ar_input_masks = torch.ones_like(
ar_input_ids, dtype=torch.int, device=self.device
)
if not use_global_guided_inference:
total_len = ar_input_ids.shape[1]
style_ref_ratio = (
style_ref_speech16k.shape[1] / src_speech16k.shape[1]
)
ar_input_masks[:, : int(total_len * style_ref_ratio)] = 0
ar_input_ids = self.random_mask_codec(
codecs=ar_input_ids,
codec_masks=ar_input_masks,
ratio=vc_input_mask_ratio,
mask_value=self.ar_cfg.model.vc_input_vocab_size,
)
if self.ar_cfg.model.train_both_vc_and_tts:
# [Important] When traing both VC and TTS, the VC's input_ids should be shifted, since Llama use a unified codebook
ar_input_ids += self.ar_cfg.model.tts_input_vocab_size
if display_audio:
print("AR input_ids:", ar_input_ids.shape)
if use_global_guided_inference:
prompt_output_ids = None
else:
prompt_output_ids, _ = self.extract_hubert_codec(
self.content_style_tokenizer,
style_ref_speech16k,
duration_reduction=False,
)
if display_audio:
print("Prompt output_ids:", prompt_output_ids.shape)
# [1, T]
predicted_hubert_codecs = self.ar_model.generate(
input_ids=ar_input_ids,
prompt_mels=self.extract_prompt_mel_feature(style_ref_speech16k),
prompt_output_ids=prompt_output_ids,
)
## Diffusion ##
timbre_ref_hubert_codecs, _ = self.extract_hubert_codec(
self.content_style_tokenizer, timbre_ref_speech16k, duration_reduction=False
)
diffusion_input_codecs = torch.cat(
[timbre_ref_hubert_codecs, predicted_hubert_codecs], dim=1
)
# [1, T, D]
predict_mel_feat = self.fmt_model.reverse_diffusion(
cond=self.fmt_model.cond_emb(diffusion_input_codecs),
prompt=self.extract_mel_feature(timbre_ref_speech24k),
n_timesteps=flow_matching_steps,
)
## Vocoder and Display ##
# [1, 1, T] -> [1, T]
synthesized_audio = (
self.vocoder_model(predict_mel_feat.transpose(1, 2)).detach().cpu()
)[0]
if display_audio:
# [T]
audio = synthesized_audio.numpy()[0]
display_audio_in_notebook(audio, rate=24000)
return synthesized_audio
def inference_fm(
self,
src_wav_path,
timbre_ref_wav_path,
flow_matching_steps=32,
display_audio=False,
):
src_speech, src_speech24k, src_speech16k = load_wav(src_wav_path, self.device)
timbre_ref_speech, timbre_ref_speech24k, timbre_ref_speech16k = load_wav(
timbre_ref_wav_path, self.device
)
if display_audio:
print("-" * 20)
if src_wav_path == timbre_ref_wav_path:
print("Audio:")
display_audio_in_notebook(src_wav_path, rate=24000)
else:
print("Source audio:")
display_audio_in_notebook(src_speech, rate=24000)
print("Timbre Reference audio:")
display_audio_in_notebook(timbre_ref_speech, rate=24000)
print("-" * 20)
## Diffusion ##
src_hubert_codecs, _ = self.extract_hubert_codec(
self.content_style_tokenizer, src_speech16k, duration_reduction=False
)
timbre_ref_hubert_codecs, _ = self.extract_hubert_codec(
self.content_style_tokenizer, timbre_ref_speech16k, duration_reduction=False
)
diffusion_input_codecs = torch.cat(
[timbre_ref_hubert_codecs, src_hubert_codecs], dim=1
)
# [1, T, D]
predict_mel_feat = self.fmt_model.reverse_diffusion(
cond=self.fmt_model.cond_emb(diffusion_input_codecs),
prompt=self.extract_mel_feature(timbre_ref_speech24k),
n_timesteps=flow_matching_steps,
)
## Vocoder and Display ##
# [1, 1, T] -> [1, T]
synthesized_audio = (
self.vocoder_model(predict_mel_feat.transpose(1, 2)).detach().cpu()
)[0]
if display_audio:
# [T]
audio = synthesized_audio.numpy()[0]
display_audio_in_notebook(audio, rate=24000)
return synthesized_audio