File size: 9,772 Bytes
82bc972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
from email.policy import default
def parse_args():
    parser = argparse.ArgumentParser(description="encode the dataset using codec model")
    parser.add_argument('--root', type=str, default="/data/scratch/pyp/datasets/emilia", help="Path to the directory")
    parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory")
    parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model")
    parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes")
    parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
    parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate')
    parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
    parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
    parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate')
    parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number')
    parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number')
    parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing')
    parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'])
    return parser.parse_args()

if __name__ == "__main__":
    import logging
    formatter = (
        "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
    )
    logging.basicConfig(format=formatter, level=logging.INFO)

    import os, sys
    import numpy as np
    import torch
    import torchaudio
    import tqdm
    import time

    args = parse_args()

    def sort_by_audio_len(lens):
        inds = np.argsort(lens).tolist()
        
        logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.")
        logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.")
        logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.")
        logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.")
        return inds[::-1]

    def write_array_to_txt_file(array, filename):
        with open(filename, 'w') as f:
            for a in array[:-1]:
                f.write(' '.join(map(str, a))+'\n')
            f.write(' '.join(map(str, array[-1])))

    class mydataset(torch.utils.data.Dataset):
        def __init__(self, split):
            super().__init__()
            self.split = split
            self.audio_dir = audio_dir
            manifest_fn = os.path.join(encodec_manifest_dir, split+".txt")
            cur_sp = int(args.partition.split("/")[0])-1
            total_sp = int(args.partition.split("/")[1])
            with open(manifest_fn, "r") as rf:
                self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp]
        def __len__(self):
            return len(self.data)
        def __getitem__(self, ind):
            try:
                afn = self.data[ind][0]
                fn = os.path.join(self.audio_dir, afn)
                audio, sr = torchaudio.load(fn)
                if sr != args.model_sr:
                    audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio)
                    sr = args.model_sr
                assert sr == args.model_sr, sr
            except Exception as e:
                # logging.info(f"{e}")
                return None, None, None
            assert audio.ndim==2 and audio.shape[0] == 1, audio.shape
            return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0]
        def collate(self, batch):
            lens, audios, segment_ids = [], [], []
            for item in batch:
                if item[0] != None:
                    audios.append(item[0])
                    lens.append(item[1])
                    segment_ids.append(item[2])
            return audios, lens, segment_ids
    
    # roots
    sub_root = args.sub_root
    encodec_manifest_dir = os.path.join(args.root, sub_root, "manifest_for_codec")
    audio_dir = os.path.join(args.root, sub_root, "audio")
    save_manifest_dir = os.path.join(args.root, sub_root,"manifest_final_encodec")
    if args.encodec_name == "encodec_6f79c6a8.th":
        save_codes_dir = os.path.join(args.root, sub_root,"encodec_4cb")
    elif args.encodec_name == "encodec_8cb1024_giga.th":
        save_codes_dir = os.path.join(args.root, sub_root,"encodec_8cb")

    os.makedirs(save_manifest_dir, exist_ok=True)
    os.makedirs(save_codes_dir, exist_ok=True)
    
    def import_encodec():
        from encodec import get_compression_model
        userdir = os.path.expanduser("~")
        model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda")
        model = torch.nn.DataParallel(model)
        return model
    model = import_encodec()
    
    # setup dataloader
    mega_batch_size = 2048
    batch_size = args.batch_size
    
    dataset = mydataset(args.split)
    if len(dataset) == 0:
        logging.info(f"no data found for split {args.split} partition {args.partition}")
        sys.exit(0)
    loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate)
    split = args.split

    skip = 0
    logging.info(f"now processing split {split} partition {args.partition}...")
    mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size))
    # mega_n_steps = int(np.ceil(len(gs) / mega_batch_size))
    logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples")
    mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt")
    logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}")
    with open(mani_fn, "w") as mani_wf:
    # with open(mani_fn, "a") as mani_wf: # resume from where we failed
        for m, mega_batch in enumerate(tqdm.tqdm(loader, mininterval=60, maxinterval=60)):

            logging.info(f"====================================")
            logging.info(f"====================================")
            logging.info(f"now processing mega step {m+1}/{mega_n_steps}")

            try:
                lengths = np.array(mega_batch[1])
                sorted_inds = sort_by_audio_len(lengths)
                for j in range(len(sorted_inds))[::-1]:
                    if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
                        skip += 1
                        del sorted_inds[j]
                
                n_steps = int(np.ceil(len(sorted_inds) / batch_size))
                for n in tqdm.tqdm(range(n_steps), disable=True):
                    inds_used = sorted_inds[n*batch_size:(n+1)*batch_size]
                    wav_batch = [mega_batch[0][id] for id in inds_used]
                    all_lens = [mega_batch[1][id] for id in inds_used]
                    segment_id_batch = [mega_batch[2][id] for id in inds_used]
                    padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
                    # Extract discrete codes from EnCodec
                    with torch.no_grad():
                        if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time
                            codes = []
                            inwav = padded_wav.cuda()
                            codes.append(model(inwav[:len(inwav)//2])[0].cpu())
                            codes.append(model(inwav[len(inwav)//2:])[0].cpu())
                            codes = torch.cat(codes, dim=0)
                        else:
                            encoded_frames = model(padded_wav.cuda()) 
                            codes = encoded_frames[0].cpu() # [B, n_codebook, T]

                    for i, length in enumerate(all_lens):
                        save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt")
                        actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model
                        cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
                        os.makedirs(os.path.dirname(save_fn), exist_ok=True)
                        write_array_to_txt_file(cur_code, save_fn)

                        mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file
                        # if i == 10:
                        #    raise
            except Exception as e:
                print(f'exception!! at {m+1}')
                print(e)
                continue

            # break
    logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short")
        # break