File size: 3,944 Bytes
5381499 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from typing import Callable, Optional, Union
from tqdm import tqdm
import os
import torch
import torchaudio
import torchaudio.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset, random_split
from pytorch_lightning import LightningDataModule
import webdataset
class VLSP2020Dataset(Dataset):
def __init__(self, root: str, sample_rate: int = 16000):
super().__init__()
self.sample_rate = sample_rate
self.memory = self._prepare_data(root)
self._memory = tuple(
(v["transcript"], v["audio"]) for v in self.memory.values()
)
@staticmethod
def _prepare_data(root: str):
memory = {}
for f in os.scandir(root):
file_name, file_ext = os.path.splitext(f.name)
if file_ext == ".txt":
if file_name not in memory:
memory[file_name] = {"transcript": f.path}
elif "transcript" not in memory[file_name]:
memory[file_name]["transcript"] = f.path
else:
raise ValueError(f"Duplicate transcript for {f.path}")
else:
if file_name not in memory:
memory[file_name] = {"audio": f.path}
elif "audio" not in memory[file_name]:
memory[file_name]["audio"] = f.path
else:
raise ValueError(f"Duplicate audio for {f.path}")
for key, value in memory.items():
if "audio" not in value:
raise ValueError(f"Missing audio for {key}")
elif "transcript" not in value:
raise ValueError(f"Missing transcript for {key}")
return memory
def __len__(self):
return len(self.memory)
def __getitem__(self, index: int):
transcript, audio = self._memory[index]
with open(transcript, "r") as f:
transcript = f.read()
audio, sample_rate = torchaudio.load(audio)
audio = F.resample(audio, sample_rate, self.sample_rate)
return transcript, audio
class VLSP2020TarDataset:
def __init__(self, outpath: str):
self.outpath = outpath
def convert(self, dataset: VLSP2020Dataset):
writer = webdataset.TarWriter(self.outpath)
for idx, (transcript, audio) in enumerate(tqdm(dataset, colour="green")):
writer.write(
{
"__key__": f"{idx:08d}",
"txt": transcript,
"pth": audio,
}
)
writer.close()
def load(self) -> webdataset.WebDataset:
self.data = (
webdataset.WebDataset(self.outpath)
.decode(
webdataset.handle_extension("txt", lambda x: x.decode("utf-8")),
webdataset.torch_audio,
)
.to_tuple("txt", "pth")
)
return self.data
def get_dataloader(
dataset: Union[VLSP2020Dataset, webdataset.WebDataset],
return_transcript: bool = False,
target_transform: Optional[Callable] = None,
batch_size: int = 32,
num_workers: int = 2,
):
def collate_fn(batch):
def get_audio(item):
audio = item[1]
assert (
isinstance(audio, torch.Tensor)
and audio.ndim == 2
and audio.size(0) == 1
)
return audio.squeeze(0)
audio = tuple(get_audio(item) for item in batch)
if return_transcript:
if target_transform is not None:
transcript = tuple(target_transform(item[0]) for item in batch)
else:
transcript = tuple(item[0] for item in batch)
return transcript, audio
else:
return audio
return DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn
)
|