gemma-3-4b-it-speech / examples /evaluate_speech.py
junnei's picture
Upload evaluate_speech.py
800f55d verified
from io import BytesIO
from urllib.request import urlopen
import soundfile
import torch
from datasets import load_dataset, Audio
import numpy as np
from transformers import AutoModel, AutoProcessor, BatchFeature
from tqdm import tqdm
import json
import os
import time
from datetime import datetime
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
import sacrebleu
from jiwer import cer, wer
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import re
normalizer = {
"en_us" : EnglishTextNormalizer(),
"ko_kr" : BasicTextNormalizer()
}
# λͺ¨λΈ 및 ν”„λ‘œμ„Έμ„œ λ‘œλ“œ
model_id = "junnei/gemma-3-4b-it-speech"
revision = "main" #"v1.0"
model = AutoModel.from_pretrained(
model_id, device_map="auto", revision = revision, trust_remote_code=True
).eval()
processor = AutoProcessor.from_pretrained(
model_id, revision = revision, trust_remote_code=True
)
# κ²°κ³Ό μ €μž₯ 디렉토리 생성
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(results_dir, exist_ok=True)
INSTRUCTION = {
"ast": "Translate the audio to {0}.",
"asr": "Transcribe the audio clip into text.",
}
class BaseAudioDataset(Dataset):
def __init__(self, processor, split, sampling_rate=16000, debug=False):
self.processor = processor
self.training = "train" in split
self.debug = debug
self.sampling_rate = sampling_rate
self.name = ""
def set_dataset_name(self, name):
self.name = name
@staticmethod
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
original_size = len(data)
data = data.cast_column(audio_field, Audio(decode=False))
def identify_corrupted_files(example):
try:
sf.read(example[audio_field]["path"])
for field in text_fields:
if example[field].replace('"', '') == "":
return False
return True
except Exception:
return False
data = data.filter(identify_corrupted_files, num_proc=16)
validated_size = len(data)
# μ˜€λ””μ˜€ λ””μ½”λ”©
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
if debug:
print(f"데이터셋: {dataset_name}")
print(f"원본 데이터 개수: {original_size}")
print(f"필터링 ν›„ 데이터 개수: {validated_size}")
print(f"필터링 λΉ„μœ¨: {validated_size/original_size:.2%}")
return data
@staticmethod
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
original_size = len(data)
def filter_audio_by_length(example):
try:
audio = example[audio_field]['array']
channel = 1
if hasattr(audio, 'ndim') and audio.ndim > 1:
channel = audio.ndim
audio = audio.squeeze()
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
return min_sec <= audio_length <= max_sec
except Exception as e:
if debug:
print(f"였λ₯˜ λ°œμƒ: {str(e)[:100]}... - μƒ˜ν”Œ μ œμ™Έλ¨")
return False
data = data.filter(filter_audio_by_length, num_proc=16)
filtered_size = len(data)
if debug:
print(f"길이 필터링 μ „ 데이터 개수: {original_size}")
print(f"길이 필터링 ν›„ 데이터 개수: {filtered_size}")
print(f"필터링 λΉ„μœ¨: {filtered_size/original_size:.2%}")
return data
def prepare_model_inputs(self, audio_array, instruction, answer_text):
user_message = {
'role': 'user',
'content': '<start_of_audio>' + instruction,
}
prompt = self.processor.tokenizer.apply_chat_template(
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
)
inputs = self.processor(
text=prompt,
audio=[audio_array],
add_special_tokens=False,
return_tensors='pt'
)
input_ids = inputs.input_ids
token_type_ids = inputs.token_type_ids
return {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'input_audio_embeds': inputs.input_audio_embeds,
'audio_embed_sizes': inputs.audio_embed_sizes,
'input_modes': inputs.input_modes,
'answer': answer_text,
}
# CoVoST2 Dataset Class
class CoVoSTDataset(BaseAudioDataset):
def __init__(self, processor, data_dir, split, ast=False,
lang=("en_ko", "Korean"), sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name("CoVoST")
self.ast = ast
self.lang = lang[0]
self.data = load_dataset("junnei/covost2",
lang[0],
data_dir=data_dir,
split=split,
trust_remote_code=True
)
text_fields = ["sentence", "translation"] if ast else ["sentence"]
self.data = self.filter_corrupted_files(self.data, "audio", text_fields, "CoVoST")
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# Instruction Setting
self.instruction = INSTRUCTION["ast"].format(lang[1]) if ast else INSTRUCTION["asr"]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
if self.ast:
answer_text = data["translation"]
else:
answer_text = data["sentence"].replace('"', '')
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
# Libri Speech Dataset Class
class LibriSpeechDataset(BaseAudioDataset):
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name(f"LibriSpeech_{subset}")
# only ASR
self.ast = False
self.lang = "en"
if split == "train":
split = "train.360"
# load dataset
self.data = load_dataset("fixie-ai/librispeech_asr",
subset,
split=split,
trust_remote_code=True
)
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# Instruction Setting
self.instruction = INSTRUCTION["asr"]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
# Libri Speech is only for ASR
answer_text = data["text"].replace('"', '')
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
# Fleurs Dataset Class
class FleursDataset(BaseAudioDataset):
def __init__(self, processor, split, source_lang, target_lang=None,
mode="asr", sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name("Fleurs")
# Mode Setting (ASR or AST)
if mode not in ["asr", "ast"]:
raise ValueError("mode must be 'asr' or 'ast'.")
self.mode = mode
self.ast = (mode == "ast")
self.source_lang = source_lang
# Language name mapping (expand if needed)
self.lang_names = {
'en_us': 'English', 'ko_kr': 'Korean'
}
# load dataset - source language dataset
self.data = load_dataset("google/fleurs",
source_lang,
split=split,
trust_remote_code=True
)
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# When AST mode, load target language dataset.
if self.ast:
if target_lang is None:
raise ValueError("AST mode requires target_lang.")
self.target_lang = target_lang
self.lang = f"{source_lang}_{target_lang}"
# load dataset - target language dataset (for translation)
target_data = load_dataset("google/fleurs",
target_lang,
split=split,
trust_remote_code=True
)
source_dict = {item['id']: item for item in self.data}
target_dict = {item['id']: item for item in target_data}
# only Common ID, add translation fields
common_ids = set(source_dict.keys()) & set(target_dict.keys())
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
self.data = [
{**source_dict[id], 'translation': target_dict[id]['transcription']}
for id in common_ids
]
# Instruction Setting - use target language name
target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
self.instruction = INSTRUCTION["ast"].format(target_lang_name)
else:
# ASR mode
self.lang = source_lang
self.instruction = INSTRUCTION["asr"]
if self.debug:
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
if self.ast:
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
print(f"dataset size: {len(self.data)}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
audio_array = data["audio"]["array"]
if self.ast:
answer_text = data["translation"]
else:
answer_text = data["transcription"]
return self.prepare_model_inputs(
audio_array,
self.instruction,
answer_text
)
def pad_sequence(sequences, padding_side='left', padding_value=0):
"""
Pad a list of sequences to the same length.
sequences: list of tensors in [seq_len, *] shape
"""
assert padding_side in ['right', 'left']
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max(len(seq) for seq in sequences)
batch_size = len(sequences)
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
for i, seq in enumerate(sequences):
length = seq.size(0)
if padding_side == 'right':
output.data[i, :length] = seq
else:
output.data[i, -length:] = seq
return output
def cat_with_pad(tensors, dim, padding_value=0):
"""
cat along dim, while pad to max for all other dims
"""
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in tensors[1:]
), 'All tensors must have the same number of dimensions'
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
out_size[dim] = sum(t.shape[dim] for t in tensors)
output = tensors[0].new_full(out_size, padding_value)
index = 0
for t in tensors:
# Create a slice list where every dimension except dim is full slice
slices = [slice(0, t.shape[d]) for d in range(ndim)]
# Update only the concat dimension slice
slices[dim] = slice(index, index + t.shape[dim])
output[slices] = t
index += t.shape[dim]
return output
def covost_collate_fn(batch):
input_ids_list = []
input_audio_embeds_list = []
audio_embed_sizes_list = []
audio_attention_mask_list = []
input_modes_list = []
answer_list = []
for inputs in batch:
input_ids_list.append(inputs['input_ids'][0])
input_audio_embeds_list.append(inputs['input_audio_embeds'])
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
audio_attention_mask_list.append(
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
)
input_modes_list.append(inputs['input_modes'])
answer_list.append(inputs['answer'])
try:
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
audio_attention_mask = (
pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False)
if len(audio_attention_mask_list) > 1
else None
)
except Exception as e:
print(e)
print(input_ids_list)
print(audio_attention_mask)
raise
attention_mask = (input_ids != 0).long()
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
input_modes = torch.cat(input_modes_list)
return BatchFeature(
{
'input_ids': input_ids,
'attention_mask': attention_mask,
'input_audio_embeds': input_audio_embeds,
'audio_embed_sizes': audio_embed_sizes,
'audio_attention_mask': audio_attention_mask,
'input_modes': input_modes,
'answer': answer_list,
}
)
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None):
"""κ²°κ³Όλ₯Ό JSON 파일둜 μ €μž₯"""
filename = f"{task}_{dataset_name}_{source_lang}"
if target_lang:
filename += f"_to_{target_lang}"
if sample_idx is not None:
filename += f"_sample_{sample_idx}"
filepath = os.path.join(results_dir, f"{filename}.json")
# 결과에 νƒ€μž„μŠ€νƒ¬ν”„ μΆ”κ°€
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"κ²°κ³Όκ°€ {filepath}에 μ €μž₯λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
return filepath
def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 32, is_asr=True):
"""ASR(μžλ™ μŒμ„± 인식) μ„±λŠ₯ 평가"""
task_type = "asr" if is_asr else "translation"
eval_lang = source_lang if is_asr else target_lang
eval_normalizer = normalizer[eval_lang]
sample_results = []
# μƒ˜ν”Œ 수 처리
if num_samples > 0 and num_samples < len(dataset):
indices = np.random.choice(len(dataset), num_samples, replace=False)
dataset = dataset.select(indices)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn)
evaluated_samples = {}
# 배치 λ‹¨μœ„λ‘œ 처리
for batch_idx, batch in enumerate(tqdm(dataloader)):
batch_references = batch.pop("answer")
# GPU둜 이동
if torch.cuda.is_available():
batch = {k: v.to("cuda") for k, v in batch.items()}
# 배치 μΆ”λ‘ 
with torch.inference_mode():
generate_ids = model.generate(**batch,
max_new_tokens=256,
#temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True
)
input_lengths = batch['input_ids'].shape[1]
generate_ids = generate_ids[:, input_lengths:]
# λ””μ½”λ”©
batch_predictions = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# κ²°κ³Ό μ €μž₯
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)):
idx = batch_idx * batch_size + i
sample_result = {
"id": idx,
"reference": reference,
"prediction": prediction
}
sample_results.append(sample_result)
# 10λ°°μΉ˜λ§ˆλ‹€ 쀑간 κ²°κ³Ό μ €μž₯
if (batch_idx + 1) % 10 == 0:
temp_results = []
# λͺ¨λ“  μƒ˜ν”Œμ— λŒ€ν•΄ 처리
for item in sample_results:
sample_id = item["id"]
# 이미 ν‰κ°€λœ μƒ˜ν”Œμ€ 평가 κ²°κ³Όλ₯Ό μž¬μ‚¬μš©
if sample_id in evaluated_samples:
temp_item = item.copy()
temp_item.update(evaluated_samples[sample_id])
temp_results.append(temp_item)
else:
# 아직 ν‰κ°€λ˜μ§€ μ•Šμ€ μƒ˜ν”Œμ€ μƒˆλ‘œ 평가
temp_item = item.copy()
try:
ref = eval_normalizer(item["reference"])
pred = eval_normalizer(item["prediction"])
# BLEU, WER/CER 계산
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
utt_wer = round(wer(ref, pred) * 100, 2)
metrics = {
"bleu": utt_bleu,
"cer": utt_cer,
"wer": utt_wer
}
# 평가 κ²°κ³Ό μ €μž₯
evaluated_samples[sample_id] = metrics
temp_item.update(metrics)
except Exception as e:
print(f"Error evaluating sample {sample_id}: {e}")
# 였λ₯˜ λ°œμƒ μ‹œ κΈ°λ³Έκ°’ μ„€μ •
metrics = {
"bleu": 0,
"cer": 100,
"wer": 100,
"error": str(e)
}
evaluated_samples[sample_id] = metrics
temp_item.update(metrics)
temp_results.append(temp_item)
partial_results = {
"task": task_type,
"source_lang": source_lang,
"target_lang": target_lang,
"num_samples": len(temp_results),
"sample_results": temp_results
}
save_results(partial_results, dataset.name, task_type, source_lang, target_lang)
for item in sample_results:
ref = eval_normalizer(item["reference"])
pred = eval_normalizer(item["prediction"])
# BLEU, WER/CER 계산
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2)
utt_wer = round(wer(ref, pred) * 100, 2)
item.update({
"bleu": utt_bleu,
"cer": utt_cer,
"wer": utt_wer
})
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results)
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results)
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results)
results = {
"dataset": dataset.name,
"task": task_type,
"source_lang": source_lang,
"target_lang": target_lang,
"num_samples": len(sample_results),
"metrics": {
"bleu": avg_bleu,
"cer": avg_cer,
"wer": avg_wer
},
"sample_results": sample_results
}
# μ΅œμ’… κ²°κ³Ό μ €μž₯
save_results(results, dataset.name, task_type, source_lang, target_lang)
return results
# 메인 μ‹€ν–‰ μ½”λ“œ
if __name__ == "__main__":
# 평가할 μ–Έμ–΄ λͺ©λ‘ (μ†ŒμŠ€ μ–Έμ–΄)
source_languages = [
#("ko_kr", "Korean"),
("en_us", "English"), # μ˜μ–΄ (λ―Έκ΅­)
]
# λ²ˆμ—­ λŒ€μƒ μ–Έμ–΄ λͺ©λ‘ (μ½”λ“œ, 이름)
target_languages = [
#("en_us", "English"),
("ko_kr", "Korean"),
]
data_dir = {
#"ko_kr" : "/workspace/CommonVoice/ko",
"en_us" : "/workspace/CommonVoice/EN",
}
# μƒ˜ν”Œ 수 μ„€μ • (-1은 전체 데이터셋 μ‚¬μš©)
num_samples = -1
batch_size = 32
# λͺ¨λ“  μ†ŒμŠ€ 언어에 λŒ€ν•΄ ASR 평가
for source_lang, target_lang in zip(source_languages, target_languages):
print(f"\n===== {source_lang[0]} ASR 평가 μ‹œμž‘ =====")
# 데이터셋 λ‘œλ“œ
split = "test"
datasets = []
# Covost ASR mode (English -> English text)
covost = CoVoSTDataset(
processor=processor,
data_dir="/workspace/CommonVoice/EN",
split=split,
ast=False,
lang=("en_ko", "Korean")
)
datasets.append(covost)
# Libri Speech Clean ASR mode (English -> English text)
libri_speech_clean = LibriSpeechDataset(
processor=processor,
subset="clean",
split=split
)
datasets.append(libri_speech_clean)
# Libri Speech Other ASR mode (English -> English text)
libri_speech_other = LibriSpeechDataset(
processor=processor,
subset="other",
split=split
)
datasets.append(libri_speech_other)
# Fleurs ASR mode (English -> English text)
fleurs = FleursDataset(
processor=processor,
split=split,
source_lang="en_us", # English
mode="asr"
)
datasets.append(fleurs)
for dataset in datasets:
# ASR 평가
asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True)
print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR κ²°κ³Ό ===")
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}")
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}")
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}")
try:
print(f"\n===== {source_lang[0]} -> {target_lang[0]} λ²ˆμ—­ 평가 μ‹œμž‘ =====")
datasets = []
# Covost AST mode (English -> Korean text)
covost = CoVoSTDataset(
processor=processor,
data_dir="/workspace/CommonVoice/EN",
split=split,
ast=True,
lang=("en_ko", "Korean")
)
datasets.append(covost)
# Fleurs AST mode (English -> Korean text)
fleurs = FleursDataset(
processor=processor,
split=split,
source_lang="en_us", # English
target_lang="ko_kr", # Korean
mode="ast"
)
datasets.append(fleurs)
for dataset in datasets:
# λ²ˆμ—­ 평가
translation_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = False)
print(f"\n=== {translation_results.get('dataset', 'Dataset')} | {source_lang[0]} -> {target_lang[0]} λ²ˆμ—­ κ²°κ³Ό ===")
print(f"BLEU: {translation_results.get('metrics', {}).get('bleu', 'N/A')}")
print(f"WER: {translation_results.get('metrics', {}).get('wer', 'N/A')}")
print(f"CER: {translation_results.get('metrics', {}).get('cer', 'N/A')}")
except Exception as e:
error_info = {
"error": str(e),
"source_lang": source_lang[0],
"target_lang": target_lang[0],
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
error_file = os.path.join(results_dir, f"error_translation_{source_lang[0]}_to_{target_lang[0]}_global.json")
with open(error_file, 'w') as f:
json.dump(error_info, f, indent=2)
print(f"{source_lang[0]} -> {target_lang[0]} λ²ˆμ—­ 평가 쀑 였λ₯˜ λ°œμƒ: {str(e)}")
continue
print(f"\nλͺ¨λ“  평가가 μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€. κ²°κ³ΌλŠ” {results_dir} 디렉토리에 μ €μž₯λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")