File size: 2,058 Bytes
c094356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torchaudio
import os
import soundfile as sf
import librosa

from utils import unpack_sequence, token_seg_list_to_midi
from train import LitTranscriber
from utils import rms_normalize_wav  

BASE_DIR = os.path.dirname(os.path.abspath(__file__))  # backend/src を指す
PTH_PATH = os.path.join(BASE_DIR, "model.pth")         # ✅ .pth に変更

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model():
    args = {
        "n_mels": 128,
        "sample_rate": 16000,
        "n_fft": 1024,
        "hop_length": 128,
    }
    model = LitTranscriber(transcriber_args=args, lr=1e-4, lr_decay=0.99)
    state_dict = torch.load(PTH_PATH, map_location=device)  # ✅ .pthをロード
    model.load_state_dict(state_dict)
    #model.to(device)  # ✅ デバイスに転送
    model.eval()
    return model



def convert_to_pcm_wav(input_path, output_path):
    # librosaで読み込み(自動的にPCM形式に変換される)
    y, sr = librosa.load(input_path, sr=16000, mono=True)
    sf.write(output_path, y, sr)


def infer_midi_from_wav(input_wav_path: str) -> str:
    model = load_model()

    converted_path = os.path.join(BASE_DIR, "converted_input.wav")
    convert_to_pcm_wav(input_wav_path, converted_path)

    normalized_path = os.path.join(BASE_DIR, "tmp_normalized.wav")
    rms_normalize_wav(converted_path, normalized_path, target_rms=0.1)

    waveform, sr = torchaudio.load(normalized_path)
    waveform = waveform.mean(0).to(device)

    if sr != model.transcriber.sr:
        waveform = torchaudio.functional.resample(
            waveform, sr, model.transcriber.sr
        ).to(device)

    with torch.no_grad():
        output_tokens = model(waveform)

    unpadded_tokens = unpack_sequence(output_tokens.cpu().numpy())
    unpadded_tokens = [t[1:] for t in unpadded_tokens]
    est_midi = token_seg_list_to_midi(unpadded_tokens)

    midi_path = os.path.join(BASE_DIR, "output.mid")
    est_midi.write(midi_path)
    print(f"MIDI saved at: {midi_path}")
    return midi_path