import torch import torch.nn as nn import gradio as gr import numpy as np import torchaudio import torchaudio.transforms as T import time import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties from model.tinyvad import TinyVAD # Font configuration font_path = './fonts/Times_New_Roman.ttf' font_prop = FontProperties(fname=font_path, size=18) # Model and Processing Parameters WINDOW_SIZE = 0.63 SINC_CONV = False SSM = False TARGET_SAMPLE_RATE = 16000 # Model Initialization model = TinyVAD(1, 32, 64, patch_size=8, num_blocks=2, sinc_conv=SINC_CONV, ssm=SSM) checkpoint_path = './sincvad.ckpt' checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True) model.load_state_dict(checkpoint, strict=False) model.eval() # Audio Processing Transforms mel_spectrogram = T.MelSpectrogram(sample_rate=TARGET_SAMPLE_RATE, n_mels=64, win_length=400, hop_length=160) log_mel_spectrogram = T.AmplitudeToDB() # Chunking Parameters chunk_duration = WINDOW_SIZE shift_duration = WINDOW_SIZE * 0.875 # Increased overlap compared to first version def predict(audio_input, threshold): """ Predict voice activity in an audio file with detailed processing and visualization. Args: audio_file (str): Path to the audio file threshold (float): Decision threshold for speech/non-speech classification Yields: Intermediate and final prediction results """ start_time = time.time() try: # Load and preprocess audio waveform, orig_sample_rate = torchaudio.load(audio_input) # Resample if necessary if orig_sample_rate != TARGET_SAMPLE_RATE: print(f"Resampling from {orig_sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz") resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=TARGET_SAMPLE_RATE) waveform = resampler(waveform) # Ensure mono channel if waveform.size(0) > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) except Exception as e: print(f"Error loading audio file: {e}") yield "Error loading audio file.", None, None, None return # Audio duration checks and padding audio_duration = waveform.size(1) / TARGET_SAMPLE_RATE print(f"Audio duration: {audio_duration:.2f} seconds") print(f"Original sample rate: {orig_sample_rate} Hz") print(f"Current sample rate: {TARGET_SAMPLE_RATE} Hz") if audio_duration < chunk_duration: required_length = int(chunk_duration * TARGET_SAMPLE_RATE) padding_length = required_length - waveform.size(1) waveform = torch.nn.functional.pad(waveform, (0, padding_length)) # Chunk processing parameters chunk_size = int(chunk_duration * TARGET_SAMPLE_RATE) shift_size = int(shift_duration * TARGET_SAMPLE_RATE) num_chunks = (waveform.size(1) - chunk_size) // shift_size + 1 predictions = [] time_stamps = [] detailed_predictions = [] # Initialize plot fig, ax = plt.subplots(figsize=(12, 5)) ax.set_xlabel('Time (seconds)', fontproperties=font_prop) ax.set_ylabel('Probability', fontproperties=font_prop) ax.set_title('Voice Activity Detection Probability Over Time', fontproperties=font_prop) ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold') ax.grid(True) ax.set_ylim([-0.05, 1.05]) # Process audio in chunks for i in range(num_chunks): start_idx = i * shift_size end_idx = start_idx + chunk_size chunk = waveform[:, start_idx:end_idx] if chunk.size(1) < chunk_size: break # Feature extraction inputs = mel_spectrogram(chunk) inputs = log_mel_spectrogram(inputs).unsqueeze(0) # Model inference with torch.no_grad(): outputs = model(inputs) outputs = torch.sigmoid(outputs) # Process outputs predictions.append(outputs.item()) time_stamps.append(start_idx / TARGET_SAMPLE_RATE) detailed_predictions.append({ 'start_time': start_idx / TARGET_SAMPLE_RATE, 'output': outputs.item(), }) # Update plot dynamically ax.clear() ax.set_xlabel('Time (seconds)', fontproperties=font_prop) ax.set_ylabel('Probability', fontproperties=font_prop) ax.set_title('Speech Probability Over Time', fontproperties=font_prop) ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold') ax.grid(True) ax.set_ylim([-0.05, 1.05]) ax.plot(time_stamps, predictions, label='Speech Probability', color='tab:blue') plt.tight_layout() # Yield intermediate progress yield "Processing...", None, None, fig # Detailed logging print("Detailed Predictions:") for pred in detailed_predictions: print(f"Start Time: {pred['start_time']:.2f}s, Output: {pred['output']:.4f}") # Final prediction processing avg_output = max(0, min(1, np.mean(predictions))) prediction_time = time.time() - start_time prediction = "Speech" if avg_output > threshold else "Non-speech" probability = f'{(float(avg_output) * 100):.2f}' inference_time = f'{prediction_time:.4f}' print(f"Final Prediction: {prediction}") print(f"Average Probability: {probability}%") print(f"Number of chunks processed: {num_chunks}") # Final result yield prediction, probability, inference_time, fig # Gradio Interface with gr.Blocks() as demo: gr.Image("./img/logo.png", elem_id="logo", height=100) # Title and Description gr.Markdown("

Voice Activity Detection using SincVAD

") gr.Markdown("

Upload or record audio to predict speech activity and view the probability curve.

") # Interface Layout with gr.Row(): with gr.Column(): # Separate recording and file upload audio_input = gr.Audio(type="filepath", label="Upload or Record Audio") threshold_input = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Threshold") with gr.Column(): prediction_output = gr.Textbox(label="Prediction") probability_output = gr.Number(label="Average Probability (%)") time_output = gr.Textbox(label="Inference Time (seconds)") plot_output = gr.Plot(label="Probability Curve") # Prediction Trigger predict_btn = gr.Button("Start Prediction") predict_btn.click( predict, [audio_input, threshold_input], [prediction_output, probability_output, time_output, plot_output], api_name="predict" ) # Launch Configuration if __name__ == "__main__": demo.queue() # Enable queue to support generators demo.launch()