File size: 7,098 Bytes
82bc972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import torch
import torchaudio
import numpy as np
import random
import whisper
import fire
from argparse import Namespace

from data.tokenizer import (
    AudioTokenizer,
    TextTokenizer,
)

from models import voice_star
from inference_tts_utils import inference_one_sample

############################################################
# Utility Functions
############################################################

def seed_everything(seed=1):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def estimate_duration(ref_audio_path, text):
    """
    Estimate duration based on seconds per character from the reference audio.
    """
    info = torchaudio.info(ref_audio_path)
    audio_duration = info.num_frames / info.sample_rate
    length_text = max(len(text), 1)
    spc = audio_duration / length_text  # seconds per character
    return len(text) * spc

############################################################
# Main Inference Function
############################################################

def run_inference(
    reference_speech="./demo/5895_34622_000026_000002.wav",
    target_text="I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long.",
    # Model
    model_name="VoiceStar_840M_30s", # or VoiceStar_840M_40s, the later model is trained on maximally 40s long speech
    model_root="./pretrained",
    # Additional optional
    reference_text=None,  # if None => run whisper on reference_speech
    target_duration=None, # if None => estimate from reference_speech and target_text
    # Default hyperparameters from snippet
    codec_audio_sr=16000, # do not change
    codec_sr=50, # do not change
    top_k=10, # try 10, 20, 30, 40
    top_p=1, # do not change
    min_p=1, # do not change
    temperature=1,
    silence_tokens=None, # do not change it
    kvcache=1, # if OOM, set to 0
    multi_trial=None, # do not change it
    repeat_prompt=1, # increase this to improve speaker similarity, but it reference speech duration in total adding target duration is longer than maximal training duration, quality may drop
    stop_repetition=3, # will not use it
    sample_batch_size=1, # do not change
    # Others
    seed=1,
    output_dir="./generated_tts",
    # Some snippet-based defaults
    cut_off_sec=100, # do not adjust this, we always use the entire reference speech. If you wish to change, also make sure to change the reference_transcript, so that it's only the trasnscript of the speech remained
):
    """
    Inference script using Fire.

    Example:
        python inference_commandline.py \
            --reference_speech "./demo/5895_34622_000026_000002.wav" \
            --target_text "I cannot believe ... this audio is 10 seconds long." \
            --reference_text "(optional) text to use as prefix" \
            --target_duration (optional float) 
    """

    # Seed everything
    seed_everything(seed)

    # Load model, phn2num, and args
    torch.serialization.add_safe_globals([Namespace])
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ckpt_fn = os.path.join(model_root, model_name+".pth")
    if not os.path.exists(ckpt_fn):
        # use wget to download
        print(f"[Info] Downloading {model_name} checkpoint...")
        os.system(f"wget https://huggingface.co./pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}")
    bundle = torch.load(ckpt_fn, map_location=device, weights_only=True)
    args = bundle["args"]
    phn2num = bundle["phn2num"]
    model = voice_star.VoiceStar(args)
    model.load_state_dict(bundle["model"])
    model.to(device)
    model.eval()

    # If reference_text not provided, use whisper large-v3-turbo
    if reference_text is None:
        print("[Info] No reference_text provided, transcribing reference_speech with Whisper.")
        wh_model = whisper.load_model("large-v3-turbo")
        result = wh_model.transcribe(reference_speech)
        prefix_transcript = result["text"]
        print(f"[Info] Whisper transcribed text: {prefix_transcript}")
    else:
        prefix_transcript = reference_text

    # If target_duration not provided, estimate from reference speech + target_text
    if target_duration is None:
        target_generation_length = estimate_duration(reference_speech, target_text)
        print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f} seconds. If not desired, please provide a target_duration.")
    else:
        target_generation_length = float(target_duration)

    # signature from snippet
    if args.n_codebooks == 4:
        signature = "./pretrained/encodec_6f79c6a8.th"
    elif args.n_codebooks == 8:
        signature = "./pretrained/encodec_8cb1024_giga.th"
    else:
        # fallback, just use the 6-f79c6a8
        signature = "./pretrained/encodec_6f79c6a8.th"

    if silence_tokens is None:
        # default from snippet
        silence_tokens = []

    if multi_trial is None:
        # default from snippet
        multi_trial = []

    delay_pattern_increment = args.n_codebooks + 1  # from snippet

    # We can compute prompt_end_frame if we want, from snippet
    info = torchaudio.info(reference_speech)
    prompt_end_frame = int(cut_off_sec * info.sample_rate)

    # Prepare tokenizers
    audio_tokenizer = AudioTokenizer(signature=signature)
    text_tokenizer = TextTokenizer(backend="espeak")

    # decode_config from snippet
    decode_config = {
        'top_k': top_k,
        'top_p': top_p,
        'min_p': min_p,
        'temperature': temperature,
        'stop_repetition': stop_repetition,
        'kvcache': kvcache,
        'codec_audio_sr': codec_audio_sr,
        'codec_sr': codec_sr,
        'silence_tokens': silence_tokens,
        'sample_batch_size': sample_batch_size
    }

    # Run inference
    print("[Info] Running TTS inference...")
    concated_audio, gen_audio = inference_one_sample(
        model, args, phn2num, text_tokenizer, audio_tokenizer,
        reference_speech, target_text,
        device, decode_config,
        prompt_end_frame=prompt_end_frame,
        target_generation_length=target_generation_length,
        delay_pattern_increment=delay_pattern_increment,
        prefix_transcript=prefix_transcript,
        multi_trial=multi_trial,
        repeat_prompt=repeat_prompt,
    )

    # The model returns a list of waveforms, pick the first
    concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()

    # Save the audio (just the generated portion, as the snippet does)
    os.makedirs(output_dir, exist_ok=True)
    out_filename = "generated.wav"
    out_path = os.path.join(output_dir, out_filename)
    torchaudio.save(out_path, gen_audio, codec_audio_sr)

    print(f"[Success] Generated audio saved to {out_path}")


def main():
    fire.Fire(run_inference)

if __name__ == "__main__":
    main()