File size: 3,944 Bytes
5381499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import Callable, Optional, Union
from tqdm import tqdm
import os
import torch
import torchaudio
import torchaudio.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset, random_split
from pytorch_lightning import LightningDataModule
import webdataset


class VLSP2020Dataset(Dataset):
    def __init__(self, root: str, sample_rate: int = 16000):
        super().__init__()

        self.sample_rate = sample_rate
        self.memory = self._prepare_data(root)
        self._memory = tuple(
            (v["transcript"], v["audio"]) for v in self.memory.values()
        )

    @staticmethod
    def _prepare_data(root: str):
        memory = {}

        for f in os.scandir(root):
            file_name, file_ext = os.path.splitext(f.name)

            if file_ext == ".txt":
                if file_name not in memory:
                    memory[file_name] = {"transcript": f.path}
                elif "transcript" not in memory[file_name]:
                    memory[file_name]["transcript"] = f.path
                else:
                    raise ValueError(f"Duplicate transcript for {f.path}")
            else:
                if file_name not in memory:
                    memory[file_name] = {"audio": f.path}
                elif "audio" not in memory[file_name]:
                    memory[file_name]["audio"] = f.path
                else:
                    raise ValueError(f"Duplicate audio for {f.path}")

        for key, value in memory.items():
            if "audio" not in value:
                raise ValueError(f"Missing audio for {key}")
            elif "transcript" not in value:
                raise ValueError(f"Missing transcript for {key}")

        return memory

    def __len__(self):
        return len(self.memory)

    def __getitem__(self, index: int):
        transcript, audio = self._memory[index]

        with open(transcript, "r") as f:
            transcript = f.read()

        audio, sample_rate = torchaudio.load(audio)
        audio = F.resample(audio, sample_rate, self.sample_rate)

        return transcript, audio


class VLSP2020TarDataset:
    def __init__(self, outpath: str):
        self.outpath = outpath

    def convert(self, dataset: VLSP2020Dataset):
        writer = webdataset.TarWriter(self.outpath)

        for idx, (transcript, audio) in enumerate(tqdm(dataset, colour="green")):
            writer.write(
                {
                    "__key__": f"{idx:08d}",
                    "txt": transcript,
                    "pth": audio,
                }
            )

        writer.close()

    def load(self) -> webdataset.WebDataset:
        self.data = (
            webdataset.WebDataset(self.outpath)
            .decode(
                webdataset.handle_extension("txt", lambda x: x.decode("utf-8")),
                webdataset.torch_audio,
            )
            .to_tuple("txt", "pth")
        )

        return self.data


def get_dataloader(
    dataset: Union[VLSP2020Dataset, webdataset.WebDataset],
    return_transcript: bool = False,
    target_transform: Optional[Callable] = None,
    batch_size: int = 32,
    num_workers: int = 2,
):
    def collate_fn(batch):
        def get_audio(item):
            audio = item[1]

            assert (
                isinstance(audio, torch.Tensor)
                and audio.ndim == 2
                and audio.size(0) == 1
            )

            return audio.squeeze(0)

        audio = tuple(get_audio(item) for item in batch)

        if return_transcript:
            if target_transform is not None:
                transcript = tuple(target_transform(item[0]) for item in batch)
            else:
                transcript = tuple(item[0] for item in batch)

            return transcript, audio
        else:
            return audio

    return DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn
    )