File size: 4,416 Bytes
d66c48f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import torch
import numpy as np
import librosa
from safetensors.torch import load_model
import os
from utils.util import load_config
from models.vc.Noro.noro_trainer import NoroTrainer
from models.vc.Noro.noro_model import Noro_VCmodel
from processors.content_extractor import HubertExtractor
from utils.mel import mel_spectrogram_torch
from utils.f0 import get_f0_features_using_dio, interpolate
from torch.nn.utils.rnn import pad_sequence


def build_trainer(args, cfg):
    supported_trainer = {
        "VC": NoroTrainer,
    }
    trainer_class = supported_trainer[cfg.model_type]
    trainer = trainer_class(args, cfg)
    return trainer


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        default="config.json",
        help="JSON file for configurations.",
        required=True,
    )
    parser.add_argument(
        "--checkpoint_path",
        type=str,
        help="Checkpoint for resume training or fine-tuning.",
        required=True,
    )
    parser.add_argument(
        "--output_dir",
        help="Output path",
        required=True,
    )
    parser.add_argument(
        "--ref_path",
        type=str,
        help="Reference voice path",
    )
    parser.add_argument(
        "--source_path",
        type=str,
        help="Source voice path",
    )
    parser.add_argument("--cuda_id", type=int, default=0, help="CUDA id for training.")

    parser.add_argument("--local_rank", default=-1, type=int)
    args = parser.parse_args()
    cfg = load_config(args.config)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    cuda_id = args.cuda_id
    args.local_rank = torch.device(f"cuda:{cuda_id}")
    print("Local rank:", args.local_rank)

    args.content_extractor = "mhubert"

    with torch.cuda.device(args.local_rank):
        torch.cuda.empty_cache()
    ckpt_path = args.checkpoint_path

    w2v = HubertExtractor(cfg)
    w2v = w2v.to(device=args.local_rank)
    w2v.eval()

    model = Noro_VCmodel(cfg=cfg.model)
    print("Loading model")

    load_model(model, ckpt_path)
    print("Model loaded")
    model.cuda(args.local_rank)
    model.eval()

    wav_path = args.source_path
    ref_wav_path = args.ref_path

    wav, _ = librosa.load(wav_path, sr=16000)
    wav = np.pad(wav, (0, 1600 - len(wav) % 1600))
    audio = torch.from_numpy(wav).to(args.local_rank)
    audio = audio[None, :]

    ref_wav, _ = librosa.load(ref_wav_path, sr=16000)
    ref_wav = np.pad(ref_wav, (0, 200 - len(ref_wav) % 200))
    ref_audio = torch.from_numpy(ref_wav).to(args.local_rank)
    ref_audio = ref_audio[None, :]

    with torch.no_grad():
        ref_mel = mel_spectrogram_torch(ref_audio, cfg)
        ref_mel = ref_mel.transpose(1, 2).to(device=args.local_rank)
        ref_mask = (
            torch.ones(ref_mel.shape[0], ref_mel.shape[1]).to(args.local_rank).bool()
        )

        _, content_feature = w2v.extract_content_features(audio)
        content_feature = content_feature.to(device=args.local_rank)

        wav = audio.cpu().numpy()
        wav = wav[0, :]
        f0s = []
        pitch_raw = get_f0_features_using_dio(wav, cfg.preprocess)
        pitch_raw, _ = interpolate(pitch_raw)
        frame_num = len(wav) // cfg.preprocess.hop_size
        pitch_raw = torch.from_numpy(pitch_raw[:frame_num]).float()
        f0s.append(pitch_raw)
        pitch = pad_sequence(f0s, batch_first=True, padding_value=0).float()
        pitch = (pitch - pitch.mean(dim=1, keepdim=True)) / (
            pitch.std(dim=1, keepdim=True) + 1e-6
        )
        pitch = pitch.to(device=args.local_rank)

        x0 = model.inference(
            content_feature=content_feature,
            pitch=pitch,
            x_ref=ref_mel,
            x_ref_mask=ref_mask,
            inference_steps=200,
            sigma=1.2,
        )  # 150-300 0.95-1.5

        recon_path = f"{args.output_dir}/recon_mel.npy"
        np.save(recon_path, x0.transpose(1, 2).detach().cpu().numpy())
        print(f"Mel spectrogram saved to: {recon_path}")


if __name__ == "__main__":
    main()