Spaces:
Runtime error
Runtime error
File size: 8,467 Bytes
bc5a411 077fc3d bc5a411 cff8aa8 ab9e9c4 bc5a411 077fc3d 51df617 3668992 077fc3d bc5a411 077fc3d 449a298 d6f9b71 bc5a411 612ce40 bc5a411 568d1c7 612ce40 568d1c7 eecb1f6 ab9e9c4 6144b88 3a72088 6144b88 077fc3d bc5a411 7af4a09 bf89172 6144b88 be31516 bf89172 5e6effb bf89172 918aa0f bf89172 918aa0f bf89172 bc5a411 6144b88 3a72088 9cd412c 3a72088 be31516 3a72088 bc5a411 e3d2366 3a72088 612ce40 e3d2366 bc5a411 077fc3d 26ca94f 077fc3d 2b1e8e5 3a72088 6144b88 077fc3d bc5a411 be31516 077fc3d aa344ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gc
import math
import sys
from IPython import display
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision import utils as tv_utils
from torchvision.transforms import functional as TF
import gradio as gr
from git.repo.base import Repo
from os.path import exists as path_exists
if not (path_exists(f"v-diffusion-pytorch")):
Repo.clone_from("https://github.com/crowsonkb/v-diffusion-pytorch", "v-diffusion-pytorch")
if not (path_exists(f"CLIP")):
Repo.clone_from("https://github.com/openai/CLIP", "CLIP")
sys.path.append('v-diffusion-pytorch')
from huggingface_hub import hf_hub_download
from CLIP import clip
from diffusion import get_model, sampling, utils
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.cut_pow = cut_pow
def forward(self, input):
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(self.cutn):
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)
cutouts.append(cutout)
return torch.cat(cutouts)
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
#cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
model = get_model('cc12m_1_cfg')()
_, side_y, side_x = model.shape
model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
model = model.half().cuda().eval().requires_grad_(False)
#model_small = get_model('cc12m_1')()
#model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
#model_small = model_small.half().cuda().eval().requires_grad_(False)
clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.)
gc.collect()
torch.cuda.empty_cache()
def run_all(prompt, steps, n_images, weight, clip_guided):
gc.collect()
torch.cuda.empty_cache()
import random
seed = int(random.randint(0, 2147483647))
target_embed = clip_model.encode_text(clip.tokenize(prompt).to('cuda')).float()#.cuda()
if(clip_guided):
n_images = 1
steps = steps*5
clip_guidance_scale = weight*100
prompts = [prompt]
target_embeds, weights = [], []
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
for prompt in prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to('cuda')).float())
weights.append(weight)
target_embeds = torch.cat(target_embeds)
weights = torch.tensor(weights, device='cuda')
if weights.sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
weights /= weights.sum().abs()
clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
clip_embed = target_embed.repeat([n_images, 1])
def cfg_model_fn(x, t):
"""The CFG wrapper function."""
n = x.shape[0]
x_in = x.repeat([2, 1, 1, 1])
t_in = t.repeat([2])
clip_embed_repeat = target_embed.repeat([n, 1])
clip_embed_in = torch.cat([torch.zeros_like(clip_embed_repeat), clip_embed_repeat])
v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
v = v_uncond + (v_cond - v_uncond) * weight
return v
def make_cond_model_fn(model, cond_fn):
def cond_model_fn(x, t, **extra_args):
with torch.enable_grad():
x = x.detach().requires_grad_()
v = model(x, t, **extra_args)
alphas, sigmas = utils.t_to_alpha_sigma(t)
pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
cond_grad = cond_fn(x, t, pred, **extra_args).detach()
v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
return v
return cond_model_fn
def cond_fn(x, t, pred, clip_embed):
if min(pred.shape[2:4]) < 256:
pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([16, x.shape[0], -1])
losses = spherical_dist_loss(image_embeds, clip_embed[None])
loss = losses.mean(0).sum() * clip_guidance_scale
grad = -torch.autograd.grad(loss, x)[0]
return grad
torch.manual_seed(seed)
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
if model.min_t == 0:
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
else:
step_list = utils.get_ddpm_schedule(t)
if(not clip_guided):
outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
else:
extra_args = {'clip_embed': clip_embed}
cond_fn_ = cond_fn
model_fn = make_cond_model_fn(model, cond_fn_)
outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
images_out = []
for i, out in enumerate(outs):
images_out.append(utils.to_pil_image(out))
return(images_out)
##################### START GRADIO HERE ############################
gallery = gr.outputs.Carousel(label="Individual images",components=["image"])
iface = gr.Interface(
fn=run_all,
inputs=[
gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="an eerie alien forest"),
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1),
gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1),
gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1),
gr.inputs.Checkbox(label="CLIP Guided - improves coherence with complex prompts, makes it slower (with CLIP Guidance only one image is generated)"),
],
outputs=gallery,
title="Generate images from text with V-Diffusion",
description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/crowsonkb/v-diffusion-pytorch' target='_blank'>V-Diffusion</a> is diffusion text-to-image model created by <a href='https://twitter.com/RiversHaveWings' target='_blank'>Katherine Crowson</a> and <a href='https://twitter.com/jd_pressman'>JDP</a>, trained on the <a href='https://github.com/google-research-datasets/conceptual-12m'>CC12M dataset</a>. The UI to the model was assembled by <a style='color: rgb(99, 102, 241);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a>, keep up with the <a style='color: rgb(99, 102, 241);' href='https://multimodal.art/news' target='_blank'>latest multimodal ai art news here</a> and consider <a style='color: rgb(99, 102, 241);' href='https://www.patreon.com/multimodalart' target='_blank'>supporting us on Patreon</a></div>",
)
iface.launch(enable_queue=True) |