File size: 10,461 Bytes
c094356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import numpy as np
import pandas as pd
import torch
import torchaudio
import tqdm

import torch.utils.data as data

from utils import MIDITokenExtractor
from config import voc_single_track
from config import FRAME_PER_SEC, FRAME_STEP_SIZE_SEC, AUDIO_SEGMENT_SEC, SEGMENT_N_FRAMES  

"""# Dataset
Uses MAESTRO v3.0.0 dataset.
"""

class AMTDatasetBase(data.Dataset):
    def __init__(
        self,
        flist_audio, # オーディオファイルのパスをリスト形式で渡す
        flist_midi, # MIDIファイルのパスをリスト形式で渡す
        sample_rate, # オーディオファイルのサンプリングレートを指定。全てのオーディオがこれにリサンプリングされる。
        voc_dict, # トークン定義を渡す
        apply_pedal=True,
        whole_song=False,
    ):
        super().__init__()
        self.midi_filelist = flist_midi
        self.audio_filelist = flist_audio
        self.audio_metalist = [torchaudio.info(f) for f in flist_audio] # 各オーディオファイルのメタ情報(サンプルレート、フレーム数など)を収集します。
        self.voc_dict = voc_dict
        # 各MIDIファイルを MIDITokenExtractor を使ってトークン化し、その結果をリストとして保持します。
        self.midi_list = [
            MIDITokenExtractor(f, voc_dict, apply_pedal)
            for f in tqdm.tqdm(self.midi_filelist, desc="load dataset")
        ]
        self.sample_rate = sample_rate
        self.whole_song = whole_song

    def __len__(self):
        return len(self.audio_filelist)

    def __getitem__(self, index):
        """
        Return a pair of (audio, tokens) for the given index.
        On the training stage, return a random segment from the song.
        On the test stage, return the audio and MIDI of the whole song.
        """
        if not self.whole_song:
            return self.getitem_segment(index)
        else:
            return self.getitem_wholesong(index)

    def getitem_segment(self, index, start_pos=None): # 対象ファイルを指定するindexとセグメントの開始位置(フレーム単位)。Noneの場合はランダムに選択
        metadata = self.audio_metalist[index]
        num_frames = metadata.num_frames # オーディオの全体の「サンプル数」。
        sample_rate = metadata.sample_rate
        duration_y = round(num_frames / float(sample_rate) * FRAME_PER_SEC) # オーディオ全体の長さをフレーム単位に変換
        midi_item = self.midi_list[index]

        # セグメントの開始位置と終了位置(フレーム単位)を決定。
        if start_pos is None: # np.random.randint を使用して、オーディオ全体からランダムに開始位置を選択。
            segment_start = np.random.randint(duration_y - SEGMENT_N_FRAMES)
        else: # start_pos が指定されている場合
            segment_start = start_pos
        segment_end = segment_start + SEGMENT_N_FRAMES
        # オーディオセグメントのサンプル単位の開始位置
        segment_start_sample = round(
            segment_start * FRAME_STEP_SIZE_SEC * sample_rate
        )

        # セグメント範囲(segment_start ~ segment_end)に対応するMIDIトークン列を抽出。
        segment_tokens = midi_item.get_segment_tokens(segment_start, segment_end)
        segment_tokens = torch.from_numpy(segment_tokens).long() # NumPy配列をPyTorchテンソルに変換。long()でテンソルのデータ型を64ビット整数(long)に設定。

        # 指定されたセグメント範囲のオーディオデータを読み込む。
        # frame_offset から始まる範囲を num_frames サンプル分読み込む。
        y_segment, _ = torchaudio.load(
            self.audio_filelist[index],
            frame_offset=segment_start_sample,
            num_frames=round(AUDIO_SEGMENT_SEC * sample_rate),
        )
        y_segment = y_segment.mean(0) # オーディオが複数チャンネルの場合(例: ステレオ)、チャンネルを平均してモノラルに変換。

        # サンプルレートのリサンプリング
        # オーディオデータのサンプルレートが self.sample_rate と異なる場合、指定されたサンプルレートにリサンプリング。
        if sample_rate != self.sample_rate:
            y_segment = torchaudio.functional.resample(
                y_segment,
                sample_rate,
                self.sample_rate,
                resampling_method="kaiser_window",  # Kaiserウィンドウによるリサンプリングアルゴリズムを適用。
            )
        return y_segment, segment_tokens

    def getitem_wholesong(self, index):
        """
        Return a pair of (audio, midi) for the given index.
        """
        y, sr = torchaudio.load(self.audio_filelist[index]) # 読み込まれた波形データ(テンソル形式)。形状は (チャンネル数, サンプル数)。
        y = y.mean(0) # モノラル化
        # サンプルレートのリサンプリング
        if sr != self.sample_rate:
            y = torchaudio.functional.resample(
                y, sr, self.sample_rate,
                resampling_method="kaiser_window"
            )
        midi = self.midi_list[index].pm
        return y, midi

    # collateはバッチにまとめる役割の関数
    def collate_wholesong(self, batch): # batch: データセットから取り出された複数のデータ(オーディオとMIDIのペア)のリスト。
        # b[0]で各データペアの0番目の要素、つまりオーディオデータを取り出す。
        # torch.stack([...], dim=0): 複数のテンソルを新しい次元(バッチ次元)で結合。
        # 出力: テンソルの形状は (バッチサイズ, サンプル数)。
        batch_audio = torch.stack([b[0] for b in batch], dim=0)
        midi = [b[1] for b in batch] # バッチ内の各曲のMIDIデータをリストとしてまとめる。
        return batch_audio, midi # テンソル, リスト

    def collate_batch(self, batch): # データセットから取り出されたセグメント化されたオーディオテンソルとセグメント化されたMIDIトークン列のリスト。
        # b[0]で各データペアの0番目の要素、つまりオーディオデータを取り出す。
        # torch.stack([...], dim=0): 複数のテンソルを新しい次元(バッチ次元)で結合。
        # 出力: テンソルの形状は (バッチサイズ, サンプル数)。
        batch_audio = torch.stack([b[0] for b in batch], dim=0)
        batch_tokens = [b[1] for b in batch] # バッチ内の各セグメントのトークン列をテンソル?リスト形式で取得。

        # バッチ内のMIDIトークン列の長さを揃えるためにパディング
        # torch.nn.utils.rnn.pad_sequence は、異なる長さのシーケンス(テンソルリスト)をパディングして同じ長さに揃えるためのPyTorchユーティリティ(すべてのテンソルは同じ次元数である必要があります(長さ以外は一致)。)
        # batch_first = True: パディング後のテンソル形状を (バッチサイズ, 最大長さ) に設定
        batch_tokens_pad = torch.nn.utils.rnn.pad_sequence(
            batch_tokens, batch_first=True, padding_value=self.voc_dict["pad"]
        )
        return batch_audio, batch_tokens_pad # テンソル, テンソル (バッチサイズ, サンプル数), (バッチサイズ, 最大トークンの長さ)


