jethrowang commited on
Commit
14c8893
·
verified ·
1 Parent(s): 230c11f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torchaudio
7
+ import torchaudio.transforms as T
8
+ import time
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.font_manager import FontProperties
11
+
12
+ from model.tinyvad import TinyVAD
13
+
14
+ # Configuration
15
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+ # Font configuration
19
+ font_path = '/share/nas169/jethrowang/fonts/Times_New_Roman.ttf'
20
+ font_prop = FontProperties(fname=font_path, size=18)
21
+
22
+ # Model and Processing Parameters
23
+ WINDOW_SIZE = 0.63
24
+ SINC_CONV = False
25
+ SSM = False
26
+ TARGET_SAMPLE_RATE = 16000
27
+
28
+ # Model Initialization
29
+ model = TinyVAD(1, 32, 64, patch_size=8, num_blocks=2,
30
+ sinc_conv=SINC_CONV, ssm=SSM).to(device)
31
+ checkpoint_path = '/share/nas169/jethrowang/SincVAD/exp/exp_0.63_tinyvad_psq_0.05/model_epoch_37_val_auroc=0.8894.ckpt'
32
+ model.load_state_dict(torch.load(checkpoint_path, weights_only=True))
33
+ model.eval()
34
+
35
+ # Audio Processing Transforms
36
+ mel_spectrogram = T.MelSpectrogram(sample_rate=TARGET_SAMPLE_RATE, n_mels=64, win_length=400, hop_length=160)
37
+ log_mel_spectrogram = T.AmplitudeToDB()
38
+
39
+ # Chunking Parameters
40
+ chunk_duration = WINDOW_SIZE
41
+ shift_duration = WINDOW_SIZE * 0.875 # Increased overlap compared to first version
42
+
43
+ def predict(audio_record, audio_upload, threshold):
44
+ """
45
+ Predict voice activity in an audio file with detailed processing and visualization.
46
+
47
+ Args:
48
+ audio_file (str): Path to the audio file
49
+ threshold (float): Decision threshold for speech/non-speech classification
50
+
51
+ Yields:
52
+ Intermediate and final prediction results
53
+ """
54
+ start_time = time.time()
55
+
56
+ audio_input = audio_record if audio_record else audio_upload
57
+ if not audio_input:
58
+ return "No audio provided!", 0.0, "N/A", None
59
+
60
+ try:
61
+ # Load and preprocess audio
62
+ waveform, orig_sample_rate = torchaudio.load(audio_input)
63
+
64
+ # Resample if necessary
65
+ if orig_sample_rate != TARGET_SAMPLE_RATE:
66
+ print(f"Resampling from {orig_sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz")
67
+ resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=TARGET_SAMPLE_RATE)
68
+ waveform = resampler(waveform)
69
+
70
+ # Ensure mono channel
71
+ if waveform.size(0) > 1:
72
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
73
+
74
+ except Exception as e:
75
+ print(f"Error loading audio file: {e}")
76
+ yield "Error loading audio file.", None, None, None
77
+ return
78
+
79
+ # Audio duration checks and padding
80
+ audio_duration = waveform.size(1) / TARGET_SAMPLE_RATE
81
+ print(f"Audio duration: {audio_duration:.2f} seconds")
82
+ print(f"Original sample rate: {orig_sample_rate} Hz")
83
+ print(f"Current sample rate: {TARGET_SAMPLE_RATE} Hz")
84
+
85
+ if audio_duration < chunk_duration:
86
+ required_length = int(chunk_duration * TARGET_SAMPLE_RATE)
87
+ padding_length = required_length - waveform.size(1)
88
+ waveform = torch.nn.functional.pad(waveform, (0, padding_length))
89
+
90
+ # Chunk processing parameters
91
+ chunk_size = int(chunk_duration * TARGET_SAMPLE_RATE)
92
+ shift_size = int(shift_duration * TARGET_SAMPLE_RATE)
93
+ num_chunks = (waveform.size(1) - chunk_size) // shift_size + 1
94
+
95
+ predictions = []
96
+ time_stamps = []
97
+ detailed_predictions = []
98
+
99
+ # Initialize plot
100
+ fig, ax = plt.subplots(figsize=(12, 5))
101
+ ax.set_xlabel('Time (seconds)', fontproperties=font_prop)
102
+ ax.set_ylabel('Probability', fontproperties=font_prop)
103
+ ax.set_title('Voice Activity Detection Probability Over Time', fontproperties=font_prop)
104
+ ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold')
105
+ ax.grid(True)
106
+ ax.set_ylim([-0.05, 1.05])
107
+
108
+ # Process audio in chunks
109
+ for i in range(num_chunks):
110
+ start_idx = i * shift_size
111
+ end_idx = start_idx + chunk_size
112
+ chunk = waveform[:, start_idx:end_idx]
113
+
114
+ if chunk.size(1) < chunk_size:
115
+ break
116
+
117
+ # Feature extraction
118
+ inputs = mel_spectrogram(chunk)
119
+ inputs = log_mel_spectrogram(inputs).to(device).unsqueeze(0)
120
+
121
+ # Model inference
122
+ with torch.no_grad():
123
+ outputs = model(inputs)
124
+ outputs = torch.sigmoid(outputs)
125
+
126
+ # Process outputs
127
+ predictions.append(outputs.item())
128
+ time_stamps.append(start_idx / TARGET_SAMPLE_RATE)
129
+
130
+ detailed_predictions.append({
131
+ 'start_time': start_idx / TARGET_SAMPLE_RATE,
132
+ 'output': outputs.item(),
133
+ })
134
+
135
+ # Update plot dynamically
136
+ ax.clear()
137
+ ax.set_xlabel('Time (seconds)', fontproperties=font_prop)
138
+ ax.set_ylabel('Probability', fontproperties=font_prop)
139
+ ax.set_title('Speech Probability Over Time', fontproperties=font_prop)
140
+ ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold')
141
+ ax.grid(True)
142
+ ax.set_ylim([-0.05, 1.05])
143
+ ax.plot(time_stamps, predictions, label='Speech Probability', color='tab:blue')
144
+ plt.tight_layout()
145
+
146
+ # Yield intermediate progress
147
+ yield "Processing...", None, None, fig
148
+
149
+ # Detailed logging
150
+ print("Detailed Predictions:")
151
+ for pred in detailed_predictions:
152
+ print(f"Start Time: {pred['start_time']:.2f}s, Output: {pred['output']:.4f}")
153
+
154
+ # Final prediction processing
155
+ avg_output = max(0, min(1, np.mean(predictions)))
156
+ prediction_time = time.time() - start_time
157
+
158
+ prediction = "Speech" if avg_output > threshold else "Non-speech"
159
+ probability = f'{(float(avg_output) * 100):.2f}'
160
+ inference_time = f'{prediction_time:.4f}'
161
+
162
+ print(f"Final Prediction: {prediction}")
163
+ print(f"Average Probability: {probability}%")
164
+ print(f"Number of chunks processed: {num_chunks}")
165
+
166
+ # Final result
167
+ yield prediction, probability, inference_time, fig
168
+
169
+ # Gradio Interface
170
+ with gr.Blocks() as demo:
171
+ gr.Image("./img/logo.png", elem_id="logo", height=100)
172
+ # Title and Description
173
+ gr.Markdown("<h1 style='text-align: center; color: black;'>Voice Activity Detection using SincVAD</h1>")
174
+ gr.Markdown("<h3 style='text-align: center; color: black;'>Record or upload audio to predict speech activity and view the probability curve.</h3>")
175
+
176
+ # Interface Layout
177
+ with gr.Row():
178
+ with gr.Column():
179
+ # Separate recording and file upload
180
+ record_input = gr.Audio(source="microphone", type="filepath", label="Record Audio")
181
+ upload_input = gr.Audio(source="upload", type="filepath", label="Upload Audio")
182
+ threshold_input = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Threshold")
183
+ with gr.Column():
184
+ prediction_output = gr.Textbox(label="Prediction")
185
+ probability_output = gr.Number(label="Average Probability (%)")
186
+ time_output = gr.Textbox(label="Inference Time (seconds)")
187
+
188
+ plot_output = gr.Plot(label="Probability Curve")
189
+
190
+ # Prediction Trigger
191
+ predict_btn = gr.Button("Start Prediction")
192
+ predict_btn.click(
193
+ predict,
194
+ [record_input, upload_input, threshold_input],
195
+ [prediction_output, probability_output, time_output, plot_output],
196
+ api_name="predict"
197
+ )
198
+
199
+
200
+ # Launch Configuration
201
+ if __name__ == "__main__":
202
+ demo.queue() # Enable queue to support generators
203
+ demo.launch(share=True)