import gc import math import sys from IPython import display import torch from torchvision import utils as tv_utils from torchvision.transforms import functional as TF import gradio as gr from git.repo.base import Repo from os.path import exists as path_exists if not (path_exists(f"v-diffusion-pytorch")): Repo.clone_from("https://github.com/crowsonkb/v-diffusion-pytorch", "v-diffusion-pytorch") if not (path_exists(f"CLIP")): Repo.clone_from("https://github.com/openai/CLIP", "CLIP") sys.path.append('v-diffusion-pytorch') from huggingface_hub import hf_hub_download from CLIP import clip from diffusion import get_model, sampling, utils cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth") model = get_model('cc12m_1_cfg')() _, side_y, side_x = model.shape model.load_state_dict(torch.load(cc12m_model, map_location='cpu')) model = model.half().cuda().eval().requires_grad_(False) clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0] def run_all(prompt, steps, n_images, weight, clip_guided): import random seed = int(random.randint(0, 2147483647)) target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda() clip_embed = target_embed.repeat([n, 1]) def cfg_model_fn(x, t): """The CFG wrapper function.""" n = x.shape[0] x_in = x.repeat([2, 1, 1, 1]) t_in = t.repeat([2]) clip_embed_repeat = target_embed.repeat([n, 1]) clip_embed_in = torch.cat([torch.zeros_like(clip_embed_repeat), clip_embed_repeat]) v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0) v = v_uncond + (v_cond - v_uncond) * weight return v def make_cond_model_fn(model, cond_fn): def cond_model_fn(x, t, **extra_args): with torch.enable_grad(): x = x.detach().requires_grad_() v = model(x, t, **extra_args) alphas, sigmas = utils.t_to_alpha_sigma(t) pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None] cond_grad = cond_fn(x, t, pred, **extra_args).detach() v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None]) return v return cond_model_fn def cond_fn(x, t, pred, clip_embed): if min(pred.shape[2:4]) < 256: pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False) clip_in = normalize(make_cutouts((pred + 1) / 2)) image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1]) losses = spherical_dist_loss(image_embeds, clip_embed[None]) loss = losses.mean(0).sum() * args.clip_guidance_scale grad = -torch.autograd.grad(loss, x)[0] return grad gc.collect() torch.cuda.empty_cache() torch.manual_seed(seed) x = torch.randn([n_images, 3, side_y, side_x], device='cuda') t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1] step_list = utils.get_spliced_ddpm_cosine_schedule(t) if(not clip_guided): outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback) else: extra_args = {'clip_embed': clip_embed} cond_fn_ = cond_fn model_fn = make_cond_model_fn(model, cond_fn_) outs = sampling.plms_sample(model_fn, x, steps, extra_args) images_out = [] for i, out in enumerate(outs): images_out.append(utils.to_pil_image(out)) return(images_out) ##################### START GRADIO HERE ############################ #image = gr.outputs.Image(type="pil", label="Your result") gallery = gr.Gallery(css={"height": "256px","width":"256px"}) iface = gr.Interface( fn=run_all, inputs=[ gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"), gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1), gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1), gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1), gr.inputs.Checkbox(label="CLIP Guided - improves coherence with prompt, makes it slower"), ], outputs=gallery, title="Generate images from text with V-Diffusion CC12M CFG", description="