import gradio as gr import torch import torchaudio from speechbrain.pretrained import EncoderASR # Load the model try: asr_model = EncoderASR.from_hparams( source="speechbrain/asr-wav2vec2-dvoice-darija", savedir="tmp_model", run_opts={"device": "cpu"} # Ensures compatibility with CPU environments ) except Exception as e: print(f"Error loading model: {str(e)}") def transcribe(audio): """Transcribe uploaded audio to text using SpeechBrain ASR.""" if audio is None: return "No audio file uploaded. Please upload a valid file." try: # Load audio waveform, sample_rate = torchaudio.load(audio) # Convert stereo to mono if needed if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample if sample rate is not 16kHz if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) # Ensure waveform is 2D (1, time_steps) waveform = waveform.squeeze(0) # Remove channel dim if present waveform = waveform.unsqueeze(0) # Add batch dimension -> (1, time_steps) # Compute wav_lens as a relative fraction wav_lens = torch.tensor([waveform.shape[1] / waveform.shape[1]], dtype=torch.float32) # Transcribe transcription = asr_model.transcribe_batch(waveform, wav_lens) return transcription[0] except Exception as e: return f"Error processing audio: {str(e)}" # Create Gradio Interface iface = gr.Interface( fn=transcribe, inputs=gr.Audio(type="filepath"), outputs="text", title="Reconnaissance Vocale Darija", description="Parlez en Darija et obtenez la transcription." ) # Launch the app if __name__ == "__main__": iface.launch()