import numpy as np import torch import torch.nn.functional as F import torch.nn as nn import torch.fft import sys from torch.autograd import Variable import math class GlobalLayerNorm(nn.Module): ''' Calculate Global Layer Normalization dim: (int or list or torch.Size) – input shape from an expected input of size eps: a value added to the denominator for numerical stability. elementwise_affine: a boolean value that when set to True, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). ''' def __init__(self, dim, eps=1e-05, elementwise_affine=True): super(GlobalLayerNorm, self).__init__() self.dim = dim self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.ones(self.dim, 1)) self.bias = nn.Parameter(torch.zeros(self.dim, 1)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) def forward(self, x): # x = N x C x L # N x 1 x 1 # cln: mean,var N x 1 x L # gln: mean,var N x 1 x 1 if x.dim() != 3: raise RuntimeError("{} accept 3D tensor as input".format( self.__name__)) mean = torch.mean(x, (1, 2), keepdim=True) var = torch.mean((x-mean)**2, (1, 2), keepdim=True) # N x C x L if self.elementwise_affine: x = self.weight*(x-mean)/torch.sqrt(var+self.eps)+self.bias else: x = (x-mean)/torch.sqrt(var+self.eps) return x class TimeSincExtractor(nn.Module): """Sinc-based convolution Parameters ---------- in_channels : `int` Number of input channels. Must be 1. out_channels : `int` Number of filters. kernel_size : `int` Filter length. sample_rate : `int`, optional Sample rate. Defaults to 16000. triangular : `bool` Squared sinc -> Triangular filter. freq_nml : `bool` Normalized to gain of 1 in frequency. range_constraint : `bool` Project the learned band within nyquist freq manually. Usage ----- See `torch.nn.Conv1d` """ @staticmethod def to_mel(hz): return 2595 * np.log10(1 + hz / 700) @staticmethod def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1) def swap_(self, x, y, sort=False): mini = torch.minimum(x, y) maxi = torch.maximum(x, y) if sort: mini, idx = torch.sort(mini) maxi = maxi[idx].view(mini.shape) return mini, maxi def __init__(self, out_channels, kernel_size, triangular=False, freq_nml=False, range_constraint=False, freq_init='uniform', norm_after=True, sample_rate=16000, in_channels=1, stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50, bi_factor=False, frame_length=400, hop_length=160): super(TimeSincExtractor,self).__init__() if in_channels != 1: # msg = (f'SincConv only support one input channel ' # f'(here, in_channels = {in_channels:d}).') msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels) raise ValueError(msg) self.out_channels = out_channels self.kernel_size = kernel_size self.triangular = False self.freq_nml = False # Forcing the filters to be odd (i.e, perfectly symmetrics) if kernel_size%2 == 0: self.kernel_size = self.kernel_size+1 self.stride = stride self.padding = padding self.dilation = dilation self.frame_length = frame_length self.hop_length = hop_length if bias: raise ValueError('SincConv does not support bias.') if groups > 1: raise ValueError('SincConv does not support groups.') self.sample_rate = sample_rate self.nyquist_rate = sample_rate/2 self.min_low_hz = min_low_hz self.min_band_hz = min_band_hz self.range_constraint = range_constraint self.bi_factor = bi_factor if self.range_constraint: # msg = "Range constraint in learned frequency is not supported yet." # raise ValueError(msg) if freq_init == "uniform": low_freq, high_freq = torch.rand(out_channels*2).chunk(2) elif freq_init == "formant": # raise NotImplementedError('Formant distribution hasn\'t been implemented yet.') p = np.load('/share/nas165/Jasonho610/SincNet/exp/formant_distribution.npy') low_freq, high_freq = torch.from_numpy(np.random.choice(8000, out_channels*2, p=p)).chunk(2) low_freq = low_freq / self.nyquist_rate high_freq = high_freq / self.nyquist_rate elif freq_init == "mel": # raise NotImplementedError('Mel distribution hasn\'t been implemented yet.') low_hz = 30 high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz) mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1) hz = self.to_hz(mel) low_freq = torch.Tensor(hz[:-1]) / self.nyquist_rate high_freq = torch.Tensor(hz[1:]) / self.nyquist_rate else: raise ValueError('SincConv must specify the freq initialization methods.') low_freq, high_freq = self.swap_(low_freq, high_freq) if self.bi_factor: self.band_imp = nn.Parameter(torch.ones(out_channels)) self.low_f_ = nn.Parameter(low_freq.view(-1, 1)) self.high_f_ = nn.Parameter(high_freq.view(-1, 1)) else: # initialize filterbanks such that they are equally spaced in Mel scale low_hz = 30 high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz) mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1) hz = self.to_hz(mel) # filter lower frequency (out_channels, 1) self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) # filter frequency band (out_channels, 1) self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) # Hamming window # self.window_ = torch.hamming_window(self.kernel_size) n_lin = torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window self.window_ = 0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size); # (1, kernel_size/2) n = (self.kernel_size - 1) / 2.0 self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes self.norm_after = norm_after if self.norm_after: self.ln = GlobalLayerNorm(out_channels) def forward(self, waveforms, embedding): """ Parameters ---------- waveforms : `torch.Tensor` (batch_size, 1, n_samples) Batch of waveforms. Returns ------- features : `torch.Tensor` (batch_size, out_channels, n_samples_out) Batch of sinc filters activations. """ self.n_ = self.n_.to(waveforms.device) self.window_ = self.window_.to(waveforms.device) # waveforms = waveforms.unsqueeze(1) # print("Waveforms:", waveforms.shape) framing_padding = self.frame_length - (waveforms.shape[-1] % self.hop_length) waveforms = F.pad(waveforms, (0, framing_padding)) frames = waveforms.unfold(-1, self.frame_length, self.hop_length) batch_size = frames.shape[0] n_frames = frames.shape[2] if self.range_constraint: low_f_, high_f_ = self.swap_(torch.abs(self.low_f_), torch.abs(self.high_f_)) low = self.min_low_hz + low_f_*self.nyquist_rate high = torch.clamp(self.min_band_hz + high_f_*self.nyquist_rate, self.min_low_hz, self.nyquist_rate) band = (high-low)[:,0] else: low = self.min_low_hz + torch.abs(self.low_hz_) high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.nyquist_rate) band = (high-low)[:,0] self.low = low self.high = high self.band = band f_times_t_low = torch.matmul(low, self.n_) f_times_t_high = torch.matmul(high, self.n_) band_pass_left = ((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. band_pass_center = 2*band.view(-1,1) band_pass_right = torch.flip(band_pass_left,dims=[1]) band_pass = torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) band_pass = band_pass / (2*band[:,None]) if self.triangular: band_pass = band_pass**2 if self.freq_nml: mag_resp = torch.fft.rfft(band_pass).abs() mag_max = torch.max(mag_resp, dim=-1)[0] band_pass = band_pass / mag_max.unsqueeze(-1) if self.bi_factor: band_imp = F.relu(self.band_imp) band_pass = band_pass * band_imp.unsqueeze(-1) self.filters = (band_pass).view( self.out_channels, 1, self.kernel_size) # print("Filters:", self.filters.shape) # print("Frames:", frames.shape) rs_frames = frames.reshape(batch_size*n_frames, 1, self.frame_length) # print("Reshaped frames:", rs_frames.shape) filtered = F.conv1d(rs_frames, self.filters, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=None, groups=1) # print('Pass conv1d') # print("Filtered:", filtered.shape) if self.norm_after: filtered = self.ln(filtered) # print("Normed filtered:", filtered.shape) filtered = filtered.reshape(batch_size, n_frames, self.out_channels , -1) # print("Final filtered:", filtered.shape) energy = torch.mean(filtered**2, dim=-1) log_filtered_energy = torch.log10(energy + 1e-6) # print("Log filtered energy:", log_filtered_energy.shape) # (batch_size, n_samples_out(time), out_channels(frequency)) log_filtered_energy = log_filtered_energy.unsqueeze(1) # print("Unsqueezed log filtered energy:", log_filtered_energy.shape) # (batch_size, channels, n_samples_out(time), out_channels(frequency)) log_filtered_energy = log_filtered_energy.permute(0, 1, 3, 2) # print("Permuted log filtered energy:", log_filtered_energy.shape) # (batch_size, channels, out_channels(frequency), n_samples_out(time)) return log_filtered_energy, self.filters, self.stride, self.padding class FreqSincExtractor(nn.Module): @staticmethod def to_mel(hz): return 2595 * np.log10(1 + hz / 700) @staticmethod def to_hz(mel): return 700 * (10 ** (mel / 2595) - 1) def swap_(self, x, y, sort=False): mini = torch.minimum(x, y) maxi = torch.maximum(x, y) if sort: mini, idx = torch.sort(mini) maxi = maxi[idx].view(mini.shape) return mini, maxi def __init__(self, out_channels, kernel_size, triangular=False, freq_nml=False, range_constraint=False, freq_init='uniform', norm_after=True, sample_rate=16000, in_channels=1, stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50, bi_factor=False, frame_length=400, hop_length=160, n_fft=400): super(FreqSincExtractor, self).__init__() if in_channels != 1: msg = "FreqSincExtractor only supports one input channel (here, in_channels = {%i})" % (in_channels) raise ValueError(msg) self.out_channels = out_channels self.kernel_size = kernel_size self.triangular = triangular self.freq_nml = freq_nml self.sample_rate = sample_rate self.nyquist_rate = sample_rate/2 self.min_low_hz = min_low_hz self.min_band_hz = min_band_hz self.range_constraint = range_constraint self.bi_factor = bi_factor self.frame_length = frame_length self.hop_length = hop_length self.n_fft = n_fft self.stride = stride self.padding = padding self.output_size = 64 # Initialize frequency bands if self.range_constraint: if freq_init == "uniform": low_freq, high_freq = torch.rand(out_channels*2).chunk(2) elif freq_init == "mel": low_hz = 30 high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz) mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1) hz = self.to_hz(mel) low_freq = torch.Tensor(hz[:-1]) / self.nyquist_rate high_freq = torch.Tensor(hz[1:]) / self.nyquist_rate else: raise ValueError('FreqSincExtractor must specify the freq initialization methods.') low_freq, high_freq = self.swap_(low_freq, high_freq) if self.bi_factor: self.band_imp = nn.Parameter(torch.ones(out_channels)) self.low_f_ = nn.Parameter(low_freq.view(-1, 1)) self.high_f_ = nn.Parameter(high_freq.view(-1, 1)) else: low_hz = 30 high_hz = self.nyquist_rate - (self.min_low_hz + self.min_band_hz) mel = np.linspace(self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1) hz = self.to_hz(mel) self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) # Frequency axis for STFT self.freq_axis = torch.linspace(0, self.nyquist_rate, self.n_fft//2 + 1) self.norm_after = norm_after if self.norm_after: self.ln = GlobalLayerNorm(out_channels) def get_filters(self): if self.range_constraint: low_f_, high_f_ = self.swap_(torch.abs(self.low_f_), torch.abs(self.high_f_)) low = self.min_low_hz + low_f_ * self.nyquist_rate high = torch.clamp(self.min_low_hz + high_f_ * self.nyquist_rate, self.min_low_hz, self.nyquist_rate) else: low = self.min_low_hz + torch.abs(self.low_hz_) high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_), self.min_low_hz, self.nyquist_rate) # Create frequency domain filters freq_axis = self.freq_axis.to(low.device) filters = torch.zeros((self.out_channels, len(freq_axis))).to(low.device) for i in range(self.out_channels): mask = (freq_axis >= low[i]) & (freq_axis <= high[i]) filters[i, mask] = 1.0 if self.triangular: center_freq = (low[i] + high[i]) / 2 bandwidth = high[i] - low[i] mask = (freq_axis >= low[i]) & (freq_axis <= high[i]) freq_response = 1.0 - torch.abs(freq_axis[mask] - center_freq) / (bandwidth/2) filters[i, mask] = freq_response if self.freq_nml: filters = F.normalize(filters, p=2, dim=1) if self.bi_factor: band_imp = F.relu(self.band_imp) filters = filters * band_imp.unsqueeze(-1) return filters def forward(self, waveforms, embedding=None): batch_size = waveforms.shape[0] # Calculate necessary padding to achieve the correct output size target_length = self.hop_length * (self.output_size - 1) + self.frame_length current_length = waveforms.shape[-1] padding_needed = target_length - current_length # Pad the input if necessary if padding_needed > 0: waveforms = F.pad(waveforms, (0, padding_needed)) # Compute STFT stft = torch.stft(waveforms.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.frame_length, window=torch.hann_window(self.frame_length).to(waveforms.device), return_complex=True) # Get magnitude spectrogram mag_spec = torch.abs(stft) # (batch_size, freq_bins, time_frames) # Get and apply filters filters = self.get_filters() # (out_channels, freq_bins) filtered = torch.matmul(filters, mag_spec) # (batch_size, out_channels, time_frames) if self.norm_after: filtered = self.ln(filtered) # Compute log energy energy = filtered ** 2 log_energy = torch.log10(energy + 1e-6) # Ensure correct time dimension if log_energy.shape[-1] != self.output_size: log_energy = F.interpolate( log_energy, size=self.output_size, mode='linear', align_corners=False ) # Reshape to the desired output format log_energy = log_energy.unsqueeze(1) # Add channel dimension log_energy = log_energy.permute(0, 1, 3, 2) # Rearrange to (batch, channel, freq, time) return log_energy, filters, self.stride, self.padding if __name__ == "__main__": batch_size = 256 n_samples = 10080 waveforms = torch.rand(batch_size, 1, n_samples) # model = TimeSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2) model = FreqSincExtractor(out_channels=64, kernel_size=101, range_constraint=True, stride=2) print(model) outputs, _, _, _ = model(waveforms, embedding=None) print("Outputs:", outputs.shape)