Y-Mangoes's picture
Update app.py
fa783fb verified
import os
import gradio as gr
import torch
import torchaudio
from pydub import AudioSegment
from pyannote.audio import Pipeline
from huggingface_hub import login
import numpy as np
import json
# Authenticate with Huggingface
AUTH_TOKEN = os.getenv("HF_TOKEN")
# Load the diarization pipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.0", use_auth_token = AUTH_TOKEN).to(device)
def preprocess_audio(audio_path):
"""Convert audio to mono, 16kHz WAV format suitable for pyannote."""
try:
# Load audio with pydub
audio = AudioSegment.from_file(audio_path)
# Convert to mono and set sample rate to 16kHz
audio = audio.set_channels(1).set_frame_rate(16000)
# Export to temporary WAV file
temp_wav = "temp_audio.wav"
audio.export(temp_wav, format="wav")
return temp_wav
except Exception as e:
raise ValueError(f"Error preprocessing audio: {str(e)}")
def diarize_audio(audio_path, num_speakers):
"""Perform speaker diarization and return formatted results."""
try:
# Validate inputs
if not os.path.exists(audio_path):
raise ValueError("Audio file not found.")
if not isinstance(num_speakers, int) or num_speakers < 1:
raise ValueError("Number of speakers must be a positive integer.")
# Preprocess audio
wav_path = preprocess_audio(audio_path)
# Load audio for pyannote
waveform, sample_rate = torchaudio.load(wav_path)
audio_dict = {"waveform": waveform, "sample_rate": sample_rate}
# Configure pipeline with number of speakers
pipeline_params = {"num_speakers": num_speakers}
diarization = pipeline(audio_dict, **pipeline_params)
# Format results
results = []
text_output = ""
for turn, _, speaker in diarization.itertracks(yield_label=True):
result = {
"start": round(turn.start, 3),
"end": round(turn.end, 3),
"speaker_id": speaker
}
results.append(result)
text_output += f"Speaker {speaker}: {result['start']}s - {result['end']}s\n"
# Clean up temporary file
if os.path.exists(wav_path):
os.remove(wav_path)
# Return text and JSON results
json_output = json.dumps(results, indent=2)
return text_output, json_output
except Exception as e:
return f"Error: {str(e)}", ""
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Speaker Diarization with Pyannote 3.0")
gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.")
with gr.Row():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
num_speakers = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Speakers", value=2)
submit_btn = gr.Button("Diarize")
with gr.Row():
text_output = gr.Textbox(label="Diarization Results (Text)")
json_output = gr.Textbox(label="Diarization Results (JSON)")
submit_btn.click(
fn=diarize_audio,
inputs=[audio_input, num_speakers],
outputs=[text_output, json_output]
)
# Launch the Gradio app
demo.launch()