# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import numpy as np import librosa import torch from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence from utils.data_utils import * from multiprocessing import Pool, Lock import random import torchaudio NUM_WORKERS = 64 lock = Lock() SAMPLE_RATE = 16000 def get_metadata(file_path): metadata = torchaudio.info(file_path) return file_path, metadata.num_frames def get_speaker(file_path): speaker_id = file_path.split(os.sep)[-3] if "mls" in file_path: speaker = "mls_" + speaker_id else: speaker = "libri_" + speaker_id return file_path, speaker def safe_write_to_file(data, file_path, mode="w"): try: with lock, open(file_path, mode, encoding="utf-8") as f: json.dump(data, f) f.flush() os.fsync(f.fileno()) except IOError as e: print(f"Error writing to {file_path}: {e}") class VCDataset(Dataset): def __init__(self, args, TRAIN_MODE=True): print(f"Initializing VCDataset") if TRAIN_MODE: directory_list = args.directory_list else: directory_list = args.test_directory_list random.shuffle(directory_list) self.use_ref_noise = args.use_ref_noise print(f"use_ref_noise: {self.use_ref_noise}") # number of workers print(f"Using {NUM_WORKERS} workers") self.directory_list = directory_list print(f"Loading {len(directory_list)} directories: {directory_list}") self.metadata_cache = {} self.speaker_cache = {} self.files = [] # Load all flac files for directory in directory_list: print(f"Loading {directory}") files = self.get_flac_files(directory) random.shuffle(files) print(f"Loaded {len(files)} files") self.files.extend(files) del files print(f"Now {len(self.files)} files") self.meta_data_cache = self.process_files() self.speaker_cache = self.process_speakers() print(f"Loaded {len(self.files)} files") random.shuffle(self.files) # Shuffle the files. self.filtered_files, self.all_num_frames, index2numframes, index2speakerid = ( self.filter_files() ) print(f"Loaded {len(self.filtered_files)} files") self.index2numframes = index2numframes self.index2speaker = index2speakerid self.speaker2id = self.create_speaker2id() self.num_frame_sorted = np.array(sorted(self.all_num_frames)) self.num_frame_indices = np.array( sorted( range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k] ) ) del self.meta_data_cache, self.speaker_cache if self.use_ref_noise: if TRAIN_MODE: self.noise_filenames = self.get_all_flac(args.noise_dir) else: self.noise_filenames = self.get_all_flac(args.test_noise_dir) def process_files(self): print(f"Processing metadata...") files_to_process = [ file for file in self.files if file not in self.metadata_cache ] if files_to_process: with Pool(processes=NUM_WORKERS) as pool: results = list( tqdm( pool.imap_unordered(get_metadata, files_to_process), total=len(files_to_process), ) ) for file, num_frames in results: self.metadata_cache[file] = num_frames else: print( f"Skipping processing metadata, loaded {len(self.metadata_cache)} files" ) return self.metadata_cache def process_speakers(self): print(f"Processing speakers...") files_to_process = [ file for file in self.files if file not in self.speaker_cache ] if files_to_process: with Pool(processes=NUM_WORKERS) as pool: results = list( tqdm( pool.imap_unordered(get_speaker, files_to_process), total=len(files_to_process), ) ) for file, speaker in results: self.speaker_cache[file] = speaker else: print( f"Skipping processing speakers, loaded {len(self.speaker_cache)} files" ) return self.speaker_cache def get_flac_files(self, directory): flac_files = [] for root, dirs, files in os.walk(directory): for file in files: # flac or wav if file.endswith(".flac") or file.endswith(".wav"): flac_files.append(os.path.join(root, file)) return flac_files def get_all_flac(self, directory): directories = [ os.path.join(directory, d) for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d)) ] if not directories: return self.get_flac_files(directory) with Pool(processes=NUM_WORKERS) as pool: results = [] for result in tqdm( pool.imap_unordered(self.get_flac_files, directories), total=len(directories), desc="Processing", ): results.extend(result) print(f"Found {len(results)} waveform files") return results def get_num_frames(self, index): return self.index2numframes[index] def filter_files(self): # Filter files metadata_cache = self.meta_data_cache speaker_cache = self.speaker_cache filtered_files = [] all_num_frames = [] index2numframes = {} index2speaker = {} for file in self.files: num_frames = metadata_cache[file] if SAMPLE_RATE * 3 <= num_frames <= SAMPLE_RATE * 30: filtered_files.append(file) all_num_frames.append(num_frames) index2speaker[len(filtered_files) - 1] = speaker_cache[file] index2numframes[len(filtered_files) - 1] = num_frames return filtered_files, all_num_frames, index2numframes, index2speaker def create_speaker2id(self): speaker2id = {} unique_id = 0 print(f"Creating speaker2id from {len(self.index2speaker)} utterences") for _, speaker in tqdm(self.index2speaker.items()): if speaker not in speaker2id: speaker2id[speaker] = unique_id unique_id += 1 print(f"Created speaker2id with {len(speaker2id)} speakers") return speaker2id def snr_mixer(self, clean, noise, snr): # Normalizing to -25 dB FS rmsclean = (clean**2).mean() ** 0.5 epsilon = 1e-10 rmsclean = max(rmsclean, epsilon) scalarclean = 10 ** (-25 / 20) / rmsclean clean = clean * scalarclean rmsnoise = (noise**2).mean() ** 0.5 scalarnoise = 10 ** (-25 / 20) / rmsnoise noise = noise * scalarnoise rmsnoise = (noise**2).mean() ** 0.5 # Set the noise level for a given SNR noisescalar = np.sqrt(rmsclean / (10 ** (snr / 20)) / rmsnoise) noisenewlevel = noise * noisescalar noisyspeech = clean + noisenewlevel noisyspeech_tensor = torch.tensor(noisyspeech, dtype=torch.float32) return noisyspeech_tensor def add_noise(self, clean): # self.noise_filenames: list of noise files random_idx = np.random.randint(0, np.size(self.noise_filenames)) noise, _ = librosa.load(self.noise_filenames[random_idx], sr=SAMPLE_RATE) clean = clean.cpu().numpy() if len(noise) >= len(clean): noise = noise[0 : len(clean)] else: while len(noise) <= len(clean): random_idx = (random_idx + 1) % len(self.noise_filenames) newnoise, fs = librosa.load( self.noise_filenames[random_idx], sr=SAMPLE_RATE ) noiseconcat = np.append(noise, np.zeros(int(fs * 0.2))) noise = np.append(noiseconcat, newnoise) noise = noise[0 : len(clean)] snr = random.uniform(0.0, 20.0) noisyspeech = self.snr_mixer(clean=clean, noise=noise, snr=snr) del noise return noisyspeech def __len__(self): return len(self.files) def __getitem__(self, idx): file_path = self.filtered_files[idx] speech, _ = librosa.load(file_path, sr=SAMPLE_RATE) if len(speech) > 30 * SAMPLE_RATE: speech = speech[: 30 * SAMPLE_RATE] speech = torch.tensor(speech, dtype=torch.float32) inputs = self._get_reference_vc(speech, hop_length=200) speaker = self.index2speaker[idx] speaker_id = self.speaker2id[speaker] inputs["speaker_id"] = speaker_id return inputs def _get_reference_vc(self, speech, hop_length): pad_size = 1600 - speech.shape[0] % 1600 speech = torch.nn.functional.pad(speech, (0, pad_size)) # hop_size frame_nums = speech.shape[0] // hop_length clip_frame_nums = np.random.randint( int(frame_nums * 0.25), int(frame_nums * 0.45) ) clip_frame_nums += (frame_nums - clip_frame_nums) % 8 start_frames, end_frames = 0, clip_frame_nums ref_speech = speech[start_frames * hop_length : end_frames * hop_length] new_speech = torch.cat( (speech[: start_frames * hop_length], speech[end_frames * hop_length :]), 0 ) ref_mask = torch.ones(len(ref_speech) // hop_length) mask = torch.ones(len(new_speech) // hop_length) if not self.use_ref_noise: # not use noise return { "speech": new_speech, "ref_speech": ref_speech, "ref_mask": ref_mask, "mask": mask, } else: # use reference noise noisy_ref_speech = self.add_noise(ref_speech) return { "speech": new_speech, "ref_speech": ref_speech, "noisy_ref_speech": noisy_ref_speech, "ref_mask": ref_mask, "mask": mask, } class BaseCollator(object): """Zero-pads model inputs and targets based on number of frames per step""" def __init__(self, cfg): self.cfg = cfg def __call__(self, batch): packed_batch_features = dict() # mel: [b, T, n_mels] # frame_pitch, frame_energy: [1, T] # target_len: [1] # spk_id: [b, 1] # mask: [b, T, 1] for key in batch[0].keys(): if key == "target_len": packed_batch_features["target_len"] = torch.LongTensor( [b["target_len"] for b in batch] ) masks = [ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "phone_len": packed_batch_features["phone_len"] = torch.LongTensor( [b["phone_len"] for b in batch] ) masks = [ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["phn_mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "audio_len": packed_batch_features["audio_len"] = torch.LongTensor( [b["audio_len"] for b in batch] ) masks = [ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch ] else: values = [torch.from_numpy(b[key]) for b in batch] packed_batch_features[key] = pad_sequence( values, batch_first=True, padding_value=0 ) return packed_batch_features class VCCollator(BaseCollator): def __init__(self, cfg): BaseCollator.__init__(self, cfg) self.use_ref_noise = self.cfg.trans_exp.use_ref_noise print(f"use_ref_noise: {self.use_ref_noise}") def __call__(self, batch): packed_batch_features = dict() # Function to handle tensor copying def process_tensor(data, dtype=torch.float32): if isinstance(data, torch.Tensor): return data.clone().detach() else: return torch.tensor(data, dtype=dtype) # Process 'speech' data speeches = [process_tensor(b["speech"]) for b in batch] packed_batch_features["speech"] = pad_sequence( speeches, batch_first=True, padding_value=0 ) # Process 'ref_speech' data ref_speeches = [process_tensor(b["ref_speech"]) for b in batch] packed_batch_features["ref_speech"] = pad_sequence( ref_speeches, batch_first=True, padding_value=0 ) # Process 'mask' data masks = [process_tensor(b["mask"]) for b in batch] packed_batch_features["mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) # Process 'ref_mask' data ref_masks = [process_tensor(b["ref_mask"]) for b in batch] packed_batch_features["ref_mask"] = pad_sequence( ref_masks, batch_first=True, padding_value=0 ) # Process 'speaker_id' data speaker_ids = [ process_tensor(b["speaker_id"], dtype=torch.int64) for b in batch ] packed_batch_features["speaker_id"] = torch.stack(speaker_ids, dim=0) if self.use_ref_noise: # Process 'noisy_ref_speech' data noisy_ref_speeches = [process_tensor(b["noisy_ref_speech"]) for b in batch] packed_batch_features["noisy_ref_speech"] = pad_sequence( noisy_ref_speeches, batch_first=True, padding_value=0 ) return packed_batch_features def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): if len(batch) == 0: return 0 if len(batch) == max_sentences: return 1 if num_tokens > max_tokens: return 1 return 0 def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). """ bsz_mult = required_batch_size_multiple sample_len = 0 sample_lens = [] batch = [] batches = [] for i in range(len(indices)): idx = indices[i] num_tokens = num_tokens_fn(idx) sample_lens.append(num_tokens) sample_len = max(sample_len, num_tokens) assert ( sample_len <= max_tokens ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( idx, sample_len, max_tokens ) num_tokens = (len(batch) + 1) * sample_len if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): mod_len = max( bsz_mult * (len(batch) // bsz_mult), len(batch) % bsz_mult, ) batches.append(batch[:mod_len]) batch = batch[mod_len:] sample_lens = sample_lens[mod_len:] sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 batch.append(idx) if len(batch) > 0: batches.append(batch) return batches