junbiao.chen
Trellis update
cc0c59d
from abc import abstractmethod
import os
import time
import json
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
import numpy as np
from torchvision import utils
from torch.utils.tensorboard import SummaryWriter
from .utils import *
from ..utils.general_utils import *
from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
class Trainer:
"""
Base class for training.
"""
def __init__(self,
models,
dataset,
*,
output_dir,
load_dir,
step,
max_steps,
batch_size=None,
batch_size_per_gpu=None,
batch_split=None,
optimizer={},
lr_scheduler=None,
elastic=None,
grad_clip=None,
ema_rate=0.9999,
fp16_mode='inflat_all',
fp16_scale_growth=1e-3,
finetune_ckpt=None,
log_param_stats=False,
prefetch_data=True,
i_print=1000,
i_log=500,
i_sample=10000,
i_save=10000,
i_ddpcheck=10000,
**kwargs
):
assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
self.models = models
self.dataset = dataset
self.batch_split = batch_split if batch_split is not None else 1
self.max_steps = max_steps
self.optimizer_config = optimizer
self.lr_scheduler_config = lr_scheduler
self.elastic_controller_config = elastic
self.grad_clip = grad_clip
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
self.fp16_mode = fp16_mode
self.fp16_scale_growth = fp16_scale_growth
self.log_param_stats = log_param_stats
self.prefetch_data = prefetch_data
if self.prefetch_data:
self._data_prefetched = None
self.output_dir = output_dir
self.i_print = i_print
self.i_log = i_log
self.i_sample = i_sample
self.i_save = i_save
self.i_ddpcheck = i_ddpcheck
if dist.is_initialized():
# Multi-GPU params
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
self.local_rank = dist.get_rank() % torch.cuda.device_count()
self.is_master = self.rank == 0
else:
# Single-GPU params
self.world_size = 1
self.rank = 0
self.local_rank = 0
self.is_master = True
self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
self.init_models_and_more(**kwargs)
self.prepare_dataloader(**kwargs)
# Load checkpoint
self.step = 0
if load_dir is not None and step is not None:
self.load(load_dir, step)
elif finetune_ckpt is not None:
self.finetune_from(finetune_ckpt)
if self.is_master:
os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
if self.world_size > 1:
self.check_ddp()
if self.is_master:
print('\n\nTrainer initialized.')
print(self)
@property
def device(self):
for _, model in self.models.items():
if hasattr(model, 'device'):
return model.device
return next(list(self.models.values())[0].parameters()).device
@abstractmethod
def init_models_and_more(self, **kwargs):
"""
Initialize models and more.
"""
pass
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = ResumableSampler(
self.dataset,
shuffle=True,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
@abstractmethod
def load(self, load_dir, step=0):
"""
Load a checkpoint.
Should be called by all processes.
"""
pass
@abstractmethod
def save(self):
"""
Save a checkpoint.
Should be called only by the rank 0 process.
"""
pass
@abstractmethod
def finetune_from(self, finetune_ckpt):
"""
Finetune from a checkpoint.
Should be called by all processes.
"""
pass
@abstractmethod
def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
"""
Run a snapshot of the model.
"""
pass
@torch.no_grad()
def visualize_sample(self, sample):
"""
Convert a sample to an image.
"""
if hasattr(self.dataset, 'visualize_sample'):
return self.dataset.visualize_sample(sample)
else:
return sample
@torch.no_grad()
def snapshot_dataset(self, num_samples=100):
"""
Sample images from the dataset.
"""
dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=num_samples,
num_workers=0,
shuffle=True,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
data = next(iter(dataloader))
data = recursive_to_device(data, self.device)
vis = self.visualize_sample(data)
if isinstance(vis, dict):
save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
else:
save_cfg = [('dataset', vis)]
for name, image in save_cfg:
utils.save_image(
image,
os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
nrow=int(np.sqrt(num_samples)),
normalize=True,
value_range=self.dataset.value_range,
)
@torch.no_grad()
def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
"""
Sample images from the model.
NOTE: This function should be called by all processes.
"""
if self.is_master:
print(f'\nSampling {num_samples} images...', end='')
if suffix is None:
suffix = f'step{self.step:07d}'
# Assign tasks
num_samples_per_process = int(np.ceil(num_samples / self.world_size))
samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
# Preprocess images
for key in list(samples.keys()):
if samples[key]['type'] == 'sample':
vis = self.visualize_sample(samples[key]['value'])
if isinstance(vis, dict):
for k, v in vis.items():
samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
del samples[key]
else:
samples[key] = {'value': vis, 'type': 'image'}
# Gather results
if self.world_size > 1:
for key in samples.keys():
samples[key]['value'] = samples[key]['value'].contiguous()
if self.is_master:
all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
else:
all_images = []
dist.gather(samples[key]['value'], all_images, dst=0)
if self.is_master:
samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
# Save images
if self.is_master:
os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
for key in samples.keys():
if samples[key]['type'] == 'image':
utils.save_image(
samples[key]['value'],
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
nrow=int(np.sqrt(num_samples)),
normalize=True,
value_range=self.dataset.value_range,
)
elif samples[key]['type'] == 'number':
min = samples[key]['value'].min()
max = samples[key]['value'].max()
images = (samples[key]['value'] - min) / (max - min)
images = utils.make_grid(
images,
nrow=int(np.sqrt(num_samples)),
normalize=False,
)
save_image_with_notes(
images,
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
notes=f'{key} min: {min}, max: {max}',
)
if self.is_master:
print(' Done.')
@abstractmethod
def update_ema(self):
"""
Update exponential moving average.
Should only be called by the rank 0 process.
"""
pass
@abstractmethod
def check_ddp(self):
"""
Check if DDP is working properly.
Should be called by all process.
"""
pass
@abstractmethod
def training_losses(**mb_data):
"""
Compute training losses.
"""
pass
def load_data(self):
"""
Load data.
"""
if self.prefetch_data:
if self._data_prefetched is None:
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
data = self._data_prefetched
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
else:
data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
if isinstance(data, dict):
if self.batch_split == 1:
data_list = [data]
else:
batch_size = list(data.values())[0].shape[0]
data_list = [
{k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
for i in range(self.batch_split)
]
elif isinstance(data, list):
data_list = data
else:
raise ValueError('Data must be a dict or a list of dicts.')
return data_list
@abstractmethod
def run_step(self, data_list):
"""
Run a training step.
"""
pass
def run(self):
"""
Run training.
"""
if self.is_master:
print('\nStarting training...')
self.snapshot_dataset()
if self.step == 0:
self.snapshot(suffix='init')
else: # resume
self.snapshot(suffix=f'resume_step{self.step:07d}')
log = []
time_last_print = 0.0
time_elapsed = 0.0
while self.step < self.max_steps:
time_start = time.time()
data_list = self.load_data()
step_log = self.run_step(data_list)
time_end = time.time()
time_elapsed += time_end - time_start
self.step += 1
# Print progress
if self.is_master and self.step % self.i_print == 0:
speed = self.i_print / (time_elapsed - time_last_print) * 3600
columns = [
f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
f'Elapsed: {time_elapsed / 3600:.2f} h',
f'Speed: {speed:.2f} steps/h',
f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
]
print(' | '.join([c.ljust(25) for c in columns]), flush=True)
time_last_print = time_elapsed
# Check ddp
if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
self.check_ddp()
# Sample images
if self.step % self.i_sample == 0:
self.snapshot()
if self.is_master:
log.append((self.step, {}))
# Log time
log[-1][1]['time'] = {
'step': time_end - time_start,
'elapsed': time_elapsed,
}
# Log losses
if step_log is not None:
log[-1][1].update(step_log)
# Log scale
if self.fp16_mode == 'amp':
log[-1][1]['scale'] = self.scaler.get_scale()
elif self.fp16_mode == 'inflat_all':
log[-1][1]['log_scale'] = self.log_scale
# Save log
if self.step % self.i_log == 0:
## save to log file
log_str = '\n'.join([
f'{step}: {json.dumps(log)}' for step, log in log
])
with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
log_file.write(log_str + '\n')
# show with mlflow
log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
log_show = dict_reduce(log_show, lambda x: np.mean(x))
log_show = dict_flatten(log_show, sep='/')
for key, value in log_show.items():
self.writer.add_scalar(key, value, self.step)
log = []
# Save checkpoint
if self.step % self.i_save == 0:
self.save()
if self.is_master:
self.snapshot(suffix='final')
self.writer.close()
print('Training finished.')
def profile(self, wait=2, warmup=3, active=5):
"""
Profile the training loop.
"""
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
profile_memory=True,
with_stack=True,
) as prof:
for _ in range(wait + warmup + active):
self.run_step()
prof.step()