|
""" |
|
This file contains some tools |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import os |
|
from tqdm import tqdm |
|
from torchvision import transforms |
|
from torchvision.utils import save_image |
|
from absl import logging |
|
from PIL import Image, ImageDraw, ImageFont |
|
import textwrap |
|
|
|
def save_image_with_caption(image_tensor, caption, filename, font_size=20, font_path='/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf'): |
|
""" |
|
Save an image with a caption |
|
""" |
|
image_tensor = image_tensor.clone().detach() |
|
image_tensor = torch.clamp(image_tensor, min=0, max=1) |
|
image_pil = transforms.ToPILImage()(image_tensor) |
|
draw = ImageDraw.Draw(image_pil) |
|
|
|
font = ImageFont.truetype(font_path, font_size) |
|
wrap_text = textwrap.wrap(caption, width=len(caption)//4 + 1) |
|
text_sizes = [draw.textsize(line, font=font) for line in wrap_text] |
|
max_text_width = max(size[0] for size in text_sizes) |
|
total_text_height = sum(size[1] for size in text_sizes) + 15 |
|
|
|
new_height = image_pil.height + total_text_height + 25 |
|
new_image = Image.new('RGB', (image_pil.width, new_height), 'white') |
|
new_image.paste(image_pil, (0, 0)) |
|
current_y = image_pil.height + 5 |
|
draw = ImageDraw.Draw(new_image) |
|
|
|
for line, size in zip(wrap_text, text_sizes): |
|
x = (new_image.width - size[0]) / 2 |
|
draw.text((x, current_y), line, font=font, fill='black') |
|
current_y += size[1] + 5 |
|
new_image.save(filename) |
|
|
|
|
|
def set_logger(log_level='info', fname=None): |
|
import logging as _logging |
|
handler = logging.get_absl_handler() |
|
formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logging.set_verbosity(log_level) |
|
if fname is not None: |
|
handler = _logging.FileHandler(fname) |
|
handler.setFormatter(formatter) |
|
logging.get_absl_logger().addHandler(handler) |
|
|
|
|
|
def dct2str(dct): |
|
return str({k: f'{v:.6g}' for k, v in dct.items()}) |
|
|
|
|
|
def get_nnet(name, **kwargs): |
|
if name == 'dimr': |
|
from libs.model.dimr_t2i import MRModel |
|
return MRModel(kwargs["model_args"]) |
|
elif name == 'dit': |
|
from libs.model.dit_t2i import DiT_H_2 |
|
return DiT_H_2(kwargs["model_args"]) |
|
else: |
|
raise NotImplementedError(name) |
|
|
|
|
|
def set_seed(seed: int): |
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
def get_optimizer(params, name, **kwargs): |
|
if name == 'adam': |
|
from torch.optim import Adam |
|
return Adam(params, **kwargs) |
|
elif name == 'adamw': |
|
from torch.optim import AdamW |
|
return AdamW(params, **kwargs) |
|
else: |
|
raise NotImplementedError(name) |
|
|
|
|
|
def customized_lr_scheduler(optimizer, warmup_steps=-1): |
|
from torch.optim.lr_scheduler import LambdaLR |
|
def fn(step): |
|
if warmup_steps > 0: |
|
return min(step / warmup_steps, 1) |
|
else: |
|
return 1 |
|
return LambdaLR(optimizer, fn) |
|
|
|
|
|
def get_lr_scheduler(optimizer, name, **kwargs): |
|
if name == 'customized': |
|
return customized_lr_scheduler(optimizer, **kwargs) |
|
elif name == 'cosine': |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
return CosineAnnealingLR(optimizer, **kwargs) |
|
else: |
|
raise NotImplementedError(name) |
|
|
|
|
|
def ema(model_dest: nn.Module, model_src: nn.Module, rate): |
|
param_dict_src = dict(model_src.named_parameters()) |
|
for p_name, p_dest in model_dest.named_parameters(): |
|
p_src = param_dict_src[p_name] |
|
assert p_src is not p_dest |
|
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) |
|
|
|
|
|
class TrainState(object): |
|
def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): |
|
self.optimizer = optimizer |
|
self.lr_scheduler = lr_scheduler |
|
self.step = step |
|
self.nnet = nnet |
|
self.nnet_ema = nnet_ema |
|
|
|
def ema_update(self, rate=0.9999): |
|
if self.nnet_ema is not None: |
|
ema(self.nnet_ema, self.nnet, rate) |
|
|
|
def save(self, path): |
|
os.makedirs(path, exist_ok=True) |
|
torch.save(self.step, os.path.join(path, 'step.pth')) |
|
for key, val in self.__dict__.items(): |
|
if key != 'step' and val is not None: |
|
torch.save(val.state_dict(), os.path.join(path, f'{key}.pth')) |
|
|
|
def load(self, path): |
|
logging.info(f'load from {path}') |
|
self.step = torch.load(os.path.join(path, 'step.pth')) |
|
for key, val in self.__dict__.items(): |
|
if key != 'step' and val is not None: |
|
val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) |
|
|
|
def resume(self, ckpt_root, step=None): |
|
if not os.path.exists(ckpt_root): |
|
return |
|
if step is None: |
|
ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) |
|
if not ckpts: |
|
return |
|
steps = map(lambda x: int(x.split(".")[0]), ckpts) |
|
step = max(steps) |
|
ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') |
|
logging.info(f'resume from {ckpt_path}') |
|
self.load(ckpt_path) |
|
|
|
def to(self, device): |
|
for key, val in self.__dict__.items(): |
|
if isinstance(val, nn.Module): |
|
val.to(device) |
|
|
|
|
|
def trainable_parameters(nnet): |
|
params_decay = [] |
|
params_nodecay = [] |
|
for name, param in nnet.named_parameters(): |
|
if name.endswith(".nodecay_weight") or name.endswith(".nodecay_bias"): |
|
params_nodecay.append(param) |
|
else: |
|
params_decay.append(param) |
|
print("params_decay", len(params_decay)) |
|
print("params_nodecay", len(params_nodecay)) |
|
params = [ |
|
{'params': params_decay}, |
|
{'params': params_nodecay, 'weight_decay': 0.0} |
|
] |
|
return params |
|
|
|
|
|
def initialize_train_state(config, device): |
|
|
|
nnet = get_nnet(**config.nnet) |
|
nnet_ema = get_nnet(**config.nnet) |
|
nnet_ema.eval() |
|
|
|
optimizer = get_optimizer(trainable_parameters(nnet), **config.optimizer) |
|
lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) |
|
|
|
train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, |
|
nnet=nnet, nnet_ema=nnet_ema) |
|
train_state.ema_update(0) |
|
train_state.to(device) |
|
return train_state |
|
|
|
|
|
def amortize(n_samples, batch_size): |
|
k = n_samples // batch_size |
|
r = n_samples % batch_size |
|
return k * [batch_size] if r == 0 else k * [batch_size] + [r] |
|
|
|
|
|
def sample2dir(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, return_clipScore=False, ClipSocre_model=None, config=None): |
|
os.makedirs(path, exist_ok=True) |
|
idx = 0 |
|
batch_size = mini_batch_size * accelerator.num_processes |
|
clip_score_list = [] |
|
|
|
if return_clipScore: |
|
assert ClipSocre_model is not None |
|
|
|
for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): |
|
samples, clip_score = sample_fn(mini_batch_size, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, config=config) |
|
samples = unpreprocess_fn(samples) |
|
samples = accelerator.gather(samples.contiguous())[:_batch_size] |
|
clip_score_list.append(accelerator.gather(clip_score)[:_batch_size]) |
|
if accelerator.is_main_process: |
|
for sample in samples: |
|
save_image(sample, os.path.join(path, f"{idx}.png")) |
|
idx += 1 |
|
|
|
if return_clipScore: |
|
return clip_score_list |
|
else: |
|
return None |
|
|
|
|
|
def sample2dir_wCLIP(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, return_clipScore=False, ClipSocre_model=None, config=None): |
|
os.makedirs(path, exist_ok=True) |
|
idx = 0 |
|
batch_size = mini_batch_size * accelerator.num_processes |
|
clip_score_list = [] |
|
|
|
if return_clipScore: |
|
assert ClipSocre_model is not None |
|
|
|
for _batch_size in amortize(n_samples, batch_size): |
|
samples, clip_score = sample_fn(mini_batch_size, return_clipScore=return_clipScore, ClipSocre_model=ClipSocre_model, config=config) |
|
samples = unpreprocess_fn(samples) |
|
samples = accelerator.gather(samples.contiguous())[:_batch_size] |
|
clip_score_list.append(accelerator.gather(clip_score)[:_batch_size]) |
|
if accelerator.is_main_process: |
|
for sample in samples: |
|
save_image(sample, os.path.join(path, f"{idx}.png")) |
|
idx += 1 |
|
break |
|
|
|
if return_clipScore: |
|
return clip_score_list |
|
else: |
|
return None |
|
|
|
|
|
def sample2dir_wPrompt(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, config=None): |
|
os.makedirs(path, exist_ok=True) |
|
idx = 0 |
|
batch_size = mini_batch_size * accelerator.num_processes |
|
|
|
for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir'): |
|
samples, samples_caption = sample_fn(mini_batch_size, return_caption=True, config=config) |
|
samples = unpreprocess_fn(samples) |
|
samples = accelerator.gather(samples.contiguous())[:_batch_size] |
|
if accelerator.is_main_process: |
|
for sample, caption in zip(samples,samples_caption): |
|
try: |
|
save_image_with_caption(sample, caption, os.path.join(path, f"{idx}.png")) |
|
except: |
|
save_image(sample, os.path.join(path, f"{idx}.png")) |
|
idx += 1 |
|
|
|
|
|
def grad_norm(model): |
|
total_norm = 0. |
|
for p in model.parameters(): |
|
param_norm = p.grad.data.norm(2) |
|
total_norm += param_norm.item() ** 2 |
|
total_norm = total_norm ** (1. / 2) |
|
return total_norm |
|
|