Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import gradio as gr | |
import numpy as np | |
import torchaudio | |
import torchaudio.transforms as T | |
import time | |
import matplotlib.pyplot as plt | |
from matplotlib.font_manager import FontProperties | |
from model.tinyvad import TinyVAD | |
# Font configuration | |
font_path = './fonts/Times_New_Roman.ttf' | |
font_prop = FontProperties(fname=font_path, size=18) | |
# Model and Processing Parameters | |
WINDOW_SIZE = 0.63 | |
SINC_CONV = False | |
SSM = False | |
TARGET_SAMPLE_RATE = 16000 | |
# Model Initialization | |
model = TinyVAD(1, 32, 64, patch_size=8, num_blocks=2, | |
sinc_conv=SINC_CONV, ssm=SSM) | |
checkpoint_path = './sincvad.ckpt' | |
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True) | |
model.load_state_dict(checkpoint, strict=False) | |
model.eval() | |
# Audio Processing Transforms | |
mel_spectrogram = T.MelSpectrogram(sample_rate=TARGET_SAMPLE_RATE, n_mels=64, win_length=400, hop_length=160) | |
log_mel_spectrogram = T.AmplitudeToDB() | |
# Chunking Parameters | |
chunk_duration = WINDOW_SIZE | |
shift_duration = WINDOW_SIZE * 0.875 # Increased overlap compared to first version | |
def predict(audio_input, threshold): | |
""" | |
Predict voice activity in an audio file with detailed processing and visualization. | |
Args: | |
audio_file (str): Path to the audio file | |
threshold (float): Decision threshold for speech/non-speech classification | |
Yields: | |
Intermediate and final prediction results | |
""" | |
start_time = time.time() | |
try: | |
# Load and preprocess audio | |
waveform, orig_sample_rate = torchaudio.load(audio_input) | |
# Resample if necessary | |
if orig_sample_rate != TARGET_SAMPLE_RATE: | |
print(f"Resampling from {orig_sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz") | |
resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=TARGET_SAMPLE_RATE) | |
waveform = resampler(waveform) | |
# Ensure mono channel | |
if waveform.size(0) > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
except Exception as e: | |
print(f"Error loading audio file: {e}") | |
yield "Error loading audio file.", None, None, None | |
return | |
# Audio duration checks and padding | |
audio_duration = waveform.size(1) / TARGET_SAMPLE_RATE | |
print(f"Audio duration: {audio_duration:.2f} seconds") | |
print(f"Original sample rate: {orig_sample_rate} Hz") | |
print(f"Current sample rate: {TARGET_SAMPLE_RATE} Hz") | |
if audio_duration < chunk_duration: | |
required_length = int(chunk_duration * TARGET_SAMPLE_RATE) | |
padding_length = required_length - waveform.size(1) | |
waveform = torch.nn.functional.pad(waveform, (0, padding_length)) | |
# Chunk processing parameters | |
chunk_size = int(chunk_duration * TARGET_SAMPLE_RATE) | |
shift_size = int(shift_duration * TARGET_SAMPLE_RATE) | |
num_chunks = (waveform.size(1) - chunk_size) // shift_size + 1 | |
predictions = [] | |
time_stamps = [] | |
detailed_predictions = [] | |
# Initialize plot | |
fig, ax = plt.subplots(figsize=(12, 5)) | |
ax.set_xlabel('Time (seconds)', fontproperties=font_prop) | |
ax.set_ylabel('Probability', fontproperties=font_prop) | |
ax.set_title('Voice Activity Detection Probability Over Time', fontproperties=font_prop) | |
ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold') | |
ax.grid(True) | |
ax.set_ylim([-0.05, 1.05]) | |
# Process audio in chunks | |
for i in range(num_chunks): | |
start_idx = i * shift_size | |
end_idx = start_idx + chunk_size | |
chunk = waveform[:, start_idx:end_idx] | |
if chunk.size(1) < chunk_size: | |
break | |
# Feature extraction | |
inputs = mel_spectrogram(chunk) | |
inputs = log_mel_spectrogram(inputs).unsqueeze(0) | |
# Model inference | |
with torch.no_grad(): | |
outputs = model(inputs) | |
outputs = torch.sigmoid(outputs) | |
# Process outputs | |
predictions.append(outputs.item()) | |
time_stamps.append(start_idx / TARGET_SAMPLE_RATE) | |
detailed_predictions.append({ | |
'start_time': start_idx / TARGET_SAMPLE_RATE, | |
'output': outputs.item(), | |
}) | |
# Update plot dynamically | |
ax.clear() | |
ax.set_xlabel('Time (seconds)', fontproperties=font_prop) | |
ax.set_ylabel('Probability', fontproperties=font_prop) | |
ax.set_title('Speech Probability Over Time', fontproperties=font_prop) | |
ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold') | |
ax.grid(True) | |
ax.set_ylim([-0.05, 1.05]) | |
ax.plot(time_stamps, predictions, label='Speech Probability', color='tab:blue') | |
plt.tight_layout() | |
# Yield intermediate progress | |
yield "Processing...", None, None, fig | |
# Detailed logging | |
print("Detailed Predictions:") | |
for pred in detailed_predictions: | |
print(f"Start Time: {pred['start_time']:.2f}s, Output: {pred['output']:.4f}") | |
# Final prediction processing | |
avg_output = max(0, min(1, np.mean(predictions))) | |
prediction_time = time.time() - start_time | |
prediction = "Speech" if avg_output > threshold else "Non-speech" | |
probability = f'{(float(avg_output) * 100):.2f}' | |
inference_time = f'{prediction_time:.4f}' | |
print(f"Final Prediction: {prediction}") | |
print(f"Average Probability: {probability}%") | |
print(f"Number of chunks processed: {num_chunks}") | |
# Final result | |
yield prediction, probability, inference_time, fig | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Image("./img/logo.png", elem_id="logo", height=100) | |
# Title and Description | |
gr.Markdown("<h1 style='text-align: center; color: black;'>Voice Activity Detection using SincVAD</h1>") | |
gr.Markdown("<h3 style='text-align: center; color: black;'>Upload or record audio to predict speech activity and view the probability curve.</h3>") | |
# Interface Layout | |
with gr.Row(): | |
with gr.Column(): | |
# Separate recording and file upload | |
audio_input = gr.Audio(type="filepath", label="Upload or Record Audio") | |
threshold_input = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Threshold") | |
with gr.Column(): | |
prediction_output = gr.Textbox(label="Prediction") | |
probability_output = gr.Number(label="Average Probability (%)") | |
time_output = gr.Textbox(label="Inference Time (seconds)") | |
plot_output = gr.Plot(label="Probability Curve") | |
# Prediction Trigger | |
predict_btn = gr.Button("Start Prediction") | |
predict_btn.click( | |
predict, | |
[audio_input, threshold_input], | |
[prediction_output, probability_output, time_output, plot_output], | |
api_name="predict" | |
) | |
# Launch Configuration | |
if __name__ == "__main__": | |
demo.queue() # Enable queue to support generators | |
demo.launch() |