|
import argparse |
|
import json |
|
import os |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import sacrebleu |
|
|
|
from datasets import load_dataset |
|
from torch.utils.data import Dataset, ConcatDataset |
|
from tqdm import tqdm |
|
from transformers import ( |
|
AutoProcessor, |
|
AutoModel, |
|
BatchFeature, |
|
Trainer, |
|
TrainingArguments, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
) |
|
from collections import defaultdict |
|
|
|
import soundfile as sf |
|
from datasets import Audio |
|
import random |
|
|
|
class BaseAudioDataset(Dataset): |
|
def __init__(self, processor, split, sampling_rate=16000, debug=False): |
|
self.processor = processor |
|
self.training = "train" in split |
|
self.debug = debug |
|
self.sampling_rate = sampling_rate |
|
self.name = "" |
|
|
|
def set_dataset_name(self, name): |
|
self.name = name |
|
|
|
@staticmethod |
|
def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True): |
|
original_size = len(data) |
|
|
|
data = data.cast_column(audio_field, Audio(decode=False)) |
|
|
|
def identify_corrupted_files(example): |
|
try: |
|
sf.read(example[audio_field]["path"]) |
|
|
|
for field in text_fields: |
|
if field in example and example[field].replace('"', '') == "": |
|
return False |
|
return True |
|
except Exception: |
|
return False |
|
|
|
data = data.filter(identify_corrupted_files, num_proc=16) |
|
validated_size = len(data) |
|
|
|
|
|
data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True)) |
|
|
|
if debug: |
|
print(f"Dataset: {dataset_name}") |
|
print(f"Original data nums: {original_size}") |
|
print(f"After filtering data nums: {validated_size}") |
|
print(f"Filtering ratio: {validated_size/original_size:.2%}") |
|
|
|
return data |
|
|
|
@staticmethod |
|
def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True): |
|
original_size = len(data) |
|
|
|
def filter_audio_by_length(example): |
|
try: |
|
audio = example[audio_field]['array'] |
|
channel = 1 |
|
if hasattr(audio, 'ndim') and audio.ndim > 1: |
|
channel = audio.ndim |
|
audio = audio.squeeze() |
|
audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel |
|
return min_sec <= audio_length <= max_sec |
|
except Exception as e: |
|
if debug: |
|
print(f"Error : {str(e)[:100]}... - sample excluded") |
|
return False |
|
|
|
data = data.filter(filter_audio_by_length, num_proc=16) |
|
filtered_size = len(data) |
|
|
|
if debug: |
|
print(f"Before Length Filtering data nums: {original_size}") |
|
print(f"After Length Filtering data nums: {filtered_size}") |
|
print(f"Filtering ratio: {filtered_size/original_size:.2%}") |
|
|
|
return data |
|
|
|
def prepare_model_inputs(self, audio_array, instruction, answer_text): |
|
user_message = { |
|
'role': 'user', |
|
'content': '<start_of_audio>' + instruction, |
|
} |
|
prompt = self.processor.tokenizer.apply_chat_template( |
|
[user_message], tokenize=False, add_generation_prompt=True, add_bos=True |
|
) |
|
|
|
inputs = self.processor( |
|
text=prompt, |
|
audio=[audio_array], |
|
add_special_tokens=False, |
|
return_tensors='pt' |
|
) |
|
|
|
answer = f"{answer_text}{ANSWER_SUFFIX}" |
|
answer_ids = self.processor.tokenizer(answer, add_special_tokens=False, return_tensors='pt').input_ids |
|
|
|
if self.debug: |
|
self.debug = False |
|
task_type = 'AST' if hasattr(self, 'ast') and self.ast else 'ASR' |
|
lang_info = f" - {self.lang}" if hasattr(self, 'lang') else "" |
|
print(f"{task_type}{lang_info}\nPROMPT: {prompt}\nINPUT: {self.processor.decode(inputs.input_ids[0], skip_special_tokens=False)}\nANSWER: {self.processor.decode(answer_ids[0], skip_special_tokens=False)}\n") |
|
print(f"INPUT_MODE: {inputs.input_modes[0].item()}") |
|
|
|
if self.training: |
|
input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1) |
|
labels = torch.full_like(input_ids, _IGNORE_INDEX) |
|
labels[:, -answer_ids.shape[1]:] = answer_ids |
|
padding = torch.zeros((inputs.token_type_ids.shape[0], answer_ids.shape[1])) |
|
token_type_ids = torch.cat([inputs.token_type_ids, padding], dim=1) |
|
else: |
|
input_ids = inputs.input_ids |
|
labels = answer_ids |
|
token_type_ids = inputs.token_type_ids |
|
|
|
return { |
|
'input_ids': input_ids, |
|
'labels': labels, |
|
'token_type_ids': token_type_ids, |
|
'input_audio_embeds': inputs.input_audio_embeds, |
|
'audio_embed_sizes': inputs.audio_embed_sizes, |
|
'input_modes': inputs.input_modes, |
|
} |
|
|
|
|
|
class CoVoSTDataset(BaseAudioDataset): |
|
def __init__(self, processor, data_dir, split, ast=False, |
|
lang=("en_ko", "Korean"), sampling_rate=16000, debug=False): |
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
self.set_dataset_name("CoVoST") |
|
self.ast = ast |
|
self.lang = lang[0] |
|
|
|
self.data = load_dataset("junnei/covost2", |
|
lang[0], |
|
data_dir=data_dir, |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
text_fields = ["sentence", "translation"] if ast else ["sentence"] |
|
self.data = self.filter_corrupted_files(self.data, "audio", text_fields, "CoVoST") |
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
self.instruction = random.choice(INSTRUCTION["ast"]).format(lang[1]) if ast else random.choice(INSTRUCTION["asr"]) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
|
|
if self.ast: |
|
answer_text = data["translation"] |
|
else: |
|
answer_text = data["sentence"].replace('"', '') |
|
|
|
return self.prepare_model_inputs( |
|
data["audio"]["array"], |
|
self.instruction, |
|
answer_text |
|
) |
|
|
|
|
|
class ZerothKoreanDataset(BaseAudioDataset): |
|
def __init__(self, processor, split, sampling_rate=16000, debug=False): |
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
self.set_dataset_name("Zeroth") |
|
|
|
self.ast = False |
|
self.lang = "ko" |
|
|
|
|
|
self.data = load_dataset("Bingsu/zeroth-korean", |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
self.instruction = random.choice(INSTRUCTION["asr"]) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
|
|
|
|
answer_text = data["text"].replace('"', '') |
|
|
|
return self.prepare_model_inputs( |
|
data["audio"]["array"], |
|
self.instruction, |
|
answer_text |
|
) |
|
|
|
|
|
class LibriSpeechDataset(BaseAudioDataset): |
|
def __init__(self, processor, subset, split, sampling_rate=16000, debug=False): |
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
self.set_dataset_name(f"LibriSpeech_{subset}") |
|
|
|
self.ast = False |
|
self.lang = "en" |
|
|
|
|
|
self.data = load_dataset("fixie-ai/librispeech_asr", |
|
subset, |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
self.instruction = random.choice(INSTRUCTION["asr"]) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
|
|
|
|
answer_text = data["text"].replace('"', '') |
|
|
|
return self.prepare_model_inputs( |
|
data["audio"]["array"], |
|
self.instruction, |
|
answer_text |
|
) |
|
|
|
|
|
class FleursDataset(BaseAudioDataset): |
|
def __init__(self, processor, split, source_lang, target_lang=None, |
|
mode="asr", sampling_rate=16000, debug=False): |
|
super().__init__(processor, split, sampling_rate, debug) |
|
|
|
self.set_dataset_name("Fleurs") |
|
|
|
if mode not in ["asr", "ast"]: |
|
raise ValueError("mode must be 'asr' or 'ast'.") |
|
|
|
self.mode = mode |
|
self.ast = (mode == "ast") |
|
self.source_lang = source_lang |
|
|
|
|
|
self.lang_names = { |
|
'en_us': 'English', 'ko_kr': 'Korean' |
|
} |
|
|
|
|
|
self.data = load_dataset("google/fleurs", |
|
source_lang, |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.data = self.filter_by_audio_length(self.data, "audio") |
|
|
|
|
|
if self.ast: |
|
if target_lang is None: |
|
raise ValueError("AST mode requires target_lang.") |
|
|
|
self.target_lang = target_lang |
|
self.lang = f"{source_lang}_{target_lang}" |
|
|
|
|
|
target_data = load_dataset("google/fleurs", |
|
target_lang, |
|
split=split, |
|
trust_remote_code=True |
|
) |
|
|
|
source_dict = {item['id']: item for item in self.data} |
|
target_dict = {item['id']: item for item in target_data} |
|
|
|
|
|
common_ids = set(source_dict.keys()) & set(target_dict.keys()) |
|
print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}") |
|
self.data = [ |
|
{**source_dict[id], 'translation': target_dict[id]['transcription']} |
|
for id in common_ids |
|
] |
|
|
|
|
|
target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize()) |
|
self.instruction = random.choice(INSTRUCTION["ast"]).format(target_lang_name) |
|
else: |
|
|
|
self.lang = source_lang |
|
self.instruction = random.choice(INSTRUCTION["asr"]) |
|
|
|
if self.debug: |
|
print(f"FLEURS dataset loaded: {self.mode.upper()} mode") |
|
print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})") |
|
if self.ast: |
|
print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})") |
|
print(f"dataset size: {len(self.data)}") |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
audio_array = data["audio"]["array"] |
|
|
|
if self.ast: |
|
answer_text = data["translation"] |
|
else: |
|
answer_text = data["transcription"] |
|
|
|
return self.prepare_model_inputs( |
|
audio_array, |
|
self.instruction, |
|
answer_text |
|
) |
|
|
|
def covost_collate_fn(batch): |
|
input_ids_list = [] |
|
labels_list = [] |
|
token_type_ids_list = [] |
|
input_audio_embeds_list = [] |
|
audio_embed_sizes_list = [] |
|
audio_attention_mask_list = [] |
|
input_modes_list = [] |
|
for inputs in batch: |
|
input_ids_list.append(inputs['input_ids'][0]) |
|
labels_list.append(inputs['labels'][0]) |
|
token_type_ids_list.append(inputs['token_type_ids'][0]) |
|
input_audio_embeds_list.append(inputs['input_audio_embeds']) |
|
audio_embed_sizes_list.append(inputs['audio_embed_sizes']) |
|
audio_attention_mask_list.append( |
|
inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) |
|
) |
|
input_modes_list.append(inputs['input_modes']) |
|
|
|
try: |
|
token_type_ids = pad_sequence(token_type_ids_list, padding_side='left', padding_value=0) |
|
input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) |
|
labels = pad_sequence(labels_list, padding_side='left', padding_value=0) |
|
audio_attention_mask = ( |
|
pad_sequence(audio_attention_mask_list, padding_side='left', padding_value=False) |
|
if len(audio_attention_mask_list) > 1 |
|
else None |
|
) |
|
except Exception as e: |
|
print(e) |
|
print(input_ids_list) |
|
print(labels_list) |
|
raise |
|
attention_mask = (input_ids != 0).long() |
|
input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) |
|
audio_embed_sizes = torch.cat(audio_embed_sizes_list) |
|
input_modes = torch.cat(input_modes_list) |
|
|
|
return BatchFeature( |
|
{ |
|
'input_ids': input_ids, |
|
'labels': labels, |
|
'token_type_ids': token_type_ids, |
|
'attention_mask': attention_mask, |
|
'input_audio_embeds': input_audio_embeds, |
|
'audio_embed_sizes': audio_embed_sizes, |
|
'audio_attention_mask': audio_attention_mask, |
|
'input_modes': input_modes, |
|
} |
|
) |
|
|
|
def pad_sequence(sequences, padding_side='left', padding_value=0): |
|
""" |
|
Pad a list of sequences to the same length. |
|
sequences: list of tensors in [seq_len, *] shape |
|
""" |
|
assert padding_side in ['right', 'left'] |
|
max_size = sequences[0].size() |
|
trailing_dims = max_size[1:] |
|
max_len = max(len(seq) for seq in sequences) |
|
batch_size = len(sequences) |
|
output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) |
|
for i, seq in enumerate(sequences): |
|
length = seq.size(0) |
|
if padding_side == 'right': |
|
output.data[i, :length] = seq |
|
else: |
|
output.data[i, -length:] = seq |
|
return output |
|
|
|
def cat_with_pad(tensors, dim, padding_value=0): |
|
""" |
|
cat along dim, while pad to max for all other dims |
|
""" |
|
ndim = tensors[0].dim() |
|
assert all( |
|
t.dim() == ndim for t in tensors[1:] |
|
), 'All tensors must have the same number of dimensions' |
|
|
|
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] |
|
out_size[dim] = sum(t.shape[dim] for t in tensors) |
|
output = tensors[0].new_full(out_size, padding_value) |
|
|
|
index = 0 |
|
for t in tensors: |
|
|
|
slices = [slice(0, t.shape[d]) for d in range(ndim)] |
|
|
|
slices[dim] = slice(index, index + t.shape[dim]) |
|
|
|
output[slices] = t |
|
index += t.shape[dim] |
|
|
|
return output |
|
|
|
def count_parameters_by_module(model): |
|
|
|
module_params = defaultdict(lambda: {"total": 0, "trainable": 0}) |
|
|
|
|
|
total_params = 0 |
|
total_trainable_params = 0 |
|
|
|
|
|
embedding_masks = {} |
|
for name, param in model.named_parameters(): |
|
if 'embed_tokens.weight' in name and hasattr(param, '_backward_hooks') and param._backward_hooks: |
|
|
|
for hook_id, hook_fn in param._backward_hooks.items(): |
|
if hook_fn.__code__.co_name == 'embedding_grad_mask_hook': |
|
|
|
for cell in hook_fn.__closure__ or []: |
|
if isinstance(cell.cell_contents, torch.Tensor) and cell.cell_contents.dtype == torch.bool: |
|
|
|
embedding_masks[name] = ~cell.cell_contents |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
|
module_name = name.split('.')[0] |
|
param_count = param.numel() |
|
|
|
module_params[module_name]["total"] += param_count |
|
total_params += param_count |
|
|
|
if param.requires_grad: |
|
|
|
if name in embedding_masks: |
|
trainable_count = embedding_masks[name].sum().item() |
|
module_params[module_name]["trainable"] += trainable_count |
|
total_trainable_params += trainable_count |
|
else: |
|
module_params[module_name]["trainable"] += param_count |
|
total_trainable_params += param_count |
|
|
|
print(f"All Params: {total_params:,}") |
|
print(f"Trainable Params: {total_trainable_params:,} ({total_trainable_params/total_params*100:.2f}%)") |
|
print("\nParams by Module:") |
|
|
|
for module_name, counts in sorted(module_params.items()): |
|
trainable_percentage = counts["trainable"] / counts["total"] * 100 if counts["total"] > 0 else 0 |
|
total_percentage = counts["total"] / total_params * 100 |
|
|
|
print(f"- {module_name}:") |
|
print(f" Total: {counts['total']:,} ({total_percentage:.2f}% of model)") |
|
print(f" Trainable: {counts['trainable']:,} ({trainable_percentage:.2f}% of module)") |
|
|
|
return module_params |
|
|
|
def create_model(model_name_or_path, revision="main", use_flash_attention = False): |
|
model = AutoModel.from_pretrained( |
|
model_name_or_path, |
|
revision=revision, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
attn_implementation="flash_attention_2" if use_flash_attention else "eager", |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
model.config.use_cache = False |
|
|
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
model.set_lora_adapter('speech') |
|
model.to(torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
for param in model.audio_projector.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
train_embed = True |
|
if train_embed: |
|
embed_tokens = model.language_model.model.model.embed_tokens |
|
|
|
embed_tokens.weight.requires_grad = False |
|
|
|
|
|
trainable_token_ids = [256001, 256002] |
|
|
|
embed_tokens.weight.requires_grad = True |
|
mask = torch.ones_like(embed_tokens.weight, dtype=torch.bool) |
|
mask[trainable_token_ids] = False |
|
|
|
|
|
def embedding_grad_mask_hook(grad): |
|
return grad.masked_fill(mask, 0) |
|
|
|
embed_tokens.weight.register_hook(embedding_grad_mask_hook) |
|
|
|
model.language_model.model.model.embed_tokens = embed_tokens |
|
|
|
count_parameters_by_module(model) |
|
|
|
return model |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
INSTRUCTION = { |
|
"ast": [ |
|
"Translate the audio to {0}.", |
|
"Translate the audio clip into {0}.", |
|
"Based on the attached audio, generate a comprehensive {0} translation of the spoken content.", |
|
"Translate the provided audio file into {0}.", |
|
"Convert the audio speech to {0} text.", |
|
"Write an {0} translation of the audio file.", |
|
"Translate spoken words from the audio into {0}.", |
|
"Create an {0} version of the audio content.", |
|
"Produce an accurate {0} translation of the audio.", |
|
"Extract speech from the audio and translate it to {0}.", |
|
"Turn the audio into readable {0} text.", |
|
"Write all spoken content from the audio in {0}.", |
|
"Generate an {0} translation of the speech in the file.", |
|
"Convert the recording into {0} text.", |
|
"Accurately translate the audio recording to {0}.", |
|
"Write down dialogue from the given audio in {0}.", |
|
"Translate all speech in this audio file to {0}.", |
|
"Create an accurate {0} version of the speech.", |
|
"Perform a complete {0} translation of the audio." |
|
], |
|
"asr": [ |
|
"Transcribe the audio clip into text.", |
|
"Based on the attached audio, generate a comprehensive text transcription of the spoken content.", |
|
"Transcribe the provided audio file into text.", |
|
"Convert the audio speech to text.", |
|
"Write a transcript of the audio file.", |
|
"Transcribe spoken words from the audio.", |
|
"Create a text version of the audio content.", |
|
"Produce a verbatim transcript of the audio.", |
|
"Extract and transcribe speech from the audio.", |
|
"Turn the audio into readable text.", |
|
"Write all spoken words from the audio.", |
|
"Generate a transcript of the speech in the file.", |
|
"Convert the recording into a text transcript.", |
|
"Accurately transcribe the audio recording.", |
|
"Write down dialogue from the given audio.", |
|
"Transcribe all speech in this audio file.", |
|
"Create an accurate text version of the speech.", |
|
"Perform a complete transcription of the audio." |
|
], |
|
} |
|
|
|
ANSWER_SUFFIX = "<end_of_turn>" |
|
_IGNORE_INDEX = -100 |
|
|
|
model_name_or_path = 'junnei/gemma-3-4b-it-speech' |
|
use_flash_attention = True |
|
|
|
output_dir = '/workspace/output' |
|
batch_size = 128 |
|
batch_size_per_gpu = 32 |
|
learning_rate = 4.0e-5 |
|
wd = 0.01 |
|
num_train_epochs = 5 |
|
|
|
revision = "main" |
|
|
|
processor = AutoProcessor.from_pretrained( |
|
model_name_or_path, |
|
revision=revision, |
|
trust_remote_code=True, |
|
) |
|
|
|
model = create_model( |
|
model_name_or_path, |
|
revision=revision, |
|
use_flash_attention=use_flash_attention, |
|
) |
|
|
|
train_datasets = [] |
|
|
|
|
|
covost_asr_dataset = CoVoSTDataset( |
|
processor=processor, |
|
data_dir="/workspace/CommonVoice/EN", |
|
split="train", |
|
ast=False, |
|
lang=("en_ko", "Korean") |
|
) |
|
train_datasets.append(covost_asr_dataset) |
|
|
|
|
|
covost_dataset = CoVoSTDataset( |
|
processor=processor, |
|
data_dir="/workspace/CommonVoice/EN", |
|
split="train", |
|
ast=True, |
|
lang=("en_ko", "Korean") |
|
) |
|
train_datasets.append(covost_dataset) |
|
|
|
|
|
libri_speech_clean = LibriSpeechDataset( |
|
processor=processor, |
|
subset="clean", |
|
split="train.360" |
|
) |
|
train_datasets.append(libri_speech_clean) |
|
|
|
|
|
libri_speech_other = LibriSpeechDataset( |
|
processor=processor, |
|
subset="other", |
|
split="train.500" |
|
) |
|
train_datasets.append(libri_speech_other) |
|
|
|
|
|
en_asr_fleurs = FleursDataset( |
|
processor=processor, |
|
split="train", |
|
source_lang="en_us", |
|
mode="asr" |
|
) |
|
train_datasets.append(en_asr_fleurs) |
|
|
|
|
|
en_ko_ast_fleurs = FleursDataset( |
|
processor=processor, |
|
split="train", |
|
source_lang="en_us", |
|
target_lang="ko_kr", |
|
mode="ast" |
|
) |
|
train_datasets.append(en_ko_ast_fleurs) |
|
|
|
|
|
covost_ko_asr_dataset = CoVoSTDataset( |
|
processor=processor, |
|
data_dir="/workspace/CommonVoice/ko", |
|
split="train", |
|
ast=False, |
|
lang=("ko_en", "English") |
|
) |
|
train_datasets.append(covost_ko_asr_dataset) |
|
|
|
|
|
covost_ko_dataset = CoVoSTDataset( |
|
processor=processor, |
|
data_dir="/workspace/CommonVoice/ko", |
|
split="train", |
|
ast=True, |
|
lang=("ko_en", "English") |
|
) |
|
train_datasets.append(covost_ko_dataset) |
|
|
|
|
|
ko_asr_zeroth = ZerothKoreanDataset( |
|
processor=processor, |
|
split="train" |
|
) |
|
train_datasets.append(ko_asr_zeroth) |
|
|
|
|
|
ko_asr_fleurs = FleursDataset( |
|
processor=processor, |
|
split="train", |
|
source_lang="ko_kr", |
|
mode="asr" |
|
) |
|
train_datasets.append(ko_asr_fleurs) |
|
|
|
|
|
ko_en_ast_fleurs = FleursDataset( |
|
processor=processor, |
|
split="train", |
|
source_lang="ko_kr", |
|
target_lang="en_us", |
|
mode="ast" |
|
) |
|
train_datasets.append(ko_en_ast_fleurs) |
|
|
|
print("Count Num of Datasets", len(train_datasets)) |
|
print([len(dataset) for dataset in train_datasets]) |
|
|
|
|
|
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0] |
|
print("Count Length of Datas", len(train_dataset)) |
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
print(f'training on {num_gpus} GPUs') |
|
|
|
assert ( |
|
batch_size % (num_gpus * batch_size_per_gpu) == 0 |
|
), 'Batch size must be divisible by the number of GPUs' |
|
gradient_accumulation_steps = batch_size // (num_gpus * batch_size_per_gpu) |
|
|
|
|
|
training_args = TrainingArguments( |
|
num_train_epochs=num_train_epochs, |
|
per_device_train_batch_size=batch_size_per_gpu, |
|
gradient_checkpointing=True, |
|
gradient_checkpointing_kwargs={'use_reentrant': False}, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
optim='adamw_torch', |
|
adam_beta1=0.9, |
|
adam_beta2=0.95, |
|
adam_epsilon=1e-7, |
|
learning_rate=learning_rate, |
|
weight_decay=wd, |
|
max_grad_norm=1.0, |
|
lr_scheduler_type='cosine', |
|
warmup_steps=50, |
|
logging_steps=50, |
|
output_dir=output_dir, |
|
save_strategy='no', |
|
save_total_limit=10, |
|
save_only_model=True, |
|
bf16=True, |
|
fp16=False, |
|
remove_unused_columns=False, |
|
report_to='none', |
|
deepspeed=None, |
|
disable_tqdm=False, |
|
dataloader_num_workers=4, |
|
ddp_find_unused_parameters=True, |
|
) |
|
|
|
out_path = Path(training_args.output_dir) |
|
out_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
filter(lambda p: p.requires_grad, model.parameters()), |
|
lr=learning_rate, |
|
weight_decay=wd, |
|
betas=(0.9, 0.95), |
|
eps=1e-7, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=covost_collate_fn, |
|
train_dataset=train_dataset, |
|
optimizers=(optimizer, None), |
|
) |
|
|
|
trainer.train() |
|
|
|
import shutil |
|
|
|
|
|
model.language_model.model.save_pretrained(output_dir) |
|
|
|
|
|
markdown_file = os.path.join(output_dir, "README.md") |
|
if os.path.exists(markdown_file): |
|
os.remove(markdown_file) |
|
|
|
|
|
model.save_pretrained(output_dir) |
|
|
|
|
|
del model |
|
del trainer |
|
__import__('gc').collect() |
|
torch.cuda.empty_cache() |
|
|
|
from huggingface_hub import HfApi, login, create_repo, Repository, upload_folder |
|
|
|
upload_dir = "/workspace/upload" |
|
|
|
|
|
repo_id = "junnei/gemma-3-4b-it-speech" |
|
branch_name = "main" |
|
|
|
repo = Repository(local_dir=upload_dir, clone_from = repo_id) |
|
repo.git_checkout(branch_name, create_branch_ok=True) |
|
|
|
|
|
for item in os.listdir(output_dir): |
|
s = os.path.join(output_dir, item) |
|
d = os.path.join(upload_dir, item) |
|
if os.path.isdir(s): |
|
shutil.copytree(s, d, dirs_exist_ok=True) |
|
else: |
|
shutil.copy2(s, d) |