Spaces:
Paused
Paused
File size: 4,416 Bytes
d66c48f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
import numpy as np
import librosa
from safetensors.torch import load_model
import os
from utils.util import load_config
from models.vc.Noro.noro_trainer import NoroTrainer
from models.vc.Noro.noro_model import Noro_VCmodel
from processors.content_extractor import HubertExtractor
from utils.mel import mel_spectrogram_torch
from utils.f0 import get_f0_features_using_dio, interpolate
from torch.nn.utils.rnn import pad_sequence
def build_trainer(args, cfg):
supported_trainer = {
"VC": NoroTrainer,
}
trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="JSON file for configurations.",
required=True,
)
parser.add_argument(
"--checkpoint_path",
type=str,
help="Checkpoint for resume training or fine-tuning.",
required=True,
)
parser.add_argument(
"--output_dir",
help="Output path",
required=True,
)
parser.add_argument(
"--ref_path",
type=str,
help="Reference voice path",
)
parser.add_argument(
"--source_path",
type=str,
help="Source voice path",
)
parser.add_argument("--cuda_id", type=int, default=0, help="CUDA id for training.")
parser.add_argument("--local_rank", default=-1, type=int)
args = parser.parse_args()
cfg = load_config(args.config)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
cuda_id = args.cuda_id
args.local_rank = torch.device(f"cuda:{cuda_id}")
print("Local rank:", args.local_rank)
args.content_extractor = "mhubert"
with torch.cuda.device(args.local_rank):
torch.cuda.empty_cache()
ckpt_path = args.checkpoint_path
w2v = HubertExtractor(cfg)
w2v = w2v.to(device=args.local_rank)
w2v.eval()
model = Noro_VCmodel(cfg=cfg.model)
print("Loading model")
load_model(model, ckpt_path)
print("Model loaded")
model.cuda(args.local_rank)
model.eval()
wav_path = args.source_path
ref_wav_path = args.ref_path
wav, _ = librosa.load(wav_path, sr=16000)
wav = np.pad(wav, (0, 1600 - len(wav) % 1600))
audio = torch.from_numpy(wav).to(args.local_rank)
audio = audio[None, :]
ref_wav, _ = librosa.load(ref_wav_path, sr=16000)
ref_wav = np.pad(ref_wav, (0, 200 - len(ref_wav) % 200))
ref_audio = torch.from_numpy(ref_wav).to(args.local_rank)
ref_audio = ref_audio[None, :]
with torch.no_grad():
ref_mel = mel_spectrogram_torch(ref_audio, cfg)
ref_mel = ref_mel.transpose(1, 2).to(device=args.local_rank)
ref_mask = (
torch.ones(ref_mel.shape[0], ref_mel.shape[1]).to(args.local_rank).bool()
)
_, content_feature = w2v.extract_content_features(audio)
content_feature = content_feature.to(device=args.local_rank)
wav = audio.cpu().numpy()
wav = wav[0, :]
f0s = []
pitch_raw = get_f0_features_using_dio(wav, cfg.preprocess)
pitch_raw, _ = interpolate(pitch_raw)
frame_num = len(wav) // cfg.preprocess.hop_size
pitch_raw = torch.from_numpy(pitch_raw[:frame_num]).float()
f0s.append(pitch_raw)
pitch = pad_sequence(f0s, batch_first=True, padding_value=0).float()
pitch = (pitch - pitch.mean(dim=1, keepdim=True)) / (
pitch.std(dim=1, keepdim=True) + 1e-6
)
pitch = pitch.to(device=args.local_rank)
x0 = model.inference(
content_feature=content_feature,
pitch=pitch,
x_ref=ref_mel,
x_ref_mask=ref_mask,
inference_steps=200,
sigma=1.2,
) # 150-300 0.95-1.5
recon_path = f"{args.output_dir}/recon_mel.npy"
np.save(recon_path, x0.transpose(1, 2).detach().cpu().numpy())
print(f"Mel spectrogram saved to: {recon_path}")
if __name__ == "__main__":
main()
|