Spaces:
Sleeping
Sleeping
File size: 6,958 Bytes
14c8893 134cc7e 14c8893 818e6ea 634bb0d cda51a4 6d5f4d7 14c8893 1cf2426 14c8893 9300764 14c8893 bef20d8 14c8893 1cf2426 14c8893 1cf2426 14c8893 1cf2426 14c8893 c506b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
import torchaudio
import torchaudio.transforms as T
import time
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from model.tinyvad import TinyVAD
# Font configuration
font_path = './fonts/Times_New_Roman.ttf'
font_prop = FontProperties(fname=font_path, size=18)
# Model and Processing Parameters
WINDOW_SIZE = 0.63
SINC_CONV = False
SSM = False
TARGET_SAMPLE_RATE = 16000
# Model Initialization
model = TinyVAD(1, 32, 64, patch_size=8, num_blocks=2,
sinc_conv=SINC_CONV, ssm=SSM)
checkpoint_path = './sincvad.ckpt'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(checkpoint, strict=False)
model.eval()
# Audio Processing Transforms
mel_spectrogram = T.MelSpectrogram(sample_rate=TARGET_SAMPLE_RATE, n_mels=64, win_length=400, hop_length=160)
log_mel_spectrogram = T.AmplitudeToDB()
# Chunking Parameters
chunk_duration = WINDOW_SIZE
shift_duration = WINDOW_SIZE * 0.875 # Increased overlap compared to first version
def predict(audio_input, threshold):
"""
Predict voice activity in an audio file with detailed processing and visualization.
Args:
audio_file (str): Path to the audio file
threshold (float): Decision threshold for speech/non-speech classification
Yields:
Intermediate and final prediction results
"""
start_time = time.time()
try:
# Load and preprocess audio
waveform, orig_sample_rate = torchaudio.load(audio_input)
# Resample if necessary
if orig_sample_rate != TARGET_SAMPLE_RATE:
print(f"Resampling from {orig_sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz")
resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=TARGET_SAMPLE_RATE)
waveform = resampler(waveform)
# Ensure mono channel
if waveform.size(0) > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
except Exception as e:
print(f"Error loading audio file: {e}")
yield "Error loading audio file.", None, None, None
return
# Audio duration checks and padding
audio_duration = waveform.size(1) / TARGET_SAMPLE_RATE
print(f"Audio duration: {audio_duration:.2f} seconds")
print(f"Original sample rate: {orig_sample_rate} Hz")
print(f"Current sample rate: {TARGET_SAMPLE_RATE} Hz")
if audio_duration < chunk_duration:
required_length = int(chunk_duration * TARGET_SAMPLE_RATE)
padding_length = required_length - waveform.size(1)
waveform = torch.nn.functional.pad(waveform, (0, padding_length))
# Chunk processing parameters
chunk_size = int(chunk_duration * TARGET_SAMPLE_RATE)
shift_size = int(shift_duration * TARGET_SAMPLE_RATE)
num_chunks = (waveform.size(1) - chunk_size) // shift_size + 1
predictions = []
time_stamps = []
detailed_predictions = []
# Initialize plot
fig, ax = plt.subplots(figsize=(12, 5))
ax.set_xlabel('Time (seconds)', fontproperties=font_prop)
ax.set_ylabel('Probability', fontproperties=font_prop)
ax.set_title('Voice Activity Detection Probability Over Time', fontproperties=font_prop)
ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold')
ax.grid(True)
ax.set_ylim([-0.05, 1.05])
# Process audio in chunks
for i in range(num_chunks):
start_idx = i * shift_size
end_idx = start_idx + chunk_size
chunk = waveform[:, start_idx:end_idx]
if chunk.size(1) < chunk_size:
break
# Feature extraction
inputs = mel_spectrogram(chunk)
inputs = log_mel_spectrogram(inputs).unsqueeze(0)
# Model inference
with torch.no_grad():
outputs = model(inputs)
outputs = torch.sigmoid(outputs)
# Process outputs
predictions.append(outputs.item())
time_stamps.append(start_idx / TARGET_SAMPLE_RATE)
detailed_predictions.append({
'start_time': start_idx / TARGET_SAMPLE_RATE,
'output': outputs.item(),
})
# Update plot dynamically
ax.clear()
ax.set_xlabel('Time (seconds)', fontproperties=font_prop)
ax.set_ylabel('Probability', fontproperties=font_prop)
ax.set_title('Speech Probability Over Time', fontproperties=font_prop)
ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold')
ax.grid(True)
ax.set_ylim([-0.05, 1.05])
ax.plot(time_stamps, predictions, label='Speech Probability', color='tab:blue')
plt.tight_layout()
# Yield intermediate progress
yield "Processing...", None, None, fig
# Detailed logging
print("Detailed Predictions:")
for pred in detailed_predictions:
print(f"Start Time: {pred['start_time']:.2f}s, Output: {pred['output']:.4f}")
# Final prediction processing
avg_output = max(0, min(1, np.mean(predictions)))
prediction_time = time.time() - start_time
prediction = "Speech" if avg_output > threshold else "Non-speech"
probability = f'{(float(avg_output) * 100):.2f}'
inference_time = f'{prediction_time:.4f}'
print(f"Final Prediction: {prediction}")
print(f"Average Probability: {probability}%")
print(f"Number of chunks processed: {num_chunks}")
# Final result
yield prediction, probability, inference_time, fig
# Gradio Interface
with gr.Blocks() as demo:
gr.Image("./img/logo.png", elem_id="logo", height=100)
# Title and Description
gr.Markdown("<h1 style='text-align: center; color: black;'>Voice Activity Detection using SincVAD</h1>")
gr.Markdown("<h3 style='text-align: center; color: black;'>Upload or record audio to predict speech activity and view the probability curve.</h3>")
# Interface Layout
with gr.Row():
with gr.Column():
# Separate recording and file upload
audio_input = gr.Audio(type="filepath", label="Upload or Record Audio")
threshold_input = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Threshold")
with gr.Column():
prediction_output = gr.Textbox(label="Prediction")
probability_output = gr.Number(label="Average Probability (%)")
time_output = gr.Textbox(label="Inference Time (seconds)")
plot_output = gr.Plot(label="Probability Curve")
# Prediction Trigger
predict_btn = gr.Button("Start Prediction")
predict_btn.click(
predict,
[audio_input, threshold_input],
[prediction_output, probability_output, time_output, plot_output],
api_name="predict"
)
# Launch Configuration
if __name__ == "__main__":
demo.queue() # Enable queue to support generators
demo.launch() |