|
import argparse |
|
import torch |
|
from omegaconf import OmegaConf |
|
|
|
from ldm.models.diffusion.ddim import DDIMSampler |
|
from ldm.models.diffusion.plms import PLMSSampler |
|
from ldm.modules.encoders.adapter import Adapter, StyleAdapter, Adapter_light |
|
from ldm.modules.extra_condition.api import ExtraCondition |
|
from ldm.util import fix_cond_shapes, load_model_from_config, read_state_dict |
|
|
|
DEFAULT_NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \ |
|
'fewer digits, cropped, worst quality, low quality' |
|
|
|
|
|
def get_base_argument_parser() -> argparse.ArgumentParser: |
|
"""get the base argument parser for inference scripts""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--outdir', |
|
type=str, |
|
help='dir to write results to', |
|
default=None, |
|
) |
|
|
|
parser.add_argument( |
|
'--prompt', |
|
type=str, |
|
nargs='?', |
|
default=None, |
|
help='positive prompt', |
|
) |
|
|
|
parser.add_argument( |
|
'--neg_prompt', |
|
type=str, |
|
default=DEFAULT_NEGATIVE_PROMPT, |
|
help='negative prompt', |
|
) |
|
|
|
parser.add_argument( |
|
'--cond_path', |
|
type=str, |
|
default=None, |
|
help='condition image path', |
|
) |
|
|
|
parser.add_argument( |
|
'--cond_inp_type', |
|
type=str, |
|
default='image', |
|
help='the type of the input condition image, take depth T2I as example, the input can be raw image, ' |
|
'which depth will be calculated, or the input can be a directly a depth map image', |
|
) |
|
|
|
parser.add_argument( |
|
'--sampler', |
|
type=str, |
|
default='ddim', |
|
choices=['ddim', 'plms'], |
|
help='sampling algorithm, currently, only ddim and plms are supported, more are on the way', |
|
) |
|
|
|
parser.add_argument( |
|
'--steps', |
|
type=int, |
|
default=50, |
|
help='number of sampling steps', |
|
) |
|
|
|
parser.add_argument( |
|
'--sd_ckpt', |
|
type=str, |
|
default='models/sd-v1-4.ckpt', |
|
help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported', |
|
) |
|
|
|
parser.add_argument( |
|
'--vae_ckpt', |
|
type=str, |
|
default=None, |
|
help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded', |
|
) |
|
|
|
parser.add_argument( |
|
'--adapter_ckpt', |
|
type=str, |
|
default=None, |
|
help='path to checkpoint of adapter', |
|
) |
|
|
|
parser.add_argument( |
|
'--config', |
|
type=str, |
|
default='configs/stable-diffusion/sd-v1-inference.yaml', |
|
help='path to config which constructs SD model', |
|
) |
|
|
|
parser.add_argument( |
|
'--max_resolution', |
|
type=float, |
|
default=512 * 512, |
|
help='max image height * width, only for computer with limited vram', |
|
) |
|
|
|
parser.add_argument( |
|
'--resize_short_edge', |
|
type=int, |
|
default=None, |
|
help='resize short edge of the input image, if this arg is set, max_resolution will not be used', |
|
) |
|
|
|
parser.add_argument( |
|
'--C', |
|
type=int, |
|
default=4, |
|
help='latent channels', |
|
) |
|
|
|
parser.add_argument( |
|
'--f', |
|
type=int, |
|
default=8, |
|
help='downsampling factor', |
|
) |
|
|
|
parser.add_argument( |
|
'--scale', |
|
type=float, |
|
default=7.5, |
|
help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))', |
|
) |
|
|
|
parser.add_argument( |
|
'--cond_tau', |
|
type=float, |
|
default=1.0, |
|
help='timestamp parameter that determines until which step the adapter is applied, ' |
|
'similar as Prompt-to-Prompt tau', |
|
) |
|
|
|
parser.add_argument( |
|
'--style_cond_tau', |
|
type=float, |
|
default=1.0, |
|
help='timestamp parameter that determines until which step the adapter is applied, ' |
|
'similar as Prompt-to-Prompt tau', |
|
) |
|
|
|
parser.add_argument( |
|
'--cond_weight', |
|
type=float, |
|
default=1.0, |
|
help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned ' |
|
'the generated image and condition will be, but the generated quality may be reduced', |
|
) |
|
|
|
parser.add_argument( |
|
'--seed', |
|
type=int, |
|
default=42, |
|
) |
|
|
|
parser.add_argument( |
|
'--n_samples', |
|
type=int, |
|
default=4, |
|
help='# of samples to generate', |
|
) |
|
|
|
return parser |
|
|
|
|
|
def get_sd_models(opt): |
|
""" |
|
build stable diffusion model, sampler |
|
""" |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
config = OmegaConf.load(f"{opt.config}") |
|
model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt) |
|
model = model.half() |
|
sd_model = model.to(opt.device) |
|
|
|
|
|
if opt.sampler == 'plms': |
|
sampler = PLMSSampler(model) |
|
elif opt.sampler == 'ddim': |
|
sampler = DDIMSampler(model) |
|
else: |
|
raise NotImplementedError |
|
|
|
return sd_model, sampler |
|
|
|
|
|
def get_t2i_adapter_models(opt): |
|
config = OmegaConf.load(f"{opt.config}") |
|
model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt) |
|
adapter_ckpt_path = getattr(opt, f'{opt.which_cond}_adapter_ckpt', None) |
|
if adapter_ckpt_path is None: |
|
adapter_ckpt_path = getattr(opt, 'adapter_ckpt') |
|
adapter_ckpt = read_state_dict(adapter_ckpt_path) |
|
new_state_dict = {} |
|
for k, v in adapter_ckpt.items(): |
|
if not k.startswith('adapter.'): |
|
new_state_dict[f'adapter.{k}'] = v |
|
else: |
|
new_state_dict[k] = v |
|
m, u = model.load_state_dict(new_state_dict, strict=False) |
|
if len(u) > 0: |
|
print(f"unexpected keys in loading adapter ckpt {adapter_ckpt_path}:") |
|
print(u) |
|
|
|
model = model.to(opt.device) |
|
|
|
|
|
if opt.sampler == 'plms': |
|
sampler = PLMSSampler(model) |
|
elif opt.sampler == 'ddim': |
|
sampler = DDIMSampler(model) |
|
else: |
|
raise NotImplementedError |
|
|
|
return model, sampler |
|
|
|
|
|
def get_cond_ch(cond_type: ExtraCondition): |
|
if cond_type == ExtraCondition.sketch or cond_type == ExtraCondition.canny: |
|
return 1 |
|
return 3 |
|
|
|
|
|
def get_adapters(opt, cond_type: ExtraCondition): |
|
adapter = {} |
|
cond_weight = getattr(opt, f'{cond_type.name}_weight', None) |
|
if cond_weight is None: |
|
cond_weight = getattr(opt, 'cond_weight') |
|
adapter['cond_weight'] = cond_weight |
|
|
|
if cond_type == ExtraCondition.style: |
|
adapter['model'] = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(opt.device) |
|
elif cond_type == ExtraCondition.color: |
|
adapter['model'] = Adapter_light( |
|
cin=64 * get_cond_ch(cond_type), |
|
channels=[320, 640, 1280, 1280], |
|
nums_rb=4).to(opt.device) |
|
else: |
|
adapter['model'] = Adapter( |
|
cin=64 * get_cond_ch(cond_type), |
|
channels=[320, 640, 1280, 1280][:4], |
|
nums_rb=2, |
|
ksize=1, |
|
sk=True, |
|
use_conv=False).to(opt.device) |
|
ckpt_path = getattr(opt, f'{cond_type.name}_adapter_ckpt', None) |
|
if ckpt_path is None: |
|
ckpt_path = getattr(opt, 'adapter_ckpt') |
|
adapter['model'].load_state_dict(torch.load(ckpt_path)) |
|
|
|
return adapter |
|
|
|
|
|
def diffusion_inference(opt, model, sampler, adapter_features, append_to_context=None): |
|
|
|
c = model.get_learned_conditioning([opt.prompt]) |
|
if opt.scale != 1.0: |
|
uc = model.get_learned_conditioning([opt.neg_prompt]) |
|
else: |
|
uc = None |
|
c, uc = fix_cond_shapes(model, c, uc) |
|
|
|
if not hasattr(opt, 'H'): |
|
opt.H = 512 |
|
opt.W = 512 |
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f] |
|
|
|
samples_latents, _ = sampler.sample( |
|
S=opt.steps, |
|
conditioning=c, |
|
batch_size=1, |
|
shape=shape, |
|
verbose=False, |
|
unconditional_guidance_scale=opt.scale, |
|
unconditional_conditioning=uc, |
|
x_T=None, |
|
features_adapter=adapter_features, |
|
append_to_context=append_to_context, |
|
cond_tau=opt.cond_tau, |
|
style_cond_tau=opt.style_cond_tau, |
|
) |
|
|
|
x_samples = model.decode_first_stage(samples_latents) |
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
|
return x_samples |
|
|