File size: 4,767 Bytes
dcbed68
41deb18
 
dcbed68
 
 
 
 
41deb18
dcbed68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41deb18
 
 
 
 
 
 
 
 
 
 
 
 
dcbed68
 
41deb18
dcbed68
 
 
 
 
 
 
 
 
 
 
41deb18
dcbed68
 
41deb18
 
dcbed68
 
 
 
 
 
41deb18
dcbed68
 
 
 
41deb18
 
 
dcbed68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import csv
import os
import tempfile
from typing import Tuple

import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from whisper_bidec import decode_wav, get_logits_processor, load_corpus_from_sentences
from pydub import AudioSegment


def _parse_file(file_path: str) -> list[str]:
    """Parse .txt / .md / .csv and return its content as a list of strings by splitting per new line or row."""

    if file_path.endswith(".csv"):
        sentences = []
        with open(file_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f)
            for row in reader:
                sentences.append(row)
    else:
        with open(file_path, "r") as f:
            sentences = f.readlines()
    return sentences


def _convert_audio(input_audio_path: str) -> str:
    """Whisper decoder expects wav files with 16kHz sample rate and mono channel.
    Convert the audio file to this format, save it in a tmp file and return the path.
    """
    fd, tmp_path = tempfile.mkstemp(suffix=".wav")
    os.close(fd)  # Close file descriptor

    audio = AudioSegment.from_file(input_audio_path)
    audio = audio.set_channels(1).set_frame_rate(16000)
    audio.export(tmp_path, format="wav")
    return tmp_path


def transcribe(
    processor_name: str,
    audio_path: str,
    bias_strength: float,
    bias_text: str | None,
    bias_text_file: str | None,
) -> Tuple[str, str]:
    processor = WhisperProcessor.from_pretrained(processor_name)
    model = WhisperForConditionalGeneration.from_pretrained(processor_name)

    sentences = ""

    if bias_text:
        sentences = bias_text.split(",")
    elif bias_text_file:
        sentences = _parse_file(bias_text_file)

    converted_audio_path = _convert_audio(audio_path)

    if sentences:
        corpus = load_corpus_from_sentences(sentences, processor)
        logits_processor = get_logits_processor(
            corpus=corpus, processor=processor, bias_towards_lm=bias_strength
        )
        text_with_bias = decode_wav(
            model, processor, converted_audio_path, logits_processor=logits_processor
        )
    else:
        text_with_bias = ""

    text_no_bias = decode_wav(
        model, processor, converted_audio_path, logits_processor=None
    )

    return text_no_bias, text_with_bias


def setup_gradio_demo():
    css = """
    #centered-column {
        display: flex;
        justify-content: center;
        align-items: center;
        flex-direction: column; 
        text-align: center;
    }
    """
    with gr.Blocks(css=css) as demo:
        gr.Markdown("# Whisper Bidec Demo")

        gr.Markdown("## Step 1: Select a Whisper model")
        processor = gr.Textbox(
            value="openai/whisper-tiny.en", label="Whisper Model from Hugging Face"
        )

        gr.Markdown("## Step 2: Upload your audio file")
        audio_clip = gr.Audio(type="filepath", label="Upload a WAV file")

        gr.Markdown("## Step 3: Set your biasing text")
        with gr.Row():
            with gr.Column(scale=20):
                gr.Markdown(
                    "You can add multiple possible sentences by separating them with a comma <,>."
                )
                bias_text = gr.Textbox(label="Write your biasing text here")
            with gr.Column(scale=1, elem_id="centered-column"):
                gr.Markdown("## OR")
            with gr.Column(scale=20):
                gr.Markdown(
                    "Note that each new line (.txt / .md) or row (.csv) will be treated as a separate sentence to bias towards to."
                )
                bias_text_file = gr.File(
                    label="Upload a file with multiple lines of text",
                    file_types=[".txt", ".md", ".csv"],
                )

        gr.Markdown("## Step 4: Set how much you want to bias towards the LM")
        bias_amount = gr.Slider(
            minimum=0.0,
            maximum=1.0,
            value=0.5,
            step=0.1,
            label="Bias strength",
            interactive=True,
        )

        gr.Markdown("## Step 5: Get your transcription before and after biasing")
        transcribe_button = gr.Button("Transcribe")

        with gr.Row():
            with gr.Column():
                output = gr.Text(label="Output")
            with gr.Column():
                biased_output = gr.Text(label="Biased output")

        transcribe_button.click(
            fn=transcribe,
            inputs=[
                processor,
                audio_clip,
                bias_amount,
                bias_text,
                bias_text_file,
            ],
            outputs=[output, biased_output],
        )
    demo.launch()


if __name__ == "__main__":
    setup_gradio_demo()