import os import ffmpeg import torch import random import copy import logging import torch.distributed as dist import shutil import csv import torchaudio import glob import numpy as np from data.tokenizer import TextTokenizer, tokenize_text, AudioTokenizer def find_files(root_dir, endswith=".wav"): files = [] # os.walk generates the file names in a directory tree for dirpath, dirnames, filenames in os.walk(root_dir): for filename in filenames: # os.path.splitext splits the file name into a base and extension # base, ext = os.path.splitext(filename) if filename.lower().endswith(endswith): # os.path.join combines one or more path names into a single path full_path = os.path.join(dirpath, filename) files.append(full_path) return files class dataset(torch.utils.data.Dataset): def __init__(self, args, split): super().__init__() self.args = args self.args.target_time_stretch_prob = getattr(self.args, "target_time_stretch_prob", 0) self.args.target_time_stretch_bound = getattr(self.args, "target_time_stretch_bound", 0.1) self.split = split assert self.split in ['train', 'valid', 'test'], f"split should be one of ['train', 'valid', 'test'], but it's {split}" if "[" not in self.args.dataset_dir or "]" not in self.args.dataset_dir: self.dataset_dir = f"['{self.args.dataset_dir}']" else: self.dataset_dir = copy.deepcopy(self.args.dataset_dir) self.dataset_dir = eval(self.dataset_dir) data = [] if "[" not in self.args.manifest_name or "]" not in self.args.manifest_name: self.args.manifest_name = f"['{self.args.manifest_name}']" else: self.args.manifest_name = copy.deepcopy(self.args.manifest_name) self.manifest_name = eval(self.args.manifest_name) if len(self.manifest_name) != len(self.dataset_dir): assert len(self.manifest_name) == 1, f"len(self.manifest_name) should be 1 or equal to len(self.dataset_dir), but it's {len(self.manifest_name)}" self.manifest_name = self.manifest_name * len(self.dataset_dir) for i_data, dataset_dir in enumerate(self.dataset_dir): if getattr(self.args, "no_libri_in_training", None) != None and ("librilight" in dataset_dir) and self.split == "train": if not dist.is_initialized() or dist.get_rank() == 0: logging.info(f"skipping librilight in training split") continue n_datapoints = 0 manifest_fn = os.path.join(dataset_dir, self.manifest_name[i_data], self.split+".txt") if not os.path.isfile(manifest_fn): all_manifest_fn = glob.glob(manifest_fn.replace(".txt", "_*=*.txt")) if len(all_manifest_fn) == 0: logging.info(f"no manifest file found for {split} split in {dataset_dir}") continue if self.args.debug: logging.info(f"debugging mode, only using the frist found manifest file: {all_manifest_fn[0]}") all_manifest_fn = all_manifest_fn[:1] else: if dist.is_initialized() and dist.get_rank() == 0: logging.info(f"Combining found manifest files for {split}: {all_manifest_fn}") for cur_manifest_fn in all_manifest_fn: with open(cur_manifest_fn, "r") as rf: tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()] # i_data is the index of the dataset n_datapoints += len(tmp) data += tmp else: with open(manifest_fn, "r") as rf: tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()] data += tmp n_datapoints += len(tmp) if dist.is_initialized() and dist.get_rank() == 0: logging.info(f"number of data points for {split} split in {dataset_dir}: {n_datapoints}") assert len(data) > 0, f"no data found for {split} split" lengths_list = [int(item[1]) for item in data] # use 1 because there might be more than 1 columns (for gigaspeech we have 3 columns: path, duration, selfsim) self.data = [] self.lengths_list = [] total_duration = 0 for d, l in zip(data, lengths_list): if l >= self.args.encodec_sr*self.args.audio_min_length: if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length: continue self.data.append(d) self.lengths_list.append(l) total_duration += l / self.args.encodec_sr / 3600 # logging.info(f"for now cut the dataset to only have 500 examples for debugging") # self.data = self.data[:1000] # self.lengths_list = self.lengths_list[:1000] if dist.is_initialized() and dist.get_rank() == 0: logging.info(f"TOTAL number of data points for {self.split} split: {len(self.lengths_list)}") logging.info(f"TOTAL duration for {self.split} split: {total_duration:.1f} hours") # phoneme vocabulary phn_set = set() for dataset_dir in self.dataset_dir: vocab_fn = os.path.join(dataset_dir, "vocab.txt") with open(vocab_fn, "r") as f: temp = [l.strip().split("\t") for l in f.readlines() if len(l) != 0] phn_set.update([item[-1] for item in temp]) self.phn2num = {item:i for i, item in enumerate(phn_set)} assert self.args.text_vocab_size > len(self.phn2num), f"need self.args.text_vocab_size to be bigger than number of phns in vocab to handle OOD phn, but the former is {self.args.text_vocab_size} while the latter is {len(self.phn2num)}" if (self.args.neighbor_prompt_prob > 0 and self.args.time_stretch_prob > 0) or self.args.target_time_stretch_prob > 0: userdir = os.path.expanduser("~") encodec_signature = getattr(self.args, "encodec_signature", os.path.join(userdir, "VoiceStar", "pretrained", "encodec_6f79c6a8.th")) self.audio_tokenizer = AudioTokenizer(signature=encodec_signature, device=torch.device("cpu"), encode_only=True) assert self.audio_tokenizer.sample_rate == self.args.codec_audio_sr, f"audio_tokenizer.sample_rate: {self.audio_tokenizer.sample_rate}, self.args.encodec_sr: {self.args.encodec_sr}" if dist.is_initialized() and dist.get_rank() == 0: logging.info(f"rank: {dist.get_rank()}, audio_tokenizer device: {self.audio_tokenizer._device}") def __len__(self): return len(self.lengths_list) def _load_phn_enc(self, index): item = self.data[index] dataset_dir = self.dataset_dir[item[-1]] pf = os.path.join(dataset_dir, self.args.phn_folder_name, item[0]+".txt") ef = os.path.join(dataset_dir, self.args.encodec_folder_name, item[0]+".txt") # with certain probability, we load the audio, and time stretch it, note that we should not hit self.args.audio_max_length if "/librilight" in dataset_dir: audio_ext = ".flac" elif "/emilia" in dataset_dir: audio_ext = ".mp3" else: raise NotImplementedError(f"dataset_dir: {dataset_dir}") audio_fn = os.path.join(dataset_dir, self.args.audio_folder_name, item[0].replace(".txt", "")+audio_ext) speed_factor = random.uniform(-self.args.target_time_stretch_bound, self.args.target_time_stretch_bound) + 1 length_ok = (float(item[1]) / self.args.encodec_sr) / speed_factor < self.args.audio_max_length # NOTE to calculate the maximal duration after time stretching, we should be used as orig/(1-bound), rather than orig*(1+bound) if self.args.target_time_stretch_prob > 0 and random.random() < self.args.target_time_stretch_prob and os.path.isfile(audio_fn) and length_ok: try: with open(pf, "r") as p: phns = [l.strip() for l in p.readlines()] assert len(phns) == 1, phns all_phns = phns[0].split(" ") x = [self.phn2num[item] for item in all_phns if item in self.phn2num] except: logging.info(f"loading failed for {pf}, maybe files don't exist or are corrupted") return [], [[]], dataset_dir, audio_ext # time stretch try: process = ( ffmpeg.input(audio_fn, ss=0, t=float(item[1]) / self.args.encodec_sr) .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor)) .run_async(pipe_stdout=True, pipe_stderr=True) ) # Read the processed audio from ffmpeg stdout output, _ = process.communicate() # Convert the output to a numpy array output_np = np.frombuffer(output, dtype=np.float32).copy() # Reshape the numpy array back to the expected shape (1, samples for mono) waveform = torch.from_numpy(output_np) waveform = waveform.unsqueeze(0).unsqueeze(0) assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape with torch.no_grad(): encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device)) assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}" encos = encos.cpu().squeeze(0).numpy().tolist() # [K, T] if self.args.special_first: raise NotImplementedError # y = [[int(n)+self.args.n_special for n in l] for l in encos] else: y = [[int(n) for n in l] for l in encos] return x, y, dataset_dir, audio_ext except Exception as e: logging.info(f"failed with time stretch and codec encode for {audio_fn}") logging.info(f"error: {e}") pass try: with open(pf, "r") as p, open(ef, "r") as e: phns = [l.strip() for l in p.readlines()] assert len(phns) == 1, phns all_phns = phns[0].split(" ") x = [self.phn2num[item] for item in all_phns if item in self.phn2num] # we assume that OOD will not happen, because phn vocab is small encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks] assert len(encos) == self.args.n_codebooks, ef if self.args.special_first: raise NotImplementedError # y = [[int(n)+self.args.n_special for n in l] for l in encos] else: y = [[int(n) for n in l] for l in encos] except: logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted") return [], [[]], dataset_dir, audio_ext return x, y, dataset_dir, audio_ext # this uses the output of step7_ipa_alignment.py def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): neighbor = random.choice(neighbors) neighbor_enc_fn = os.path.join(dataset_dir, self.args.encodec_folder_name, neighbor[0]) if not os.path.isfile(neighbor_enc_fn): return None, None neighbor_audio_path = os.path.join(dataset_dir, self.args.audio_folder_name, neighbor[0].replace(".txt", audio_ext)) if getattr(self.args, "time_stretch_prob", 0) > 0 and not os.path.isfile(neighbor_audio_path): logging.info(f"audio file not found: {neighbor_audio_path}") return None, None if random.random() < getattr(self.args, "time_stretch_prob", 0): time_stretch_flag = True speed_factor = random.uniform(-self.args.time_stretch_bound, self.args.time_stretch_bound) + 1 duration_factor = 1 / speed_factor else: time_stretch_flag = False duration_factor = 1 ####################### TODO for now always use the entire neighbor for emilia ####################### TODO for now always use the entire neighbor for emilia # if it's gigaspeech or emilia, we did not run MFA forced alignment, and therefore no ipa alignment, and will just use the entire neighbor as the prompt if "/emilia" in dataset_dir: # get neighbor duration neighbor_dur = float(neighbor[2]) if neighbor_dur * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length or neighbor_dur * duration_factor < self.args.min_prompt_len: return None, None try: neighbor_pf = os.path.join(dataset_dir, self.args.phn_folder_name, neighbor[0]) with open(neighbor_pf, "r") as p: phns = [l.strip() for l in p.readlines()] assert len(phns) == 1, phns all_phns = phns[0].split(" ") phn_token = [self.phn2num[item] for item in all_phns if item in self.phn2num] except: logging.info(f"loading failed for {neighbor_pf}, maybe files don't exist") return None, None # if do not stretch the audio if not time_stretch_flag: with open(neighbor_enc_fn, "r") as f: neighbor_enc = [l.strip().split() for l in f.readlines()] if len(neighbor_enc) != self.args.n_codebooks: return None, None # if too long else: if self.args.special_first: raise NotImplementedError # neighbor_enc = [[int(n)+self.args.n_special for n in l] for l in neighbor_enc] else: neighbor_enc = [[int(n) for n in l] for l in neighbor_enc] return phn_token, neighbor_enc else: # stretch the audio with ffmpeg-python process = ( ffmpeg.input(neighbor_audio_path, ss=0, t=neighbor_dur) .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor)) .run_async(pipe_stdout=True, pipe_stderr=True) ) # Read the processed audio from ffmpeg stdout output, _ = process.communicate() # Convert the output to a numpy array output_np = np.frombuffer(output, dtype=np.float32).copy() # Reshape the numpy array back to the expected shape (1, samples for mono) waveform = torch.from_numpy(output_np) waveform = waveform.unsqueeze(0).unsqueeze(0) assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape with torch.no_grad(): encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device)) assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}" neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T] return phn_token, neighbor_enc ####################### TODO for now always use the entire neighbor for emilia ####################### TODO for now always use the entire neighbor for emilia ipa_alignment_fn = os.path.join(dataset_dir, self.args.ipa_alignment_folder_name, neighbor[0]) if not os.path.isfile(ipa_alignment_fn): # print(f"file not found: {ipa_alignment_fn}", flush=True) return None, None with open(ipa_alignment_fn, "r") as f: alignments = [l.strip().split("\t") for l in f.readlines()] alignments = [[float(l[0]), float(l[1]), l[2]] for l in alignments if len(l) == 3] alignments = [l for l in alignments if self.args.min_prompt_len < (l[1] - l[0]) * duration_factor < self.args.max_prompt_len] if len(alignments) == 0: # print(f"no valid alignment found for {ipa_alignment_fn}") return None, None idx = random.choice(range(len(alignments))) while (alignments[idx][1] - alignments[idx][0]) * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length: idx -= 1 if idx < 0: # print(f"too long combined with y_len {ipa_alignment_fn=}, and {y_len=}") return None, None if (alignments[idx][1] - alignments[idx][0]) * duration_factor < self.args.min_prompt_len: return None, None start_time, end_time = alignments[idx][:2] phn = alignments[idx][2].split(" ") phn_token = [self.phn2num[item] for item in phn if item in self.phn2num] if len(phn_token) == 0: return None, None if time_stretch_flag: duration = end_time - start_time process = ( ffmpeg.input(neighbor_audio_path, ss=start_time, t=duration) .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor)) .run_async(pipe_stdout=True, pipe_stderr=True) ) # Read the processed audio from ffmpeg stdout output, _ = process.communicate() # Convert the output to a numpy array output_np = np.frombuffer(output, dtype=np.float32).copy() # Reshape the numpy array back to the expected shape (1, samples for mono) waveform = torch.from_numpy(output_np) waveform = waveform.unsqueeze(0).unsqueeze(0) assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape try: with torch.no_grad(): encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device)) except: logging.info(f"failed with time stretch for {neighbor_audio_path}, from {start_time} to {end_time} with duration factor {duration_factor}, which leads to {duration*duration_factor} seconds") return None, None assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}" neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T] return phn_token, neighbor_enc else: # get encodec codes from storage with open(neighbor_enc_fn, "r") as f: neighbor_enc = [l.strip().split() for l in f.readlines()] if len(neighbor_enc) != self.args.n_codebooks: # print(f"wrong number of codebooks for {neighbor_enc_fn}") return None, None else: # trim the encodec codes to the segment start_enc_frame = int(start_time * self.args.encodec_sr) end_enc_frame = int(end_time * self.args.encodec_sr) neighbor_enc = [l[start_enc_frame:end_enc_frame] for l in neighbor_enc] if len(neighbor_enc[0]) == 0: # print(f"no valid encodec codes found for {neighbor_enc_fn}") return None, None if self.args.special_first: raise NotImplementedError else: neighbor_enc = [[int(n) for n in l] for l in neighbor_enc] return phn_token, neighbor_enc def __getitem__(self, index): x, y, dataset_dir, audio_ext = self._load_phn_enc(index) x_len, y_len = len(x), len(y[0]) extra_ret = {'x_sep_token_position': 0, 'y_sep_token_position': 0} if x_len == 0 or y_len == 0: ret = { "x": None, "x_len": None, "y": None, "y_len": None, } ret.update(extra_ret) return ret while y_len < self.args.encodec_sr*self.args.audio_min_length: assert not self.args.dynamic_batching index = random.choice(range(len(self))) # regenerate an index x, y, dataset_dir, audio_ext = self._load_phn_enc(index) x_len, y_len = len(x), len(y[0]) # if use neighbor prompt x_neighbor, y_neighbor = None, None use_neighbor_prob = random.random() neighbor_fn = os.path.join(dataset_dir, self.args.neighbor_folder_name, self.data[index][0]+".txt") if self.args.neighbor_prompt_prob > 0 and use_neighbor_prob < self.args.neighbor_prompt_prob and os.path.isfile(neighbor_fn): # it might not exist, just because we didn't find neighbor for this file (other than itself, which is common for emilia) with open(neighbor_fn, "r") as f: neighbors = [l.strip().split("\t") for l in f.readlines()] # select neighbors if "maxdist" in self.args.neighbor_selection_method: maxdist = int(self.args.neighbor_selection_method.split("_")[-1]) # only keep neighbors with distance within maxdist neighbors = [n for n in neighbors if float(n[1]) <= maxdist] else: raise NotImplementedError x_neighbor, y_neighbor = None, None if len(neighbors) > 0: x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext) i_trial = 0 while x_neighbor is None and i_trial < self.args.num_trial and i_trial < len(neighbors): x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext) i_trial += 1 if x_neighbor != None: if self.args.x_sep_token != None: x = x_neighbor + [self.args.x_sep_token] + x else: x = x_neighbor + x if self.args.y_sep_token != None: y = [y_neighbor[i] + [self.args.y_sep_token] + y[i] for i in range(len(y))] else: y = [y_neighbor[i] + y[i] for i in range(len(y))] extra_ret['y_sep_token_position'] = len(y_neighbor[0]) + 1 # if using y_sep_token, this is actually the position of the token right before the y_sep_token, but since y_sep_token is ignored in loss computation, it's fine that we use the position of the token right before it extra_ret['x_sep_token_position'] = len(x_neighbor) + 1 x_len, y_len = len(x), len(y[0]) # consider adding eos to the end of the text if self.args.add_eos_to_text != 0: x.append(self.args.add_eos_to_text) x_len += 1 if getattr(self.args, "add_bos_to_text", 0) != 0: x = [self.args.add_bos_to_text] + x x_len += 1 ### padding and cropping ### ### padding and cropping ### # adjust the length of encodec codes, pad to max_len or randomly crop orig_y_len = copy.copy(y_len) max_len = int(self.args.audio_max_length * self.args.encodec_sr) if y_len > max_len + 10: # give it some margin for rounding error raise RuntimeError(f"audio is too long, {y_len=}, {max_len=}") else: audio_start = 0 if not self.args.dynamic_batching: pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len) for i in range(len(y)): y[i] = y[i] + pad if self.args.pad_x and x_len <= self.args.text_max_length: pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len) x = x + pad ret = { "x": torch.LongTensor(x), "x_len": x_len, "y": torch.LongTensor(y), "y_len": y_len, } ret.update(extra_ret) return ret def collate(self, batch): # make sure keys in every batch is the same for batch1, batch2 in zip(batch[:-1], batch[1:]): assert set(batch1.keys()) == set(batch2.keys()), f"keys in batch1: {batch1.keys()} and keys in batch2: {batch2.keys()} are different" out = {key:[] for key in batch[0]} for item in batch: if item['x'] == None: # deal with load failure continue for key, val in item.items(): out[key].append(val) res = {} if self.args.pad_x: res["x"] = torch.stack(out["x"], dim=0) else: res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token) res["x_lens"] = torch.LongTensor(out["x_len"]) if self.args.dynamic_batching: res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token) res['y'] = res['y'].permute(1,2,0) # T B K -> B K T else: res['y'] = torch.stack(out['y'], dim=0) res["y_lens"] = torch.LongTensor(out["y_len"]) res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1) res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1) if "y_sep_token_position" in out: res["y_sep_token_position"] = torch.LongTensor(out["y_sep_token_position"]) if "x_sep_token_position" in out: res["x_sep_token_position"] = torch.LongTensor(out["x_sep_token_position"]) return res