File size: 9,234 Bytes
b35040f
 
 
 
f564269
 
b35040f
f564269
5eaee65
b35040f
 
7a742e9
f564269
c3ffb57
f564269
c3ffb57
f564269
c3ffb57
 
 
f564269
c3ffb57
f564269
c3ffb57
 
f564269
 
c3ffb57
 
 
 
f564269
c3ffb57
 
 
 
 
 
 
 
 
 
 
 
f564269
c3ffb57
 
 
5eaee65
f564269
7a742e9
f564269
 
 
 
 
 
7a742e9
f564269
 
c3ffb57
 
b35040f
f564269
 
 
b35040f
f564269
b35040f
 
5eaee65
f564269
 
5eaee65
f564269
 
 
b35040f
f564269
b35040f
 
 
f564269
 
 
b35040f
5eaee65
f564269
b35040f
 
5eaee65
b35040f
 
 
 
 
 
 
 
 
 
 
f564269
b35040f
f564269
b35040f
 
f564269
 
5eaee65
b35040f
f564269
 
 
 
 
 
 
 
 
 
 
 
 
 
5eaee65
f564269
b35040f
f564269
 
 
b35040f
f564269
b35040f
f564269
 
b35040f
 
f564269
b35040f
 
f564269
b35040f
f564269
 
 
b35040f
f564269
 
 
 
 
 
ec99653
 
 
 
 
 
 
 
 
 
 
f564269
 
 
 
 
b35040f
f564269
 
 
 
 
 
b35040f
f564269
 
 
 
b35040f
 
 
 
7a742e9
f564269
 
 
b35040f
 
 
f564269
b35040f
f564269
b35040f
 
 
 
5eaee65
f564269
5eaee65
b35040f
 
f564269
 
 
 
 
 
 
 
 
 
 
b35040f
f564269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b35040f
 
 
f564269
b35040f
 
f564269
b35040f
 
 
 
 
 
 
f564269
b35040f
 
 
 
 
f564269
b35040f
 
 
 
 
 
 
f564269
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
from snac import SNAC
import time  # Import the time module
from dotenv import load_dotenv
from optimum.bettertransformer import BetterTransformer

load_dotenv()


# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 1. Load SNAC Model (for audio decoding)
print("Loading SNAC model...")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to(device)
snac_model.eval()  # Set SNAC to evaluation mode

# 2. Load Orpheus Language Model (for text-to-token generation)
model_name = "canopylabs/orpheus-3b-0.1-ft"

# Download only necessary files (config and safetensors)
print("Downloading Orpheus model files...")
snapshot_download(
    repo_id=model_name,
    allow_patterns=[
        "config.json",
        ".safetensors",
        "model.safetensors.index.json",
    ],
    ignore_patterns=[
        "optimizer.pt",
        "pytorch_model.bin",
        "training_args.bin",
        "scheduler.pt",
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "vocab.json",
        "merges.txt",
        "tokenizer."
    ]
)

print("Loading Orpheus model...")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)

# --- Optimization 1: Convert to BetterTransformer ---
try:
    model = BetterTransformer.transform(model)
    print("Model converted to BetterTransformer for faster inference.")
except Exception as e:
    print(f"BetterTransformer conversion failed: {e}. Proceeding without it.")

model.to(device)
model.eval()  # Set the Orpheus model to evaluation mode
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Orpheus model loaded to {device}")


# --- Function Definitions ---

def process_prompt(prompt, voice, tokenizer, device):
    """Processes the text prompt and converts it to input IDs."""
    prompt = f"{voice}: {prompt}"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    start_token = torch.tensor([[128259]], dtype=torch.int64)  # Start of human
    end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)  # End of text, End of human

    modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)  # SOH SOT Text EOT EOH

    # No padding needed for single input
    attention_mask = torch.ones_like(modified_input_ids)

    return modified_input_ids.to(device), attention_mask.to(device)

