Spaces:
Running
on
Zero
Running
on
Zero
import time, sys, subprocess, json, re | |
from pathlib import Path | |
import os, random | |
import torch | |
import math, pickle | |
from tqdm import tqdm | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import LambdaLR | |
import torch.nn as nn | |
import torch.distributed as dist | |
from torch.utils.data.sampler import Sampler | |
import copy | |
from torch.utils.tensorboard import SummaryWriter | |
import numpy as np | |
from torch.utils.data.distributed import DistributedSampler | |
import logging | |
# from data import librilight, gigaspeech, gigaspeech_waveform | |
from data import combined_dataset | |
from models import voice_star | |
from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, StatefulSampler, AverageMeter, print_model_info | |
from .optim import ScaledAdam, Eden | |
import run_gen | |
import wandb, socket | |
class Trainer: | |
def __init__(self, args, world_size, rank, local_rank): | |
self.start_time = time.time() | |
self.args = args | |
if self.args.val_max_num_tokens == None: | |
self.args.val_max_num_tokens = self.args.max_num_tokens | |
self.world_size, self.rank, self.local_rank = world_size, rank, local_rank | |
self.device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") | |
if self.rank == 0: | |
self.writer = SummaryWriter(args.exp_dir) | |
self.wandb = wandb.init(project="voice_editor", name=args.exp_dir.split("/")[-1], config=args, dir=args.exp_dir, entity=self.args.wandb_entity) | |
self.seed_everything(seed=self.args.seed) | |
self.meters = self._setup_meters() | |
self.progress, self.total_progress = self._setup_progress() | |
self.model, self.trainables, self.optim_states, self.scheduler_states, self.phn2num = self._setup_models() | |
self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader() # both are use DistributedSampler, train sampler is stateful | |
if self.args.num_steps != None: | |
self.total_step = self.args.num_steps | |
self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None | |
else: | |
self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs | |
self.optimizer, self.scheduler = self._setup_optimizer() | |
self.scaler = torch.cuda.amp.GradScaler() | |
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank], find_unused_parameters=False) | |
self.early_stop_accu_steps = 0 | |
if self.rank == 0: | |
if self.args.dynamic_batching: | |
logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}") | |
else: | |
logging.info(f"batch size (per gpu): {self.args.batch_size}") | |
self.args.inference_every_n_steps = getattr(self.args, "inference_every_n_steps", self.args.val_every_n_steps*5) | |
assert self.args.inference_every_n_steps > self.args.val_every_n_steps and self.args.inference_every_n_steps % self.args.val_every_n_steps == 0, "inference_every_n_steps should be divisible by val_every_n_steps, otherwise the code will not get a chance to run inference" | |
def train(self): | |
flag = True | |
skip_flag = False | |
data_start_time = time.time() | |
if self.progress['step'] >= self.total_step: | |
if self.rank == 0: | |
self.writer.close() | |
self.wandb.finish() | |
return | |
while flag: | |
self.train_sampler.set_epoch(self.progress['epoch']) | |
for i, batch in enumerate(self.train_loader): | |
if len(batch['y_lens']) < self.args.gradient_accumulation_steps: | |
continue | |
data_end_time = time.time() | |
self.model.train() | |
if self.progress['step'] >= getattr(self.args, "uniform_weight_start_step", 1e50): | |
if self.progress['step'] == getattr(self.args, "uniform_weight_start_step", 1e50) and self.rank == 0: | |
logging.info("NOTE: start using uniform weight from step: {}".format(self.progress['step'])) | |
self.args.codebook_weight = [2.5,2,1.5,0.6] | |
self.model.module.args.codebook_weight = [2.5,2,1.5,0.6] | |
if self.progress['step'] >= self.total_step: | |
dist.barrier() | |
flag = False | |
self.validate_and_save() | |
if self.rank == 0: | |
self.writer.close() | |
self.wandb.finish() | |
break | |
if isinstance(self.scheduler, Eden): | |
self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1) | |
if self.args.optimizer_name == "ScaledAdam": | |
cur_lr = self.scheduler.get_last_lr()[0] | |
else: | |
lrs = [param_group['lr'] for param_group in self.optimizer.param_groups] | |
assert lrs[0] == lrs[1] | |
cur_lr = lrs[0] | |
if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0: | |
self.writer.add_scalar("train/lr", cur_lr, self.progress['step']) | |
self.wandb.log({"train/lr": cur_lr}, step=self.progress['step']) | |
all_inds = list(range(len(batch['y']))) | |
sum_losses = 0 | |
sum_top10acc = 0 | |
sum_ntoken = 0 | |
sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)] | |
# extra losses | |
sum_extra_losses = {} | |
# when using prompt-based training, it's likely that due to prompt, the total length gets much longer, which make effective batch size in each accumulation step much bigger and then lead to OOM. | |
# therefore we re-calculate graduent_accumulation_steps based on the effective batch size | |
if self.args.neighbor_prompt_prob > 0: | |
effective_batch_size = self.args.max_num_tokens // self.args.gradient_accumulation_steps | |
total_batch_size = sum(batch['y_lens']).item() | |
cur_gradient_accumulation_steps = max(self.args.gradient_accumulation_steps, total_batch_size // effective_batch_size) | |
gas = torch.tensor(cur_gradient_accumulation_steps, dtype=torch.int, device=self.local_rank) | |
dist.all_reduce(gas, op=dist.ReduceOp.MAX) | |
cur_gradient_accumulation_steps = gas.item() | |
len_batch = torch.tensor(len(batch['y']), dtype=torch.int, device=self.local_rank) | |
dist.all_reduce(len_batch, op=dist.ReduceOp.MIN) | |
len_batch = len_batch.item() | |
cur_gradient_accumulation_steps = min(cur_gradient_accumulation_steps, len_batch) | |
# for those that cur_gradient_accumulation_steps * effective_batch_size < total_batch_size, we only use the first cur_gradient_accumulation_steps * effective_batch_size samples | |
cur_len = 0 | |
final_all_inds = [] | |
pointer = 0 | |
while cur_len < self.args.max_num_tokens and pointer < len(all_inds): | |
cur_len += batch['y_lens'][pointer] | |
final_all_inds.append(all_inds[pointer]) | |
pointer += 1 | |
all_inds = final_all_inds | |
else: | |
cur_gradient_accumulation_steps = self.args.gradient_accumulation_steps | |
sum_losses_local = 0.0 | |
sum_top10acc_local = 0.0 | |
sum_entropy_loss_local = 0.0 | |
sum_ctc_loss_local = 0.0 | |
sum_ntoken_local = 0.0 | |
sum_top10acc_cbi_local = [0.0 for _ in range(self.args.n_codebooks)] | |
global_nan_flag = 0 | |
for j in range(cur_gradient_accumulation_steps): | |
cur_ind = all_inds[j::cur_gradient_accumulation_steps] | |
cur_batch = {key: batch[key][cur_ind] for key in batch} | |
# Automatic casting | |
if self.args.precision == "float16": | |
precision_used = torch.float16 | |
elif self.args.precision in ["bf16", "bfloat16"]: | |
precision_used = torch.bfloat16 | |
else: | |
precision_used = torch.float32 | |
with torch.amp.autocast('cuda', dtype=precision_used): | |
out = self.model(cur_batch, calc_loss=True) | |
if out is None: | |
continue | |
if torch.isnan(out['loss']).any(): | |
local_nan_flag = torch.tensor(1, device=self.local_rank) | |
else: | |
local_nan_flag = torch.tensor(0, device=self.local_rank) | |
# All ranks check if *any* rank got a NaN | |
dist.all_reduce(local_nan_flag, op=dist.ReduceOp.SUM) | |
global_nan_flag = local_nan_flag.item() | |
if global_nan_flag > 0: | |
# Now *all* ranks break at the same j | |
logging.info(f"rank: {self.rank}. Loss at micro-batch {j} in step {self.progress['step']} was NaN on at least one rank; skipping.") | |
break | |
# Accumulate local values | |
record_loss = out['loss'].detach() | |
top10acc = out['top10acc'].detach() | |
effective_ntoken = out['effective_ntoken'].detach() | |
sum_losses_local += record_loss.item() | |
sum_top10acc_local += top10acc.item() | |
sum_ntoken_local += effective_ntoken.item() | |
# Optional losses | |
if 'entropy_loss' in out: | |
sum_entropy_loss_local += out['entropy_loss'].detach().item() | |
if 'ctc_loss' in out: | |
sum_ctc_loss_local += out['ctc_loss'].detach().item() | |
# Codebook accuracy | |
if 'top10acc_by_codebook' in out: | |
for cb in range(self.args.n_codebooks): | |
sum_top10acc_cbi_local[cb] += out['top10acc_by_codebook'][cb].detach().item() | |
# Backprop on this micro-batch | |
if self.args.optimizer_name == "ScaledAdam": | |
self.scaler.scale(out['loss']).backward() | |
else: | |
self.scaler.scale(out['loss'] / out['effective_ntoken']).backward() | |
if global_nan_flag > 0: | |
# If *any* rank had NaN, skip this step | |
logging.info(f"rank: {self.rank}. Loss at one micro-batch in step {self.progress['step']} was NaN on at least one rank; skipping.") | |
self.progress['step'] += 1 | |
self.progress['cur_step'] += 1 | |
self.optimizer.zero_grad() | |
continue | |
# Otherwise, do one big reduce for the summed metrics | |
metrics_tensor = torch.tensor([ | |
sum_losses_local, | |
sum_top10acc_local, | |
sum_entropy_loss_local, | |
sum_ctc_loss_local, | |
sum_ntoken_local | |
], device=self.local_rank, dtype=torch.float32) | |
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM) | |
# Also reduce the codebook array in one shot if needed | |
codebook_tensor = torch.tensor(sum_top10acc_cbi_local, device=self.local_rank, dtype=torch.float32) | |
dist.all_reduce(codebook_tensor, op=dist.ReduceOp.SUM) | |
# Convert them back to Python scalars | |
sum_losses = metrics_tensor[0].item() | |
sum_top10acc = metrics_tensor[1].item() | |
sum_entropy_loss = metrics_tensor[2].item() | |
sum_ctc_loss = metrics_tensor[3].item() | |
sum_ntoken = metrics_tensor[4].item() | |
sum_top10acc_cbi = codebook_tensor.tolist() | |
if self.args.optimizer_name != "ScaledAdam": | |
self.scaler.unscale_(self.optimizer) | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val) | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
self.optimizer.zero_grad() | |
if self.args.optimizer_name == "ScaledAdam": | |
self.scheduler.step_batch(self.progress['step']) | |
else: | |
self.scheduler.step() | |
# logging | |
if self.rank == 0: | |
average_loss = sum_losses / sum_ntoken | |
average_top10acc = sum_top10acc / sum_ntoken | |
average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)] | |
self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size) | |
self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) | |
self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) | |
for cb in range(self.args.n_codebooks): | |
self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size) | |
self.meters['data_time'].update(data_end_time - data_start_time) | |
self.meters['train_time'].update(time.time() - data_end_time) | |
# log extra losses | |
for key in sum_extra_losses: | |
if "train_"+key not in self.meters: | |
self.meters["train_"+key] = AverageMeter() | |
self.meters["train_"+key].update(sum(sum_extra_losses[key])/len(sum_extra_losses[key]), batch['x'].shape[0]*self.world_size) | |
if self.progress['step'] % self.args.tb_write_every_n_steps == 0: | |
self.writer.add_scalar('train/loss', average_loss, self.progress['step']) | |
self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step']) | |
self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step']) | |
self.wandb.log({"train/loss": average_loss, "train/top10acc": average_top10acc, "train/ntokens": sum_ntoken, "train/data_time": data_end_time - data_start_time, "train/train_time": time.time() - data_end_time}, step=self.progress['step']) | |
for cb in range(self.args.n_codebooks): | |
self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step']) | |
self.wandb.log({f'train/top10acc_cb{cb+1}': average_top10acc_cbi[cb]}, step=self.progress['step']) | |
self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step']) | |
self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step']) | |
# write extra losses | |
for key in sum_extra_losses: | |
self.writer.add_scalar(f"train/{key}", sum(sum_extra_losses[key])/len(sum_extra_losses[key]), self.progress['step']) | |
self.wandb.log({f"train/{key}": sum(sum_extra_losses[key])/len(sum_extra_losses[key])}, step=self.progress['step']) | |
# logging.info(f"ntoken: {sum_ntoken}") | |
# logging | |
if self.progress['step'] % self.args.print_every_n_steps == 0: | |
log_out = {} | |
log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}" | |
log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}" | |
log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}" | |
log_out['lr'] = f"{cur_lr:.7f}" | |
log_out['ntokens'] = f"{sum_ntoken}" | |
for key in self.meters: | |
if self.meters[key].val != 0 or self.meters[key].avg != 0: | |
log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}" | |
logging.info(log_out) | |
if np.isnan(self.meters['train_loss'].avg): | |
logging.warning("training diverged...") | |
raise RuntimeError("training diverged...") | |
# save the model only | |
if self.progress['step'] % self.args.save_every_n_steps == 0: | |
dist.barrier() | |
if self.rank == 0: | |
save_path = os.path.join(self.args.exp_dir,f"bundle_step{self.progress['step']}.pth") | |
self.save_progress(name=f"step{self.progress['step']}") | |
torch.save( | |
{ | |
"model": self.model.module.state_dict(), | |
"args": self.args, | |
"phn2num": self.train_loader.dataset.phn2num, | |
"optimizer": self.optimizer.state_dict(), | |
"scheduler": self.scheduler.state_dict(), | |
},save_path | |
) | |
logging.info(f"save model, optimizer, scheduler and progress at {save_path} at global step {self.progress['step']}") | |
dist.barrier() | |
# validation and save models | |
if self.progress['step'] % self.args.val_every_n_steps == 0: | |
dist.barrier() | |
continue_training = self.validate_and_save() | |
# broadcast continue_training to all processes, so that all processes gets into generation stage | |
continue_training = torch.tensor(int(continue_training), dtype=torch.int, device=self.local_rank) | |
dist.broadcast(continue_training, src=0) | |
continue_training = bool(continue_training.item()) | |
dist.barrier() # need this to ensure all processes get to the next line? | |
logging.info(f"rank: {self.rank}, continue_training: {continue_training}") | |
if not continue_training: | |
if self.rank == 0: | |
self.writer.close() | |
self.wandb.finish() | |
flag = False | |
break | |
self.progress['step'] += 1 | |
self.progress['cur_step'] += 1 | |
data_start_time = time.time() | |
self.progress['epoch'] += 1 | |
self.progress['cur_step'] = 0 # reset cur_step to be 0 | |
dist.destroy_process_group() | |
def validate_and_save(self): | |
self.model.eval() | |
score = self.validate(self.valid_loader) | |
if self.args.early_stop_threshold > 0: | |
if self.progress['best_score'] - score < self.args.early_stop_threshold: | |
self.early_stop_accu_steps += self.args.val_every_n_steps | |
if self.early_stop_accu_steps >= self.args.early_stop_step-1: | |
logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}") | |
logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}") | |
return False | |
else: | |
self.early_stop_accu_steps = 0 | |
if self.rank == 0: | |
save_path = os.path.join(self.args.exp_dir,"bundle.pth") | |
if os.path.isfile(save_path): | |
os.system(f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}") | |
torch.save( | |
{ | |
"model": self.model.module.state_dict(), | |
"optimizer": self.optimizer.state_dict(), | |
"scheduler": self.scheduler.state_dict(), | |
"args": self.args, | |
"phn2num": self.train_loader.dataset.phn2num | |
},save_path | |
) | |
self.save_progress() | |
logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}") | |
if (score < self.progress['best_score']): | |
self.progress['best_step'] = self.progress['step'] | |
self.progress['best_score'] = score | |
save_path = os.path.join(self.args.exp_dir,"best_bundle.pth") | |
if os.path.isfile(save_path): | |
os.system(f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}") | |
torch.save( | |
{ | |
"model": self.model.module.state_dict(), | |
"optimizer": self.optimizer.state_dict(), | |
"scheduler": self.scheduler.state_dict(), | |
"args": self.args, | |
"phn2num": self.train_loader.dataset.phn2num | |
},save_path | |
) | |
logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}") | |
# sync best score and best step, so that all processes early stop at the same time | |
best_score_tensor = torch.tensor(self.progress['best_score'], device=self.local_rank) | |
dist.broadcast(best_score_tensor, src=0) | |
self.progress['best_score'] = float(best_score_tensor.item()) | |
best_step_tensor = torch.tensor(self.progress['best_step'], device=self.local_rank) | |
dist.broadcast(best_step_tensor, src=0) | |
self.progress['best_step'] = int(best_step_tensor.item()) | |
dist.barrier() | |
return True | |
def validate(self, valid_loader=None, hide_progress=True): | |
if valid_loader == None: | |
valid_loader = self.valid_loader | |
self.model.eval() | |
start_val_time = time.time() | |
sum_losses = 0 | |
sum_top10acc = 0 | |
sum_ntoken = 0 | |
sum_dur_loss = 0 | |
sum_dur_acc = 0 | |
sum_entropy_loss = 0 | |
sum_ctc_loss = 0 | |
sum_top10acc_cbi = [0 for _ in range(self.args.n_codebooks)] | |
mean_perplexity_cbi = [0 for _ in range(self.args.n_codebooks)] | |
with torch.no_grad(): | |
for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)): | |
out = self.model(batch, calc_loss=True) # no reduction is applied to loss | |
sum_losses += out['loss'] | |
sum_top10acc += out['top10acc'] | |
sum_ntoken += out['effective_ntoken'] | |
if "dur_loss" in out: | |
sum_dur_loss += out['dur_loss'] | |
sum_dur_acc += out['dur_acc'] | |
if "entropy_loss" in out: | |
sum_entropy_loss += out['entropy_loss'] | |
if "ctc_loss" in out: | |
sum_ctc_loss += out['ctc_loss'] | |
# logging.info(f"iter {i}::: {sum_losses}, {sum_top10acc}, {sum_ntoken}") | |
if 'top10acc_by_codebook' in out: | |
for cb in range(self.args.n_codebooks): | |
sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb] | |
if 'perplexity_by_codebook' in out: | |
for cb in range(self.args.n_codebooks): | |
mean_perplexity_cbi[cb] += out['perplexity_by_codebook'][cb] | |
# if i > 10: | |
# break | |
dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM) | |
dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM) | |
dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM) | |
if "dur_loss" in out: | |
dist.all_reduce(sum_dur_loss, op=dist.ReduceOp.SUM) | |
dist.all_reduce(sum_dur_acc, op=dist.ReduceOp.SUM) | |
if "entropy_loss" in out: | |
dist.all_reduce(sum_entropy_loss, op=dist.ReduceOp.SUM) | |
if "ctc_loss" in out: | |
dist.all_reduce(sum_ctc_loss, op=dist.ReduceOp.SUM) | |
if 'top10acc_by_codebook' in out: | |
for cb in range(self.args.n_codebooks): | |
dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM) | |
if 'perplexity_by_codebook' in out: | |
for cb in range(self.args.n_codebooks): | |
dist.all_reduce(mean_perplexity_cbi[cb], op=dist.ReduceOp.SUM) | |
val_loss = sum_losses / sum_ntoken | |
val_top10acc = sum_top10acc / sum_ntoken | |
if self.rank == 0: | |
if "dur_loss" in out: | |
val_dur_loss = sum_dur_loss / sum_ntoken | |
val_dur_acc = sum_dur_acc / sum_ntoken | |
self.meters['val_dur_loss'].update(val_dur_loss) | |
logging.info(f"val dur_loss: {val_dur_loss:.5f}") | |
self.meters['val_dur_acc'].update(val_dur_acc) | |
logging.info(f"val dur_acc: {val_dur_acc:.5f}") | |
self.writer.add_scalar("val/dur_loss", val_dur_loss, self.progress['step']) | |
self.writer.add_scalar("val/dur_acc", val_dur_acc, self.progress['step']) | |
self.wandb.log({"val/dur_loss": val_dur_loss, "val/dur_acc": val_dur_acc}, step=self.progress['step']) | |
# logging | |
self.meters['val_loss'].update(val_loss) | |
logging.info(f"val loss: {val_loss:.5f}") | |
self.writer.add_scalar("val/loss", val_loss, self.progress['step']) | |
self.wandb.log({"val/loss": val_loss}, step=self.progress['step']) | |
self.meters['val_top10acc'].update(val_top10acc) | |
logging.info(f"val top10acc: {val_top10acc:.5f}") | |
self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step']) | |
self.wandb.log({"val/top10acc": val_top10acc}, step=self.progress['step']) | |
for cb in range(self.args.n_codebooks): | |
average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks | |
self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi) | |
self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step']) | |
self.wandb.log({f'val/top10acc_cb{cb+1}': average_top10acc_cbi}, step=self.progress['step']) | |
temp = mean_perplexity_cbi[cb]/len(valid_loader) | |
self.writer.add_scalar(f'val/perplexity_cb{cb+1}', temp, self.progress['step']) | |
self.wandb.log({f'val/perplexity_cb{cb+1}': temp}, step=self.progress['step']) | |
average_perplexity = sum(mean_perplexity_cbi)/(self.args.n_codebooks*len(valid_loader)) | |
self.wandb.log({"val/average_perplexity": average_perplexity}, step=self.progress['step']) | |
self.writer.add_scalar('val/average_perplexity', average_perplexity, self.progress['step']) | |
# log entropy and ctc loss | |
if "entropy_loss" in out: | |
val_entropy_loss = sum_entropy_loss / ((i+1) * self.world_size) | |
self.meters['val_entropy_loss'].update(val_entropy_loss) | |
logging.info(f"val entropy_loss: {val_entropy_loss:.5f}") | |
self.writer.add_scalar("val/entropy_loss", val_entropy_loss, self.progress['step']) | |
self.wandb.log({"val/entropy_loss": val_entropy_loss}, step=self.progress['step']) | |
if "ctc_loss" in out: | |
val_ctc_loss = sum_ctc_loss / ((i+1) * self.world_size) | |
self.meters['val_ctc_loss'].update(val_ctc_loss) | |
logging.info(f"val ctc_loss: {val_ctc_loss:.5f}") | |
self.writer.add_scalar("val/ctc_loss", val_ctc_loss, self.progress['step']) | |
self.wandb.log({"val/ctc_loss": val_ctc_loss}, step=self.progress['step']) | |
logging.info(f"validation takes: {time.time() - start_val_time:.2f}s") | |
logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}") | |
return val_loss.item() | |
def _setup_meters(self): | |
meters = {} | |
meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time'] | |
meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc'] | |
meter_names += ['val_perplexity'] | |
meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] | |
meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] | |
meter_names += [f'val_perplexity_cb{cb+1}' for cb in range(self.args.n_codebooks)] | |
for name in meter_names: | |
meters[name] = AverageMeter() | |
return meters | |
def _setup_progress(self): | |
""" | |
Need to customize it | |
""" | |
progress = {} | |
progress['best_step'] = 1 | |
progress['best_score'] = np.inf # this records loss value | |
progress['step'] = 1 | |
progress['epoch'] = 1 | |
progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler | |
total_progress = [] | |
# if self.args.resume or self.args.validate: | |
if self.args.resume: | |
progress_pkl = "%s/progress.pkl" % self.args.exp_dir | |
with open(progress_pkl, "rb") as f: | |
total_progress = pickle.load(f) | |
progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1] | |
if self.rank == 0: | |
logging.info("\nResume training from:") | |
logging.info(" epoch = %s" % progress['epoch']) | |
logging.info(" cur_step = %s" % progress['cur_step']) | |
logging.info(" step = %s" % progress['step']) | |
logging.info(" best_step = %s" % progress['best_step']) | |
logging.info(" best_score = %s" % progress['best_score']) | |
return progress, total_progress | |
def save_progress(self, name=None): | |
self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time]) | |
if name is not None: | |
progress_fn = f"{self.args.exp_dir}/progress_{name}.pkl" | |
else: | |
progress_fn = f"{self.args.exp_dir}/progress.pkl" | |
with open(progress_fn, "wb") as f: | |
pickle.dump(self.total_progress, f) | |
def _setup_dataloader(self): | |
train_dataset, val_dataset = combined_dataset.dataset(self.args, 'train'), combined_dataset.dataset(self.args, 'valid') # need to change 'train' to 'valid' in actual training | |
if self.args.dynamic_batching: | |
train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0) | |
valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0) | |
else: | |
train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True) | |
valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False) | |
if self.progress['step'] > 1: | |
train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step']) | |
assert self.phn2num != None | |
if self.phn2num != None: | |
train_dataset.phn2num = self.phn2num | |
val_dataset.phn2num = self.phn2num | |
if self.args.dynamic_batching: | |
train_loader = torch.utils.data.DataLoader(train_dataset, | |
batch_sampler=train_sampler, | |
num_workers=self.args.num_workers, | |
collate_fn=train_dataset.collate, persistent_workers=True | |
) | |
valid_loader = torch.utils.data.DataLoader(val_dataset, | |
batch_sampler=valid_sampler, | |
num_workers=self.args.num_workers, | |
collate_fn=val_dataset.collate, persistent_workers=True | |
) | |
else: | |
train_loader = torch.utils.data.DataLoader(train_dataset, | |
batch_size=self.args.batch_size, sampler=train_sampler, num_workers=self.args.num_workers, | |
collate_fn=train_dataset.collate, persistent_workers=True | |
) | |
valid_loader = torch.utils.data.DataLoader(val_dataset, | |
batch_size=self.args.batch_size, sampler=valid_sampler, | |
num_workers=self.args.num_workers, | |
collate_fn=val_dataset.collate, persistent_workers=True | |
) | |
return len(train_dataset), train_sampler, train_loader, valid_loader | |
def _setup_models(self): | |
model = voice_star.VoiceStar(self.args) | |
if self.rank == 0: | |
logging.info(model) | |
logging.info("model parameters") | |
print_model_info(model) | |
phn2num = None | |
optim_states = None | |
scheduler_states = None | |
if self.progress['step'] > 1: | |
bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu") | |
model.load_state_dict(bundle['model']) | |
optim_states = bundle['optimizer'] | |
scheduler_states = bundle['scheduler'] | |
phn2num = bundle['phn2num'] | |
if self.rank == 0: | |
logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step'])) | |
del bundle['model'] | |
if self.args.load_model_from != None and self.progress['step'] <= 1: | |
logging.info(f"load weights from {self.args.load_model_from}") | |
sd = torch.load(self.args.load_model_from, map_location="cpu") | |
if hasattr(model, "carefully_load_state_dict"): | |
model.carefully_load_state_dict(sd['model']) | |
else: | |
model.load_state_dict(sd['model']) | |
phn2num = sd['phn2num'] | |
del sd | |
#### below operations is for getting params for optimizer, which is at wrapper level ### | |
if self.args.optimizer_name == "ScaledAdam": | |
trainables = [p for p in model.parameters() if p.requires_grad] | |
else: | |
no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], | |
"weight_decay": self.args.weight_decay, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], | |
"weight_decay": 0.0, | |
}, | |
] | |
if len(optimizer_grouped_parameters[1]['params']) == 0: | |
logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names") | |
trainables = optimizer_grouped_parameters[0] | |
else: | |
trainables = optimizer_grouped_parameters | |
#### below operations is for getting params for optimizer, which is at wrapper level ### | |
model.to(self.device) | |
return model, trainables, optim_states, scheduler_states, phn2num | |
def _setup_optimizer(self): | |
if self.args.optimizer_name == "ScaledAdam": | |
parameters_names = [] | |
_model = self.model.module if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) else self.model | |
parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad]) | |
optimizer = ScaledAdam( | |
self.trainables, | |
lr=self.args.lr, | |
betas=(0.9, 0.95), | |
clipping_scale=2.0, | |
parameters_names=parameters_names, | |
show_dominant_parameters=False, | |
clipping_update_period=self.args.clipping_update_period, | |
) | |
scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction) # NOTE: if using ScaledAdam, we will use the Eden scheduler! | |
else: | |
optimizer = AdamW(self.trainables, lr=self.args.lr) | |
warmup_steps = self.total_step * self.args.warmup_fraction | |
def lr_lambda(current_step: int): | |
if current_step < warmup_steps: | |
return float(current_step) / float(max(1, warmup_steps)) | |
return max( | |
0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps)) | |
) | |
scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) | |
# if resume | |
if self.progress['step'] > 1: | |
optimizer.load_state_dict(self.optim_states) | |
for state in optimizer.state.values(): | |
for k, v in state.items(): | |
if isinstance(v, torch.Tensor): | |
state[k] = v.cuda() | |
del self.optim_states | |
scheduler.load_state_dict(self.scheduler_states) | |
optimizer.zero_grad() | |
return optimizer, scheduler | |
def seed_everything(self, seed=1): | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True |