Spaces:
Running
on
T4
Running
on
T4
import os | |
import gradio as gr | |
import torch | |
import torchaudio | |
from pydub import AudioSegment | |
from pyannote.audio import Pipeline | |
from huggingface_hub import login | |
import numpy as np | |
import json | |
# Authenticate with Huggingface | |
AUTH_TOKEN = os.getenv("HF_TOKEN") | |
# Load the diarization pipeline | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.0", use_auth_token = AUTH_TOKEN).to(device) | |
def preprocess_audio(audio_path): | |
"""Convert audio to mono, 16kHz WAV format suitable for pyannote.""" | |
try: | |
# Load audio with pydub | |
audio = AudioSegment.from_file(audio_path) | |
# Convert to mono and set sample rate to 16kHz | |
audio = audio.set_channels(1).set_frame_rate(16000) | |
# Export to temporary WAV file | |
temp_wav = "temp_audio.wav" | |
audio.export(temp_wav, format="wav") | |
return temp_wav | |
except Exception as e: | |
raise ValueError(f"Error preprocessing audio: {str(e)}") | |
def diarize_audio(audio_path, num_speakers): | |
"""Perform speaker diarization and return formatted results.""" | |
try: | |
# Validate inputs | |
if not os.path.exists(audio_path): | |
raise ValueError("Audio file not found.") | |
if not isinstance(num_speakers, int) or num_speakers < 1: | |
raise ValueError("Number of speakers must be a positive integer.") | |
# Preprocess audio | |
wav_path = preprocess_audio(audio_path) | |
# Load audio for pyannote | |
waveform, sample_rate = torchaudio.load(wav_path) | |
audio_dict = {"waveform": waveform, "sample_rate": sample_rate} | |
# Configure pipeline with number of speakers | |
pipeline_params = {"num_speakers": num_speakers} | |
diarization = pipeline(audio_dict, **pipeline_params) | |
# Format results | |
results = [] | |
text_output = "" | |
for turn, _, speaker in diarization.itertracks(yield_label=True): | |
result = { | |
"start": round(turn.start, 3), | |
"end": round(turn.end, 3), | |
"speaker_id": speaker | |
} | |
results.append(result) | |
text_output += f"Speaker {speaker}: {result['start']}s - {result['end']}s\n" | |
# Clean up temporary file | |
if os.path.exists(wav_path): | |
os.remove(wav_path) | |
# Return text and JSON results | |
json_output = json.dumps(results, indent=2) | |
return text_output, json_output | |
except Exception as e: | |
return f"Error: {str(e)}", "" | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Speaker Diarization with Pyannote 3.0") | |
gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.") | |
with gr.Row(): | |
audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
num_speakers = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Speakers", value=2) | |
submit_btn = gr.Button("Diarize") | |
with gr.Row(): | |
text_output = gr.Textbox(label="Diarization Results (Text)") | |
json_output = gr.Textbox(label="Diarization Results (JSON)") | |
submit_btn.click( | |
fn=diarize_audio, | |
inputs=[audio_input, num_speakers], | |
outputs=[text_output, json_output] | |
) | |
# Launch the Gradio app | |
demo.launch() | |