|
import gradio as gr |
|
import spaces |
|
import random |
|
import json |
|
import os |
|
import string |
|
from difflib import SequenceMatcher |
|
from jiwer import wer |
|
import torchaudio |
|
from transformers import pipeline |
|
|
|
|
|
with open("common_voice_en_validated_249_hf_ready.json") as f: |
|
data = json.load(f) |
|
|
|
|
|
ages = sorted(set(entry["age"] for entry in data)) |
|
genders = sorted(set(entry["gender"] for entry in data)) |
|
accents = sorted(set(entry["accent"] for entry in data)) |
|
|
|
|
|
def convert_to_wav(file_path): |
|
wav_path = file_path.replace(".mp3", ".wav") |
|
if not os.path.exists(wav_path): |
|
waveform, sample_rate = torchaudio.load(file_path) |
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
torchaudio.save(wav_path, waveform, sample_rate) |
|
return wav_path |
|
|
|
def highlight_differences(ref, hyp): |
|
sm = SequenceMatcher(None, ref.split(), hyp.split()) |
|
result = [] |
|
for opcode, i1, i2, j1, j2 in sm.get_opcodes(): |
|
if opcode == "equal": |
|
result.extend(hyp.split()[j1:j2]) |
|
else: |
|
wrong = hyp.split()[j1:j2] |
|
result.extend([f"<span style='color:red'>{w}</span>" for w in wrong]) |
|
return " ".join(result) |
|
|
|
def normalize(text): |
|
text = text.lower() |
|
text = text.translate(str.maketrans('', '', string.punctuation)) |
|
return text.strip() |
|
|
|
|
|
def generate_audio(age, gender, accent): |
|
filtered = [ |
|
entry for entry in data |
|
if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent |
|
] |
|
if not filtered: |
|
return None, "No matching sample." |
|
sample = random.choice(filtered) |
|
file_path = os.path.join("common_voice_en_validated_249", sample["path"]) |
|
wav_file_path = convert_to_wav(file_path) |
|
return wav_file_path, wav_file_path |
|
|
|
|
|
|
|
def transcribe_audio(file_path): |
|
if not file_path: |
|
return "No file selected.", "", "", "", "", "", "" |
|
|
|
filename_mp3 = os.path.basename(file_path).replace(".wav", ".mp3") |
|
gold = "" |
|
for entry in data: |
|
if entry["path"].endswith(filename_mp3): |
|
gold = normalize(entry["sentence"]) |
|
break |
|
if not gold: |
|
return "Reference not found.", "", "", "", "", "", "" |
|
|
|
model_ids = [ |
|
"openai/whisper-tiny", |
|
"openai/whisper-tiny.en", |
|
"openai/whisper-base", |
|
"openai/whisper-base.en", |
|
"openai/whisper-medium", |
|
"openai/whisper-medium.en", |
|
"distil-whisper/distil-large-v3.5", |
|
"facebook/wav2vec2-base-960h", |
|
"facebook/wav2vec2-large-960h", |
|
"facebook/wav2vec2-large-960h-lv60-self", |
|
"facebook/hubert-large-ls960-ft", |
|
] |
|
|
|
outputs = {} |
|
for model_id in model_ids: |
|
try: |
|
pipe = pipeline("automatic-speech-recognition", model=model_id) |
|
text = pipe(file_path)["text"].strip().lower() |
|
clean = normalize(text) |
|
wer_score = wer(gold, clean) |
|
outputs[model_id] = f"<b>{model_id} (WER: {wer_score:.2f}):</b><br>{highlight_differences(gold, clean)}" |
|
except Exception as e: |
|
outputs[model_id] = f"<b>{model_id}:</b><br><span style='color:red'>Error: {str(e)}</span>" |
|
|
|
return (gold, *outputs.values()) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Comparing ASR Models on Diverse English Speech Samples") |
|
gr.Markdown(""" |
|
This demo compares the transcription performance of several automatic speech recognition (ASR) models. |
|
Users can select age, gender, and accent to generate diverse English audio samples. |
|
The models are evaluated on their ability to transcribe those samples. |
|
Data is sourced from 249 validated entries in the Common Voice English Delta Segment 21.0 release. |
|
""") |
|
|
|
with gr.Row(): |
|
accent = gr.Dropdown(choices=accents, label="Accent", interactive=True) |
|
gender = gr.Dropdown(choices=[], label="Gender", interactive=True) |
|
age = gr.Dropdown(choices=[], label="Age", interactive=True) |
|
|
|
def update_gender_options(selected_accent): |
|
options = sorted(set(entry["gender"] for entry in data if entry["accent"] == selected_accent)) |
|
return gr.update(choices=options, value=None) |
|
|
|
def update_age_options(selected_accent, selected_gender): |
|
options = sorted(set( |
|
entry["age"] for entry in data |
|
if entry["accent"] == selected_accent and entry["gender"] == selected_gender |
|
)) |
|
return gr.update(choices=options, value=None) |
|
|
|
accent.change(update_gender_options, inputs=[accent], outputs=[gender]) |
|
gender.change(update_age_options, inputs=[accent, gender], outputs=[age]) |
|
|
|
generate_btn = gr.Button("Get Audio") |
|
audio_output = gr.Audio(label="Audio", type="filepath", interactive=False) |
|
file_path_output = gr.Textbox(label="Audio File Path", visible=False) |
|
|
|
generate_btn.click(generate_audio, [age, gender, accent], [audio_output, file_path_output]) |
|
|
|
transcribe_btn = gr.Button("Transcribe with All Models") |
|
gold_text = gr.Textbox(label="Reference (Gold Standard)") |
|
|
|
whisper_tiny_html = gr.HTML(label="Whisper Tiny") |
|
whisper_tiny_en_html = gr.HTML(label="Whisper Tiny English") |
|
whisper_base_html = gr.HTML(label="Whisper Base") |
|
whisper_base_en_html = gr.HTML(label="Whisper Base English") |
|
whisper_medium_html = gr.HTML(label="Whisper Medium") |
|
whisper_medium_en_html = gr.HTML(label="Whisper Medium English") |
|
distil_html = gr.HTML(label="Distil-Whisper Large") |
|
wav2vec_base_html = gr.HTML(label="Wav2Vec2 Base") |
|
wav2vec_large_html = gr.HTML(label="Wav2Vec2 Large") |
|
wav2vec_lv60_html = gr.HTML(label="Wav2Vec2 Large + LibriLight") |
|
hubert_html = gr.HTML(label="HuBERT Large") |
|
|
|
transcribe_btn.click( |
|
transcribe_audio, |
|
inputs=[file_path_output], |
|
outputs=[ |
|
gold_text, |
|
whisper_tiny_html, |
|
whisper_tiny_en_html, |
|
whisper_base_html, |
|
whisper_base_en_html, |
|
whisper_medium_html, |
|
whisper_medium_en_html, |
|
distil_html, |
|
wav2vec_base_html, |
|
wav2vec_large_html, |
|
wav2vec_lv60_html, |
|
hubert_html, |
|
], |
|
) |
|
|
|
demo.launch() |