|
from io import BytesIO |
|
from urllib.request import urlopen |
|
import soundfile |
|
import torch |
|
from datasets import load_dataset, Audio |
|
import numpy as np |
|
from transformers import AutoModel, AutoProcessor, BatchFeature |
|
from tqdm import tqdm |
|
import json |
|
import os |
|
import time |
|
from datetime import datetime |
|
from whisper_normalizer.english import EnglishTextNormalizer |
|
from whisper_normalizer.basic import BasicTextNormalizer |
|
import sacrebleu |
|
from jiwer import cer, wer |
|
from torch.utils.data import Dataset, DataLoader |
|
import soundfile as sf |
|
import re |
|
|
|
normalizer = { |
|
"en_us" : EnglishTextNormalizer(), |
|
"ko_kr" : BasicTextNormalizer() |
|
} |
|
|
|
|
|
model_id = "junnei/gemma-3-4b-it-speech" |
|
revision = "main" |
|
|
|
model = AutoModel.from_pretrained( |
|
model_id, device_map="auto", revision = revision, trust_remote_code=True |
|
).eval() |
|
|
|
processor = AutoProcessor.from_pretrained( |
|
model_id, revision = revision, trust_remote_code=True |
|
) |
|
|
|
|
|
results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
os.makedirs(results_dir, exist_ok=True) |
|
|
|
|
|
INSTRUCTION = { |
|
"ast": "Translate the audio to {0}.", |
|
"asr": "Transcribe the audio clip into text.", |
|
} |
|
|
|
class BaseAudioDataset(Dataset): |
|
def __init__(self, processor, split, sampling_rate=16000, debug=False): |
|
self.processor = processor |
|
self.training = "train" in split |
|
self.debug = debug |
|
self.sampling_rate = sampling_rate |
|
self.name = "" |
|
|
|
def set_dataset_name(self, name): |
|
self.name = name |
|
|
|
@staticmethod |
|
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True): |
|
original_size = len(data) |
|
|
|
data = data.cast_column(audio_field, Audio(decode=False)) |
|
|
|
def identify_corrupted_files(example): |
|
try: |
|
sf.read(example[audio_field]["path"]) |
|
|
|
for field in text_fields: |
|
if example[field].replace('"', '') == "": |
|
return False |
|
return True |
|
except Exception: |
|
return False |
|
|
|
data = data.filter(identify_corrupted_files, num_proc=16) |
|
validated_size = len(data) |
|
|
|
|
|
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True)) |
|
|
|
if debug: |
|
print(f"λ°μ΄ν°μ
: {dataset_name}") |
|
print(f"μλ³Έ λ°μ΄ν° κ°μ: {original_size}") |
|
print(f"νν°λ§ ν λ°μ΄ν° κ°μ: {validated_size}") |
|
print(f"νν°λ§ λΉμ¨: {validated_size/original_size:.2%}") |
|
|
|
return data |
|
|
|
@staticmethod |
|
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True): |
|
original_size = len(data) |
|
|
|
def filter_audio_by_length(example): |
|
try: |
|
audio = example[audio_field]['array'] |
|
channel = 1 |
|
if hasattr(audio, 'ndim') and audio.ndim > 1: |
|
channel = audio.ndim |
|
audio = audio.squeeze() |
|
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel |
|
return min_sec <= audio_length <= max_sec |
|
except Exception as e: |
|
if debug: |
|
print(f"μ€λ₯ λ°μ: {str(e)[:100]}... - μν μ μΈλ¨") |
|
return False |
|
|
|
data = data.filter(filter_audio_by_length, num_proc=16) |
|
filtered_size = len(data) |
|
|
|
if debug: |
|
print(f"κΈΈμ΄ νν°λ§ μ λ°μ΄ν° κ°μ: {original_size}") |
|
print(f"κΈΈμ΄ νν°λ§ ν λ°μ΄ν° κ°μ: {filtered_size}") |
|
print(f"νν°λ§ λΉμ¨: {filtered_size/original_size:.2%}") |
|
|
|
return data |
|
|
|
def prepare_model_inputs(self, audio_array, instruction, answer_text): |
|
user_message = { |
|
'role': 'user', |
|
'content': '<start_of_audio>' + instruction, |
|
} |
|
prompt = self.processor.tokenizer.apply_chat_template( |
|
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True |
|
) |
|
|
|
inputs = self.processor( |
|
text=prompt, |
|
audio=[audio_array], |
|
add_special_tokens=False, |
|
return_tensors='pt' |
|
) |
|
|
|
input_ids = inputs.input_ids |
|
token_type_ids = inputs.token_type_ids |
|
|
|
return { |
|
'input_ids': input_ids, |
|
'token_type_ids': token_type_ids, |
|
'input_audio_embeds': inputs.input_audio_embeds, |
|
'audio_embed_sizes': inputs.audio_embed_sizes, |
|
'input_modes': inputs.input_modes, |
|
'answer': answer_text, |
|
} |
|
|
|
|
|
class CoVoSTDataset(BaseAudioDataset): |
|
def __init__(self, processor, data_dir, split, ast=False, |
|
lang=("en_ko", "Korean"), sampling_rate=16000, debug=False): |
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
self.set_dataset_name("CoVoST") |
|
|
|
self.ast = ast |
|
self.lang = lang[0] |
|
|
|
self.data = load_dataset("junnei/covost2", |
|
lang[0], |
|
data_dir=data_dir, |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
text_fields = ["sentence", "translation"] if ast else ["sentence"] |
|
self.data = self.filter_corrupted_files(self.data, "audio", text_fields, "CoVoST") |
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
self.instruction = INSTRUCTION["ast"].format(lang[1]) if ast else INSTRUCTION["asr"] |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
|
|
if self.ast: |
|
answer_text = data["translation"] |
|
else: |
|
answer_text = data["sentence"].replace('"', '') |
|
|
|
return self.prepare_model_inputs( |
|
data["audio"]["array"], |
|
self.instruction, |
|
answer_text |
|
) |
|
|
|
|
|
|
|
class LibriSpeechDataset(BaseAudioDataset): |
|
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False): |
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
self.set_dataset_name(f"LibriSpeech_{subset}") |
|
|
|
|
|
self.ast = False |
|
self.lang = "en" |
|
|
|
if split == "train": |
|
split = "train.360" |
|
|
|
|
|
self.data = load_dataset("fixie-ai/librispeech_asr", |
|
subset, |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
self.instruction = INSTRUCTION["asr"] |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
|
|
|
|
answer_text = data["text"].replace('"', '') |
|
|
|
return self.prepare_model_inputs( |
|
data["audio"]["array"], |
|
self.instruction, |
|
answer_text |
|
) |
|
|
|
|
|
class FleursDataset(BaseAudioDataset): |
|
def __init__(self, processor, split, source_lang, target_lang=None, |
|
mode="asr", sampling_rate=16000, debug=False): |
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
self.set_dataset_name("Fleurs") |
|
|
|
|
|
if mode not in ["asr", "ast"]: |
|
raise ValueError("mode must be 'asr' or 'ast'.") |
|
|
|
self.mode = mode |
|
self.ast = (mode == "ast") |
|
self.source_lang = source_lang |
|
|
|
|
|
self.lang_names = { |
|
'en_us': 'English', 'ko_kr': 'Korean' |
|
} |
|
|
|
|
|
self.data = load_dataset("google/fleurs", |
|
source_lang, |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
if self.ast: |
|
if target_lang is None: |
|
raise ValueError("AST mode requires target_lang.") |
|
|
|
self.target_lang = target_lang |
|
self.lang = f"{source_lang}_{target_lang}" |
|
|
|
|
|
target_data = load_dataset("google/fleurs", |
|
target_lang, |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
source_dict = {item['id']: item for item in self.data} |
|
target_dict = {item['id']: item for item in target_data} |
|
|
|
|
|
common_ids = set(source_dict.keys()) & set(target_dict.keys()) |
|
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}") |
|
self.data = [ |
|
{**source_dict[id], 'translation': target_dict[id]['transcription']} |
|
for id in common_ids |
|
] |
|
|
|
|
|
target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize()) |
|
self.instruction = INSTRUCTION["ast"].format(target_lang_name) |
|
else: |
|
|
|
self.lang = source_lang |
|
self.instruction = INSTRUCTION["asr"] |
|
|
|
if self.debug: |
|
print(f"FLEURS dataset loaded: {self.mode.upper()} mode") |
|
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})") |
|
if self.ast: |
|
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})") |
|
print(f"dataset size: {len(self.data)}") |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
audio_array = data["audio"]["array"] |
|
|
|
if self.ast: |
|
answer_text = data["translation"] |
|
else: |
|
answer_text = data["transcription"] |
|
|
|
return self.prepare_model_inputs( |
|
audio_array, |
|
self.instruction, |
|
answer_text |
|
) |
|
|
|
def pad_sequence(sequences, padding_side='left', padding_value=0): |
|
""" |
|
Pad a list of sequences to the same length. |
|
sequences: list of tensors in [seq_len, *] shape |
|
""" |
|
assert padding_side in ['right', 'left'] |
|
max_size = sequences[0].size() |
|
trailing_dims = max_size[1:] |
|
max_len = max(len(seq) for seq in sequences) |
|
batch_size = len(sequences) |
|
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) |
|
for i, seq in enumerate(sequences): |
|
length = seq.size(0) |
|
if padding_side == 'right': |
|
output.data[i, :length] = seq |
|
else: |
|
output.data[i, -length:] = seq |
|
return output |
|
|
|
def cat_with_pad(tensors, dim, padding_value=0): |
|
""" |
|
cat along dim, while pad to max for all other dims |
|
""" |
|
ndim = tensors[0].dim() |
|
assert all( |
|
t.dim() == ndim for t in tensors[1:] |
|
), 'All tensors must have the same number of dimensions' |
|
|
|
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] |
|
out_size[dim] = sum(t.shape[dim] for t in tensors) |
|
output = tensors[0].new_full(out_size, padding_value) |
|
|
|
index = 0 |
|
for t in tensors: |
|
|
|
slices = [slice(0, t.shape[d]) for d in range(ndim)] |
|
|
|
slices[dim] = slice(index, index + t.shape[dim]) |
|
|
|
output[slices] = t |
|
index += t.shape[dim] |
|
|
|
return output |
|
|
|
def covost_collate_fn(batch): |
|
input_ids_list = [] |
|
input_audio_embeds_list = [] |
|
audio_embed_sizes_list = [] |
|
audio_attention_mask_list = [] |
|
input_modes_list = [] |
|
answer_list = [] |
|
for inputs in batch: |
|
input_ids_list.append(inputs['input_ids'][0]) |
|
input_audio_embeds_list.append(inputs['input_audio_embeds']) |
|
audio_embed_sizes_list.append(inputs['audio_embed_sizes']) |
|
audio_attention_mask_list.append( |
|
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) |
|
) |
|
input_modes_list.append(inputs['input_modes']) |
|
answer_list.append(inputs['answer']) |
|
|
|
try: |
|
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) |
|
audio_attention_mask = ( |
|
pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False) |
|
if len(audio_attention_mask_list) > 1 |
|
else None |
|
) |
|
except Exception as e: |
|
print(e) |
|
print(input_ids_list) |
|
print(audio_attention_mask) |
|
raise |
|
attention_mask = (input_ids != 0).long() |
|
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) |
|
audio_embed_sizes = torch.cat(audio_embed_sizes_list) |
|
input_modes = torch.cat(input_modes_list) |
|
|
|
return BatchFeature( |
|
{ |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask, |
|
'input_audio_embeds': input_audio_embeds, |
|
'audio_embed_sizes': audio_embed_sizes, |
|
'audio_attention_mask': audio_attention_mask, |
|
'input_modes': input_modes, |
|
'answer': answer_list, |
|
} |
|
) |
|
|
|
def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None): |
|
"""κ²°κ³Όλ₯Ό JSON νμΌλ‘ μ μ₯""" |
|
filename = f"{task}_{dataset_name}_{source_lang}" |
|
if target_lang: |
|
filename += f"_to_{target_lang}" |
|
if sample_idx is not None: |
|
filename += f"_sample_{sample_idx}" |
|
|
|
filepath = os.path.join(results_dir, f"{filename}.json") |
|
|
|
|
|
results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
|
print(f"κ²°κ³Όκ° {filepath}μ μ μ₯λμμ΅λλ€.") |
|
return filepath |
|
|
|
def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 32, is_asr=True): |
|
"""ASR(μλ μμ± μΈμ) μ±λ₯ νκ°""" |
|
task_type = "asr" if is_asr else "translation" |
|
eval_lang = source_lang if is_asr else target_lang |
|
eval_normalizer = normalizer[eval_lang] |
|
sample_results = [] |
|
|
|
|
|
if num_samples > 0 and num_samples < len(dataset): |
|
indices = np.random.choice(len(dataset), num_samples, replace=False) |
|
dataset = dataset.select(indices) |
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn) |
|
|
|
evaluated_samples = {} |
|
|
|
|
|
for batch_idx, batch in enumerate(tqdm(dataloader)): |
|
batch_references = batch.pop("answer") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
batch = {k: v.to("cuda") for k, v in batch.items()} |
|
|
|
|
|
with torch.inference_mode(): |
|
generate_ids = model.generate(**batch, |
|
max_new_tokens=256, |
|
|
|
) |
|
|
|
input_lengths = batch['input_ids'].shape[1] |
|
generate_ids = generate_ids[:, input_lengths:] |
|
|
|
|
|
batch_predictions = processor.batch_decode( |
|
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
) |
|
|
|
|
|
for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)): |
|
idx = batch_idx * batch_size + i |
|
sample_result = { |
|
"id": idx, |
|
"reference": reference, |
|
"prediction": prediction |
|
} |
|
sample_results.append(sample_result) |
|
|
|
|
|
if (batch_idx + 1) % 10 == 0: |
|
temp_results = [] |
|
|
|
|
|
for item in sample_results: |
|
sample_id = item["id"] |
|
|
|
|
|
if sample_id in evaluated_samples: |
|
temp_item = item.copy() |
|
temp_item.update(evaluated_samples[sample_id]) |
|
temp_results.append(temp_item) |
|
else: |
|
|
|
temp_item = item.copy() |
|
try: |
|
ref = eval_normalizer(item["reference"]) |
|
pred = eval_normalizer(item["prediction"]) |
|
|
|
|
|
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score |
|
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) |
|
utt_wer = round(wer(ref, pred) * 100, 2) |
|
|
|
metrics = { |
|
"bleu": utt_bleu, |
|
"cer": utt_cer, |
|
"wer": utt_wer |
|
} |
|
|
|
|
|
evaluated_samples[sample_id] = metrics |
|
temp_item.update(metrics) |
|
except Exception as e: |
|
print(f"Error evaluating sample {sample_id}: {e}") |
|
|
|
metrics = { |
|
"bleu": 0, |
|
"cer": 100, |
|
"wer": 100, |
|
"error": str(e) |
|
} |
|
evaluated_samples[sample_id] = metrics |
|
temp_item.update(metrics) |
|
|
|
temp_results.append(temp_item) |
|
|
|
partial_results = { |
|
"task": task_type, |
|
"source_lang": source_lang, |
|
"target_lang": target_lang, |
|
"num_samples": len(temp_results), |
|
"sample_results": temp_results |
|
} |
|
save_results(partial_results, dataset.name, task_type, source_lang, target_lang) |
|
|
|
for item in sample_results: |
|
ref = eval_normalizer(item["reference"]) |
|
pred = eval_normalizer(item["prediction"]) |
|
|
|
|
|
utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score |
|
utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) |
|
utt_wer = round(wer(ref, pred) * 100, 2) |
|
|
|
item.update({ |
|
"bleu": utt_bleu, |
|
"cer": utt_cer, |
|
"wer": utt_wer |
|
}) |
|
|
|
avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results) |
|
avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results) |
|
avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results) |
|
|
|
results = { |
|
"dataset": dataset.name, |
|
"task": task_type, |
|
"source_lang": source_lang, |
|
"target_lang": target_lang, |
|
"num_samples": len(sample_results), |
|
"metrics": { |
|
"bleu": avg_bleu, |
|
"cer": avg_cer, |
|
"wer": avg_wer |
|
}, |
|
"sample_results": sample_results |
|
} |
|
|
|
|
|
save_results(results, dataset.name, task_type, source_lang, target_lang) |
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
source_languages = [ |
|
|
|
("en_us", "English"), |
|
] |
|
|
|
|
|
target_languages = [ |
|
|
|
("ko_kr", "Korean"), |
|
] |
|
|
|
data_dir = { |
|
|
|
"en_us" : "/workspace/CommonVoice/EN", |
|
} |
|
|
|
|
|
num_samples = -1 |
|
batch_size = 32 |
|
|
|
|
|
for source_lang, target_lang in zip(source_languages, target_languages): |
|
print(f"\n===== {source_lang[0]} ASR νκ° μμ =====") |
|
|
|
|
|
split = "test" |
|
|
|
datasets = [] |
|
|
|
|
|
covost = CoVoSTDataset( |
|
processor=processor, |
|
data_dir="/workspace/CommonVoice/EN", |
|
split=split, |
|
ast=False, |
|
lang=("en_ko", "Korean") |
|
) |
|
datasets.append(covost) |
|
|
|
|
|
libri_speech_clean = LibriSpeechDataset( |
|
processor=processor, |
|
subset="clean", |
|
split=split |
|
) |
|
datasets.append(libri_speech_clean) |
|
|
|
|
|
libri_speech_other = LibriSpeechDataset( |
|
processor=processor, |
|
subset="other", |
|
split=split |
|
) |
|
datasets.append(libri_speech_other) |
|
|
|
|
|
fleurs = FleursDataset( |
|
processor=processor, |
|
split=split, |
|
source_lang="en_us", |
|
mode="asr" |
|
) |
|
datasets.append(fleurs) |
|
|
|
for dataset in datasets: |
|
|
|
asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True) |
|
|
|
print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR κ²°κ³Ό ===") |
|
print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}") |
|
print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}") |
|
print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}") |
|
|
|
try: |
|
print(f"\n===== {source_lang[0]} -> {target_lang[0]} λ²μ νκ° μμ =====") |
|
|
|
datasets = [] |
|
|
|
|
|
covost = CoVoSTDataset( |
|
processor=processor, |
|
data_dir="/workspace/CommonVoice/EN", |
|
split=split, |
|
ast=True, |
|
lang=("en_ko", "Korean") |
|
) |
|
datasets.append(covost) |
|
|
|
|
|
fleurs = FleursDataset( |
|
processor=processor, |
|
split=split, |
|
source_lang="en_us", |
|
target_lang="ko_kr", |
|
mode="ast" |
|
) |
|
datasets.append(fleurs) |
|
|
|
for dataset in datasets: |
|
|
|
translation_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = False) |
|
|
|
print(f"\n=== {translation_results.get('dataset', 'Dataset')} | {source_lang[0]} -> {target_lang[0]} λ²μ κ²°κ³Ό ===") |
|
print(f"BLEU: {translation_results.get('metrics', {}).get('bleu', 'N/A')}") |
|
print(f"WER: {translation_results.get('metrics', {}).get('wer', 'N/A')}") |
|
print(f"CER: {translation_results.get('metrics', {}).get('cer', 'N/A')}") |
|
|
|
except Exception as e: |
|
error_info = { |
|
"error": str(e), |
|
"source_lang": source_lang[0], |
|
"target_lang": target_lang[0], |
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
} |
|
error_file = os.path.join(results_dir, f"error_translation_{source_lang[0]}_to_{target_lang[0]}_global.json") |
|
with open(error_file, 'w') as f: |
|
json.dump(error_info, f, indent=2) |
|
print(f"{source_lang[0]} -> {target_lang[0]} λ²μ νκ° μ€ μ€λ₯ λ°μ: {str(e)}") |
|
continue |
|
|
|
print(f"\nλͺ¨λ νκ°κ° μλ£λμμ΅λλ€. κ²°κ³Όλ {results_dir} λλ ν 리μ μ μ₯λμμ΅λλ€.") |
|
|