File size: 6,958 Bytes
14c8893
 
 
 
 
 
 
 
 
 
 
 
 
134cc7e
14c8893
 
 
 
 
 
 
 
 
 
818e6ea
634bb0d
cda51a4
6d5f4d7
14c8893
 
 
 
 
 
 
 
 
 
1cf2426
14c8893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9300764
14c8893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bef20d8
14c8893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cf2426
14c8893
 
 
 
 
1cf2426
14c8893
 
 
 
 
 
 
 
 
 
 
 
1cf2426
14c8893
 
 
 
 
 
 
c506b70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
import torchaudio
import torchaudio.transforms as T
import time
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

from model.tinyvad import TinyVAD

# Font configuration
font_path = './fonts/Times_New_Roman.ttf'
font_prop = FontProperties(fname=font_path, size=18)

# Model and Processing Parameters
WINDOW_SIZE = 0.63
SINC_CONV = False
SSM = False
TARGET_SAMPLE_RATE = 16000

# Model Initialization
model = TinyVAD(1, 32, 64, patch_size=8, num_blocks=2, 
                sinc_conv=SINC_CONV, ssm=SSM)
checkpoint_path = './sincvad.ckpt'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(checkpoint, strict=False)
model.eval()

# Audio Processing Transforms
mel_spectrogram = T.MelSpectrogram(sample_rate=TARGET_SAMPLE_RATE, n_mels=64, win_length=400, hop_length=160)
log_mel_spectrogram = T.AmplitudeToDB()

# Chunking Parameters
chunk_duration = WINDOW_SIZE
shift_duration = WINDOW_SIZE * 0.875  # Increased overlap compared to first version

def predict(audio_input, threshold):
    """
    Predict voice activity in an audio file with detailed processing and visualization.
    
    Args:
        audio_file (str): Path to the audio file
        threshold (float): Decision threshold for speech/non-speech classification
    
    Yields:
        Intermediate and final prediction results
    """
    start_time = time.time()

    try:
        # Load and preprocess audio
        waveform, orig_sample_rate = torchaudio.load(audio_input)
        
        # Resample if necessary
        if orig_sample_rate != TARGET_SAMPLE_RATE:
            print(f"Resampling from {orig_sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz")
            resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=TARGET_SAMPLE_RATE)
            waveform = resampler(waveform)
        
        # Ensure mono channel
        if waveform.size(0) > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
    except Exception as e:
        print(f"Error loading audio file: {e}")
        yield "Error loading audio file.", None, None, None
        return
    
    # Audio duration checks and padding
    audio_duration = waveform.size(1) / TARGET_SAMPLE_RATE
    print(f"Audio duration: {audio_duration:.2f} seconds")
    print(f"Original sample rate: {orig_sample_rate} Hz")
    print(f"Current sample rate: {TARGET_SAMPLE_RATE} Hz")

    if audio_duration < chunk_duration:
        required_length = int(chunk_duration * TARGET_SAMPLE_RATE)
        padding_length = required_length - waveform.size(1)
        waveform = torch.nn.functional.pad(waveform, (0, padding_length))
    
    # Chunk processing parameters
    chunk_size = int(chunk_duration * TARGET_SAMPLE_RATE)
    shift_size = int(shift_duration * TARGET_SAMPLE_RATE)
    num_chunks = (waveform.size(1) - chunk_size) // shift_size + 1

    predictions = []
    time_stamps = []
    detailed_predictions = []

    # Initialize plot
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.set_xlabel('Time (seconds)', fontproperties=font_prop)
    ax.set_ylabel('Probability', fontproperties=font_prop)
    ax.set_title('Voice Activity Detection Probability Over Time', fontproperties=font_prop)
    ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold')
    ax.grid(True)
    ax.set_ylim([-0.05, 1.05])

    # Process audio in chunks
    for i in range(num_chunks):
        start_idx = i * shift_size
        end_idx = start_idx + chunk_size
        chunk = waveform[:, start_idx:end_idx]

        if chunk.size(1) < chunk_size:
            break

        # Feature extraction
        inputs = mel_spectrogram(chunk)
        inputs = log_mel_spectrogram(inputs).unsqueeze(0)

        # Model inference
        with torch.no_grad():
            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)
        
        # Process outputs
        predictions.append(outputs.item())
        time_stamps.append(start_idx / TARGET_SAMPLE_RATE)
        
        detailed_predictions.append({
            'start_time': start_idx / TARGET_SAMPLE_RATE,
            'output': outputs.item(),
        })

        # Update plot dynamically
        ax.clear()
        ax.set_xlabel('Time (seconds)', fontproperties=font_prop)
        ax.set_ylabel('Probability', fontproperties=font_prop)
        ax.set_title('Speech Probability Over Time', fontproperties=font_prop)
        ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold')
        ax.grid(True)
        ax.set_ylim([-0.05, 1.05])
        ax.plot(time_stamps, predictions, label='Speech Probability', color='tab:blue')
        plt.tight_layout()

        # Yield intermediate progress
        yield "Processing...", None, None, fig

    # Detailed logging
    print("Detailed Predictions:")
    for pred in detailed_predictions:
        print(f"Start Time: {pred['start_time']:.2f}s, Output: {pred['output']:.4f}")

    # Final prediction processing
    avg_output = max(0, min(1, np.mean(predictions)))
    prediction_time = time.time() - start_time

    prediction = "Speech" if avg_output > threshold else "Non-speech"
    probability = f'{(float(avg_output) * 100):.2f}'
    inference_time = f'{prediction_time:.4f}'

    print(f"Final Prediction: {prediction}")
    print(f"Average Probability: {probability}%")
    print(f"Number of chunks processed: {num_chunks}")

    # Final result
    yield prediction, probability, inference_time, fig

# Gradio Interface
with gr.Blocks() as demo:
    gr.Image("./img/logo.png", elem_id="logo", height=100)
    # Title and Description
    gr.Markdown("<h1 style='text-align: center; color: black;'>Voice Activity Detection using SincVAD</h1>")
    gr.Markdown("<h3 style='text-align: center; color: black;'>Upload or record audio to predict speech activity and view the probability curve.</h3>")
    
    # Interface Layout
    with gr.Row():
        with gr.Column():
            # Separate recording and file upload
            audio_input = gr.Audio(type="filepath", label="Upload or Record Audio")
            threshold_input = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Threshold")
        with gr.Column():
            prediction_output = gr.Textbox(label="Prediction")
            probability_output = gr.Number(label="Average Probability (%)")
            time_output = gr.Textbox(label="Inference Time (seconds)")
        
    plot_output = gr.Plot(label="Probability Curve")

    # Prediction Trigger
    predict_btn = gr.Button("Start Prediction")
    predict_btn.click(
        predict, 
        [audio_input, threshold_input], 
        [prediction_output, probability_output, time_output, plot_output],
        api_name="predict"
    )

# Launch Configuration
if __name__ == "__main__":
    demo.queue()  # Enable queue to support generators
    demo.launch()