VoiceStar / data /combined_dataset.py
mrfakename's picture
Upload 51 files
82bc972 verified
import os
import ffmpeg
import torch
import random
import copy
import logging
import torch.distributed as dist
import shutil
import csv
import torchaudio
import glob
import numpy as np
from data.tokenizer import TextTokenizer, tokenize_text, AudioTokenizer
def find_files(root_dir, endswith=".wav"):
files = []
# os.walk generates the file names in a directory tree
for dirpath, dirnames, filenames in os.walk(root_dir):
for filename in filenames:
# os.path.splitext splits the file name into a base and extension
# base, ext = os.path.splitext(filename)
if filename.lower().endswith(endswith):
# os.path.join combines one or more path names into a single path
full_path = os.path.join(dirpath, filename)
files.append(full_path)
return files
class dataset(torch.utils.data.Dataset):
def __init__(self, args, split):
super().__init__()
self.args = args
self.args.target_time_stretch_prob = getattr(self.args, "target_time_stretch_prob", 0)
self.args.target_time_stretch_bound = getattr(self.args, "target_time_stretch_bound", 0.1)
self.split = split
assert self.split in ['train', 'valid', 'test'], f"split should be one of ['train', 'valid', 'test'], but it's {split}"
if "[" not in self.args.dataset_dir or "]" not in self.args.dataset_dir:
self.dataset_dir = f"['{self.args.dataset_dir}']"
else:
self.dataset_dir = copy.deepcopy(self.args.dataset_dir)
self.dataset_dir = eval(self.dataset_dir)
data = []
if "[" not in self.args.manifest_name or "]" not in self.args.manifest_name:
self.args.manifest_name = f"['{self.args.manifest_name}']"
else:
self.args.manifest_name = copy.deepcopy(self.args.manifest_name)
self.manifest_name = eval(self.args.manifest_name)
if len(self.manifest_name) != len(self.dataset_dir):
assert len(self.manifest_name) == 1, f"len(self.manifest_name) should be 1 or equal to len(self.dataset_dir), but it's {len(self.manifest_name)}"
self.manifest_name = self.manifest_name * len(self.dataset_dir)
for i_data, dataset_dir in enumerate(self.dataset_dir):
if getattr(self.args, "no_libri_in_training", None) != None and ("librilight" in dataset_dir) and self.split == "train":
if not dist.is_initialized() or dist.get_rank() == 0:
logging.info(f"skipping librilight in training split")
continue
n_datapoints = 0
manifest_fn = os.path.join(dataset_dir, self.manifest_name[i_data], self.split+".txt")
if not os.path.isfile(manifest_fn):
all_manifest_fn = glob.glob(manifest_fn.replace(".txt", "_*=*.txt"))
if len(all_manifest_fn) == 0:
logging.info(f"no manifest file found for {split} split in {dataset_dir}")
continue
if self.args.debug:
logging.info(f"debugging mode, only using the frist found manifest file: {all_manifest_fn[0]}")
all_manifest_fn = all_manifest_fn[:1]
else:
if dist.is_initialized() and dist.get_rank() == 0:
logging.info(f"Combining found manifest files for {split}: {all_manifest_fn}")
for cur_manifest_fn in all_manifest_fn:
with open(cur_manifest_fn, "r") as rf:
tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()] # i_data is the index of the dataset
n_datapoints += len(tmp)
data += tmp
else:
with open(manifest_fn, "r") as rf:
tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()]
data += tmp
n_datapoints += len(tmp)
if dist.is_initialized() and dist.get_rank() == 0:
logging.info(f"number of data points for {split} split in {dataset_dir}: {n_datapoints}")
assert len(data) > 0, f"no data found for {split} split"
lengths_list = [int(item[1]) for item in data] # use 1 because there might be more than 1 columns (for gigaspeech we have 3 columns: path, duration, selfsim)
self.data = []
self.lengths_list = []
total_duration = 0
for d, l in zip(data, lengths_list):
if l >= self.args.encodec_sr*self.args.audio_min_length:
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
continue
self.data.append(d)
self.lengths_list.append(l)
total_duration += l / self.args.encodec_sr / 3600
# logging.info(f"for now cut the dataset to only have 500 examples for debugging")
# self.data = self.data[:1000]
# self.lengths_list = self.lengths_list[:1000]
if dist.is_initialized() and dist.get_rank() == 0:
logging.info(f"TOTAL number of data points for {self.split} split: {len(self.lengths_list)}")
logging.info(f"TOTAL duration for {self.split} split: {total_duration:.1f} hours")
# phoneme vocabulary
phn_set = set()
for dataset_dir in self.dataset_dir:
vocab_fn = os.path.join(dataset_dir, "vocab.txt")
with open(vocab_fn, "r") as f:
temp = [l.strip().split("\t") for l in f.readlines() if len(l) != 0]
phn_set.update([item[-1] for item in temp])
self.phn2num = {item:i for i, item in enumerate(phn_set)}
assert self.args.text_vocab_size > len(self.phn2num), f"need self.args.text_vocab_size to be bigger than number of phns in vocab to handle OOD phn, but the former is {self.args.text_vocab_size} while the latter is {len(self.phn2num)}"
if (self.args.neighbor_prompt_prob > 0 and self.args.time_stretch_prob > 0) or self.args.target_time_stretch_prob > 0:
userdir = os.path.expanduser("~")
encodec_signature = getattr(self.args, "encodec_signature", os.path.join(userdir, "VoiceStar", "pretrained", "encodec_6f79c6a8.th"))
self.audio_tokenizer = AudioTokenizer(signature=encodec_signature, device=torch.device("cpu"), encode_only=True)
assert self.audio_tokenizer.sample_rate == self.args.codec_audio_sr, f"audio_tokenizer.sample_rate: {self.audio_tokenizer.sample_rate}, self.args.encodec_sr: {self.args.encodec_sr}"
if dist.is_initialized() and dist.get_rank() == 0:
logging.info(f"rank: {dist.get_rank()}, audio_tokenizer device: {self.audio_tokenizer._device}")
def __len__(self):
return len(self.lengths_list)
def _load_phn_enc(self, index):
item = self.data[index]
dataset_dir = self.dataset_dir[item[-1]]
pf = os.path.join(dataset_dir, self.args.phn_folder_name, item[0]+".txt")
ef = os.path.join(dataset_dir, self.args.encodec_folder_name, item[0]+".txt")
# with certain probability, we load the audio, and time stretch it, note that we should not hit self.args.audio_max_length
if "/librilight" in dataset_dir:
audio_ext = ".flac"
elif "/emilia" in dataset_dir:
audio_ext = ".mp3"
else:
raise NotImplementedError(f"dataset_dir: {dataset_dir}")
audio_fn = os.path.join(dataset_dir, self.args.audio_folder_name, item[0].replace(".txt", "")+audio_ext)
speed_factor = random.uniform(-self.args.target_time_stretch_bound, self.args.target_time_stretch_bound) + 1
length_ok = (float(item[1]) / self.args.encodec_sr) / speed_factor < self.args.audio_max_length # NOTE to calculate the maximal duration after time stretching, we should be used as orig/(1-bound), rather than orig*(1+bound)
if self.args.target_time_stretch_prob > 0 and random.random() < self.args.target_time_stretch_prob and os.path.isfile(audio_fn) and length_ok:
try:
with open(pf, "r") as p:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
all_phns = phns[0].split(" ")
x = [self.phn2num[item] for item in all_phns if item in self.phn2num]
except:
logging.info(f"loading failed for {pf}, maybe files don't exist or are corrupted")
return [], [[]], dataset_dir, audio_ext
# time stretch
try:
process = (
ffmpeg.input(audio_fn, ss=0, t=float(item[1]) / self.args.encodec_sr)
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
.run_async(pipe_stdout=True, pipe_stderr=True)
)
# Read the processed audio from ffmpeg stdout
output, _ = process.communicate()
# Convert the output to a numpy array
output_np = np.frombuffer(output, dtype=np.float32).copy()
# Reshape the numpy array back to the expected shape (1, samples for mono)
waveform = torch.from_numpy(output_np)
waveform = waveform.unsqueeze(0).unsqueeze(0)
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
with torch.no_grad():
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
encos = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
if self.args.special_first:
raise NotImplementedError
# y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
return x, y, dataset_dir, audio_ext
except Exception as e:
logging.info(f"failed with time stretch and codec encode for {audio_fn}")
logging.info(f"error: {e}")
pass
try:
with open(pf, "r") as p, open(ef, "r") as e:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
all_phns = phns[0].split(" ")
x = [self.phn2num[item] for item in all_phns if item in self.phn2num] # we assume that OOD will not happen, because phn vocab is small
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
assert len(encos) == self.args.n_codebooks, ef
if self.args.special_first:
raise NotImplementedError
# y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
except:
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
return [], [[]], dataset_dir, audio_ext
return x, y, dataset_dir, audio_ext
# this uses the output of step7_ipa_alignment.py
def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext):
neighbor = random.choice(neighbors)
neighbor_enc_fn = os.path.join(dataset_dir, self.args.encodec_folder_name, neighbor[0])
if not os.path.isfile(neighbor_enc_fn):
return None, None
neighbor_audio_path = os.path.join(dataset_dir, self.args.audio_folder_name, neighbor[0].replace(".txt", audio_ext))
if getattr(self.args, "time_stretch_prob", 0) > 0 and not os.path.isfile(neighbor_audio_path):
logging.info(f"audio file not found: {neighbor_audio_path}")
return None, None
if random.random() < getattr(self.args, "time_stretch_prob", 0):
time_stretch_flag = True
speed_factor = random.uniform(-self.args.time_stretch_bound, self.args.time_stretch_bound) + 1
duration_factor = 1 / speed_factor
else:
time_stretch_flag = False
duration_factor = 1
####################### TODO for now always use the entire neighbor for emilia
####################### TODO for now always use the entire neighbor for emilia
# if it's gigaspeech or emilia, we did not run MFA forced alignment, and therefore no ipa alignment, and will just use the entire neighbor as the prompt
if "/emilia" in dataset_dir:
# get neighbor duration
neighbor_dur = float(neighbor[2])
if neighbor_dur * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length or neighbor_dur * duration_factor < self.args.min_prompt_len:
return None, None
try:
neighbor_pf = os.path.join(dataset_dir, self.args.phn_folder_name, neighbor[0])
with open(neighbor_pf, "r") as p:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
all_phns = phns[0].split(" ")
phn_token = [self.phn2num[item] for item in all_phns if item in self.phn2num]
except:
logging.info(f"loading failed for {neighbor_pf}, maybe files don't exist")
return None, None
# if do not stretch the audio
if not time_stretch_flag:
with open(neighbor_enc_fn, "r") as f:
neighbor_enc = [l.strip().split() for l in f.readlines()]
if len(neighbor_enc) != self.args.n_codebooks:
return None, None
# if too long
else:
if self.args.special_first:
raise NotImplementedError
# neighbor_enc = [[int(n)+self.args.n_special for n in l] for l in neighbor_enc]
else:
neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
return phn_token, neighbor_enc
else: # stretch the audio with ffmpeg-python
process = (
ffmpeg.input(neighbor_audio_path, ss=0, t=neighbor_dur)
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
.run_async(pipe_stdout=True, pipe_stderr=True)
)
# Read the processed audio from ffmpeg stdout
output, _ = process.communicate()
# Convert the output to a numpy array
output_np = np.frombuffer(output, dtype=np.float32).copy()
# Reshape the numpy array back to the expected shape (1, samples for mono)
waveform = torch.from_numpy(output_np)
waveform = waveform.unsqueeze(0).unsqueeze(0)
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
with torch.no_grad():
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
return phn_token, neighbor_enc
####################### TODO for now always use the entire neighbor for emilia
####################### TODO for now always use the entire neighbor for emilia
ipa_alignment_fn = os.path.join(dataset_dir, self.args.ipa_alignment_folder_name, neighbor[0])
if not os.path.isfile(ipa_alignment_fn):
# print(f"file not found: {ipa_alignment_fn}", flush=True)
return None, None
with open(ipa_alignment_fn, "r") as f:
alignments = [l.strip().split("\t") for l in f.readlines()]
alignments = [[float(l[0]), float(l[1]), l[2]] for l in alignments if len(l) == 3]
alignments = [l for l in alignments if self.args.min_prompt_len < (l[1] - l[0]) * duration_factor < self.args.max_prompt_len]
if len(alignments) == 0:
# print(f"no valid alignment found for {ipa_alignment_fn}")
return None, None
idx = random.choice(range(len(alignments)))
while (alignments[idx][1] - alignments[idx][0]) * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length:
idx -= 1
if idx < 0:
# print(f"too long combined with y_len {ipa_alignment_fn=}, and {y_len=}")
return None, None
if (alignments[idx][1] - alignments[idx][0]) * duration_factor < self.args.min_prompt_len:
return None, None
start_time, end_time = alignments[idx][:2]
phn = alignments[idx][2].split(" ")
phn_token = [self.phn2num[item] for item in phn if item in self.phn2num]
if len(phn_token) == 0:
return None, None
if time_stretch_flag:
duration = end_time - start_time
process = (
ffmpeg.input(neighbor_audio_path, ss=start_time, t=duration)
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
.run_async(pipe_stdout=True, pipe_stderr=True)
)
# Read the processed audio from ffmpeg stdout
output, _ = process.communicate()
# Convert the output to a numpy array
output_np = np.frombuffer(output, dtype=np.float32).copy()
# Reshape the numpy array back to the expected shape (1, samples for mono)
waveform = torch.from_numpy(output_np)
waveform = waveform.unsqueeze(0).unsqueeze(0)
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
try:
with torch.no_grad():
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
except:
logging.info(f"failed with time stretch for {neighbor_audio_path}, from {start_time} to {end_time} with duration factor {duration_factor}, which leads to {duration*duration_factor} seconds")
return None, None
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
return phn_token, neighbor_enc
else:
# get encodec codes from storage
with open(neighbor_enc_fn, "r") as f:
neighbor_enc = [l.strip().split() for l in f.readlines()]
if len(neighbor_enc) != self.args.n_codebooks:
# print(f"wrong number of codebooks for {neighbor_enc_fn}")
return None, None
else:
# trim the encodec codes to the segment
start_enc_frame = int(start_time * self.args.encodec_sr)
end_enc_frame = int(end_time * self.args.encodec_sr)
neighbor_enc = [l[start_enc_frame:end_enc_frame] for l in neighbor_enc]
if len(neighbor_enc[0]) == 0:
# print(f"no valid encodec codes found for {neighbor_enc_fn}")
return None, None
if self.args.special_first:
raise NotImplementedError
else:
neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
return phn_token, neighbor_enc
def __getitem__(self, index):
x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
extra_ret = {'x_sep_token_position': 0, 'y_sep_token_position': 0}
if x_len == 0 or y_len == 0:
ret = {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
}
ret.update(extra_ret)
return ret
while y_len < self.args.encodec_sr*self.args.audio_min_length:
assert not self.args.dynamic_batching
index = random.choice(range(len(self))) # regenerate an index
x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
# if use neighbor prompt
x_neighbor, y_neighbor = None, None
use_neighbor_prob = random.random()
neighbor_fn = os.path.join(dataset_dir, self.args.neighbor_folder_name, self.data[index][0]+".txt")
if self.args.neighbor_prompt_prob > 0 and use_neighbor_prob < self.args.neighbor_prompt_prob and os.path.isfile(neighbor_fn): # it might not exist, just because we didn't find neighbor for this file (other than itself, which is common for emilia)
with open(neighbor_fn, "r") as f:
neighbors = [l.strip().split("\t") for l in f.readlines()]
# select neighbors
if "maxdist" in self.args.neighbor_selection_method:
maxdist = int(self.args.neighbor_selection_method.split("_")[-1])
# only keep neighbors with distance within maxdist
neighbors = [n for n in neighbors if float(n[1]) <= maxdist]
else:
raise NotImplementedError
x_neighbor, y_neighbor = None, None
if len(neighbors) > 0:
x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
i_trial = 0
while x_neighbor is None and i_trial < self.args.num_trial and i_trial < len(neighbors):
x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
i_trial += 1
if x_neighbor != None:
if self.args.x_sep_token != None:
x = x_neighbor + [self.args.x_sep_token] + x
else:
x = x_neighbor + x
if self.args.y_sep_token != None:
y = [y_neighbor[i] + [self.args.y_sep_token] + y[i] for i in range(len(y))]
else:
y = [y_neighbor[i] + y[i] for i in range(len(y))]
extra_ret['y_sep_token_position'] = len(y_neighbor[0]) + 1 # if using y_sep_token, this is actually the position of the token right before the y_sep_token, but since y_sep_token is ignored in loss computation, it's fine that we use the position of the token right before it
extra_ret['x_sep_token_position'] = len(x_neighbor) + 1
x_len, y_len = len(x), len(y[0])
# consider adding eos to the end of the text
if self.args.add_eos_to_text != 0:
x.append(self.args.add_eos_to_text)
x_len += 1
if getattr(self.args, "add_bos_to_text", 0) != 0:
x = [self.args.add_bos_to_text] + x
x_len += 1
### padding and cropping ###
### padding and cropping ###
# adjust the length of encodec codes, pad to max_len or randomly crop
orig_y_len = copy.copy(y_len)
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
if y_len > max_len + 10: # give it some margin for rounding error
raise RuntimeError(f"audio is too long, {y_len=}, {max_len=}")
else:
audio_start = 0
if not self.args.dynamic_batching:
pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
for i in range(len(y)):
y[i] = y[i] + pad
if self.args.pad_x and x_len <= self.args.text_max_length:
pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
x = x + pad
ret = {
"x": torch.LongTensor(x),
"x_len": x_len,
"y": torch.LongTensor(y),
"y_len": y_len,
}
ret.update(extra_ret)
return ret
def collate(self, batch):
# make sure keys in every batch is the same
for batch1, batch2 in zip(batch[:-1], batch[1:]):
assert set(batch1.keys()) == set(batch2.keys()), f"keys in batch1: {batch1.keys()} and keys in batch2: {batch2.keys()} are different"
out = {key:[] for key in batch[0]}
for item in batch:
if item['x'] == None: # deal with load failure
continue
for key, val in item.items():
out[key].append(val)
res = {}
if self.args.pad_x:
res["x"] = torch.stack(out["x"], dim=0)
else:
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.args.dynamic_batching:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.LongTensor(out["y_len"])
res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
if "y_sep_token_position" in out:
res["y_sep_token_position"] = torch.LongTensor(out["y_sep_token_position"])
if "x_sep_token_position" in out:
res["x_sep_token_position"] = torch.LongTensor(out["x_sep_token_position"])
return res