File size: 3,361 Bytes
84035c8
c54eecb
 
84035c8
 
c54eecb
 
00e1f93
84035c8
c54eecb
84035c8
740f2f9
c54eecb
84035c8
00e1f93
740f2f9
c54eecb
84035c8
 
00e1f93
84035c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa783fb
84035c8
 
 
 
 
 
 
 
c54eecb
84035c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c54eecb
84035c8
00e1f93
84035c8
 
 
 
00e1f93
84035c8
 
 
00e1f93
84035c8
00e1f93
84035c8
 
 
00e1f93
84035c8
 
 
 
 
c54eecb
84035c8
05071ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import gradio as gr
import torch
import torchaudio
from pydub import AudioSegment
from pyannote.audio import Pipeline
from huggingface_hub import login
import numpy as np
import json

# Authenticate with Huggingface
AUTH_TOKEN = os.getenv("HF_TOKEN")

# Load the diarization pipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.0", use_auth_token = AUTH_TOKEN).to(device)

def preprocess_audio(audio_path):
    """Convert audio to mono, 16kHz WAV format suitable for pyannote."""
    try:
        # Load audio with pydub
        audio = AudioSegment.from_file(audio_path)
        # Convert to mono and set sample rate to 16kHz
        audio = audio.set_channels(1).set_frame_rate(16000)
        # Export to temporary WAV file
        temp_wav = "temp_audio.wav"
        audio.export(temp_wav, format="wav")
        return temp_wav
    except Exception as e:
        raise ValueError(f"Error preprocessing audio: {str(e)}")

def diarize_audio(audio_path, num_speakers):
    """Perform speaker diarization and return formatted results."""
    try:
        # Validate inputs
        if not os.path.exists(audio_path):
            raise ValueError("Audio file not found.")
        if not isinstance(num_speakers, int) or num_speakers < 1:
            raise ValueError("Number of speakers must be a positive integer.")

        # Preprocess audio
        wav_path = preprocess_audio(audio_path)

        # Load audio for pyannote
        waveform, sample_rate = torchaudio.load(wav_path)
        audio_dict = {"waveform": waveform, "sample_rate": sample_rate}

        # Configure pipeline with number of speakers
        pipeline_params = {"num_speakers": num_speakers}
        diarization = pipeline(audio_dict, **pipeline_params)

        # Format results
        results = []
        text_output = ""
        for turn, _, speaker in diarization.itertracks(yield_label=True):
            result = {
                "start": round(turn.start, 3),
                "end": round(turn.end, 3),
                "speaker_id": speaker
            }
            results.append(result)
            text_output += f"Speaker {speaker}: {result['start']}s - {result['end']}s\n"

        # Clean up temporary file
        if os.path.exists(wav_path):
            os.remove(wav_path)

        # Return text and JSON results
        json_output = json.dumps(results, indent=2)
        return text_output, json_output

    except Exception as e:
        return f"Error: {str(e)}", ""

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Speaker Diarization with Pyannote 3.0")
    gr.Markdown("Upload an audio file and specify the number of speakers to diarize the audio.")
    
    with gr.Row():
        audio_input = gr.Audio(label="Upload Audio File", type="filepath")
        num_speakers = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Speakers", value=2)
    
    submit_btn = gr.Button("Diarize")
    
    with gr.Row():
        text_output = gr.Textbox(label="Diarization Results (Text)")
        json_output = gr.Textbox(label="Diarization Results (JSON)")
    
    submit_btn.click(
        fn=diarize_audio,
        inputs=[audio_input, num_speakers],
        outputs=[text_output, json_output]
    )

# Launch the Gradio app
demo.launch()