gemma-3-4b-it-speech / examples /finetune_speech.py
junnei's picture
Upload finetune_speech.py
2c3732b verified
import argparse
import json
import os
from pathlib import Path
import numpy as np
import torch
import sacrebleu
from datasets import load_dataset
from torch.utils.data import Dataset, ConcatDataset
from tqdm import tqdm
from transformers import (
AutoProcessor,
AutoModel,
BatchFeature,
Trainer,
TrainingArguments,
StoppingCriteria,
StoppingCriteriaList,
)
from collections import defaultdict
import soundfile as sf
from datasets import Audio
import random
class BaseAudioDataset(Dataset):
def __init__(self, processor, split, sampling_rate=16000, debug=False):
self.processor = processor
self.training = "train" in split
self.debug = debug
self.sampling_rate = sampling_rate
self.name = ""
def set_dataset_name(self, name):
self.name = name
@staticmethod
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True):
original_size = len(data)
data = data.cast_column(audio_field, Audio(decode=False))
def identify_corrupted_files(example):
try:
sf.read(example[audio_field]["path"])
for field in text_fields:
if field in example and example[field].replace('"', '') == "":
return False
return True
except Exception:
return False
data = data.filter(identify_corrupted_files, num_proc=16)
validated_size = len(data)
# Audio Decoding
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True))
if debug:
print(f"Dataset: {dataset_name}")
print(f"Original data nums: {original_size}")
print(f"After filtering data nums: {validated_size}")
print(f"Filtering ratio: {validated_size/original_size:.2%}")
return data
@staticmethod
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True):
original_size = len(data)
def filter_audio_by_length(example):
try:
audio = example[audio_field]['array']
channel = 1
if hasattr(audio, 'ndim') and audio.ndim > 1:
channel = audio.ndim
audio = audio.squeeze()
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel
return min_sec <= audio_length <= max_sec
except Exception as e:
if debug:
print(f"Error : {str(e)[:100]}... - sample excluded")
return False
data = data.filter(filter_audio_by_length, num_proc=16)
filtered_size = len(data)
if debug:
print(f"Before Length Filtering data nums: {original_size}")
print(f"After Length Filtering data nums: {filtered_size}")
print(f"Filtering ratio: {filtered_size/original_size:.2%}")
return data
def prepare_model_inputs(self, audio_array, instruction, answer_text):
user_message = {
'role': 'user',
'content': '<start_of_audio>' + instruction,
}
prompt = self.processor.tokenizer.apply_chat_template(
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True
)
inputs = self.processor(
text=prompt,
audio=[audio_array],
add_special_tokens=False,
return_tensors='pt'
)
answer = f"{answer_text}{ANSWER_SUFFIX}"
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids
if self.debug:
self.debug = False
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR'
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else ""
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n")
print(f"INPUT_MODE: {inputs.input_modes[0].item()}")
if self.training:
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)
labels = torch.full_like(input_ids, _IGNORE_INDEX)
labels[:, -answer_ids.shape[1]:] = answer_ids
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1]))
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1)
else:
input_ids = inputs.input_ids
labels = answer_ids
token_type_ids = inputs.token_type_ids
return {
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'input_audio_embeds': inputs.input_audio_embeds,
'audio_embed_sizes': inputs.audio_embed_sizes,
'input_modes': inputs.input_modes,
}
# CoVoST2 Dataset Class
class CoVoSTDataset(BaseAudioDataset):
def __init__(self, processor, data_dir, split, ast=False,
lang=("en_ko", "Korean"), sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name("CoVoST")
self.ast = ast
self.lang = lang[0]
self.data = load_dataset("junnei/covost2",
lang[0],
data_dir=data_dir,
split=split,
trust_remote_code=True
)
text_fields = ["sentence", "translation"] if ast else ["sentence"]
self.data = self.filter_corrupted_files(self.data, "audio", text_fields, "CoVoST")
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# Instruction Setting
self.instruction = random.choice(INSTRUCTION["ast"]).format(lang[1]) if ast else random.choice(INSTRUCTION["asr"])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
if self.ast:
answer_text = data["translation"]
else:
answer_text = data["sentence"].replace('"', '')
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
# Zeroth Korean Dataset Class
class ZerothKoreanDataset(BaseAudioDataset):
def __init__(self, processor, split, sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name("Zeroth")
# only ASR
self.ast = False
self.lang = "ko"
# load dataset
self.data = load_dataset("Bingsu/zeroth-korean",
split=split,
trust_remote_code=True
)
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# Instruction Setting
self.instruction = random.choice(INSTRUCTION["asr"])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
# Zeroth Korean is only for ASR
answer_text = data["text"].replace('"', '')
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
# Libri Speech Dataset Class
class LibriSpeechDataset(BaseAudioDataset):
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name(f"LibriSpeech_{subset}")
# only ASR
self.ast = False
self.lang = "en"
# load dataset
self.data = load_dataset("fixie-ai/librispeech_asr",
subset,
split=split,
trust_remote_code=True
)
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# Instruction Setting
self.instruction = random.choice(INSTRUCTION["asr"])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
# Libri Speech is only for ASR
answer_text = data["text"].replace('"', '')
return self.prepare_model_inputs(
data["audio"]["array"],
self.instruction,
answer_text
)
# Fleurs Dataset Class
class FleursDataset(BaseAudioDataset):
def __init__(self, processor, split, source_lang, target_lang=None,
mode="asr", sampling_rate=16000, debug=False):
super().__init__(processor, split, sampling_rate, debug)
self.set_dataset_name("Fleurs")
# Mode Setting (ASR or AST)
if mode not in ["asr", "ast"]:
raise ValueError("mode must be 'asr' or 'ast'.")
self.mode = mode
self.ast = (mode == "ast")
self.source_lang = source_lang
# Language name mapping (expand if needed)
self.lang_names = {
'en_us': 'English', 'ko_kr': 'Korean'
}
# load dataset - source language dataset
self.data = load_dataset("google/fleurs",
source_lang,
split=split,
trust_remote_code=True
)
# (Optional) Audio length Filtering
self.data = self.filter_by_audio_length(self.data, "audio")
# When AST mode, load target language dataset.
if self.ast:
if target_lang is None:
raise ValueError("AST mode requires target_lang.")
self.target_lang = target_lang
self.lang = f"{source_lang}_{target_lang}"
# load dataset - target language dataset (for translation)
target_data = load_dataset("google/fleurs",
target_lang,
split=split,
trust_remote_code=True
)
source_dict = {item['id']: item for item in self.data}
target_dict = {item['id']: item for item in target_data}
# only Common ID, add translation fields
common_ids = set(source_dict.keys()) & set(target_dict.keys())
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}")
self.data = [
{**source_dict[id], 'translation': target_dict[id]['transcription']}
for id in common_ids
]
# Instruction Setting - use target language name
target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize())
self.instruction = random.choice(INSTRUCTION["ast"]).format(target_lang_name)
else:
# ASR mode
self.lang = source_lang
self.instruction = random.choice(INSTRUCTION["asr"])
if self.debug:
print(f"FLEURS dataset loaded: {self.mode.upper()} mode")
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})")
if self.ast:
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})")
print(f"dataset size: {len(self.data)}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
audio_array = data["audio"]["array"]
if self.ast:
answer_text = data["translation"]
else:
answer_text = data["transcription"]
return self.prepare_model_inputs(
audio_array,
self.instruction,
answer_text
)
def covost_collate_fn(batch):
input_ids_list = []
labels_list = []
token_type_ids_list = []
input_audio_embeds_list = []
audio_embed_sizes_list = []
audio_attention_mask_list = []
input_modes_list = []
for inputs in batch:
input_ids_list.append(inputs['input_ids'][0])
labels_list.append(inputs['labels'][0])
token_type_ids_list.append(inputs['token_type_ids'][0])
input_audio_embeds_list.append(inputs['input_audio_embeds'])
audio_embed_sizes_list.append(inputs['audio_embed_sizes'])
audio_attention_mask_list.append(
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool)
)
input_modes_list.append(inputs['input_modes'])
try:
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0)
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0)
labels = pad_sequence(labels_list, padding_side='left', padding_value=0)
audio_attention_mask = (
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False)
if len(audio_attention_mask_list) > 1
else None
)
except Exception as e:
print(e)
print(input_ids_list)
print(labels_list)
raise
attention_mask = (input_ids != 0).long()
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0)
audio_embed_sizes = torch.cat(audio_embed_sizes_list)
input_modes = torch.cat(input_modes_list)
return BatchFeature(
{
'input_ids': input_ids,
'labels': labels,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'input_audio_embeds': input_audio_embeds,
'audio_embed_sizes': audio_embed_sizes,
'audio_attention_mask': audio_attention_mask,
'input_modes': input_modes,
}
)
def pad_sequence(sequences, padding_side='left', padding_value=0):
"""
Pad a list of sequences to the same length.
sequences: list of tensors in [seq_len, *] shape
"""
assert padding_side in ['right', 'left']
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max(len(seq) for seq in sequences)
batch_size = len(sequences)
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value)
for i, seq in enumerate(sequences):
length = seq.size(0)
if padding_side == 'right':
output.data[i, :length] = seq
else:
output.data[i, -length:] = seq
return output
def cat_with_pad(tensors, dim, padding_value=0):
"""
cat along dim, while pad to max for all other dims
"""
ndim = tensors[0].dim()
assert all(
t.dim() == ndim for t in tensors[1:]
), 'All tensors must have the same number of dimensions'
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
out_size[dim] = sum(t.shape[dim] for t in tensors)
output = tensors[0].new_full(out_size, padding_value)
index = 0
for t in tensors:
# Create a slice list where every dimension except dim is full slice
slices = [slice(0, t.shape[d]) for d in range(ndim)]
# Update only the concat dimension slice
slices[dim] = slice(index, index + t.shape[dim])
output[slices] = t
index += t.shape[dim]
return output
def count_parameters_by_module(model):
# dictionary for parameters number by modules
module_params = defaultdict(lambda: {"total": 0, "trainable": 0})
# all params
total_params = 0
total_trainable_params = 0
# Check Embedding Token masks
embedding_masks = {}
for name, param in model.named_parameters():
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks:
# check if params has embedding_grad_mask_hook
for hook_id, hook_fn in param._backward_hooks.items():
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook':
# Accessing mask variables in the closure of hook functions
for cell in hook_fn.__closure__ or []:
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool:
# check mask tensor
embedding_masks[name] = ~cell.cell_contents # True : Trainable
# Count params by modules
for name, param in model.named_parameters():
# extracts top module_name
module_name = name.split('.')[0]
param_count = param.numel()
module_params[module_name]["total"] += param_count
total_params += param_count
if param.requires_grad:
# Only count for real trainable params. (with masks)
if name in embedding_masks:
trainable_count = embedding_masks[name].sum().item()
module_params[module_name]["trainable"] += trainable_count
total_trainable_params += trainable_count
else:
module_params[module_name]["trainable"] += param_count
total_trainable_params += param_count
print(f"All Params: {total_params:,}")
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)")
print("\nParams by Module:")
for module_name, counts in sorted(module_params.items()):
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0
total_percentage = counts["total"] / total_params * 100
print(f"- {module_name}:")
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)")
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)")
return module_params
def create_model(model_name_or_path, revision="main", use_flash_attention = False):
model = AutoModel.from_pretrained(
model_name_or_path,
revision=revision,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2" if use_flash_attention else "eager",
trust_remote_code=True,
)
# Set use_cache to False after model loaded
model.config.use_cache = False
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
model.set_lora_adapter('speech')
model.to(torch.bfloat16)
# (Optional) unfreeze audio_tower parameters
#for param in model.audio_tower.parameters():
# param.requires_grad = True
# Only unfreeze audio_projector parameters
for param in model.audio_projector.parameters():
param.requires_grad = True
# (Optional) unfreeze audio embed_tokens
train_embed = True
if train_embed:
embed_tokens = model.language_model.model.model.embed_tokens
embed_tokens.weight.requires_grad = False
# Added Speech token IDs (only this tokens be trainable)
trainable_token_ids = [256001, 256002]
embed_tokens.weight.requires_grad = True
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool)
mask[trainable_token_ids] = False # Trainable Tokens are False (unfreeze), else True (freeze)
# backward hook, with gradient masking
def embedding_grad_mask_hook(grad):
return grad.masked_fill(mask, 0)
embed_tokens.weight.register_hook(embedding_grad_mask_hook)
model.language_model.model.model.embed_tokens = embed_tokens
count_parameters_by_module(model)
return model
os.environ["TOKENIZERS_PARALLELISM"] = "false"
INSTRUCTION = {
"ast": [
"Translate the audio to {0}.",
"Translate the audio clip into {0}.",
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.",
"Translate the provided audio file into {0}.",
"Convert the audio speech to {0} text.",
"Write an {0} translation of the audio file.",
"Translate spoken words from the audio into {0}.",
"Create an {0} version of the audio content.",
"Produce an accurate {0} translation of the audio.",
"Extract speech from the audio and translate it to {0}.",
"Turn the audio into readable {0} text.",
"Write all spoken content from the audio in {0}.",
"Generate an {0} translation of the speech in the file.",
"Convert the recording into {0} text.",
"Accurately translate the audio recording to {0}.",
"Write down dialogue from the given audio in {0}.",
"Translate all speech in this audio file to {0}.",
"Create an accurate {0} version of the speech.",
"Perform a complete {0} translation of the audio."
],
"asr": [
"Transcribe the audio clip into text.",
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.",
"Transcribe the provided audio file into text.",
"Convert the audio speech to text.",
"Write a transcript of the audio file.",
"Transcribe spoken words from the audio.",
"Create a text version of the audio content.",
"Produce a verbatim transcript of the audio.",
"Extract and transcribe speech from the audio.",
"Turn the audio into readable text.",
"Write all spoken words from the audio.",
"Generate a transcript of the speech in the file.",
"Convert the recording into a text transcript.",
"Accurately transcribe the audio recording.",
"Write down dialogue from the given audio.",
"Transcribe all speech in this audio file.",
"Create an accurate text version of the speech.",
"Perform a complete transcription of the audio."
],
}
ANSWER_SUFFIX = "<end_of_turn>"
_IGNORE_INDEX = -100
model_name_or_path = 'junnei/gemma-3-4b-it-speech'
use_flash_attention = True
output_dir = '/workspace/output'
batch_size = 128
batch_size_per_gpu = 32
learning_rate = 4.0e-5 # 1.0e-4 for fine-tuning
wd = 0.01
num_train_epochs = 5
revision = "main" #"v1.0"
processor = AutoProcessor.from_pretrained(
model_name_or_path,
revision=revision,
trust_remote_code=True,
)
model = create_model(
model_name_or_path,
revision=revision,
use_flash_attention=use_flash_attention,
)
train_datasets = []
# Covost ASR mode (English -> English text)
covost_asr_dataset = CoVoSTDataset(
processor=processor,
data_dir="/workspace/CommonVoice/EN",
split="train",
ast=False,
lang=("en_ko", "Korean")
)
train_datasets.append(covost_asr_dataset)
# Covost AST mode (English -> Korean text)
covost_dataset = CoVoSTDataset(
processor=processor,
data_dir="/workspace/CommonVoice/EN",
split="train",
ast=True,
lang=("en_ko", "Korean")
)
train_datasets.append(covost_dataset)
# Libri Speech Clean ASR mode (English -> English text)
libri_speech_clean = LibriSpeechDataset(
processor=processor,
subset="clean",
split="train.360"
)
train_datasets.append(libri_speech_clean)
# Libri Speech Other ASR mode (English -> English text)
libri_speech_other = LibriSpeechDataset(
processor=processor,
subset="other",
split="train.500"
)
train_datasets.append(libri_speech_other)
# Fleurs ASR mode (English -> English text)
en_asr_fleurs = FleursDataset(
processor=processor,
split="train",
source_lang="en_us", # English
mode="asr"
)
train_datasets.append(en_asr_fleurs)
# Fleurs AST mode (English -> Korean text)
en_ko_ast_fleurs = FleursDataset(
processor=processor,
split="train",
source_lang="en_us", # English
target_lang="ko_kr", # Korean
mode="ast"
)
train_datasets.append(en_ko_ast_fleurs)
# Covost ASR mode (Korean -> Korean text)
covost_ko_asr_dataset = CoVoSTDataset(
processor=processor,
data_dir="/workspace/CommonVoice/ko",
split="train",
ast=False,
lang=("ko_en", "English")
)
train_datasets.append(covost_ko_asr_dataset)
# Covost AST mode (Korean -> English text)
covost_ko_dataset = CoVoSTDataset(
processor=processor,
data_dir="/workspace/CommonVoice/ko",
split="train",
ast=True,
lang=("ko_en", "English")
)
train_datasets.append(covost_ko_dataset)
# Zeroth ASR mode (Korean -> Korean text)
ko_asr_zeroth = ZerothKoreanDataset(
processor=processor,
split="train"
)
train_datasets.append(ko_asr_zeroth)
# Fleurs ASR mode (Korean -> Korean text)
ko_asr_fleurs = FleursDataset(
processor=processor,
split="train",
source_lang="ko_kr", # Korean
mode="asr"
)
train_datasets.append(ko_asr_fleurs)
# Fleurs AST mode (Korean -> English text)
ko_en_ast_fleurs = FleursDataset(
processor=processor,
split="train",
source_lang="ko_kr", # Korean
target_lang="en_us", # English
mode="ast"
)
train_datasets.append(ko_en_ast_fleurs)
print("Count Num of Datasets", len(train_datasets))
print([len(dataset) for dataset in train_datasets])
# ConcatDataset
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
print("Count Length of Datas", len(train_dataset))
# Check GPUs
num_gpus = torch.cuda.device_count()
print(f'training on {num_gpus} GPUs')
assert (
batch_size % (num_gpus * batch_size_per_gpu) == 0
), 'Batch size must be divisible by the number of GPUs'
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu)
# hard coded training args
training_args = TrainingArguments(
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size_per_gpu,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant': False},
gradient_accumulation_steps=gradient_accumulation_steps,
optim='adamw_torch',
adam_beta1=0.9,
adam_beta2=0.95,
adam_epsilon=1e-7,
learning_rate=learning_rate,
weight_decay=wd,
max_grad_norm=1.0,
lr_scheduler_type='cosine',
warmup_steps=50,
logging_steps=50,
output_dir=output_dir,
save_strategy='no',
save_total_limit=10,
save_only_model=True,
bf16=True,
fp16=False,
remove_unused_columns=False,
report_to='none',
deepspeed=None,
disable_tqdm=False,
dataloader_num_workers=4,
ddp_find_unused_parameters=True,
)
out_path = Path(training_args.output_dir)
out_path.mkdir(parents=True, exist_ok=True)
# create optimizer only for trainable params
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
weight_decay=wd,
betas=(0.9, 0.95),
eps=1e-7,
)
# Trainer Setting
trainer = Trainer(
model=model,
args=training_args,
data_collator=covost_collate_fn,
train_dataset=train_dataset,
optimizers=(optimizer, None),
)
trainer.train()
import shutil
# 1. Save LoRA Adapter
model.language_model.model.save_pretrained(output_dir)
# 1-1. Delete Markdown file
markdown_file = os.path.join(output_dir, "README.md")
if os.path.exists(markdown_file):
os.remove(markdown_file)
# 2. Save entire model
model.save_pretrained(output_dir)
# 3. Cleanup Memory
del model
del trainer
__import__('gc').collect()
torch.cuda.empty_cache()
from huggingface_hub import HfApi, login, create_repo, Repository, upload_folder
upload_dir = "/workspace/upload"
# 4. Clone Repo
repo_id = "junnei/gemma-3-4b-it-speech"
branch_name = "main" # 새 브랜치 이름
repo = Repository(local_dir=upload_dir, clone_from = repo_id)
repo.git_checkout(branch_name, create_branch_ok=True)
# 4-1. Move Trained model to Repo
for item in os.listdir(output_dir):
s = os.path.join(output_dir, item)
d = os.path.join(upload_dir, item)
if os.path.isdir(s):
shutil.copytree(s, d, dirs_exist_ok=True)
else:
shutil.copy2(s, d)