Karlo Pintaric
Upload 25 files
fdc1efd
import os
from abc import ABC, abstractmethod
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.transforms import FrequencyMasking, TimeMasking
from torchvision.transforms import Compose
from transformers import ASTFeatureExtractor
class Transform(ABC):
"""Abstract base class for audio transformations."""
@abstractmethod
def __call__(self):
"""
Abstract method to apply the transformation.
:raises NotImplementedError: If the subclass does not implement this method.
"""
pass
class Preprocess(ABC):
"""Abstract base class for preprocessing data.
This class defines the interface for preprocessing data. Subclasses must implement the call method.
"""
@abstractmethod
def __call__(self):
"""Process the data.
This method must be implemented by subclasses.
:raises NotImplementedError: Subclasses must implement this method.
"""
pass
class OneHotEncode(Transform):
"""Transform labels to one-hot encoded tensor.
This class is a transform that takes a list of labels and returns a one-hot encoded tensor.
The labels are converted to a tensor with one-hot encoding using the specified classes.
:param c: A list of classes to be used for one-hot encoding.
:type c: list
:return: A one-hot encoded tensor.
:rtype: torch.Tensor
"""
def __init__(self, c: list):
self.c = c
def __call__(self, labels):
"""
Transform labels to one-hot encoded tensor.
:param labels: A list of labels to be encoded.
:type labels: list
:return: A one-hot encoded tensor.
:rtype: torch.Tensor
"""
target = torch.zeros(len(self.c), dtype=torch.float)
for label in labels:
idx = self.c.index(label)
target[idx] = 1
return target
class ParentMultilabel(Transform):
"""
A transform that extracts a list of labels from the parent directory name of a file path.
:param sep: The separator used to split the parent directory name into labels. Defaults to " ".
:type sep: str
"""
def __init__(self, sep=" "):
self.sep = sep
def __call__(self, path):
"""
Extract a list of labels from the parent directory name of a file path.
:param path: The file path from which to extract labels.
:type path: str
:return: A list of labels extracted from the parent directory name of the input file path.
:rtype: List[str]
"""
label = path.split(os.path.sep)[-2].split(self.sep)
return label
class LabelsFromTxt(Transform):
"""
Extract multilabel parent directory from file path.
This class is a transform that extracts a multilabel parent directory from a file path.
The directory names are split by a specified separator.
:param sep: The separator used to split the directory names. Defaults to " ".
:type sep: str
"""
def __init__(self, delimiter=None):
self.delimiter = delimiter
def __call__(self, path):
"""
Extract multilabel parent directory from file path.
:param path: The path of the file to extract the multilabel directory from.
:type path: str
:return: A list of directory names representing the multilabel parent directory.
:rtype: list
"""
path = path.replace("wav", "txt")
label = np.loadtxt(path, dtype=str, ndmin=1, delimiter=self.delimiter)
return label
class PreprocessPipeline(Preprocess):
"""A preprocessing pipeline for audio data.
This class is a preprocessing pipeline for audio data.
The pipeline includes resampling to a target sampling rate, mixing down stereo to mono,
and loading audio from a file.
:param target_sr: The target sampling rate to resample to.
:type target_sr: int
"""
def __init__(self, target_sr):
self.target_sr = target_sr
def __call__(self, path):
"""
Preprocess audio data using a pipeline.
:param path: The path to the audio file to load.
:type path: str
:return: A NumPy array of preprocessed audio data.
:rtype: numpy.ndarray
"""
signal, sr = torchaudio.load(path)
signal = self._resample(signal, sr)
signal = self._mix_down(signal)
return signal.numpy()
def _mix_down(self, signal):
"""
Mix down stereo to mono.
:param signal: The audio signal to mix down.
:type signal: torch.Tensor
:return: The mixed down audio signal.
:rtype: torch.Tensor
"""
if signal.shape[0] > 1:
signal = torch.mean(signal, dim=0, keepdim=True)
return signal
def _resample(self, signal, input_sr):
"""
Resample audio signal to a target sampling rate.
:param signal: The audio signal to resample.
:type signal: torch.Tensor
:param input_sr: The current sampling rate of the audio signal.
:type input_sr: int
:return: The resampled audio signal.
:rtype: torch.Tensor
"""
if input_sr != self.target_sr:
resampler = torchaudio.transforms.Resample(input_sr, self.target_sr)
signal = resampler(signal)
return signal
class SpecToImage(Transform):
def __init__(self, mean=None, std=None, eps=1e-6):
self.mean = mean
self.std = std
self.eps = eps
def __call__(self, spec):
spec = torch.stack([spec, spec, spec], dim=-1)
mean = torch.mean(spec) if self.mean is None else self.mean
std = torch.std(spec) if self.std is None else self.std
spec_norm = (spec - mean) / std
spec_min, spec_max = torch.min(spec_norm), torch.max(spec_norm)
spec_scaled = 255 * (spec_norm - spec_min) / (spec_max - spec_min)
return spec_scaled.type(torch.uint8)
class MinMaxScale(Transform):
def __call__(self, spec):
spec_min, spec_max = torch.min(spec), torch.max(spec)
return (spec - spec_min) / (spec_max - spec_min)
class Normalize(Transform):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, spec):
return (spec - self.mean) / self.std
class FeatureExtractor(Transform):
"""Extract features from audio signal using an AST feature extractor.
This class is a transform that extracts features from an audio signal using an AST feature extractor.
The features are returned as a PyTorch tensor.
:param sr: The sampling rate of the audio signal.
:type sr: int
"""
def __init__(self, sr):
self.transform = partial(ASTFeatureExtractor(), sampling_rate=sr, return_tensors="pt")
def __call__(self, signal):
"""
Extract features from audio signal using an AST feature extractor.
:param signal: The audio signal to extract features from.
:type signal: numpy.ndarray
:return: A tensor of extracted audio features.
:rtype: torch.Tensor
"""
return self.transform(signal.squeeze()).input_values.mT
class Preemphasis(Transform):
"""perform preemphasis on the input signal.
:param signal: The signal to filter.
:param coeff: The preemphasis coefficient. 0 is none, default 0.97.
:returns: the filtered signal.
"""
def __init__(self, coeff: float = 0.97):
self.coeff = coeff
def __call__(self, signal):
return torch.cat([signal[:, :1], signal[:, 1:] - self.coeff * signal[:, :-1]], dim=1)
class Spectrogram(Transform):
def __init__(self, sample_rate, n_mels, hop_length, n_fft):
self.transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft, f_min=20, center=False
)
def __call__(self, signal):
return self.transform(signal)
class LogTransform(Transform):
def __call__(self, signal):
return torch.log(signal + 1e-8)
class PadCutToLength(Transform):
def __init__(self, max_length):
self.max_length = max_length
def __call__(self, spec):
seq_len = spec.shape[-1]
if seq_len > self.max_length:
return spec[..., : self.max_length]
if seq_len < self.max_length:
diff = self.max_length - seq_len
return F.pad(spec, (0, diff), mode="constant", value=0)
class CustomFeatureExtractor(Transform):
def __init__(self, sample_rate, n_mels, hop_length, n_fft, max_length, mean, std):
self.extract = Compose(
[
Preemphasis(),
Spectrogram(sample_rate=sample_rate, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft),
LogTransform(),
PadCutToLength(max_length=max_length),
Normalize(mean=mean, std=std),
]
)
def __call__(self, x):
return self.extract(x)
class RepeatAudio(Transform):
"""A transform to repeat audio data.
This class is a transform that repeats audio data a random number of times up to a maximum specified value.
:param max_repeats: The maximum number of times to repeat the audio data.
:type max_repeats: int
"""
def __init__(self, max_repeats: int = 2):
self.max_repeats = max_repeats
def __call__(self, signal):
"""
Repeat audio data a random number of times up to a maximum specified value.
:param signal: The audio data to repeat.
:type signal: numpy.ndarray
:return: The repeated audio data.
:rtype: numpy.ndarray
"""
num_repeats = torch.randint(1, self.max_repeats, (1,)).item()
return np.tile(signal, reps=num_repeats)
class MaskFrequency(Transform):
"""A transform to mask frequency of a spectrogram.
This class is a transform that masks out a random number of consecutive frequencies from a spectrogram.
:param max_mask_length: The maximum number of consecutive frequencies to mask out from the spectrogram.
:type max_mask_length: int
"""
def __init__(self, max_mask_length: int = 0):
self.aug = FrequencyMasking(max_mask_length)
def __call__(self, spec):
"""
Mask out a random number of consecutive frequencies from a spectrogram.
:param spec: The input spectrogram.
:type spec: numpy.ndarray
:return: The spectrogram with masked frequencies.
:rtype: numpy.ndarray
"""
return self.aug(spec)
class MaskTime(Transform):
"""A transform to mask time of a spectrogram.
This class is a transform that masks out a random number of consecutive time steps from a spectrogram.
:param max_mask_length: The maximum number of consecutive time steps to mask out from the spectrogram.
:type max_mask_length: int
"""
def __init__(self, max_mask_length: int = 0):
self.aug = TimeMasking(max_mask_length)
def __call__(self, spec):
"""
Mask out a random number of consecutive time steps from a spectrogram.
:param spec: The input spectrogram.
:type spec: numpy.ndarray
:return: The spectrogram with masked time steps.
:rtype: numpy.ndarray
"""
return self.aug(spec)