File size: 26,234 Bytes
82bc972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
import os
import ffmpeg
import torch
import random
import copy
import logging
import torch.distributed as dist
import shutil
import csv
import torchaudio
import glob
import numpy as np
from data.tokenizer import TextTokenizer, tokenize_text, AudioTokenizer
def find_files(root_dir, endswith=".wav"):
    files = []
    # os.walk generates the file names in a directory tree
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for filename in filenames:
            # os.path.splitext splits the file name into a base and extension
            # base, ext = os.path.splitext(filename)
            if filename.lower().endswith(endswith):
                # os.path.join combines one or more path names into a single path
                full_path = os.path.join(dirpath, filename)
                files.append(full_path)
    return files

class dataset(torch.utils.data.Dataset):
    def __init__(self, args, split):
        super().__init__()
        self.args = args
        self.args.target_time_stretch_prob = getattr(self.args, "target_time_stretch_prob", 0)
        self.args.target_time_stretch_bound = getattr(self.args, "target_time_stretch_bound", 0.1)
        self.split = split
        
        assert self.split in ['train', 'valid', 'test'], f"split should be one of ['train', 'valid', 'test'], but it's {split}"
        
        if "[" not in self.args.dataset_dir or "]" not in self.args.dataset_dir:
            self.dataset_dir = f"['{self.args.dataset_dir}']"
        else:
            self.dataset_dir = copy.deepcopy(self.args.dataset_dir)
        self.dataset_dir = eval(self.dataset_dir)
        data = []
        if "[" not in self.args.manifest_name or "]" not in self.args.manifest_name:
            self.args.manifest_name = f"['{self.args.manifest_name}']"
        else:
            self.args.manifest_name = copy.deepcopy(self.args.manifest_name)
        self.manifest_name = eval(self.args.manifest_name)
        if len(self.manifest_name) != len(self.dataset_dir):
            assert len(self.manifest_name) == 1, f"len(self.manifest_name) should be 1 or equal to len(self.dataset_dir), but it's {len(self.manifest_name)}"
            self.manifest_name = self.manifest_name * len(self.dataset_dir)
        for i_data, dataset_dir in enumerate(self.dataset_dir):
            if getattr(self.args, "no_libri_in_training", None) != None and ("librilight" in dataset_dir) and self.split == "train":
                if not dist.is_initialized() or dist.get_rank() == 0:
                    logging.info(f"skipping librilight in training split")
                continue
            n_datapoints = 0
            manifest_fn = os.path.join(dataset_dir, self.manifest_name[i_data], self.split+".txt")
            if not os.path.isfile(manifest_fn):
                all_manifest_fn = glob.glob(manifest_fn.replace(".txt", "_*=*.txt"))
                if len(all_manifest_fn) == 0:
                    logging.info(f"no manifest file found for {split} split in {dataset_dir}")
                    continue
                if self.args.debug:
                    logging.info(f"debugging mode, only using the frist found manifest file: {all_manifest_fn[0]}")
                    all_manifest_fn = all_manifest_fn[:1]
                else:
                    if dist.is_initialized() and dist.get_rank() == 0:
                        logging.info(f"Combining found manifest files for {split}: {all_manifest_fn}")
                for cur_manifest_fn in all_manifest_fn:
                    with open(cur_manifest_fn, "r") as rf:
                        tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()] # i_data is the index of the dataset
                        n_datapoints += len(tmp)
                        data += tmp
            else:
                with open(manifest_fn, "r") as rf:
                    tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()]
                    data += tmp
                    n_datapoints += len(tmp)
            if dist.is_initialized() and dist.get_rank() == 0:
                logging.info(f"number of data points for {split} split in {dataset_dir}: {n_datapoints}")
        assert len(data) > 0, f"no data found for {split} split"
        lengths_list = [int(item[1]) for item in data] # use 1 because there might be more than 1 columns (for gigaspeech we have 3 columns: path, duration, selfsim)
        self.data = []
        self.lengths_list = []
        total_duration = 0
        for d, l in zip(data, lengths_list):
            if l >= self.args.encodec_sr*self.args.audio_min_length:
                if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
                    continue
                self.data.append(d)
                self.lengths_list.append(l)
                total_duration += l / self.args.encodec_sr / 3600
        # logging.info(f"for now cut the dataset to only have 500 examples for debugging")
        # self.data = self.data[:1000]
        # self.lengths_list = self.lengths_list[:1000]
        if dist.is_initialized() and dist.get_rank() == 0:
            logging.info(f"TOTAL number of data points for {self.split} split: {len(self.lengths_list)}")
            logging.info(f"TOTAL duration for {self.split} split: {total_duration:.1f} hours")
        # phoneme vocabulary
        phn_set = set()
        for dataset_dir in self.dataset_dir:
            vocab_fn = os.path.join(dataset_dir, "vocab.txt")
            with open(vocab_fn, "r") as f:
                temp = [l.strip().split("\t") for l in f.readlines() if len(l) != 0]
                phn_set.update([item[-1] for item in temp])
        self.phn2num = {item:i for i, item in enumerate(phn_set)}
        assert self.args.text_vocab_size > len(self.phn2num), f"need self.args.text_vocab_size to be bigger than number of phns in vocab to handle OOD phn, but the former is {self.args.text_vocab_size} while the latter is {len(self.phn2num)}"

        if (self.args.neighbor_prompt_prob > 0 and self.args.time_stretch_prob > 0) or self.args.target_time_stretch_prob > 0:
            userdir = os.path.expanduser("~")
            encodec_signature = getattr(self.args, "encodec_signature", os.path.join(userdir, "VoiceStar", "pretrained", "encodec_6f79c6a8.th"))
            self.audio_tokenizer = AudioTokenizer(signature=encodec_signature, device=torch.device("cpu"), encode_only=True)
            assert self.audio_tokenizer.sample_rate == self.args.codec_audio_sr, f"audio_tokenizer.sample_rate: {self.audio_tokenizer.sample_rate}, self.args.encodec_sr: {self.args.encodec_sr}"
            if dist.is_initialized() and dist.get_rank() == 0:
                logging.info(f"rank: {dist.get_rank()}, audio_tokenizer device: {self.audio_tokenizer._device}")
    
    def __len__(self):
        return len(self.lengths_list)
    
    def _load_phn_enc(self, index):
        item = self.data[index]
        dataset_dir = self.dataset_dir[item[-1]]
        pf = os.path.join(dataset_dir, self.args.phn_folder_name, item[0]+".txt")
        ef = os.path.join(dataset_dir, self.args.encodec_folder_name, item[0]+".txt")
        # with certain probability, we load the audio, and time stretch it, note that we should not hit self.args.audio_max_length
        if "/librilight" in dataset_dir:
            audio_ext = ".flac"
        elif "/emilia" in dataset_dir:
            audio_ext = ".mp3"
        else:
            raise NotImplementedError(f"dataset_dir: {dataset_dir}")
        
        audio_fn = os.path.join(dataset_dir, self.args.audio_folder_name, item[0].replace(".txt", "")+audio_ext)
        speed_factor = random.uniform(-self.args.target_time_stretch_bound, self.args.target_time_stretch_bound) + 1
        length_ok = (float(item[1]) / self.args.encodec_sr) / speed_factor < self.args.audio_max_length # NOTE to calculate the maximal duration after time stretching, we should be used as orig/(1-bound), rather than orig*(1+bound)
        if self.args.target_time_stretch_prob > 0 and random.random() < self.args.target_time_stretch_prob and os.path.isfile(audio_fn) and length_ok: 
            try:
                with open(pf, "r") as p:
                    phns = [l.strip() for l in p.readlines()]
                    assert len(phns) == 1, phns
                    all_phns = phns[0].split(" ")
                    x = [self.phn2num[item] for item in all_phns if item in self.phn2num]
            except:
                logging.info(f"loading failed for {pf}, maybe files don't exist or are corrupted")
                return [], [[]], dataset_dir, audio_ext
            # time stretch
            try:
                process = (
                    ffmpeg.input(audio_fn, ss=0, t=float(item[1]) / self.args.encodec_sr)
                    .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
                    .run_async(pipe_stdout=True, pipe_stderr=True)
                )
                # Read the processed audio from ffmpeg stdout
                output, _ = process.communicate()

                # Convert the output to a numpy array
                output_np = np.frombuffer(output, dtype=np.float32).copy()

                # Reshape the numpy array back to the expected shape (1, samples for mono)
                waveform = torch.from_numpy(output_np)
                waveform = waveform.unsqueeze(0).unsqueeze(0)
                assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
                with torch.no_grad():
                    encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
                assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
                encos = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
                if self.args.special_first:
                    raise NotImplementedError
                    # y = [[int(n)+self.args.n_special for n in l] for l in encos]
                else:
                    y = [[int(n) for n in l] for l in encos]
                return x, y, dataset_dir, audio_ext
            except Exception as e:
                logging.info(f"failed with time stretch and codec encode for {audio_fn}")
                logging.info(f"error: {e}")
                pass

        try:
            with open(pf, "r") as p, open(ef, "r") as e:
                phns = [l.strip() for l in p.readlines()]
                assert len(phns) == 1, phns
                all_phns = phns[0].split(" ")
                x = [self.phn2num[item] for item in all_phns if item in self.phn2num] # we assume that OOD will not happen, because phn vocab is small
                encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
                
                assert len(encos) == self.args.n_codebooks, ef

                if self.args.special_first:
                    raise NotImplementedError
                    # y = [[int(n)+self.args.n_special for n in l] for l in encos]
                else:
                    y = [[int(n) for n in l] for l in encos]
        except:
            logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
            return [], [[]], dataset_dir, audio_ext

        return x, y, dataset_dir, audio_ext

    # this uses the output of step7_ipa_alignment.py
    def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext):
        neighbor = random.choice(neighbors)
        neighbor_enc_fn = os.path.join(dataset_dir, self.args.encodec_folder_name, neighbor[0])
        if not os.path.isfile(neighbor_enc_fn):
            return None, None
        neighbor_audio_path = os.path.join(dataset_dir, self.args.audio_folder_name, neighbor[0].replace(".txt", audio_ext))
        if getattr(self.args, "time_stretch_prob", 0) > 0 and not os.path.isfile(neighbor_audio_path):
            logging.info(f"audio file not found: {neighbor_audio_path}")
            return None, None
        if random.random() < getattr(self.args, "time_stretch_prob", 0):
            time_stretch_flag = True
            speed_factor = random.uniform(-self.args.time_stretch_bound, self.args.time_stretch_bound) + 1
            duration_factor = 1 / speed_factor
        else:
            time_stretch_flag = False
            duration_factor = 1

        ####################### TODO for now always use the entire neighbor for emilia
        ####################### TODO for now always use the entire neighbor for emilia
        # if it's gigaspeech or emilia, we did not run MFA forced alignment, and therefore no ipa alignment, and will just use the entire neighbor as the prompt

        if "/emilia" in dataset_dir:
            # get neighbor duration
            neighbor_dur = float(neighbor[2])
            if neighbor_dur * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length or neighbor_dur * duration_factor < self.args.min_prompt_len:
                return None, None
            try:
                neighbor_pf = os.path.join(dataset_dir, self.args.phn_folder_name, neighbor[0])
                with open(neighbor_pf, "r") as p:
                    phns = [l.strip() for l in p.readlines()]
                    assert len(phns) == 1, phns
                    all_phns = phns[0].split(" ")
                    phn_token = [self.phn2num[item] for item in all_phns if item in self.phn2num]
            except:
                logging.info(f"loading failed for {neighbor_pf}, maybe files don't exist")
                return None, None
            # if do not stretch the audio
            if not time_stretch_flag:
                with open(neighbor_enc_fn, "r") as f:
                    neighbor_enc = [l.strip().split() for l in f.readlines()]
                if len(neighbor_enc) != self.args.n_codebooks:
                    return None, None
                # if too long 
                else:
                    if self.args.special_first:
                        raise NotImplementedError
                        # neighbor_enc = [[int(n)+self.args.n_special for n in l] for l in neighbor_enc]
                    else:
                        neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
                    
                    return phn_token, neighbor_enc
            else: # stretch the audio with ffmpeg-python
                process = (
                    ffmpeg.input(neighbor_audio_path, ss=0, t=neighbor_dur)
                    .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
                    .run_async(pipe_stdout=True, pipe_stderr=True)
                    )
                # Read the processed audio from ffmpeg stdout
                output, _ = process.communicate()

                # Convert the output to a numpy array
                output_np = np.frombuffer(output, dtype=np.float32).copy()

                # Reshape the numpy array back to the expected shape (1, samples for mono)
                waveform = torch.from_numpy(output_np)
                waveform = waveform.unsqueeze(0).unsqueeze(0)
                assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
                with torch.no_grad():
                    encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
                assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
                neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
                return phn_token, neighbor_enc
        ####################### TODO for now always use the entire neighbor for emilia
        ####################### TODO for now always use the entire neighbor for emilia
        ipa_alignment_fn = os.path.join(dataset_dir, self.args.ipa_alignment_folder_name, neighbor[0])
        if not os.path.isfile(ipa_alignment_fn):
            # print(f"file not found: {ipa_alignment_fn}", flush=True)
            return None, None
        with open(ipa_alignment_fn, "r") as f:
            alignments = [l.strip().split("\t") for l in f.readlines()]
        alignments = [[float(l[0]), float(l[1]), l[2]] for l in alignments if len(l) == 3]
        alignments = [l for l in alignments if self.args.min_prompt_len < (l[1] - l[0]) * duration_factor < self.args.max_prompt_len]
        if len(alignments) == 0:
            # print(f"no valid alignment found for {ipa_alignment_fn}")
            return None, None
        idx = random.choice(range(len(alignments)))
        while (alignments[idx][1] - alignments[idx][0]) * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length:
            idx -= 1
            if idx < 0:
                # print(f"too long combined with y_len {ipa_alignment_fn=}, and {y_len=}")
                return None, None
        if (alignments[idx][1] - alignments[idx][0]) * duration_factor < self.args.min_prompt_len:
            return None, None
        
        start_time, end_time = alignments[idx][:2]
        phn = alignments[idx][2].split(" ")
        phn_token = [self.phn2num[item] for item in phn if item in self.phn2num]
        if len(phn_token) == 0:
            return None, None

        if time_stretch_flag:
            duration = end_time - start_time
            process = (
                ffmpeg.input(neighbor_audio_path, ss=start_time, t=duration)
                .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
                .run_async(pipe_stdout=True, pipe_stderr=True)
            )
            # Read the processed audio from ffmpeg stdout
            output, _ = process.communicate()

            # Convert the output to a numpy array
            output_np = np.frombuffer(output, dtype=np.float32).copy()

            # Reshape the numpy array back to the expected shape (1, samples for mono)
            waveform = torch.from_numpy(output_np)
            waveform = waveform.unsqueeze(0).unsqueeze(0)
            assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
            try:
                with torch.no_grad():
                    encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
            except:
                logging.info(f"failed with time stretch for {neighbor_audio_path}, from {start_time} to {end_time} with duration factor {duration_factor}, which leads to {duration*duration_factor} seconds")
                return None, None
            assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
            neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
            return phn_token, neighbor_enc
        else:
            # get encodec codes from storage
            with open(neighbor_enc_fn, "r") as f:
                neighbor_enc = [l.strip().split() for l in f.readlines()]
                if len(neighbor_enc) != self.args.n_codebooks:
                    # print(f"wrong number of codebooks for {neighbor_enc_fn}")
                    return None, None
                else:
                    # trim the encodec codes to the segment
                    start_enc_frame = int(start_time * self.args.encodec_sr)
                    end_enc_frame = int(end_time * self.args.encodec_sr)
                    neighbor_enc = [l[start_enc_frame:end_enc_frame] for l in neighbor_enc]
                    if len(neighbor_enc[0]) == 0:
                        # print(f"no valid encodec codes found for {neighbor_enc_fn}")
                        return None, None
                    if self.args.special_first:
                        raise NotImplementedError
                    else:
                        neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
                    return phn_token, neighbor_enc

    def __getitem__(self, index):
        x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
        x_len, y_len = len(x), len(y[0])
        extra_ret = {'x_sep_token_position': 0, 'y_sep_token_position': 0}
        if x_len == 0 or y_len == 0:
            ret = {
                "x": None, 
                "x_len": None, 
                "y": None, 
                "y_len": None,
                }
            ret.update(extra_ret)
            return ret
        while y_len < self.args.encodec_sr*self.args.audio_min_length:
            assert not self.args.dynamic_batching
            index = random.choice(range(len(self))) # regenerate an index
            x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
            x_len, y_len = len(x), len(y[0])
        
        # if use neighbor prompt
        x_neighbor, y_neighbor = None, None
        use_neighbor_prob = random.random()
        neighbor_fn = os.path.join(dataset_dir, self.args.neighbor_folder_name, self.data[index][0]+".txt")
        if self.args.neighbor_prompt_prob > 0 and use_neighbor_prob < self.args.neighbor_prompt_prob and os.path.isfile(neighbor_fn): # it might not exist, just because we didn't find neighbor for this file (other than itself, which is common for emilia)
            with open(neighbor_fn, "r") as f:
                neighbors = [l.strip().split("\t") for l in f.readlines()]
            # select neighbors
            if "maxdist" in self.args.neighbor_selection_method:
                maxdist = int(self.args.neighbor_selection_method.split("_")[-1])
                # only keep neighbors with distance within maxdist
                neighbors = [n for n in neighbors if float(n[1]) <= maxdist]
            else:
                raise NotImplementedError
            x_neighbor, y_neighbor = None, None
            if len(neighbors) > 0:
                x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
                i_trial = 0
                while x_neighbor is None and i_trial < self.args.num_trial and i_trial < len(neighbors):
                    x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
                    i_trial += 1
        
        if x_neighbor != None:
            if self.args.x_sep_token != None:
                x = x_neighbor + [self.args.x_sep_token] + x
            else:
                x = x_neighbor + x
            if self.args.y_sep_token != None:
                y = [y_neighbor[i] + [self.args.y_sep_token] + y[i] for i in range(len(y))]
            else:
                y = [y_neighbor[i] + y[i] for i in range(len(y))]
            extra_ret['y_sep_token_position'] = len(y_neighbor[0]) + 1 # if using y_sep_token, this is actually the position of the token right before the y_sep_token, but since y_sep_token is ignored in loss computation, it's fine that we use the position of the token right before it
            extra_ret['x_sep_token_position'] = len(x_neighbor) + 1
            x_len, y_len = len(x), len(y[0])


        # consider adding eos to the end of the text
        if self.args.add_eos_to_text != 0:
            x.append(self.args.add_eos_to_text)
            x_len += 1
        if getattr(self.args, "add_bos_to_text", 0) != 0:
            x = [self.args.add_bos_to_text] + x
            x_len += 1
        ### padding and cropping ###
        ### padding and cropping ###
        # adjust the length of encodec codes, pad to max_len or randomly crop
        orig_y_len = copy.copy(y_len)
        max_len = int(self.args.audio_max_length * self.args.encodec_sr)
        if y_len > max_len + 10: # give it some margin for rounding error
            raise RuntimeError(f"audio is too long, {y_len=}, {max_len=}")
        else:
            audio_start = 0
            if not self.args.dynamic_batching:
                pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
                for i in range(len(y)):
                    y[i] = y[i] + pad
        
        if self.args.pad_x and x_len <= self.args.text_max_length:
            pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
            x = x + pad
        
        ret = {
            "x": torch.LongTensor(x), 
            "x_len": x_len, 
            "y": torch.LongTensor(y), 
            "y_len": y_len,
            }
        ret.update(extra_ret)

        return ret
            

    def collate(self, batch):
        # make sure keys in every batch is the same
        for batch1, batch2 in zip(batch[:-1], batch[1:]):
            assert set(batch1.keys()) == set(batch2.keys()), f"keys in batch1: {batch1.keys()} and keys in batch2: {batch2.keys()} are different"
        out = {key:[] for key in batch[0]}
        for item in batch:
            if item['x'] == None: # deal with load failure
                continue
            for key, val in item.items():
                out[key].append(val)
        res = {}
        if self.args.pad_x:
            res["x"] = torch.stack(out["x"], dim=0)
        else:
            res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
        res["x_lens"] = torch.LongTensor(out["x_len"])
        if self.args.dynamic_batching:
            res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
            res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
        else:
            res['y'] = torch.stack(out['y'], dim=0)
        res["y_lens"] = torch.LongTensor(out["y_len"])
        res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
        res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
        if "y_sep_token_position" in out:
            res["y_sep_token_position"] = torch.LongTensor(out["y_sep_token_position"])
        if "x_sep_token_position" in out:
            res["x_sep_token_position"] = torch.LongTensor(out["x_sep_token_position"])
        return res