CrossFlow / configs /t2i_training_demo.py
QHL067's picture
working
f9567e5
import ml_collections
from dataclasses import dataclass
@dataclass
class Args:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
model = Args(
channels = 4,
block_grad_to_lowres = False,
norm_type = "TDRMSN",
use_t2i = True,
clip_dim=768, # 768 for CLIP, 4096 for T5-XXL
num_clip_token=77,
gradient_checking=True,
cfg_indicator=0.1,
textVAE = Args(
num_blocks = 11,
hidden_dim = 1024,
hidden_token_length = 256,
num_attention_heads = 8,
dropout_prob = 0.1,
),
stage_configs = [ # this is just an example
Args(
block_type = "TransformerBlock",
dim = 960,
hidden_dim = 1920,
num_attention_heads = 16,
num_blocks = 29,
max_height = 16,
max_width = 16,
image_input_ratio = 1,
input_feature_ratio = 4,
final_kernel_size = 3,
dropout_prob = 0,
),
Args(
block_type = "ConvNeXtBlock",
dim = 480,
hidden_dim = 960,
kernel_size = 7,
num_blocks = 15,
max_height = 32,
max_width = 32,
image_input_ratio = 1,
input_feature_ratio = 2,
final_kernel_size = 3,
dropout_prob = 0,
),
Args(
block_type = "ConvNeXtBlock",
dim = 240,
hidden_dim = 480,
kernel_size = 7,
num_blocks = 15,
max_height = 64,
max_width = 64,
image_input_ratio = 1,
input_feature_ratio = 1,
final_kernel_size = 3,
dropout_prob = 0,
),
],
)
def d(**kwargs):
"""Helper of creating a config dict."""
return ml_collections.ConfigDict(initial_dictionary=kwargs)
def get_config():
config = ml_collections.ConfigDict()
config.seed = 1234 # random seed
config.z_shape = (4, 64, 64) # image latent size
config.autoencoder = d(
pretrained_path='assets/stable-diffusion/autoencoder_kl.pth', # path of pretrained VAE CKPT from LDM
scale_factor=0.23010
)
config.train = d(
n_steps=1000000, # total training iterations
batch_size=4, # overall batch size across ALL gpus, where batch_size_per_gpu == batch_size / number_of_gpus
mode='cond',
log_interval=10,
eval_interval=10, # iteration interval for visual testing on the specified prompt
save_interval=100, # iteration interval for saving checkpoints and testing FID
n_samples_eval=5, # number of samples duing visual testing. This depends on your GPU memory and can be any integer between 1 and 15 (as we provide only 15 prompts).
)
config.optimizer = d(
name='adamw',
lr=0.00001, # learning rate
weight_decay=0.03,
betas=(0.9, 0.9),
)
config.lr_scheduler = d(
name='customized',
warmup_steps=5000 # warmup steps
)
global model
config.nnet = d(
name='dimr',
model_args=model,
)
config.loss_coeffs = [1/4, 1/2, 1] # weight on loss, only needed for DiMR. Here, loss = 1/4 * loss_block1 + 1/2 * loss_block2 + 1 * loss_block3
config.dataset = d(
name='JDB_demo_features', # dataset name
resolution=512, # dataset resolution
llm='clip', # language model to generate language embedding
train_path='/data/qihao/dataset/JDB_demo_feature/', # training set path
val_path='/data/qihao/dataset/coco_val_features/', # val set path
cfg=False
)
config.sample = d(
sample_steps=50, # sample steps duing inference/testing
n_samples=30000, # number of samples for testing (during training, we sample 10K images, which is hardcoded in the training script)
mini_batch_size=10, # batch size for testing (i.e., the number of images generated per GPU)
cfg=False,
scale=7, # cfg scale
path=''
)
return config