SoundingStreet / audio_mixer.py
FQiao's picture
Upload 70 files
3324de2 verified
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import matplotlib.pyplot as plt
import os
from typing import List, Tuple
from config import LOGS_DIR
##Some utils:
def load_audio_files(file_paths: List[str]) -> List[Tuple[torch.Tensor, int]]:
"""
Load multiple audio files and ensure they have the same length.
Args:
file_paths: List of paths to audio files
Returns:
List of tuples containing audio data and sample rate
"""
audio_data = []
for path in file_paths:
# Load audio file
waveform, sample_rate = torchaudio.load(path)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
audio_data.append((waveform.squeeze(), sample_rate))
# Verify all audio files have the same length and sample rate
lengths = [len(audio) for audio, _ in audio_data]
sample_rates = [sr for _, sr in audio_data]
if len(set(lengths)) > 1:
raise ValueError(f"Audio files have different lengths: {lengths}")
if len(set(sample_rates)) > 1:
raise ValueError(f"Audio files have different sample rates: {sample_rates}")
return audio_data
def normalize_audio_volumes(audio_data: List[Tuple[torch.Tensor, int]]) -> List[Tuple[torch.Tensor, int]]:
"""
Normalize the volume of each audio file to have the same energy level.
Args:
audio_data: List of tuples containing audio data and sample rate
Returns:
List of tuples containing normalized audio data and sample rate
"""
normalized_data = []
# Calculate RMS (Root Mean Square) for each audio
rms_values = []
for audio, sr in audio_data:
# Calculate energy (squared amplitude)
energy = torch.mean(audio ** 2)
# Calculate RMS (square root of mean energy)
rms = torch.sqrt(energy)
rms_values.append(rms)
# Find the target RMS (we'll use the median to avoid outliers)
target_rms = torch.median(torch.tensor(rms_values))
# Normalize each audio to the target RMS
for (audio, sr), rms in zip(audio_data, rms_values):
if rms > 0: # Avoid division by zero
# Calculate scaling factor
scaling_factor = target_rms / rms
# Apply scaling
normalized_audio = audio * scaling_factor
else:
normalized_audio = audio
normalized_data.append((normalized_audio, sr))
return normalized_data
def plot_energy_comparison(original_metrics: List[dict], normalized_metrics: List[dict], file_names: List[str], output_path: str = "./logs/energy_comparison.png") -> None:
"""
Plot a comparison of energy metrics before and after normalization.
Args:
original_metrics: List of dictionaries containing metrics for original audio
normalized_metrics: List of dictionaries containing metrics for normalized audio
file_names: List of audio file names
output_path: Path to save the plot
"""
fig, axs = plt.subplots(2, 2, figsize=(14, 10))
# Extract metrics
orig_rms = [m['rms'] for m in original_metrics]
norm_rms = [m['rms'] for m in normalized_metrics]
orig_peak = [m['peak'] for m in original_metrics]
norm_peak = [m['peak'] for m in normalized_metrics]
orig_dr = [m['dynamic_range_db'] for m in original_metrics]
norm_dr = [m['dynamic_range_db'] for m in normalized_metrics]
orig_cf = [m['crest_factor'] for m in original_metrics]
norm_cf = [m['crest_factor'] for m in normalized_metrics]
# Prepare x-axis
x = np.arange(len(file_names))
width = 0.35
# Plot RMS (volume)
axs[0, 0].bar(x - width/2, orig_rms, width, label='Original')
axs[0, 0].bar(x + width/2, norm_rms, width, label='Normalized')
axs[0, 0].set_title('RMS Energy (Volume)')
axs[0, 0].set_xticks(x)
axs[0, 0].set_xticklabels(file_names, rotation=45, ha='right')
axs[0, 0].set_ylabel('RMS Value')
axs[0, 0].legend()
# Plot Peak Amplitude
axs[0, 1].bar(x - width/2, orig_peak, width, label='Original')
axs[0, 1].bar(x + width/2, norm_peak, width, label='Normalized')
axs[0, 1].set_title('Peak Amplitude')
axs[0, 1].set_xticks(x)
axs[0, 1].set_xticklabels(file_names, rotation=45, ha='right')
axs[0, 1].set_ylabel('Peak Value')
axs[0, 1].legend()
# Plot Dynamic Range
axs[1, 0].bar(x - width/2, orig_dr, width, label='Original')
axs[1, 0].bar(x + width/2, norm_dr, width, label='Normalized')
axs[1, 0].set_title('Dynamic Range (dB)')
axs[1, 0].set_xticks(x)
axs[1, 0].set_xticklabels(file_names, rotation=45, ha='right')
axs[1, 0].set_ylabel('dB')
axs[1, 0].legend()
# Plot Crest Factor
axs[1, 1].bar(x - width/2, orig_cf, width, label='Original')
axs[1, 1].bar(x + width/2, norm_cf, width, label='Normalized')
axs[1, 1].set_title('Crest Factor (Peak-to-RMS Ratio)')
axs[1, 1].set_xticks(x)
axs[1, 1].set_xticklabels(file_names, rotation=45, ha='right')
axs[1, 1].set_ylabel('Ratio')
axs[1, 1].legend()
plt.tight_layout()
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
# Save the plot
plt.savefig(output_path)
plt.close()
def calculate_audio_metrics(audio_data: List[Tuple[torch.Tensor, int]]) -> List[dict]:
"""
Calculate various audio metrics for each audio file.
Args:
audio_data: List of tuples containing audio data and sample rate
Returns:
List of dictionaries containing metrics
"""
metrics = []
for audio, sr in audio_data:
# Calculate RMS (Root Mean Square)
energy = torch.mean(audio ** 2)
rms = torch.sqrt(energy)
# Calculate peak amplitude
peak = torch.max(torch.abs(audio))
# Calculate dynamic range
if torch.min(torch.abs(audio[audio != 0])) > 0:
min_non_zero = torch.min(torch.abs(audio[audio != 0]))
dynamic_range_db = 20 * torch.log10(peak / min_non_zero)
else:
dynamic_range_db = torch.tensor(float('inf'))
# Calculate crest factor (peak to RMS ratio)
crest_factor = peak / rms if rms > 0 else torch.tensor(float('inf'))
metrics.append({
'rms': rms.item(),
'peak': peak.item(),
'dynamic_range_db': dynamic_range_db.item() if not torch.isinf(dynamic_range_db) else float('inf'),
'crest_factor': crest_factor.item() if not torch.isinf(crest_factor) else float('inf')
})
return metrics
def create_weighted_composite(
audio_data: List[Tuple[torch.Tensor, int]],
weights: List[float]
) -> torch.Tensor:
"""
Create a weighted composite of multiple audio files.
Args:
audio_data: List of tuples containing audio data and sample rate
weights: List of weights for each audio file
Returns:
Weighted composite audio data
"""
if len(audio_data) != len(weights):
raise ValueError("Number of audio files and weights must match")
# Normalize weights to sum to 1
weights = torch.tensor(weights) / sum(weights)
# Initialize composite audio with zeros
composite = torch.zeros_like(audio_data[0][0])
# Add weighted audio data
for (audio, _), weight in zip(audio_data, weights):
composite += audio * weight
# Normalize to prevent clipping
max_val = torch.max(torch.abs(composite))
if max_val > 1.0:
composite = composite / max_val
return composite
def create_melspectrograms(
audio_data: List[Tuple[torch.Tensor, int]],
composite: torch.Tensor,
sr: int
) -> List[torch.Tensor]:
"""
Create melspectrograms for individual audio files and the composite.
Args:
audio_data: List of tuples containing audio data and sample rate
composite: Composite audio data
sr: Sample rate
Returns:
List of melspectrogram data
"""
specs = []
# Create mel spectrogram transform
mel_transform = T.MelSpectrogram(
sample_rate=sr,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
f_max=8000
)
# Generate spectrograms for individual audio files
for audio, _ in audio_data:
melspec = mel_transform(audio)
specs.append(melspec)
# Generate spectrogram for composite audio
composite_melspec = mel_transform(composite)
specs.append(composite_melspec)
return specs
def plot_melspectrograms(
specs: List[torch.Tensor],
sr: int,
file_names: List[str],
weights: List[float],
output_path: str = "melspectrograms.png"
) -> None:
"""
Plot melspectrograms for individual audio files and the composite.
Args:
specs: List of melspectrogram data
sr: Sample rate
file_names: List of audio file names
weights: List of weights for each audio file
output_path: Path to save the plot
"""
fig, axs = plt.subplots(len(specs), 1, figsize=(12, 4 * len(specs)))
# Create labels for the plots
labels = [f"{name} (weight: {weight:.2f})" for name, weight in zip(file_names, weights)]
labels.append("Composite.wav")
# Convert to dB scale (similar to librosa's power_to_db)
def power_to_db(spec):
return 10 * torch.log10(spec + 1e-10)
# Plot each melspectrogram
for i, (spec, label) in enumerate(zip(specs, labels)):
spec_db = power_to_db(spec).numpy().squeeze()
# For single subplot case
if len(specs) == 1:
ax = axs
else:
ax = axs[i]
img = ax.imshow(
spec_db,
aspect='auto',
origin='lower',
interpolation='none',
extent=[0, spec_db.shape[1], 0, sr/2]
)
ax.set_title(label)
ax.set_ylabel('Frequency (Hz)')
ax.set_xlabel('Time Frames')
# No colorbar as requested
plt.tight_layout()
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
# Save the plot
plt.savefig(output_path,dpi=300)
plt.close()
def compose_audio(
file_paths: List[str],
weights: List[float],
output_audio_path: str = os.path.join(LOGS_DIR, "composite.wav"),
output_plot_path: str = os.path.join(LOGS_DIR, "plot/melspectrograms.png"),
energy_plot_path: str = os.path.join(LOGS_DIR, "plot/energy_comparison.png")
) -> None:
"""
Main function to process audio files and create visualizations.
Args:
file_paths: List of paths to audio files (supports 4 audio files)
weights: List of weights for each audio file
output_audio_path: Path to save the composite audio
output_plot_path: Path to save the melspectrogram plot
energy_plot_path: Path to save the energy comparison plot
"""
# Load audio files
audio_data = load_audio_files(file_paths)
# # Calculate metrics for original audio
print("Calculating metrics for original audio...")
original_metrics = calculate_audio_metrics(audio_data)
# Normalize audio volumes to have same energy level
print("Normalizing audio volumes...")
normalized_audio_data = normalize_audio_volumes(audio_data)
# Calculate metrics for normalized audio
print("Calculating metrics for normalized audio...")
normalized_metrics = calculate_audio_metrics(normalized_audio_data)
# Print energy comparison
print("\nAudio Energy Comparison (RMS values):")
print("-" * 50)
print(f"{'File':<20} {'Original':<15} {'Normalized':<15} {'Scaling Factor':<15}")
print("-" * 50)
for i, path in enumerate(file_paths):
file_name = path.split("/")[-1]
orig_rms = original_metrics[i]['rms']
norm_rms = normalized_metrics[i]['rms']
scaling = norm_rms / orig_rms if orig_rms > 0 else float('inf')
print(f"{file_name[:20]:<20} {orig_rms:<15.6f} {norm_rms:<15.6f} {scaling:<15.6f}")
# Create energy comparison plot
print("\nCreating energy comparison plot...")
file_names = [path.split("/")[-1] for path in file_paths]
plot_energy_comparison(original_metrics, normalized_metrics, file_names, energy_plot_path)
# Get sample rate (all files have the same sample rate)
sr = normalized_audio_data[0][1]
# Create weighted composite
print("\nCreating weighted composite...")
composite = create_weighted_composite(normalized_audio_data, weights)
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(output_audio_path) or '.', exist_ok=True)
# Save composite audio
print("Saving composite audio...")
torchaudio.save(output_audio_path, composite.unsqueeze(0), sr)
# Create melspectrograms for normalized audio (not original)
print("Creating melspectrograms for normalized audio...")
specs = create_melspectrograms(normalized_audio_data, composite, sr)
# Get file names without path
labeled_file_names = [path.split("/")[-1] for path in file_paths]
# Plot melspectrograms
print("Plotting melspectrograms...")
plot_melspectrograms(specs, sr, labeled_file_names, weights, output_plot_path)
print(f"\nComposite audio saved to {output_audio_path}")
print(f"Melspectrograms saved to {output_plot_path}")
print(f"Energy comparison saved to {energy_plot_path}")
print(f"Composite audio saved to {output_audio_path}")
print(f"Melspectrograms saved to {output_plot_path}")
# if __name__ == "__main__":
# import argparse
# parser = argparse.ArgumentParser(description="Mix audio files with weights and create melspectrograms")
# parser.add_argument("--files", nargs="+", required=True, help="Paths to audio files")
# parser.add_argument("--weights", nargs="+", type=float, required=True, help="Weights for each audio file")
# parser.add_argument("--output-audio", default="./logs/composite.wav", help="Path to save the composite audio")
# parser.add_argument("--output-plot", default="./logs/melspectrograms.png", help="Path to save the melspectrogram plot")
# args = parser.parse_args()
# os.makedirs("./logs", exist_ok=True)
# main(args.files, args.weights, args.output_audio, args.output_plot)
# Example usage:
# python audio_mixer.py --files audio1.wav audio2.wav audio3.wav audio4.wav --weights 0.4 0.3 0.2 0.1