VoiceStar / data /emilia_preprocessing /step6_encodec_encode.py
mrfakename's picture
Upload 51 files
82bc972 verified
import argparse
from email.policy import default
def parse_args():
parser = argparse.ArgumentParser(description="encode the dataset using codec model")
parser.add_argument('--root', type=str, default="/data/scratch/pyp/datasets/emilia", help="Path to the directory")
parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory")
parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model")
parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes")
parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate')
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate')
parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number')
parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number')
parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing')
parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'])
return parser.parse_args()
if __name__ == "__main__":
import logging
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
import os, sys
import numpy as np
import torch
import torchaudio
import tqdm
import time
args = parse_args()
def sort_by_audio_len(lens):
inds = np.argsort(lens).tolist()
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
return inds[::-1]
def write_array_to_txt_file(array, filename):
with open(filename, 'w') as f:
for a in array[:-1]:
f.write(' '.join(map(str, a))+'\n')
f.write(' '.join(map(str, array[-1])))
class mydataset(torch.utils.data.Dataset):
def __init__(self, split):
super().__init__()
self.split = split
self.audio_dir = audio_dir
manifest_fn = os.path.join(encodec_manifest_dir, split+".txt")
cur_sp = int(args.partition.split("/")[0])-1
total_sp = int(args.partition.split("/")[1])
with open(manifest_fn, "r") as rf:
self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp]
def __len__(self):
return len(self.data)
def __getitem__(self, ind):
try:
afn = self.data[ind][0]
fn = os.path.join(self.audio_dir, afn)
audio, sr = torchaudio.load(fn)
if sr != args.model_sr:
audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio)
sr = args.model_sr
assert sr == args.model_sr, sr
except Exception as e:
# logging.info(f"{e}")
return None, None, None
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0]
def collate(self, batch):
lens, audios, segment_ids = [], [], []
for item in batch:
if item[0] != None:
audios.append(item[0])
lens.append(item[1])
segment_ids.append(item[2])
return audios, lens, segment_ids
# roots
sub_root = args.sub_root
encodec_manifest_dir = os.path.join(args.root, sub_root, "manifest_for_codec")
audio_dir = os.path.join(args.root, sub_root, "audio")
save_manifest_dir = os.path.join(args.root, sub_root,"manifest_final_encodec")
if args.encodec_name == "encodec_6f79c6a8.th":
save_codes_dir = os.path.join(args.root, sub_root,"encodec_4cb")
elif args.encodec_name == "encodec_8cb1024_giga.th":
save_codes_dir = os.path.join(args.root, sub_root,"encodec_8cb")
os.makedirs(save_manifest_dir, exist_ok=True)
os.makedirs(save_codes_dir, exist_ok=True)
def import_encodec():
from encodec import get_compression_model
userdir = os.path.expanduser("~")
model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda")
model = torch.nn.DataParallel(model)
return model
model = import_encodec()
# setup dataloader
mega_batch_size = 2048
batch_size = args.batch_size
dataset = mydataset(args.split)
if len(dataset) == 0:
logging.info(f"no data found for split {args.split} partition {args.partition}")
sys.exit(0)
loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate)
split = args.split
skip = 0
logging.info(f"now processing split {split} partition {args.partition}...")
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples")
mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt")
logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}")
with open(mani_fn, "w") as mani_wf:
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
for m, mega_batch in enumerate(tqdm.tqdm(loader, mininterval=60, maxinterval=60)):
logging.info(f"====================================")
logging.info(f"====================================")
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
try:
lengths = np.array(mega_batch[1])
sorted_inds = sort_by_audio_len(lengths)
for j in range(len(sorted_inds))[::-1]:
if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
skip += 1
del sorted_inds[j]
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
for n in tqdm.tqdm(range(n_steps), disable=True):
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
wav_batch = [mega_batch[0][id] for id in inds_used]
all_lens = [mega_batch[1][id] for id in inds_used]
segment_id_batch = [mega_batch[2][id] for id in inds_used]
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
# Extract discrete codes from EnCodec
with torch.no_grad():
if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time
codes = []
inwav = padded_wav.cuda()
codes.append(model(inwav[:len(inwav)//2])[0].cpu())
codes.append(model(inwav[len(inwav)//2:])[0].cpu())
codes = torch.cat(codes, dim=0)
else:
encoded_frames = model(padded_wav.cuda())
codes = encoded_frames[0].cpu() # [B, n_codebook, T]
for i, length in enumerate(all_lens):
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
os.makedirs(os.path.dirname(save_fn), exist_ok=True)
write_array_to_txt_file(cur_code, save_fn)
mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
# if i == 10:
# raise
except Exception as e:
print(f'exception!! at {m+1}')
print(e)
continue
# break
logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
# break