Mattral commited on
Commit
6f56f73
·
verified ·
1 Parent(s): 906d632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -118
app.py CHANGED
@@ -12,26 +12,47 @@ import matplotlib.pyplot as plt
12
  import plotly.express as px
13
  import soundfile as sf
14
  from scipy.signal import stft
 
15
 
16
- # Dummy CNN Model for Audio
 
 
17
  class AudioCNN(nn.Module):
18
  def __init__(self):
19
  super(AudioCNN, self).__init__()
 
20
  self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
21
  self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
22
- self.fc1 = nn.Linear(32 * 32 * 8, 128) # Adjusted for typical spectrogram size
23
- self.fc2 = nn.Linear(128, 10)
 
 
 
 
 
 
 
24
 
25
  def forward(self, x):
26
- x1 = F.relu(self.conv1(x)) # First conv layer activation
27
- x2 = F.relu(self.conv2(x1))
28
- x3 = F.adaptive_avg_pool2d(x2, (8, 32))
29
- x4 = x3.view(x3.size(0), -1)
30
- x5 = F.relu(self.fc1(x4))
31
- x6 = self.fc2(x5)
32
- return x6, x1
33
-
34
- # Audio processing functions
 
 
 
 
 
 
 
 
 
 
35
  def load_audio(file):
36
  audio, sr = librosa.load(file, sr=None, mono=True)
37
  return audio, sr
@@ -53,11 +74,13 @@ def filter_fft(fft, percentage):
53
  def create_spectrogram(audio, sr):
54
  n_fft = 2048
55
  hop_length = 512
56
- stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
57
- spectrogram = np.abs(stft)
58
  return spectrogram, n_fft, hop_length
59
 
60
- # Visualization functions
 
 
61
  def plot_waveform(audio, sr, title):
62
  fig = go.Figure()
63
  time = np.arange(len(audio)) / sr
@@ -65,41 +88,111 @@ def plot_waveform(audio, sr, title):
65
  fig.update_layout(title=title, xaxis_title='Time (s)', yaxis_title='Amplitude')
66
  return fig
67
 
 
 
 
 
 
 
68
  def plot_fft(magnitude, phase, sr):
69
  fig = make_subplots(rows=2, cols=1, subplot_titles=('Magnitude Spectrum', 'Phase Spectrum'))
70
  freq = np.fft.fftfreq(len(magnitude), 1/sr)
71
-
72
  fig.add_trace(go.Scatter(x=freq, y=magnitude, mode='lines', name='Magnitude'), row=1, col=1)
73
  fig.add_trace(go.Scatter(x=freq, y=phase, mode='lines', name='Phase'), row=2, col=1)
74
-
75
  fig.update_xaxes(title_text='Frequency (Hz)', row=1, col=1)
76
  fig.update_xaxes(title_text='Frequency (Hz)', row=2, col=1)
77
  fig.update_yaxes(title_text='Magnitude', row=1, col=1)
78
  fig.update_yaxes(title_text='Phase (radians)', row=2, col=1)
79
-
80
  return fig
81
 
82
- def plot_3d_fft(magnitude, phase, sr):
83
  freq = np.fft.fftfreq(len(magnitude), 1/sr)
84
- fig = go.Figure(data=[go.Scatter3d(
85
- x=freq,
86
- y=magnitude,
87
- z=phase,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  mode='markers',
89
  marker=dict(
90
- size=5,
91
- color=phase, # Color by phase
92
- colorscale='Viridis', # Choose a colorscale
93
- opacity=0.8
 
94
  )
95
- )])
96
 
 
97
  fig.update_layout(scene=dict(
98
- xaxis_title='Frequency (Hz)',
99
- yaxis_title='Magnitude',
100
- zaxis_title='Phase (radians)'
101
- ))
102
-
103
  return fig
104
 
105
  def plot_spectrogram(spectrogram, sr, hop_length):
@@ -110,117 +203,185 @@ def plot_spectrogram(spectrogram, sr, hop_length):
110
  plt.title('Spectrogram')
111
  return fig