def parse_output(generated_ids):
    """Parses the generated token IDs to extract the audio codes."""
    token_to_find = 128257  # SOT token
    token_to_remove = 128258 # EOT token

    token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)

    if len(token_indices[1]) > 0:
        last_occurrence_idx = token_indices[1][-1].item()
        cropped_tensor = generated_ids[:, last_occurrence_idx + 1:]
    else:
        cropped_tensor = generated_ids

    processed_rows = []
    for row in cropped_tensor:
        masked_row = row[row != token_to_remove]
        processed_rows.append(masked_row)

    code_lists = []
    for row in processed_rows:
        row_length = row.size(0)
        new_length = (row_length // 7) * 7  # Ensure divisibility by 7
        trimmed_row = row[:new_length]
        trimmed_row = [t - 128266 for t in trimmed_row]  # Adjust code values
        code_lists.append(trimmed_row)

    return code_lists[0]  # Return codes for the first (and only) sequence


def redistribute_codes(code_list, snac_model):
    """Redistributes the audio codes into the format required by SNAC."""
    device = next(snac_model.parameters()).device  # Get the device of SNAC model

    layer_1 = []
    layer_2 = []
    layer_3 = []
    for i in range(len(code_list) // 7):  # Corrected loop condition
        layer_1.append(code_list[7*i])
        layer_2.append(code_list[7*i+1]-4096)
        layer_3.append(code_list[7*i+2]-(2*4096))
        layer_3.append(code_list[7*i+3]-(3*4096))
        layer_2.append(code_list[7*i+4]-(4*4096))
        layer_3.append(code_list[7*i+5]-(5*4096))
        layer_3.append(code_list[7*i+6]-(6*4096))

    # Move tensors to the same device as the SNAC model
    codes = [
        torch.tensor(layer_1, device=device).unsqueeze(0),
        torch.tensor(layer_2, device=device).unsqueeze(0),
        torch.tensor(layer_3, device=device).unsqueeze(0)
    ]

    audio_hat = snac_model.decode(codes)
    return audio_hat.detach().squeeze().cpu().numpy()  # Return CPU numpy array


def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
    """Generates speech from the given text using Orpheus and SNAC."""
    if not text.strip():
        return None

    try:
        start_time = time.time()  # Start timing

        progress(0.1, "Processing text...")
        input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
        process_time = time.time() - start_time
        print(f"Text processing time: {process_time:.2f} seconds")

        start_time = time.time()  # Reset timer
        progress(0.3, "Generating speech tokens...")
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                num_return_sequences=1,
                eos_token_id=128258,
            )
        generation_time = time.time() - start_time
        print(f"Token generation time: {generation_time:.2f} seconds")

        start_time = time.time()  # Reset timer
        progress(0.6, "Processing speech tokens...")
        code_list = parse_output(generated_ids)
        parse_time = time.time() - start_time
        print(f"Token parsing time: {parse_time:.2f} seconds")


        start_time = time.time()  # Reset timer
        progress(0.8, "Converting to audio...")
        audio_samples = redistribute_codes(code_list, snac_model)
        audio_time = time.time() - start_time
        print(f"Audio conversion time: {audio_time:.2f} seconds")

        return (24000, audio_samples)  # Return sample rate and audio
    except Exception as e:
        print(f"Error generating speech: {e}")
        return None



# --- Gradio Interface Setup ---

examples = [
    ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
    ["I've also been taught to understand and produce paralinguistic things like sighing, or chuckling, or yawning!", "dan", 0.7, 0.95, 1.1, 1200],
    ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... well, lets just say a lot of parameters.", "emma", 0.6, 0.9, 1.2, 1200]
]

VOICES = ["tara", "dan", "josh", "emma"]

with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
    gr.Markdown("""
    # 🎵 Orpheus Text-to-Speech
    Enter text below to convert to speech.
    """)
    with gr.Row():
        with gr.Column(scale=3):
            text_input = gr.Textbox(
                label="Text to speak",
                placeholder="Enter your text here...",
                lines=5
            )
            voice = gr.Dropdown(
                choices=VOICES,
                value="tara",
                label="Voice"
            )

            with gr.Accordion("Advanced Settings", open=False):
                temperature = gr.Slider(
                    minimum=0.1, maximum=1.5, value=0.6, step=0.05,
                    label="Temperature"
                )
                top_p = gr.Slider(
                    minimum=0.1, maximum=1.0, value=0.95, step=0.05,
                    label="Top P"
                )
                repetition_penalty = gr.Slider(
                    minimum=1.0, maximum=2.0, value=1.1, step=0.05,
                    label="Repetition Penalty"
                )
                max_new_tokens = gr.Slider(
                    minimum=100, maximum=2000, value=1200, step=100,
                    label="Max Length"
                )

            with gr.Row():
                submit_btn = gr.Button("Generate Speech", variant="primary")
                clear_btn = gr.Button("Clear")

        with gr.Column(scale=2):
            audio_output = gr.Audio(label="Generated Speech", type="numpy")

    gr.Examples(
        examples=examples,
        inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
        outputs=audio_output,
        fn=generate_speech,
        cache_examples=True,
    )

    submit_btn.click(
        fn=generate_speech,
        inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
        outputs=audio_output
    )

    clear_btn.click(
        fn=lambda: (None, None),
        inputs=[],
        outputs=[text_input, audio_output]
    )

if __name__ == "__main__":
    demo.queue().launch(share=False)