speechbrain / app.py
pheodoraa's picture
Fix error
2548d5a verified
import gradio as gr
import torch
import torchaudio
from speechbrain.pretrained import EncoderASR
# Load the model
try:
asr_model = EncoderASR.from_hparams(
source="speechbrain/asr-wav2vec2-dvoice-darija",
savedir="tmp_model",
run_opts={"device": "cpu"} # Ensures compatibility with CPU environments
)
except Exception as e:
print(f"Error loading model: {str(e)}")
def transcribe(audio):
"""Transcribe uploaded audio to text using SpeechBrain ASR."""
if audio is None:
return "No audio file uploaded. Please upload a valid file."
try:
# Load audio
waveform, sample_rate = torchaudio.load(audio)
# Convert stereo to mono if needed
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample if sample rate is not 16kHz
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
# Ensure waveform is 2D (1, time_steps)
waveform = waveform.squeeze(0) # Remove channel dim if present
waveform = waveform.unsqueeze(0) # Add batch dimension -> (1, time_steps)
# Compute wav_lens as a relative fraction
wav_lens = torch.tensor([waveform.shape[1] / waveform.shape[1]], dtype=torch.float32)
# Transcribe
transcription = asr_model.transcribe_batch(waveform, wav_lens)
return transcription[0]
except Exception as e:
return f"Error processing audio: {str(e)}"
# Create Gradio Interface
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="Reconnaissance Vocale Darija",
description="Parlez en Darija et obtenez la transcription."
)
# Launch the app
if __name__ == "__main__":
iface.launch()