Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from utils import split_audio_into_segments | |
from utils import LogMelspec | |
from utils import LogCQT | |
from transformers import T5Config, T5ForConditionalGeneration | |
class Seq2SeqTranscriber(nn.Module): | |
def __init__( | |
self, n_mels: int, sample_rate: int, n_fft: int, hop_length: int, voc_dict: dict | |
): | |
super().__init__() | |
self.infer_max_len = 200 # 推論時の最大シーケンス長。 | |
self.voc_dict = voc_dict # トークン辞書と | |
self.n_voc_token = voc_dict["n_voc"] # トークンの数を保持。 | |
self.t5config = T5Config.from_pretrained("google/t5-v1_1-small") # Googleの事前学習済み T5 モデル(小型バージョン)の設定をロード。 | |
# カスタム設定を T5 の設定に追加: | |
custom_configs = { | |
"vocab_size": self.n_voc_token, # トークン辞書のサイズ。 | |
"pad_token_id": voc_dict["pad"], # パディングトークンID。 | |
"d_model": 96, # モデルの隠れ次元数(ここではメルバンド数に設定)。 | |
} | |
for k, v in custom_configs.items(): | |
self.t5config.__setattr__(k, v) | |
self.transformer = T5ForConditionalGeneration(self.t5config) # カスタム設定を適用した T5 モデルをロード。 | |
# self.melspec = LogMelspec(sample_rate, n_fft, n_mels, hop_length) # LogMelspec クラスを使用して、音声波形を対数メルスペクトログラムに変換するモジュールを作成。 | |
# CQT モデルインスタンス作成 | |
self.log_cqt = LogCQT(16000, 84, 128, 12) | |
self.sr = sample_rate # サンプルレートをインスタンス変数として保存。 | |
# モデルの学習時に呼び出され、損失値を計算します。 | |
def forward(self, wav, labels): | |
# spec = self.melspec(wav).transpose(-1, -2) # 音声波形(wav)をメルスペクトログラム(spec)に変換。LogMelspec クラスのforwardを実行 | |
spec = self.log_cqt(wav).transpose(-1, -2) | |
# ※ .transpose(-1, -2): T5 モデルは通常 [バッチ, 時間ステップ, 次元] の形状を期待するため、周波数軸(メルバンド)と時間軸を入れ替えます。 | |
# T5 モデルのフォワードパス | |
print("sepc.shape: ", spec.shape) # (1, n_bins, time) | |
outs = self.transformer.forward( | |
inputs_embeds=spec, return_dict=True, labels=labels | |
) | |
return outs # outs は辞書形式で損失値や出力トークン列を含む。 | |
# 入力音声波形(wav)から推定トークン列を生成する関数 | |
def infer(self, wav): | |
""" | |
Infer the transcription of a single audio file. | |
The input audio file is split into segments of 2 seconds | |
before passing to the transformer. | |
""" | |
wav_segs = split_audio_into_segments(wav, self.sr) # 音声波形を固定長(例: 2秒)に分割。 | |
#spec = self.melspec(wav_segs).transpose(-1, -2) # 各セグメントをメルスペクトログラムに変換。 | |
spec = self.log_cqt(wav_segs).transpose(-1, -2) | |
# generate: T5 モデルの推論モードを使用して、トークン列を生成。 | |
outs = self.transformer.generate( | |
inputs_embeds=spec, | |
max_length=self.infer_max_len, # 推論時の最大出力長。 | |
num_beams=5, # ビームサーチを無効化し、単純なグリーディーサーチ。 | |
do_sample=False, # サンプリングを無効化。 | |
return_dict_in_generate=False, | |
) | |
return outs #推論結果として生成されたトークン列を返します。 #形状: (セグメント数, 最大トークン長) | |