112
 
113
- def create_fft_table(magnitude, phase, sr):
114
- freq = np.fft.fftfreq(len(magnitude), 1/sr)
115
- df = pd.DataFrame({
116
- 'Frequency (Hz)': freq,
117
- 'Magnitude': magnitude,
118
- 'Phase (radians)': phase
119
- })
120
  return df
121
 
122
- # Streamlit UI
 
 
 
 
 
 
 
 
 
123
  st.set_page_config(layout="wide")
124
- st.title("Audio Frequency Analysis with CNN")
125
 
126
- # Initialize session state
127
- if 'audio_data' not in st.session_state:
128
- st.session_state.audio_data = None
129
- if 'sr' not in st.session_state:
130
- st.session_state.sr = None
131
- if 'fft' not in st.session_state:
132
- st.session_state.fft = None
 
 
 
 
133
 
134
  # File uploader
135
- uploaded_file = st.file_uploader("Upload an audio file", type=['wav', 'mp3', 'ogg'])
136
 
137
  if uploaded_file is not None:
138
- # Load and process audio
139
  audio, sr = load_audio(uploaded_file)
140
- st.session_state.audio_data = audio
141
- st.session_state.sr = sr
142
 
143
- # Display original waveform
144
- st.subheader("Original Audio Waveform")
145
- st.plotly_chart(plot_waveform(audio, sr, "Original Waveform"), use_container_width=True)
 
 
 
 
 
 
 
146
 
147
- # Apply FFT
 
 
 
 
 
 
148
  fft, magnitude, phase = apply_fft(audio)
149
- st.session_state.fft = fft
150
-
151
- # Display FFT results
152
- st.subheader("Frequency Domain Analysis")
153
- st.plotly_chart(plot_fft(magnitude, phase, sr), use_container_width=True)
154
-
155
- # 3D FFT Plot
156
- st.subheader("3D Frequency Domain Analysis")
157
- st.plotly_chart(plot_3d_fft(magnitude, phase, sr), use_container_width=True)
 
 
158
 
159
- # FFT Table
160
- st.subheader("FFT Values Table")
161
- fft_table = create_fft_table(magnitude, phase, sr)
162
- st.dataframe(fft_table)
163
-
164
- # Frequency filtering
 
 
165
  percentage = st.slider("Percentage of frequencies to retain:", 0.1, 100.0, 10.0, 0.1)
166
-
167
  if st.button("Apply Frequency Filter"):
168
- filtered_fft = filter_fft(st.session_state.fft, percentage)
169
  reconstructed = np.fft.ifft(filtered_fft).real
170
-
171
- # Display reconstructed waveform
172
- st.subheader("Reconstructed Audio")
173
- st.plotly_chart(plot_waveform(reconstructed, sr, "Filtered Waveform"), use_container_width=True)
174
-
175
- # Play audio
176
- st.audio(reconstructed, sample_rate=sr)
177
 
178
- # Spectrogram creation
179
- st.subheader("Spectrogram Analysis")
 
 
 
 
 
180
  spectrogram, n_fft, hop_length = create_spectrogram(audio, sr)
181
  st.pyplot(plot_spectrogram(spectrogram, sr, hop_length))
 
182
 
183
- # CNN Processing
184
- if st.button("Process with CNN"):
185
- # Convert spectrogram to tensor
 
 
 
 
 
186
  spec_tensor = torch.tensor(spectrogram[np.newaxis, np.newaxis, ...], dtype=torch.float32)
187
-
188
  model = AudioCNN()
189
  with torch.no_grad():
