File size: 2,993 Bytes
33ddae1
a09ccc4
33ddae1
a09ccc4
 
 
 
0c8c55f
a09ccc4
 
 
dbb9634
 
a09ccc4
 
 
 
 
 
 
 
 
 
 
 
0c8c55f
a09ccc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c39a9ba
 
a09ccc4
c39a9ba
a09ccc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b316611
 
 
 
a09ccc4
 
b316611
c39a9ba
a09ccc4
b316611
a09ccc4
33ddae1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import time

from huggingface_hub import hf_hub_download
import numpy as np
import sphn
import torch
import spaces

from moshi.models import loaders

import torch._dynamo
torch._dynamo.config.suppress_errors = True

device = "cuda" if torch.cuda.device_count() else "cpu"
num_codebooks = 32

print("loading mimi")
model_file = hf_hub_download(loaders.DEFAULT_REPO, "tokenizer-e351c8d8-checkpoint125.safetensors")

mimi = loaders.get_mimi(model_file, device, num_codebooks=num_codebooks)
mimi.eval()
print("mimi loaded")


@spaces.GPU
def mimi_streaming_test(input_wave, max_duration_sec=10.0):
    pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
    # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
    sample_pcm, sample_sr = sphn.read(input_wave)  # ("bria.mp3")
    sample_rate = mimi.sample_rate
    print("loaded pcm", sample_pcm.shape, sample_sr)
    sample_pcm = sphn.resample(
        sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
    )
    sample_pcm = torch.tensor(sample_pcm, device=device)
    max_duration_len = int(sample_rate * max_duration_sec)
    if sample_pcm.shape[-1] > max_duration_len:
        sample_pcm = sample_pcm[..., :max_duration_len]
    print("resampled pcm", sample_pcm.shape, sample_sr)
    sample_pcm = sample_pcm[None].to(device=device)

    print("streaming encoding...")
    with torch.no_grad():
        all_codes_th = mimi.encode(sample_pcm)

    print(f"codes {all_codes_th.shape}")
    
    all_codes_list = [all_codes_th[:, :1, :],
                        all_codes_th[:, :2, :],
                        all_codes_th[:, :4, :],
                        # all_codes_th[:, :8, :],
                        # all_codes_th[:, :16, :],
                        all_codes_th[:, :32, :]]
    pcm_list = []
    for i, all_codes_th in enumerate(all_codes_list):
        with torch.no_grad():
            print(f"decoding {i+1} codebooks, {all_codes_th.shape}")
            pcm = mimi.decode(all_codes_th)
            pcm_list.append((sample_rate, pcm[0, 0].cpu().numpy()))
    # sphn.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), sample_rate)
    return pcm_list


demo = gr.Interface(
    fn=mimi_streaming_test,
    inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Input audio"),
    outputs=[gr.Audio(type="numpy", label="Reconstructed with 1 codebook"),
             gr.Audio(type="numpy", label="Reconstructed with 2 codebooks"),
             gr.Audio(type="numpy", label="Reconstructed with 4 codebooks"),
            #  gr.Audio(type="numpy", label="With 8 codebooks"),
            #  gr.Audio(type="numpy", label="With 16 codebooks"),
             gr.Audio(type="numpy", label="Reconstructed with 32 codebooks")],
    examples= [["./hello.mp3"]],
    title="Mimi tokenizer playground",
    description="Explore the quality of reconstruction when audio is tokenized using various number of code books in the Mimi model."
    )

demo.launch()