Spaces:
Running
Running
File size: 1,770 Bytes
c094356 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import numpy as np
import pretty_midi as pm
import mir_eval
"""# Evaluation function"""
def extract_midi(midi: pm.PrettyMIDI, program=0): # MIDIデータを読み込んだ PrettyMIDI オブジェクト, MIDIチャンネル(楽器番号)を指定。
intervals = [] # 音符ごとの開始時間と終了時間のペアを格納したNumPy配列
pitches = [] # 音符ごとの音高(MIDIノート番号)のNumPy配列。
pm_notes = midi.instruments[program].notes # programで指定された対象楽器に含まれる全てのノート情報を取得。
"""
例;
instruments = [
Instrument 0 (Piano): [Note(start=0.5, end=1.0, pitch=60), ...],
Instrument 1 (Violin): [Note(start=1.0, end=1.5, pitch=62), ...]
]
"""
# ノートを順番に処理
for note in pm_notes:
intervals.append((note.start, note.end)) # 音符の開始・終了時間のペアを intervals に追加。
pitches.append(note.pitch) # 音符の音高を pitches に追加。
return np.array(intervals), np.array(pitches) # intervals: 2D配列(各行が1つの音符の開始・終了時間を表す。), pitches: 1D配列(各要素が1つの音符の音高(ピッチ)を表す。)
def evaluate_midi(est_midi: pm.PrettyMIDI, ref_midi: pm.PrettyMIDI, program=0):
est_intervals, est_pitches = extract_midi(est_midi, program)
ref_intervals, ref_pitches = extract_midi(ref_midi, program)
# mir_eval ライブラリの transcription モジュールを使って、音符の一致度を評価します。
dict_eval = mir_eval.transcription.evaluate(
ref_intervals, ref_pitches, est_intervals, est_pitches, onset_tolerance=0.05)
return dict_eval # dict_eval: 評価結果の辞書。 |