humtrans / model.py
hayaton0005's picture
Upload 11 files
c094356 verified
import torch
import torch.nn as nn
from utils import split_audio_into_segments
from utils import LogMelspec
from utils import LogCQT
from transformers import T5Config, T5ForConditionalGeneration
class Seq2SeqTranscriber(nn.Module):
def __init__(
self, n_mels: int, sample_rate: int, n_fft: int, hop_length: int, voc_dict: dict
):
super().__init__()
self.infer_max_len = 200 # 推論時の最大シーケンス長。
self.voc_dict = voc_dict # トークン辞書と
self.n_voc_token = voc_dict["n_voc"] # トークンの数を保持。
self.t5config = T5Config.from_pretrained("google/t5-v1_1-small") # Googleの事前学習済み T5 モデル(小型バージョン)の設定をロード。
# カスタム設定を T5 の設定に追加:
custom_configs = {
"vocab_size": self.n_voc_token, # トークン辞書のサイズ。
"pad_token_id": voc_dict["pad"], # パディングトークンID。
"d_model": 96, # モデルの隠れ次元数(ここではメルバンド数に設定)。
}
for k, v in custom_configs.items():
self.t5config.__setattr__(k, v)
self.transformer = T5ForConditionalGeneration(self.t5config) # カスタム設定を適用した T5 モデルをロード。
# self.melspec = LogMelspec(sample_rate, n_fft, n_mels, hop_length) # LogMelspec クラスを使用して、音声波形を対数メルスペクトログラムに変換するモジュールを作成。
# CQT モデルインスタンス作成
self.log_cqt = LogCQT(16000, 84, 128, 12)
self.sr = sample_rate # サンプルレートをインスタンス変数として保存。
# モデルの学習時に呼び出され、損失値を計算します。
def forward(self, wav, labels):
# spec = self.melspec(wav).transpose(-1, -2) # 音声波形(wav)をメルスペクトログラム(spec)に変換。LogMelspec クラスのforwardを実行
spec = self.log_cqt(wav).transpose(-1, -2)
# ※ .transpose(-1, -2): T5 モデルは通常 [バッチ, 時間ステップ, 次元] の形状を期待するため、周波数軸(メルバンド)と時間軸を入れ替えます。
# T5 モデルのフォワードパス
print("sepc.shape: ", spec.shape) # (1, n_bins, time)
outs = self.transformer.forward(
inputs_embeds=spec, return_dict=True, labels=labels
)
return outs # outs は辞書形式で損失値や出力トークン列を含む。
# 入力音声波形(wav)から推定トークン列を生成する関数
def infer(self, wav):
"""
Infer the transcription of a single audio file.
The input audio file is split into segments of 2 seconds
before passing to the transformer.
"""
wav_segs = split_audio_into_segments(wav, self.sr) # 音声波形を固定長(例: 2秒)に分割。
#spec = self.melspec(wav_segs).transpose(-1, -2) # 各セグメントをメルスペクトログラムに変換。
spec = self.log_cqt(wav_segs).transpose(-1, -2)
# generate: T5 モデルの推論モードを使用して、トークン列を生成。
outs = self.transformer.generate(
inputs_embeds=spec,
max_length=self.infer_max_len, # 推論時の最大出力長。
num_beams=5, # ビームサーチを無効化し、単純なグリーディーサーチ。
do_sample=False, # サンプリングを無効化。
return_dict_in_generate=False,
)
return outs #推論結果として生成されたトークン列を返します。 #形状: (セグメント数, 最大トークン長)