class CustomDataset(AMTDatasetBase):
    def __init__(
        self,
        midi_root: str = "/content/drive/MyDrive/B4/Humtrans/midi",
        wav_root: str = "/content/wav_rms",
        split: str = "train",
        sample_rate: int = 16000,
        apply_pedal: bool = True,
        whole_song: bool = False,
    ):
        """
        MIDIとWAVのペアをロードするデータセットクラス

        Args:
            midi_root (str): MIDIファイルが保存されているルートフォルダ
            wav_root (str): WAVファイルが保存されているフォルダ
            split (str): 使用するデータセットの分割 ('train', 'valid', 'test')
            sample_rate (int): サンプルレート
            apply_pedal (bool): ペダルの適用
            whole_song (bool): 曲全体をロードするか
        """
        # MIDIフォルダのパスを設定
        self.midi_root = f"/content/filtered_{split}_midi"
        self.wav_root = wav_root
        self.sample_rate = sample_rate
        self.split = split

        # MIDIとWAVのペアを見つける
        flist_midi, flist_audio = self._get_paired_files()

        # 親クラスのコンストラクタを呼び出し
        super().__init__(
            flist_audio,
            flist_midi,
            sample_rate,
            voc_dict=voc_single_track,
            apply_pedal=apply_pedal,
            whole_song=whole_song,
        )


    def _get_paired_files(self):
        """
        MIDIフォルダとWAVフォルダからペアとなるファイルリストを作成する

        Returns:
            flist_midi (list): 対応するMIDIファイルのリスト
            flist_audio (list): 対応するWAVファイルのリスト
        """
        flist_midi = []
        flist_audio = []

        # MIDIフォルダからMIDIファイルを取得
        midi_files = [f for f in os.listdir(self.midi_root) if f.endswith(".mid")]

        for midi_file in midi_files:
            # MIDIファイルのパスを構築
            midi_path = os.path.join(self.midi_root, midi_file)

            # WAVファイルのパスを構築 (拡張子を変更)
            wav_file = os.path.splitext(midi_file)[0] + ".wav"
            wav_path = os.path.join(self.wav_root, wav_file)

            # WAVファイルが存在するか確認
            if os.path.exists(wav_path):
                flist_midi.append(midi_path)
                flist_audio.append(wav_path)
            else:
                print(f"対応するWAVファイルが見つかりません: {midi_file}")

        print(f"{self.split}データセット: {len(flist_midi)} ペアのMIDI-WAVが見つかりました。")
        return flist_midi, flist_audio