Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,098 Bytes
82bc972 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import os
import torch
import torchaudio
import numpy as np
import random
import whisper
import fire
from argparse import Namespace
from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
from models import voice_star
from inference_tts_utils import inference_one_sample
############################################################
# Utility Functions
############################################################
def seed_everything(seed=1):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def estimate_duration(ref_audio_path, text):
"""
Estimate duration based on seconds per character from the reference audio.
"""
info = torchaudio.info(ref_audio_path)
audio_duration = info.num_frames / info.sample_rate
length_text = max(len(text), 1)
spc = audio_duration / length_text # seconds per character
return len(text) * spc
############################################################
# Main Inference Function
############################################################
def run_inference(
reference_speech="./demo/5895_34622_000026_000002.wav",
target_text="I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long.",
# Model
model_name="VoiceStar_840M_30s", # or VoiceStar_840M_40s, the later model is trained on maximally 40s long speech
model_root="./pretrained",
# Additional optional
reference_text=None, # if None => run whisper on reference_speech
target_duration=None, # if None => estimate from reference_speech and target_text
# Default hyperparameters from snippet
codec_audio_sr=16000, # do not change
codec_sr=50, # do not change
top_k=10, # try 10, 20, 30, 40
top_p=1, # do not change
min_p=1, # do not change
temperature=1,
silence_tokens=None, # do not change it
kvcache=1, # if OOM, set to 0
multi_trial=None, # do not change it
repeat_prompt=1, # increase this to improve speaker similarity, but it reference speech duration in total adding target duration is longer than maximal training duration, quality may drop
stop_repetition=3, # will not use it
sample_batch_size=1, # do not change
# Others
seed=1,
output_dir="./generated_tts",
# Some snippet-based defaults
cut_off_sec=100, # do not adjust this, we always use the entire reference speech. If you wish to change, also make sure to change the reference_transcript, so that it's only the trasnscript of the speech remained
):
"""
Inference script using Fire.
Example:
python inference_commandline.py \
--reference_speech "./demo/5895_34622_000026_000002.wav" \
--target_text "I cannot believe ... this audio is 10 seconds long." \
--reference_text "(optional) text to use as prefix" \
--target_duration (optional float)
"""
# Seed everything
seed_everything(seed)
# Load model, phn2num, and args
torch.serialization.add_safe_globals([Namespace])
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt_fn = os.path.join(model_root, model_name+".pth")
if not os.path.exists(ckpt_fn):
# use wget to download
print(f"[Info] Downloading {model_name} checkpoint...")
os.system(f"wget https://huggingface.co./pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}")
bundle = torch.load(ckpt_fn, map_location=device, weights_only=True)
args = bundle["args"]
phn2num = bundle["phn2num"]
model = voice_star.VoiceStar(args)
model.load_state_dict(bundle["model"])
model.to(device)
model.eval()
# If reference_text not provided, use whisper large-v3-turbo
if reference_text is None:
print("[Info] No reference_text provided, transcribing reference_speech with Whisper.")
wh_model = whisper.load_model("large-v3-turbo")
result = wh_model.transcribe(reference_speech)
prefix_transcript = result["text"]
print(f"[Info] Whisper transcribed text: {prefix_transcript}")
else:
prefix_transcript = reference_text
# If target_duration not provided, estimate from reference speech + target_text
if target_duration is None:
target_generation_length = estimate_duration(reference_speech, target_text)
print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f} seconds. If not desired, please provide a target_duration.")
else:
target_generation_length = float(target_duration)
# signature from snippet
if args.n_codebooks == 4:
signature = "./pretrained/encodec_6f79c6a8.th"
elif args.n_codebooks == 8:
signature = "./pretrained/encodec_8cb1024_giga.th"
else:
# fallback, just use the 6-f79c6a8
signature = "./pretrained/encodec_6f79c6a8.th"
if silence_tokens is None:
# default from snippet
silence_tokens = []
if multi_trial is None:
# default from snippet
multi_trial = []
delay_pattern_increment = args.n_codebooks + 1 # from snippet
# We can compute prompt_end_frame if we want, from snippet
info = torchaudio.info(reference_speech)
prompt_end_frame = int(cut_off_sec * info.sample_rate)
# Prepare tokenizers
audio_tokenizer = AudioTokenizer(signature=signature)
text_tokenizer = TextTokenizer(backend="espeak")
# decode_config from snippet
decode_config = {
'top_k': top_k,
'top_p': top_p,
'min_p': min_p,
'temperature': temperature,
'stop_repetition': stop_repetition,
'kvcache': kvcache,
'codec_audio_sr': codec_audio_sr,
'codec_sr': codec_sr,
'silence_tokens': silence_tokens,
'sample_batch_size': sample_batch_size
}
# Run inference
print("[Info] Running TTS inference...")
concated_audio, gen_audio = inference_one_sample(
model, args, phn2num, text_tokenizer, audio_tokenizer,
reference_speech, target_text,
device, decode_config,
prompt_end_frame=prompt_end_frame,
target_generation_length=target_generation_length,
delay_pattern_increment=delay_pattern_increment,
prefix_transcript=prefix_transcript,
multi_trial=multi_trial,
repeat_prompt=repeat_prompt,
)
# The model returns a list of waveforms, pick the first
concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
# Save the audio (just the generated portion, as the snippet does)
os.makedirs(output_dir, exist_ok=True)
out_filename = "generated.wav"
out_path = os.path.join(output_dir, out_filename)
torchaudio.save(out_path, gen_audio, codec_audio_sr)
print(f"[Success] Generated audio saved to {out_path}")
def main():
fire.Fire(run_inference)
if __name__ == "__main__":
main()
|