190
- output, activations = model(spec_tensor)
191
-
192
- # Visualize activations
193
- st.subheader("CNN Layer Activations")
194
-
195
- # Input spectrogram
196
- st.write("### Input Spectrogram")
197
- fig_input, ax = plt.subplots()
198
- ax.imshow(spectrogram, aspect='auto', origin='lower')
199
- st.pyplot(fig_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- # First conv layer activations
202
- st.write("### First Convolution Layer Activations")
203
- activation = activations.detach().numpy()[0]
 
 
 
 
204
 
205
- cols = 4
206
- rows = 4
207
- fig, axs = plt.subplots(rows, cols, figsize=(20, 20))
208
- for i in range(16):
209
- ax = axs[i//cols, i%cols]
210
- ax.imshow(activation[i], aspect='auto', origin='lower')
211
- ax.set_title(f'Channel {i+1}')
212
- plt.tight_layout()
213
- st.pyplot(fig)
214
 
215
- # Classification results
216
- st.write("### Classification Output")
217
- probabilities = F.softmax(output, dim=1).numpy()[0]
218
- classes = [f"Class {i}" for i in range(10)]
219
- df = pd.DataFrame({"Class": classes, "Probability": probabilities})
220
- fig = px.bar(df, x="Class", y="Probability", color="Probability")
221
- st.plotly_chart(fig)
222
-
223
- # Add some styling
 
 
 
 
 
 
 
 
 
224
  st.markdown("""
225
  <style>
226
  .stButton>button {
 
12
  import plotly.express as px
13
  import soundfile as sf
14
  from scipy.signal import stft
15
+ import math
16
 
17
+ # -------------------------------
18
+ # CNN Model for Audio Analysis
19
+ # -------------------------------
20
  class AudioCNN(nn.Module):
21
  def __init__(self):
22
  super(AudioCNN, self).__init__()
23
+ # Convolutional layers
24
  self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
25
  self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
26
+ self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
27
+ # Pooling layer
28
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
29
+ # Fully connected layers (with dynamic sizing)
30
+ self.fc1 = None
31
+ self.fc2 = nn.Linear(256, 128)
32
+ self.fc3 = nn.Linear(128, 10)
33
+ # Dropout for regularization
34
+ self.dropout = nn.Dropout(0.5)
35
 
36
  def forward(self, x):
37
+ x1 = F.relu(self.conv1(x))
38
+ x2 = self.pool(x1)
39
+ x3 = F.relu(self.conv2(x2))
40
+ x4 = self.pool(x3)
41
+ x5 = F.relu(self.conv3(x4))
42
+ x6 = self.pool(x5)
43
+ if self.fc1 is None:
44
+ fc1_input_size = x6.numel() // x6.size(0)
45
+ self.fc1 = nn.Linear(fc1_input_size, 256)
46
+ x7 = x6.view(x6.size(0), -1)
47
+ x8 = F.relu(self.fc1(x7))
48
+ x9 = self.dropout(x8)
49
+ x10 = F.relu(self.fc2(x9))
50
+ x11 = self.fc3(x10)
51
+ return x11, [x2, x4, x6], x8
52
+
53
+ # -------------------------------
54
+ # Audio Processing Functions
55
+ # -------------------------------
56
  def load_audio(file):
57
  audio, sr = librosa.load(file, sr=None, mono=True)
58
  return audio, sr
 
74
  def create_spectrogram(audio, sr):
75
  n_fft = 2048
76
  hop_length = 512
77
+ S = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
78
+ spectrogram = np.abs(S)
79
  return spectrogram, n_fft, hop_length
80
 
81
+ # -------------------------------
82
+ # Visualization Functions
83
+ # -------------------------------
84
  def plot_waveform(audio, sr, title):
85
  fig = go.Figure()
86
  time = np.arange(len(audio)) / sr
 
88
  fig.update_layout(title=title, xaxis_title='Time (s)', yaxis_title='Amplitude')
89
  return fig
90
 
91
+ def create_waveform_table(audio, sr, num_samples=100):
92
+ time = np.arange(len(audio)) / sr
93
+ indices = np.linspace(0, len(audio)-1, num_samples, dtype=int)
94
+ df = pd.DataFrame({"Time (s)": time[indices], "Amplitude": audio[indices]})
95
+ return df
96
+
97
  def plot_fft(magnitude, phase, sr):
98
  fig = make_subplots(rows=2, cols=1, subplot_titles=('Magnitude Spectrum', 'Phase Spectrum'))
99
  freq = np.fft.fftfreq(len(magnitude), 1/sr)
 
100
  fig.add_trace(go.Scatter(x=freq, y=magnitude, mode='lines', name='Magnitude'), row=1, col=1)
101
  fig.add_trace(go.Scatter(x=freq, y=phase, mode='lines', name='Phase'), row=2, col=1)
 
102
  fig.update_xaxes(title_text='Frequency (Hz)', row=1, col=1)
103
  fig.update_xaxes(title_text='Frequency (Hz)', row=2, col=1)
104
  fig.update_yaxes(title_text='Magnitude', row=1, col=1)
105
  fig.update_yaxes(title_text='Phase (radians)', row=2, col=1)
 
106
  return fig
107
 
108
+ def plot_fft_bands(magnitude, phase, sr):
109
  freq = np.fft.fftfreq(len(magnitude), 1/sr)
110
+ pos_mask = freq >= 0
111
+ freq, magnitude, phase = freq[pos_mask], magnitude[pos_mask], phase[pos_mask]
112
+ bass_mask = (freq >= 20) & (freq < 250)
113
+ mid_mask = (freq >= 250) & (freq < 4000)
114
+ treble_mask = (freq >= 4000) & (freq <= sr/2)
115
+ fig = make_subplots(rows=2, cols=1, subplot_titles=('Magnitude Spectrum by Bands', 'Phase Spectrum by Bands'))
116
+ fig.add_trace(go.Scatter(x=freq[bass_mask], y=magnitude[bass_mask], mode='lines', name='Bass'), row=1, col=1)
117
+ fig.add_trace(go.Scatter(x=freq[mid_mask], y=magnitude[mid_mask], mode='lines', name='Mid'), row=1, col=1)
118
+ fig.add_trace(go.Scatter(x=freq[treble_mask], y=magnitude[treble_mask], mode='lines', name='Treble'), row=1, col=1)
119
+ fig.add_trace(go.Scatter(x=freq[bass_mask], y=phase[bass_mask], mode='lines', name='Bass'), row=2, col=1)
120
+ fig.add_trace(go.Scatter(x=freq[mid_mask], y=phase[mid_mask], mode='lines', name='Mid'), row=2, col=1)
121
+ fig.add_trace(go.Scatter(x=freq[treble_mask], y=phase[treble_mask], mode='lines', name='Treble'), row=2, col=1)
122
+ fig.update_xaxes(title_text='Frequency (Hz)', row=1, col=1)
123
+ fig.update_xaxes(title_text='Frequency (Hz)', row=2, col=1)
124
+ fig.update_yaxes(title_text='Magnitude', row=1, col=1)
125
+ fig.update_yaxes(title_text='Phase (radians)', row=2, col=1)
126
+ return fig
127
+
128
+ def create_fft_table(magnitude, phase, sr, num_samples=100):
129
+ freq = np.fft.fftfreq(len(magnitude), 1/sr)
130
+ pos_mask = freq >= 0
131
+ freq, magnitude, phase = freq[pos_mask], magnitude[pos_mask], phase[pos_mask]
132
+ indices = np.linspace(0, len(freq)-1, num_samples, dtype=int)
133
+ df = pd.DataFrame({
134
+ "Frequency (Hz)": freq[indices],
135
+ "Magnitude": magnitude[indices],
136
+ "Phase (radians)": phase[indices]
137
+ })
138
+ return df
139
+
140
+ def plot_3d_polar_fft(magnitude, phase, sr):
141
+ # Get positive frequencies
142
+ freq = np.fft.fftfreq(len(magnitude), 1/sr)
143
+ pos_mask = freq >= 0
144
+ freq, mag, ph = freq[pos_mask], magnitude[pos_mask], phase[pos_mask]
145
+ # Convert polar to Cartesian coordinates
146
+ x = mag * np.cos(ph)
147
+ y = mag * np.sin(ph)
148
+ z = freq # Use frequency as z-axis
149
+
150
+ # Downsample the data to avoid huge message sizes.
151
+ # Compute a decimation factor so that approximately 500 points are plotted.
152
+ step = max(1, len(x) // 500)
153
+ x, y, z, ph = x[::step], y[::step], z[::step], ph[::step]
154
+
155
+ # Create a coarser grid for the contour surface.
156
+ n_rep = 10
157
+ X_surface = np.tile(x, (n_rep, 1))
158
+ Y_surface = np.tile(y, (n_rep, 1))
159
+ Z_surface = np.tile(z, (n_rep, 1))
160
+
161
+ surface = go.Surface(
162
+ x=X_surface,
163
+ y=Y_surface,
164
+ z=Z_surface,
165
+ colorscale='Viridis',
166
+ opacity=0.6,
167
+ showscale=False,
168
+ contours={
169
+ "x": {"show": True, "start": float(np.min(x)), "end": float(np.max(x)), "size": float((np.max(x)-np.min(x))/10)},
170
+ "y": {"show": True, "start": float(np.min(y)), "end": float(np.max(y)), "size": float((np.max(y)-np.min(y))/10)},
171
+ "z": {"show": True, "start": float(np.min(z)), "end": float(np.max(z)), "size": float((np.max(z)-np.min(z))/10)},
172
+ },
173
+ )
174
+
175
+ scatter = go.Scatter3d(
176
+ x=x,
177
+ y=y,
178
+ z=z,
179
  mode='markers',
180
  marker=dict(
181
+ size=3,
182
+ color=ph, # color by phase
183
+ colorscale='Viridis',
184
+ opacity=0.8,
185
+ colorbar=dict(title='Phase (radians)')
186
  )
187
+ )
188
 
189
+ fig = go.Figure(data=[surface, scatter])
190
  fig.update_layout(scene=dict(
191
+ xaxis_title='Real Component',
192
+ yaxis_title='Imaginary Component',
193
+ zaxis_title='Frequency (Hz)',
194
+ camera=dict(eye=dict(x=1.5, y=1.5, z=0.5))
195
+ ), margin=dict(l=0, r=0, b=0, t=0))
196
  return fig
197
 
198
  def plot_spectrogram(spectrogram, sr, hop_length):
 
203
  plt.title('Spectrogram')
204
  return fig
205
 
206
+ def create_spectrogram_table(spectrogram, num_rows=10, num_cols=10):
207
+ sub_spec = spectrogram[:num_rows, :num_cols]
208
+ df = pd.DataFrame(sub_spec,
209
+ index=[f'Freq Bin {i}' for i in range(sub_spec.shape[0])],
210
+ columns=[f'Time Bin {j}' for j in range(sub_spec.shape[1])])
 
 
211
  return df
212
 
213
+ def create_activation_table(activation, num_rows=10, num_cols=10):
214
+ sub_act = activation[:num_rows, :num_cols]
215
+ df = pd.DataFrame(sub_act,
216
+ index=[f'Row {i}' for i in range(sub_act.shape[0])],
217
+ columns=[f'Col {j}' for j in range(sub_act.shape[1])])
218
+ return df
219
+
220
+ # -------------------------------
221
+ # Streamlit UI & Main App
222
+ # -------------------------------
223
  st.set_page_config(layout="wide")
224
+ st.title("Audio Frequency Analysis with CNN and FFT")
225
 
226
+ st.markdown("""
227
+ ### Welcome to the Audio Frequency Analysis Tool!
228
+ This application allows you to:
229
+ - **Upload an audio file** and visualize its waveform along with a data table.
230
+ - **Analyze frequency components** using FFT (with both 2D and enhanced 3D polar plots).
231
+ - **Highlight frequency bands:** Bass (20–250 Hz), Mid (250–4000 Hz), Treble (4000 Hz to Nyquist).
232
+ - **Filter frequency components** and reconstruct the waveform.
233
+ - **Generate a spectrogram** for time-frequency analysis with a sample data table.
234
+ - **Inspect CNN activations** (pooling and dense layers) arranged in grid layouts.
235
+ - **Final Audio Classification:** Classify the audio for gender (Male/Female) and tone.
236
+ """)
237
 
238
  # File uploader
239
+ uploaded_file = st.file_uploader("Upload an audio file (WAV, MP3, OGG)", type=['wav', 'mp3', 'ogg'])
240
 
241
  if uploaded_file is not None:
 
242
  audio, sr = load_audio(uploaded_file)
 
 
243
 
244
+ # --- Section 1: Raw Audio Waveform ---
245
+ st.header("1. Raw Audio Waveform")
246
+ st.markdown("""
247
+ The waveform represents the amplitude over time.
248
+ **Graph:** Amplitude vs. Time.
249
+ **Data Table:** Sampled values.
250
+ """)
251
+ waveform_fig = plot_waveform(audio, sr, "Original Waveform")
252
+ st.plotly_chart(waveform_fig, use_container_width=True)
253
+ st.dataframe(create_waveform_table(audio, sr))
254
 
255
+ # --- Section 2: Frequency Domain Analysis ---
256
+ st.header("2. Frequency Domain Analysis")
257
+ st.markdown("""
258
+ **FFT Analysis:** Decompose the audio into frequency components.
259
+ - **Magnitude Spectrum:** Strength of frequencies.
260
+ - **Phase Spectrum:** Phase angles.
261
+ """)
262
  fft, magnitude, phase = apply_fft(audio)
263
+ col1, col2 = st.columns(2)
264
+ with col1:
265
+ st.subheader("2D FFT Plot")
266
+ st.plotly_chart(plot_fft(magnitude, phase, sr), use_container_width=True)
267
+ with col2:
268
+ st.subheader("Enhanced 3D Polar FFT Plot with Contours")
269
+ st.plotly_chart(plot_3d_polar_fft(magnitude, phase, sr), use_container_width=True)
270
+ st.subheader("FFT Data Table (Sampled)")
271
+ st.dataframe(create_fft_table(magnitude, phase, sr))
272
+ st.subheader("Frequency Bands: Bass, Mid, Treble")
273
+ st.plotly_chart(plot_fft_bands(magnitude, phase, sr), use_container_width=True)
274
 
275
+ # --- Section 3: Frequency Filtering ---
276
+ st.header("3. Frequency Filtering")
277
+ st.markdown("""
278
+ Filter the audio signal by retaining a percentage of the strongest frequencies.
279
+ Adjust the slider for retention percentage.
280
+ **Graph:** Filtered waveform.
281
+ **Data Table:** Sampled values.
282
+ """)
283
  percentage = st.slider("Percentage of frequencies to retain:", 0.1, 100.0, 10.0, 0.1)
 
284
  if st.button("Apply Frequency Filter"):
285
+ filtered_fft = filter_fft(fft, percentage)
286
  reconstructed = np.fft.ifft(filtered_fft).real
287
+ col1, col2 = st.columns(2)
288
+ with col1:
289
+ st.plotly_chart(plot_waveform(reconstructed, sr, "Filtered Waveform"), use_container_width=True)
290
+ with col2:
291
+ st.audio(reconstructed, sample_rate=sr)
292
+ st.dataframe(create_waveform_table(reconstructed, sr))
 
293
 
294
+ # --- Section 4: Spectrogram Analysis ---
295
+ st.header("4. Spectrogram Analysis")
296
+ st.markdown("""
297
+ A spectrogram shows how frequency content evolves over time.
298
+ **Graph:** Spectrogram (log-frequency scale).
299
+ **Data Table:** A subsection of the spectrogram matrix.
300
+ """)
301
  spectrogram, n_fft, hop_length = create_spectrogram(audio, sr)
302
  st.pyplot(plot_spectrogram(spectrogram, sr, hop_length))
303
+ st.dataframe(create_spectrogram_table(spectrogram))
304
 
305
+ # --- Section 5: CNN Analysis (Pooling & Dense Activations) ---
306
+ st.header("5. CNN Analysis: Pooling and Dense Activations")
307
+ st.markdown("""
308
+ Instead of classification probabilities, inspect internal activations:
309
+ - **Pooling Layer Outputs:** Arranged in a grid layout.
310
+ - **Dense Layer Activation:** Feature vector from the dense layer.
311
+ """)
312
+ if st.button("Run CNN Analysis"):
313
  spec_tensor = torch.tensor(spectrogram[np.newaxis, np.newaxis, ...], dtype=torch.float32)
 
314
  model = AudioCNN()
315
  with torch.no_grad():
316
+ output, pooling_outputs, dense_activation = model(spec_tensor)
317
+ for idx, activation in enumerate(pooling_outputs):
318
+ st.subheader(f"Pooling Layer {idx+1} Output")
319
+ act = activation[0].cpu().numpy()
320
+ num_channels = act.shape[0]
321
+ ncols = 4
322
+ nrows = math.ceil(num_channels / ncols)
323
+ fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))
324
+ axes = axes.flatten()
325
+ for i in range(nrows * ncols):
326
+ if i < num_channels:
327
+ axes[i].imshow(act[i], aspect='auto', origin='lower', cmap='viridis')
328
+ axes[i].set_title(f'Channel {i+1}', fontsize=8)
329
+ axes[i].axis('off')
330
+ else:
331
+ axes[i].axis('off')
332
+ st.pyplot(fig)
333
+ st.markdown("**Data Table for Pooling Layer Activation (Channel 1, Sampled)**")
334
+ df_act = create_activation_table(act[0])
335
+ st.dataframe(df_act)
336
+ st.subheader("Dense Layer Activation")
337
+ dense_act = dense_activation[0].cpu().numpy()
338
+ df_dense = pd.DataFrame({
339
+ "Feature Index": np.arange(len(dense_act)),
340
+ "Activation Value": dense_act
341
+ })
342
+ st.plotly_chart(px.bar(df_dense, x="Feature Index", y="Activation Value"), use_container_width=True)
343
+ st.dataframe(df_dense)
344
+
345
+ # --- Section 6: Final Audio Classification (Gender & Tone) ---
346
+ st.header("6. Final Audio Classification: Gender and Tone")
347
+ st.markdown("""
348
+ In this final step, a pretrained model classifies the audio as Male or Female,
349
+ and determines its tone (High Tone vs. Low Tone).
350
 
351
+ **Note:** This example uses a placeholder model. Replace the dummy model and random outputs with your actual pretrained model.
352
+ """)
353
+ if st.button("Run Final Classification"):
354
+ # Extract MFCC features as an example (adjust as needed)
355
+ mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=40)
356
+ features = np.mean(mfccs, axis=1) # average over time
357
+ features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
358
 
359
+ # Dummy classifier model for demonstration
360
+ class GenderToneClassifier(nn.Module):
361
+ def __init__(self):
362
+ super(GenderToneClassifier, self).__init__()
363
+ self.fc = nn.Linear(40, 4) # 4 outputs: [Male, Female, High Tone, Low Tone]
364
+ def forward(self, x):
365
+ return self.fc(x)
 
 
366
 
367
+ classifier = GenderToneClassifier()
368
+ # In practice, load your pretrained weights here.
369
+ with torch.no_grad():
370
+ output = classifier(features_tensor)
371
+ probs = F.softmax(output, dim=1).numpy()[0]
372
+ # Interpret outputs: assume first 2 are gender, next 2 are tone.
373
+ gender = "Male" if probs[0] > probs[1] else "Female"
374
+ tone = "High Tone" if probs[2] > probs[3] else "Low Tone"
375
+ st.markdown(f"**Predicted Gender:** {gender}")
376
+ st.markdown(f"**Predicted Tone:** {tone}")
377
+ categories = ["Male", "Female", "High Tone", "Low Tone"]
378
+ df_class = pd.DataFrame({"Category": categories, "Probability": probs})
379
+ st.plotly_chart(px.bar(df_class, x="Category", y="Probability"), use_container_width=True)
380
+ st.dataframe(df_class)
381
+
382
+ # -------------------------------
383
+ # Style Enhancements
384
+ # -------------------------------
385
  st.markdown("""
386
  <style>
387
  .stButton>button {