Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,216 Bytes
a3e05e8 19b7021 a3e05e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import json
import logging
import librosa
import torch
from modules.audio_detokenizer.vocoder.bigvgan import BigVGAN
from modules.audio_detokenizer.vocoder.utils import get_melspec, AttrDict, load_checkpoint
logger = logging.getLogger(__name__)
class BigVGANWrapper:
def __init__(self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None) -> None:
self.vocoder = vocoder.to(device)
if dtype is not None:
self.vocoder = self.vocoder.to(dtype)
self.vocoder = self.vocoder.eval()
self.device = device
self.h = h
def to_dtype(self, dtype):
self.vocoder = self.vocoder.to(dtype)
def extract_mel_from_wav(self, wav_path=None, wav_data=None):
"""
params:
wav_path: str, path of the wav, should be 24k
wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k
return:
mel: [T, num_mels], torch.tensor
"""
if wav_data is None:
wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"])
wav_data = torch.tensor(wav_data).unsqueeze(0)
mel = get_melspec(y=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"],
hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
return mel.squeeze(0).transpose(0, 1)
@torch.inference_mode()
def extract_mel_from_wav_batch(self, wav_data):
"""
params:
wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k
return:
mel: [Batch, T, num_mels], torch.tensor
"""
wav_data = torch.tensor(wav_data)
mel = get_melspec(wav=wav_data, n_fft=self.h["n_fft"], num_mels=self.h["num_mels"], sampling_rate=self.h["sampling_rate"],
hop_size=self.h["hop_size"], win_size=self.h["win_size"], fmin=self.h["fmin"], fmax=self.h["fmax"])
return mel.transpose(1, 2)
def decode_mel(self, mel):
"""
params:
mel: [T, num_mels], torch.tensor
return:
wav: [1, T], torch.tensor
"""
mel = mel.transpose(0, 1).unsqueeze(0).to(self.device)
wav = self.vocoder(mel)
return wav.squeeze(0)
def decode_mel_batch(self, mel):
"""
params:
mel: [B, T, num_mels], torch.tensor
return:
wav: [B, 1, T], torch.tensor
"""
mel = mel.transpose(1, 2).to(self.device)
wav = self.vocoder(mel)
return wav
@classmethod
def from_pretrained(cls, model_config, ckpt_path, device):
with open(model_config) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
# vocoder = BigVGAN(h, True)
vocoder = BigVGAN(h, False) # for huggingface demo
state_dict_g = load_checkpoint(ckpt_path, "cpu")
vocoder.load_state_dict(state_dict_g["generator"])
logger.info(">>> Load vocoder from {}".format(ckpt_path))
return cls(vocoder, device, h)
|