Spaces:
Running
on
Zero
Running
on
Zero
File size: 26,234 Bytes
82bc972 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 |
import os
import ffmpeg
import torch
import random
import copy
import logging
import torch.distributed as dist
import shutil
import csv
import torchaudio
import glob
import numpy as np
from data.tokenizer import TextTokenizer, tokenize_text, AudioTokenizer
def find_files(root_dir, endswith=".wav"):
files = []
# os.walk generates the file names in a directory tree
for dirpath, dirnames, filenames in os.walk(root_dir):
for filename in filenames:
# os.path.splitext splits the file name into a base and extension
# base, ext = os.path.splitext(filename)
if filename.lower().endswith(endswith):
# os.path.join combines one or more path names into a single path
full_path = os.path.join(dirpath, filename)
files.append(full_path)
return files
class dataset(torch.utils.data.Dataset):
def __init__(self, args, split):
super().__init__()
self.args = args
self.args.target_time_stretch_prob = getattr(self.args, "target_time_stretch_prob", 0)
self.args.target_time_stretch_bound = getattr(self.args, "target_time_stretch_bound", 0.1)
self.split = split
assert self.split in ['train', 'valid', 'test'], f"split should be one of ['train', 'valid', 'test'], but it's {split}"
if "[" not in self.args.dataset_dir or "]" not in self.args.dataset_dir:
self.dataset_dir = f"['{self.args.dataset_dir}']"
else:
self.dataset_dir = copy.deepcopy(self.args.dataset_dir)
self.dataset_dir = eval(self.dataset_dir)
data = []
if "[" not in self.args.manifest_name or "]" not in self.args.manifest_name:
self.args.manifest_name = f"['{self.args.manifest_name}']"
else:
self.args.manifest_name = copy.deepcopy(self.args.manifest_name)
self.manifest_name = eval(self.args.manifest_name)
if len(self.manifest_name) != len(self.dataset_dir):
assert len(self.manifest_name) == 1, f"len(self.manifest_name) should be 1 or equal to len(self.dataset_dir), but it's {len(self.manifest_name)}"
self.manifest_name = self.manifest_name * len(self.dataset_dir)
for i_data, dataset_dir in enumerate(self.dataset_dir):
if getattr(self.args, "no_libri_in_training", None) != None and ("librilight" in dataset_dir) and self.split == "train":
if not dist.is_initialized() or dist.get_rank() == 0:
logging.info(f"skipping librilight in training split")
continue
n_datapoints = 0
manifest_fn = os.path.join(dataset_dir, self.manifest_name[i_data], self.split+".txt")
if not os.path.isfile(manifest_fn):
all_manifest_fn = glob.glob(manifest_fn.replace(".txt", "_*=*.txt"))
if len(all_manifest_fn) == 0:
logging.info(f"no manifest file found for {split} split in {dataset_dir}")
continue
if self.args.debug:
logging.info(f"debugging mode, only using the frist found manifest file: {all_manifest_fn[0]}")
all_manifest_fn = all_manifest_fn[:1]
else:
if dist.is_initialized() and dist.get_rank() == 0:
logging.info(f"Combining found manifest files for {split}: {all_manifest_fn}")
for cur_manifest_fn in all_manifest_fn:
with open(cur_manifest_fn, "r") as rf:
tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()] # i_data is the index of the dataset
n_datapoints += len(tmp)
data += tmp
else:
with open(manifest_fn, "r") as rf:
tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()]
data += tmp
n_datapoints += len(tmp)
if dist.is_initialized() and dist.get_rank() == 0:
logging.info(f"number of data points for {split} split in {dataset_dir}: {n_datapoints}")
assert len(data) > 0, f"no data found for {split} split"
lengths_list = [int(item[1]) for item in data] # use 1 because there might be more than 1 columns (for gigaspeech we have 3 columns: path, duration, selfsim)
self.data = []
self.lengths_list = []
total_duration = 0
for d, l in zip(data, lengths_list):
if l >= self.args.encodec_sr*self.args.audio_min_length:
if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
continue
self.data.append(d)
self.lengths_list.append(l)
total_duration += l / self.args.encodec_sr / 3600
# logging.info(f"for now cut the dataset to only have 500 examples for debugging")
# self.data = self.data[:1000]
# self.lengths_list = self.lengths_list[:1000]
if dist.is_initialized() and dist.get_rank() == 0:
logging.info(f"TOTAL number of data points for {self.split} split: {len(self.lengths_list)}")
logging.info(f"TOTAL duration for {self.split} split: {total_duration:.1f} hours")
# phoneme vocabulary
phn_set = set()
for dataset_dir in self.dataset_dir:
vocab_fn = os.path.join(dataset_dir, "vocab.txt")
with open(vocab_fn, "r") as f:
temp = [l.strip().split("\t") for l in f.readlines() if len(l) != 0]
phn_set.update([item[-1] for item in temp])
self.phn2num = {item:i for i, item in enumerate(phn_set)}
assert self.args.text_vocab_size > len(self.phn2num), f"need self.args.text_vocab_size to be bigger than number of phns in vocab to handle OOD phn, but the former is {self.args.text_vocab_size} while the latter is {len(self.phn2num)}"
if (self.args.neighbor_prompt_prob > 0 and self.args.time_stretch_prob > 0) or self.args.target_time_stretch_prob > 0:
userdir = os.path.expanduser("~")
encodec_signature = getattr(self.args, "encodec_signature", os.path.join(userdir, "VoiceStar", "pretrained", "encodec_6f79c6a8.th"))
self.audio_tokenizer = AudioTokenizer(signature=encodec_signature, device=torch.device("cpu"), encode_only=True)
assert self.audio_tokenizer.sample_rate == self.args.codec_audio_sr, f"audio_tokenizer.sample_rate: {self.audio_tokenizer.sample_rate}, self.args.encodec_sr: {self.args.encodec_sr}"
if dist.is_initialized() and dist.get_rank() == 0:
logging.info(f"rank: {dist.get_rank()}, audio_tokenizer device: {self.audio_tokenizer._device}")
def __len__(self):
return len(self.lengths_list)
def _load_phn_enc(self, index):
item = self.data[index]
dataset_dir = self.dataset_dir[item[-1]]
pf = os.path.join(dataset_dir, self.args.phn_folder_name, item[0]+".txt")
ef = os.path.join(dataset_dir, self.args.encodec_folder_name, item[0]+".txt")
# with certain probability, we load the audio, and time stretch it, note that we should not hit self.args.audio_max_length
if "/librilight" in dataset_dir:
audio_ext = ".flac"
elif "/emilia" in dataset_dir:
audio_ext = ".mp3"
else:
raise NotImplementedError(f"dataset_dir: {dataset_dir}")
audio_fn = os.path.join(dataset_dir, self.args.audio_folder_name, item[0].replace(".txt", "")+audio_ext)
speed_factor = random.uniform(-self.args.target_time_stretch_bound, self.args.target_time_stretch_bound) + 1
length_ok = (float(item[1]) / self.args.encodec_sr) / speed_factor < self.args.audio_max_length # NOTE to calculate the maximal duration after time stretching, we should be used as orig/(1-bound), rather than orig*(1+bound)
if self.args.target_time_stretch_prob > 0 and random.random() < self.args.target_time_stretch_prob and os.path.isfile(audio_fn) and length_ok:
try:
with open(pf, "r") as p:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
all_phns = phns[0].split(" ")
x = [self.phn2num[item] for item in all_phns if item in self.phn2num]
except:
logging.info(f"loading failed for {pf}, maybe files don't exist or are corrupted")
return [], [[]], dataset_dir, audio_ext
# time stretch
try:
process = (
ffmpeg.input(audio_fn, ss=0, t=float(item[1]) / self.args.encodec_sr)
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
.run_async(pipe_stdout=True, pipe_stderr=True)
)
# Read the processed audio from ffmpeg stdout
output, _ = process.communicate()
# Convert the output to a numpy array
output_np = np.frombuffer(output, dtype=np.float32).copy()
# Reshape the numpy array back to the expected shape (1, samples for mono)
waveform = torch.from_numpy(output_np)
waveform = waveform.unsqueeze(0).unsqueeze(0)
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
with torch.no_grad():
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
encos = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
if self.args.special_first:
raise NotImplementedError
# y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
return x, y, dataset_dir, audio_ext
except Exception as e:
logging.info(f"failed with time stretch and codec encode for {audio_fn}")
logging.info(f"error: {e}")
pass
try:
with open(pf, "r") as p, open(ef, "r") as e:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
all_phns = phns[0].split(" ")
x = [self.phn2num[item] for item in all_phns if item in self.phn2num] # we assume that OOD will not happen, because phn vocab is small
encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
assert len(encos) == self.args.n_codebooks, ef
if self.args.special_first:
raise NotImplementedError
# y = [[int(n)+self.args.n_special for n in l] for l in encos]
else:
y = [[int(n) for n in l] for l in encos]
except:
logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
return [], [[]], dataset_dir, audio_ext
return x, y, dataset_dir, audio_ext
# this uses the output of step7_ipa_alignment.py
def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext):
neighbor = random.choice(neighbors)
neighbor_enc_fn = os.path.join(dataset_dir, self.args.encodec_folder_name, neighbor[0])
if not os.path.isfile(neighbor_enc_fn):
return None, None
neighbor_audio_path = os.path.join(dataset_dir, self.args.audio_folder_name, neighbor[0].replace(".txt", audio_ext))
if getattr(self.args, "time_stretch_prob", 0) > 0 and not os.path.isfile(neighbor_audio_path):
logging.info(f"audio file not found: {neighbor_audio_path}")
return None, None
if random.random() < getattr(self.args, "time_stretch_prob", 0):
time_stretch_flag = True
speed_factor = random.uniform(-self.args.time_stretch_bound, self.args.time_stretch_bound) + 1
duration_factor = 1 / speed_factor
else:
time_stretch_flag = False
duration_factor = 1
####################### TODO for now always use the entire neighbor for emilia
####################### TODO for now always use the entire neighbor for emilia
# if it's gigaspeech or emilia, we did not run MFA forced alignment, and therefore no ipa alignment, and will just use the entire neighbor as the prompt
if "/emilia" in dataset_dir:
# get neighbor duration
neighbor_dur = float(neighbor[2])
if neighbor_dur * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length or neighbor_dur * duration_factor < self.args.min_prompt_len:
return None, None
try:
neighbor_pf = os.path.join(dataset_dir, self.args.phn_folder_name, neighbor[0])
with open(neighbor_pf, "r") as p:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
all_phns = phns[0].split(" ")
phn_token = [self.phn2num[item] for item in all_phns if item in self.phn2num]
except:
logging.info(f"loading failed for {neighbor_pf}, maybe files don't exist")
return None, None
# if do not stretch the audio
if not time_stretch_flag:
with open(neighbor_enc_fn, "r") as f:
neighbor_enc = [l.strip().split() for l in f.readlines()]
if len(neighbor_enc) != self.args.n_codebooks:
return None, None
# if too long
else:
if self.args.special_first:
raise NotImplementedError
# neighbor_enc = [[int(n)+self.args.n_special for n in l] for l in neighbor_enc]
else:
neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
return phn_token, neighbor_enc
else: # stretch the audio with ffmpeg-python
process = (
ffmpeg.input(neighbor_audio_path, ss=0, t=neighbor_dur)
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
.run_async(pipe_stdout=True, pipe_stderr=True)
)
# Read the processed audio from ffmpeg stdout
output, _ = process.communicate()
# Convert the output to a numpy array
output_np = np.frombuffer(output, dtype=np.float32).copy()
# Reshape the numpy array back to the expected shape (1, samples for mono)
waveform = torch.from_numpy(output_np)
waveform = waveform.unsqueeze(0).unsqueeze(0)
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
with torch.no_grad():
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
return phn_token, neighbor_enc
####################### TODO for now always use the entire neighbor for emilia
####################### TODO for now always use the entire neighbor for emilia
ipa_alignment_fn = os.path.join(dataset_dir, self.args.ipa_alignment_folder_name, neighbor[0])
if not os.path.isfile(ipa_alignment_fn):
# print(f"file not found: {ipa_alignment_fn}", flush=True)
return None, None
with open(ipa_alignment_fn, "r") as f:
alignments = [l.strip().split("\t") for l in f.readlines()]
alignments = [[float(l[0]), float(l[1]), l[2]] for l in alignments if len(l) == 3]
alignments = [l for l in alignments if self.args.min_prompt_len < (l[1] - l[0]) * duration_factor < self.args.max_prompt_len]
if len(alignments) == 0:
# print(f"no valid alignment found for {ipa_alignment_fn}")
return None, None
idx = random.choice(range(len(alignments)))
while (alignments[idx][1] - alignments[idx][0]) * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length:
idx -= 1
if idx < 0:
# print(f"too long combined with y_len {ipa_alignment_fn=}, and {y_len=}")
return None, None
if (alignments[idx][1] - alignments[idx][0]) * duration_factor < self.args.min_prompt_len:
return None, None
start_time, end_time = alignments[idx][:2]
phn = alignments[idx][2].split(" ")
phn_token = [self.phn2num[item] for item in phn if item in self.phn2num]
if len(phn_token) == 0:
return None, None
if time_stretch_flag:
duration = end_time - start_time
process = (
ffmpeg.input(neighbor_audio_path, ss=start_time, t=duration)
.output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor))
.run_async(pipe_stdout=True, pipe_stderr=True)
)
# Read the processed audio from ffmpeg stdout
output, _ = process.communicate()
# Convert the output to a numpy array
output_np = np.frombuffer(output, dtype=np.float32).copy()
# Reshape the numpy array back to the expected shape (1, samples for mono)
waveform = torch.from_numpy(output_np)
waveform = waveform.unsqueeze(0).unsqueeze(0)
assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape
try:
with torch.no_grad():
encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device))
except:
logging.info(f"failed with time stretch for {neighbor_audio_path}, from {start_time} to {end_time} with duration factor {duration_factor}, which leads to {duration*duration_factor} seconds")
return None, None
assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}"
neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T]
return phn_token, neighbor_enc
else:
# get encodec codes from storage
with open(neighbor_enc_fn, "r") as f:
neighbor_enc = [l.strip().split() for l in f.readlines()]
if len(neighbor_enc) != self.args.n_codebooks:
# print(f"wrong number of codebooks for {neighbor_enc_fn}")
return None, None
else:
# trim the encodec codes to the segment
start_enc_frame = int(start_time * self.args.encodec_sr)
end_enc_frame = int(end_time * self.args.encodec_sr)
neighbor_enc = [l[start_enc_frame:end_enc_frame] for l in neighbor_enc]
if len(neighbor_enc[0]) == 0:
# print(f"no valid encodec codes found for {neighbor_enc_fn}")
return None, None
if self.args.special_first:
raise NotImplementedError
else:
neighbor_enc = [[int(n) for n in l] for l in neighbor_enc]
return phn_token, neighbor_enc
def __getitem__(self, index):
x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
extra_ret = {'x_sep_token_position': 0, 'y_sep_token_position': 0}
if x_len == 0 or y_len == 0:
ret = {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
}
ret.update(extra_ret)
return ret
while y_len < self.args.encodec_sr*self.args.audio_min_length:
assert not self.args.dynamic_batching
index = random.choice(range(len(self))) # regenerate an index
x, y, dataset_dir, audio_ext = self._load_phn_enc(index)
x_len, y_len = len(x), len(y[0])
# if use neighbor prompt
x_neighbor, y_neighbor = None, None
use_neighbor_prob = random.random()
neighbor_fn = os.path.join(dataset_dir, self.args.neighbor_folder_name, self.data[index][0]+".txt")
if self.args.neighbor_prompt_prob > 0 and use_neighbor_prob < self.args.neighbor_prompt_prob and os.path.isfile(neighbor_fn): # it might not exist, just because we didn't find neighbor for this file (other than itself, which is common for emilia)
with open(neighbor_fn, "r") as f:
neighbors = [l.strip().split("\t") for l in f.readlines()]
# select neighbors
if "maxdist" in self.args.neighbor_selection_method:
maxdist = int(self.args.neighbor_selection_method.split("_")[-1])
# only keep neighbors with distance within maxdist
neighbors = [n for n in neighbors if float(n[1]) <= maxdist]
else:
raise NotImplementedError
x_neighbor, y_neighbor = None, None
if len(neighbors) > 0:
x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
i_trial = 0
while x_neighbor is None and i_trial < self.args.num_trial and i_trial < len(neighbors):
x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext)
i_trial += 1
if x_neighbor != None:
if self.args.x_sep_token != None:
x = x_neighbor + [self.args.x_sep_token] + x
else:
x = x_neighbor + x
if self.args.y_sep_token != None:
y = [y_neighbor[i] + [self.args.y_sep_token] + y[i] for i in range(len(y))]
else:
y = [y_neighbor[i] + y[i] for i in range(len(y))]
extra_ret['y_sep_token_position'] = len(y_neighbor[0]) + 1 # if using y_sep_token, this is actually the position of the token right before the y_sep_token, but since y_sep_token is ignored in loss computation, it's fine that we use the position of the token right before it
extra_ret['x_sep_token_position'] = len(x_neighbor) + 1
x_len, y_len = len(x), len(y[0])
# consider adding eos to the end of the text
if self.args.add_eos_to_text != 0:
x.append(self.args.add_eos_to_text)
x_len += 1
if getattr(self.args, "add_bos_to_text", 0) != 0:
x = [self.args.add_bos_to_text] + x
x_len += 1
### padding and cropping ###
### padding and cropping ###
# adjust the length of encodec codes, pad to max_len or randomly crop
orig_y_len = copy.copy(y_len)
max_len = int(self.args.audio_max_length * self.args.encodec_sr)
if y_len > max_len + 10: # give it some margin for rounding error
raise RuntimeError(f"audio is too long, {y_len=}, {max_len=}")
else:
audio_start = 0
if not self.args.dynamic_batching:
pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
for i in range(len(y)):
y[i] = y[i] + pad
if self.args.pad_x and x_len <= self.args.text_max_length:
pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
x = x + pad
ret = {
"x": torch.LongTensor(x),
"x_len": x_len,
"y": torch.LongTensor(y),
"y_len": y_len,
}
ret.update(extra_ret)
return ret
def collate(self, batch):
# make sure keys in every batch is the same
for batch1, batch2 in zip(batch[:-1], batch[1:]):
assert set(batch1.keys()) == set(batch2.keys()), f"keys in batch1: {batch1.keys()} and keys in batch2: {batch2.keys()} are different"
out = {key:[] for key in batch[0]}
for item in batch:
if item['x'] == None: # deal with load failure
continue
for key, val in item.items():
out[key].append(val)
res = {}
if self.args.pad_x:
res["x"] = torch.stack(out["x"], dim=0)
else:
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.args.dynamic_batching:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.LongTensor(out["y_len"])
res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
if "y_sep_token_position" in out:
res["y_sep_token_position"] = torch.LongTensor(out["y_sep_token_position"])
if "x_sep_token_position" in out:
res["x_sep_token_position"] = torch.LongTensor(out["x_sep_token_position"])
return res |