Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,772 Bytes
82bc972 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import argparse
from email.policy import default
def parse_args():
parser = argparse.ArgumentParser(description="encode the dataset using codec model")
parser.add_argument('--root', type=str, default="/data/scratch/pyp/datasets/emilia", help="Path to the directory")
parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory")
parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model")
parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes")
parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate')
parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate')
parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number')
parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number')
parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing')
parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'])
return parser.parse_args()
if __name__ == "__main__":
import logging
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
import os, sys
import numpy as np
import torch
import torchaudio
import tqdm
import time
args = parse_args()
def sort_by_audio_len(lens):
inds = np.argsort(lens).tolist()
logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
return inds[::-1]
def write_array_to_txt_file(array, filename):
with open(filename, 'w') as f:
for a in array[:-1]:
f.write(' '.join(map(str, a))+'\n')
f.write(' '.join(map(str, array[-1])))
class mydataset(torch.utils.data.Dataset):
def __init__(self, split):
super().__init__()
self.split = split
self.audio_dir = audio_dir
manifest_fn = os.path.join(encodec_manifest_dir, split+".txt")
cur_sp = int(args.partition.split("/")[0])-1
total_sp = int(args.partition.split("/")[1])
with open(manifest_fn, "r") as rf:
self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp]
def __len__(self):
return len(self.data)
def __getitem__(self, ind):
try:
afn = self.data[ind][0]
fn = os.path.join(self.audio_dir, afn)
audio, sr = torchaudio.load(fn)
if sr != args.model_sr:
audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio)
sr = args.model_sr
assert sr == args.model_sr, sr
except Exception as e:
# logging.info(f"{e}")
return None, None, None
assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0]
def collate(self, batch):
lens, audios, segment_ids = [], [], []
for item in batch:
if item[0] != None:
audios.append(item[0])
lens.append(item[1])
segment_ids.append(item[2])
return audios, lens, segment_ids
# roots
sub_root = args.sub_root
encodec_manifest_dir = os.path.join(args.root, sub_root, "manifest_for_codec")
audio_dir = os.path.join(args.root, sub_root, "audio")
save_manifest_dir = os.path.join(args.root, sub_root,"manifest_final_encodec")
if args.encodec_name == "encodec_6f79c6a8.th":
save_codes_dir = os.path.join(args.root, sub_root,"encodec_4cb")
elif args.encodec_name == "encodec_8cb1024_giga.th":
save_codes_dir = os.path.join(args.root, sub_root,"encodec_8cb")
os.makedirs(save_manifest_dir, exist_ok=True)
os.makedirs(save_codes_dir, exist_ok=True)
def import_encodec():
from encodec import get_compression_model
userdir = os.path.expanduser("~")
model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda")
model = torch.nn.DataParallel(model)
return model
model = import_encodec()
# setup dataloader
mega_batch_size = 2048
batch_size = args.batch_size
dataset = mydataset(args.split)
if len(dataset) == 0:
logging.info(f"no data found for split {args.split} partition {args.partition}")
sys.exit(0)
loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate)
split = args.split
skip = 0
logging.info(f"now processing split {split} partition {args.partition}...")
mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
# mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples")
mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt")
logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}")
with open(mani_fn, "w") as mani_wf:
# with open(mani_fn, "a") as mani_wf: # resume from where we failed
for m, mega_batch in enumerate(tqdm.tqdm(loader, mininterval=60, maxinterval=60)):
logging.info(f"====================================")
logging.info(f"====================================")
logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
try:
lengths = np.array(mega_batch[1])
sorted_inds = sort_by_audio_len(lengths)
for j in range(len(sorted_inds))[::-1]:
if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
skip += 1
del sorted_inds[j]
n_steps = int(np.ceil(len(sorted_inds) / batch_size))
for n in tqdm.tqdm(range(n_steps), disable=True):
inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
wav_batch = [mega_batch[0][id] for id in inds_used]
all_lens = [mega_batch[1][id] for id in inds_used]
segment_id_batch = [mega_batch[2][id] for id in inds_used]
padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
# Extract discrete codes from EnCodec
with torch.no_grad():
if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time
codes = []
inwav = padded_wav.cuda()
codes.append(model(inwav[:len(inwav)//2])[0].cpu())
codes.append(model(inwav[len(inwav)//2:])[0].cpu())
codes = torch.cat(codes, dim=0)
else:
encoded_frames = model(padded_wav.cuda())
codes = encoded_frames[0].cpu() # [B, n_codebook, T]
for i, length in enumerate(all_lens):
save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
os.makedirs(os.path.dirname(save_fn), exist_ok=True)
write_array_to_txt_file(cur_code, save_fn)
mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
# if i == 10:
# raise
except Exception as e:
print(f'exception!! at {m+1}')
print(e)
continue
# break
logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
# break
|