naonauno's picture
Upload 855 files
d66c48f verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
import numpy as np
import librosa
from safetensors.torch import load_model
import os
from utils.util import load_config
from models.vc.Noro.noro_trainer import NoroTrainer
from models.vc.Noro.noro_model import Noro_VCmodel
from processors.content_extractor import HubertExtractor
from utils.mel import mel_spectrogram_torch
from utils.f0 import get_f0_features_using_dio, interpolate
from torch.nn.utils.rnn import pad_sequence
def build_trainer(args, cfg):
supported_trainer = {
"VC": NoroTrainer,
}
trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="JSON file for configurations.",
required=True,
)
parser.add_argument(
"--checkpoint_path",
type=str,
help="Checkpoint for resume training or fine-tuning.",
required=True,
)
parser.add_argument(
"--output_dir",
help="Output path",
required=True,
)
parser.add_argument(
"--ref_path",
type=str,
help="Reference voice path",
)
parser.add_argument(
"--source_path",
type=str,
help="Source voice path",
)
parser.add_argument("--cuda_id", type=int, default=0, help="CUDA id for training.")
parser.add_argument("--local_rank", default=-1, type=int)
args = parser.parse_args()
cfg = load_config(args.config)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
cuda_id = args.cuda_id
args.local_rank = torch.device(f"cuda:{cuda_id}")
print("Local rank:", args.local_rank)
args.content_extractor = "mhubert"
with torch.cuda.device(args.local_rank):
torch.cuda.empty_cache()
ckpt_path = args.checkpoint_path
w2v = HubertExtractor(cfg)
w2v = w2v.to(device=args.local_rank)
w2v.eval()
model = Noro_VCmodel(cfg=cfg.model)
print("Loading model")
load_model(model, ckpt_path)
print("Model loaded")
model.cuda(args.local_rank)
model.eval()
wav_path = args.source_path
ref_wav_path = args.ref_path
wav, _ = librosa.load(wav_path, sr=16000)
wav = np.pad(wav, (0, 1600 - len(wav) % 1600))
audio = torch.from_numpy(wav).to(args.local_rank)
audio = audio[None, :]
ref_wav, _ = librosa.load(ref_wav_path, sr=16000)
ref_wav = np.pad(ref_wav, (0, 200 - len(ref_wav) % 200))
ref_audio = torch.from_numpy(ref_wav).to(args.local_rank)
ref_audio = ref_audio[None, :]
with torch.no_grad():
ref_mel = mel_spectrogram_torch(ref_audio, cfg)
ref_mel = ref_mel.transpose(1, 2).to(device=args.local_rank)
ref_mask = (
torch.ones(ref_mel.shape[0], ref_mel.shape[1]).to(args.local_rank).bool()
)
_, content_feature = w2v.extract_content_features(audio)
content_feature = content_feature.to(device=args.local_rank)
wav = audio.cpu().numpy()
wav = wav[0, :]
f0s = []
pitch_raw = get_f0_features_using_dio(wav, cfg.preprocess)
pitch_raw, _ = interpolate(pitch_raw)
frame_num = len(wav) // cfg.preprocess.hop_size
pitch_raw = torch.from_numpy(pitch_raw[:frame_num]).float()
f0s.append(pitch_raw)
pitch = pad_sequence(f0s, batch_first=True, padding_value=0).float()
pitch = (pitch - pitch.mean(dim=1, keepdim=True)) / (
pitch.std(dim=1, keepdim=True) + 1e-6
)
pitch = pitch.to(device=args.local_rank)
x0 = model.inference(
content_feature=content_feature,
pitch=pitch,
x_ref=ref_mel,
x_ref_mask=ref_mask,
inference_steps=200,
sigma=1.2,
) # 150-300 0.95-1.5
recon_path = f"{args.output_dir}/recon_mel.npy"
np.save(recon_path, x0.transpose(1, 2).detach().cpu().numpy())
print(f"Mel spectrogram saved to: {recon_path}")
if __name__ == "__main__":
main()