|
import ml_collections |
|
from dataclasses import dataclass |
|
|
|
@dataclass |
|
class Args: |
|
def __init__(self, **kwargs): |
|
for key, value in kwargs.items(): |
|
setattr(self, key, value) |
|
|
|
model = Args( |
|
latent_size = 64, |
|
learn_sigma = False, |
|
channels = 4, |
|
block_grad_to_lowres = False, |
|
norm_type = "TDRMSN", |
|
use_t2i = True, |
|
clip_dim=4096, |
|
num_clip_token=77, |
|
gradient_checking=True, |
|
cfg_indicator=0.10, |
|
textVAE = Args( |
|
num_blocks = 11, |
|
hidden_dim = 1024, |
|
hidden_token_length = 256, |
|
num_attention_heads = 8, |
|
dropout_prob = 0.1, |
|
), |
|
) |
|
|
|
def d(**kwargs): |
|
"""Helper of creating a config dict.""" |
|
return ml_collections.ConfigDict(initial_dictionary=kwargs) |
|
|
|
|
|
def get_config(): |
|
config = ml_collections.ConfigDict() |
|
|
|
config.seed = 1234 |
|
config.z_shape = (4, 64, 64) |
|
|
|
config.autoencoder = d( |
|
pretrained_path='assets/stable-diffusion/autoencoder_kl.pth', |
|
scale_factor=0.23010 |
|
) |
|
|
|
config.train = d( |
|
n_steps=1000000, |
|
batch_size=1024, |
|
mode='cond', |
|
log_interval=10, |
|
eval_interval=5000, |
|
save_interval=50000, |
|
) |
|
|
|
config.optimizer = d( |
|
name='adamw', |
|
lr=0.00002, |
|
weight_decay=0.03, |
|
betas=(0.9, 0.9), |
|
) |
|
|
|
config.lr_scheduler = d( |
|
name='customized', |
|
warmup_steps=5000 |
|
) |
|
|
|
global model |
|
config.nnet = d( |
|
name='dit', |
|
model_args=model, |
|
) |
|
config.loss_coeffs = [] |
|
|
|
config.dataset = d( |
|
name='JDB_demo_features', |
|
resolution=512, |
|
llm='t5', |
|
train_path='/data/qihao/dataset/JDB_demo_feature/', |
|
val_path='/data/qihao/dataset/coco_val_features/', |
|
cfg=False |
|
) |
|
|
|
config.sample = d( |
|
sample_steps=50, |
|
n_samples=30000, |
|
mini_batch_size=10, |
|
cfg=False, |
|
scale=7, |
|
path='' |
|
) |
|
|
|
return config |
|
|