File size: 3,216 Bytes
a3e05e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19b7021
 
a3e05e8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import json
import logging

import librosa
import torch

from modules.audio_detokenizer.vocoder.bigvgan import BigVGAN
from modules.audio_detokenizer.vocoder.utils import get_melspec, AttrDict, load_checkpoint

logger = logging.getLogger(__name__)


class BigVGANWrapper:
    def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None) -> None:
        self.vocoder = vocoder.to(device)
        if dtype is not None:
            self.vocoder = self.vocoder.to(dtype)
        self.vocoder = self.vocoder.eval()
        self.device = device
        self.h = h
    
    def to_dtype(self, dtype):
        self.vocoder = self.vocoder.to(dtype)

    def extract_mel_from_wav(self, wav_path=None, wav_data=None):
        """
        params:
            wav_path: str, path of the wav, should be 24k
            wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k
        return:
            mel: [T, num_mels], torch.tensor
        """
        if wav_data is None:
            wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"])
        
        wav_data = torch.tensor(wav_data).unsqueeze(0)

        mel = get_melspec(y=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"], 
                          hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
        return mel.squeeze(0).transpose(0, 1)
    
    @torch.inference_mode()
    def extract_mel_from_wav_batch(self, wav_data):
        """
        params:
            wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k
        return:
            mel: [Batch, T, num_mels], torch.tensor
        """

        wav_data = torch.tensor(wav_data)

        mel = get_melspec(wav=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"], 
                          hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
        return mel.transpose(1, 2)
    
    def decode_mel(self, mel):
        """
        params:
            mel: [T, num_mels], torch.tensor
        return:
            wav: [1, T], torch.tensor
        """    
        mel = mel.transpose(0, 1).unsqueeze(0).to(self.device)
        wav = self.vocoder(mel)
        return wav.squeeze(0)

    def decode_mel_batch(self, mel):
        """
        params:
            mel: [B, T, num_mels], torch.tensor
        return:
            wav: [B, 1, T], torch.tensor
        """    
        mel = mel.transpose(1, 2).to(self.device)
        wav = self.vocoder(mel)
        return wav

    @classmethod
    def from_pretrained(cls, model_config, ckpt_path, device):
        with open(model_config) as f:
            data = f.read()
        json_config = json.loads(data)
        h = AttrDict(json_config)
        # vocoder = BigVGAN(h, True)
        vocoder = BigVGAN(h, False)  # for huggingface demo
        state_dict_g = load_checkpoint(ckpt_path, "cpu")
        vocoder.load_state_dict(state_dict_g["generator"])

        logger.info(">>> Load vocoder from {}".format(ckpt_path))
        return cls(vocoder, device, h)