Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
from email.policy import default | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="encode the dataset using codec model") | |
parser.add_argument('--root', type=str, default="/data/scratch/pyp/datasets/emilia", help="Path to the directory") | |
parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory") | |
parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model") | |
parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes") | |
parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus") | |
parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate') | |
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate') | |
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate') | |
parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate') | |
parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number') | |
parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number') | |
parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing') | |
parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test']) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
import logging | |
formatter = ( | |
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" | |
) | |
logging.basicConfig(format=formatter, level=logging.INFO) | |
import os, sys | |
import numpy as np | |
import torch | |
import torchaudio | |
import tqdm | |
import time | |
args = parse_args() | |
def sort_by_audio_len(lens): | |
inds = np.argsort(lens).tolist() | |
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.") | |
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.") | |
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.") | |
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.") | |
return inds[::-1] | |
def write_array_to_txt_file(array, filename): | |
with open(filename, 'w') as f: | |
for a in array[:-1]: | |
f.write(' '.join(map(str, a))+'\n') | |
f.write(' '.join(map(str, array[-1]))) | |
class mydataset(torch.utils.data.Dataset): | |
def __init__(self, split): | |
super().__init__() | |
self.split = split | |
self.audio_dir = audio_dir | |
manifest_fn = os.path.join(encodec_manifest_dir, split+".txt") | |
cur_sp = int(args.partition.split("/")[0])-1 | |
total_sp = int(args.partition.split("/")[1]) | |
with open(manifest_fn, "r") as rf: | |
self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp] | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, ind): | |
try: | |
afn = self.data[ind][0] | |
fn = os.path.join(self.audio_dir, afn) | |
audio, sr = torchaudio.load(fn) | |
if sr != args.model_sr: | |
audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio) | |
sr = args.model_sr | |
assert sr == args.model_sr, sr | |
except Exception as e: | |
# logging.info(f"{e}") | |
return None, None, None | |
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape | |
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0] | |
def collate(self, batch): | |
lens, audios, segment_ids = [], [], [] | |
for item in batch: | |
if item[0] != None: | |
audios.append(item[0]) | |
lens.append(item[1]) | |
segment_ids.append(item[2]) | |
return audios, lens, segment_ids | |
# roots | |
sub_root = args.sub_root | |
encodec_manifest_dir = os.path.join(args.root, sub_root, "manifest_for_codec") | |
audio_dir = os.path.join(args.root, sub_root, "audio") | |
save_manifest_dir = os.path.join(args.root, sub_root,"manifest_final_encodec") | |
if args.encodec_name == "encodec_6f79c6a8.th": | |
save_codes_dir = os.path.join(args.root, sub_root,"encodec_4cb") | |
elif args.encodec_name == "encodec_8cb1024_giga.th": | |
save_codes_dir = os.path.join(args.root, sub_root,"encodec_8cb") | |
os.makedirs(save_manifest_dir, exist_ok=True) | |
os.makedirs(save_codes_dir, exist_ok=True) | |
def import_encodec(): | |
from encodec import get_compression_model | |
userdir = os.path.expanduser("~") | |
model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda") | |
model = torch.nn.DataParallel(model) | |
return model | |
model = import_encodec() | |
# setup dataloader | |
mega_batch_size = 2048 | |
batch_size = args.batch_size | |
dataset = mydataset(args.split) | |
if len(dataset) == 0: | |
logging.info(f"no data found for split {args.split} partition {args.partition}") | |
sys.exit(0) | |
loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate) | |
split = args.split | |
skip = 0 | |
logging.info(f"now processing split {split} partition {args.partition}...") | |
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size)) | |
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size)) | |
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples") | |
mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt") | |
logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}") | |
with open(mani_fn, "w") as mani_wf: | |
# with open(mani_fn, "a") as mani_wf: # resume from where we failed | |
for m, mega_batch in enumerate(tqdm.tqdm(loader, mininterval=60, maxinterval=60)): | |
logging.info(f"====================================") | |
logging.info(f"====================================") | |
logging.info(f"now processing mega step {m+1}/{mega_n_steps}") | |
try: | |
lengths = np.array(mega_batch[1]) | |
sorted_inds = sort_by_audio_len(lengths) | |
for j in range(len(sorted_inds))[::-1]: | |
if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) | |
skip += 1 | |
del sorted_inds[j] | |
n_steps = int(np.ceil(len(sorted_inds) / batch_size)) | |
for n in tqdm.tqdm(range(n_steps), disable=True): | |
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size] | |
wav_batch = [mega_batch[0][id] for id in inds_used] | |
all_lens = [mega_batch[1][id] for id in inds_used] | |
segment_id_batch = [mega_batch[2][id] for id in inds_used] | |
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T] | |
# Extract discrete codes from EnCodec | |
with torch.no_grad(): | |
if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time | |
codes = [] | |
inwav = padded_wav.cuda() | |
codes.append(model(inwav[:len(inwav)//2])[0].cpu()) | |
codes.append(model(inwav[len(inwav)//2:])[0].cpu()) | |
codes = torch.cat(codes, dim=0) | |
else: | |
encoded_frames = model(padded_wav.cuda()) | |
codes = encoded_frames[0].cpu() # [B, n_codebook, T] | |
for i, length in enumerate(all_lens): | |
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt") | |
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model | |
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist() | |
os.makedirs(os.path.dirname(save_fn), exist_ok=True) | |
write_array_to_txt_file(cur_code, save_fn) | |
mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file | |
# if i == 10: | |
# raise | |
except Exception as e: | |
print(f'exception!! at {m+1}') | |
print(e) | |
continue | |
# break | |
logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short") | |
# break | |