Spaces:
Runtime error
Runtime error
Commit
·
3a72088
1
Parent(s):
1473645
Update app.py
Browse files
app.py
CHANGED
@@ -29,11 +29,11 @@ model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
|
|
29 |
model = model.half().cuda().eval().requires_grad_(False)
|
30 |
clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0]
|
31 |
|
32 |
-
def run_all(prompt, steps, n_images, weight):
|
33 |
import random
|
34 |
seed = int(random.randint(0, 2147483647))
|
35 |
target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda()
|
36 |
-
|
37 |
def cfg_model_fn(x, t):
|
38 |
"""The CFG wrapper function."""
|
39 |
n = x.shape[0]
|
@@ -44,14 +44,41 @@ def run_all(prompt, steps, n_images, weight):
|
|
44 |
v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
|
45 |
v = v_uncond + (v_cond - v_uncond) * weight
|
46 |
return v
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
gc.collect()
|
49 |
torch.cuda.empty_cache()
|
50 |
torch.manual_seed(seed)
|
51 |
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
|
52 |
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
|
53 |
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
images_out = []
|
56 |
for i, out in enumerate(outs):
|
57 |
images_out.append(utils.to_pil_image(out))
|
@@ -65,15 +92,10 @@ iface = gr.Interface(
|
|
65 |
fn=run_all,
|
66 |
inputs=[
|
67 |
gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
|
68 |
-
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=
|
69 |
-
gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1,step=1),
|
70 |
-
gr.inputs.Slider(label="Weight", default=5, maximum=15, minimum=0, step=1),
|
71 |
-
|
72 |
-
#gr.inputs.Dropdown(label="Flavor",choices=["ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu", "custom"]),
|
73 |
-
#markdown,
|
74 |
-
#gr.inputs.Dropdown(label="Style",choices=["Default","Balanced","Detailed","Consistent Creativity","Realistic","Smooth","Subtle MSE","Hyper Fast Results"],default="Hyper Fast Results"),
|
75 |
-
#gr.inputs.Radio(label="Width", choices=[32,64,128,256,512],default=512),
|
76 |
-
#gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=512),
|
77 |
],
|
78 |
outputs=gallery,
|
79 |
title="Generate images from text with V-Diffusion CC12M CFG",
|
|
|
29 |
model = model.half().cuda().eval().requires_grad_(False)
|
30 |
clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0]
|
31 |
|
32 |
+
def run_all(prompt, steps, n_images, weight, clip_guided):
|
33 |
import random
|
34 |
seed = int(random.randint(0, 2147483647))
|
35 |
target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda()
|
36 |
+
clip_embed = target_embed.repeat([n, 1])
|
37 |
def cfg_model_fn(x, t):
|
38 |
"""The CFG wrapper function."""
|
39 |
n = x.shape[0]
|
|
|
44 |
v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
|
45 |
v = v_uncond + (v_cond - v_uncond) * weight
|
46 |
return v
|
47 |
+
|
48 |
+
def make_cond_model_fn(model, cond_fn):
|
49 |
+
def cond_model_fn(x, t, **extra_args):
|
50 |
+
with torch.enable_grad():
|
51 |
+
x = x.detach().requires_grad_()
|
52 |
+
v = model(x, t, **extra_args)
|
53 |
+
alphas, sigmas = utils.t_to_alpha_sigma(t)
|
54 |
+
pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
|
55 |
+
cond_grad = cond_fn(x, t, pred, **extra_args).detach()
|
56 |
+
v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
|
57 |
+
return v
|
58 |
+
return cond_model_fn
|
59 |
+
def cond_fn(x, t, pred, clip_embed):
|
60 |
+
if min(pred.shape[2:4]) < 256:
|
61 |
+
pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
|
62 |
+
clip_in = normalize(make_cutouts((pred + 1) / 2))
|
63 |
+
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
|
64 |
+
losses = spherical_dist_loss(image_embeds, clip_embed[None])
|
65 |
+
loss = losses.mean(0).sum() * args.clip_guidance_scale
|
66 |
+
grad = -torch.autograd.grad(loss, x)[0]
|
67 |
+
return grad
|
68 |
+
|
69 |
gc.collect()
|
70 |
torch.cuda.empty_cache()
|
71 |
torch.manual_seed(seed)
|
72 |
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
|
73 |
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
|
74 |
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
|
75 |
+
if(not clip_guided):
|
76 |
+
outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
|
77 |
+
else:
|
78 |
+
extra_args = {'clip_embed': clip_embed}
|
79 |
+
cond_fn_ = cond_fn
|
80 |
+
model_fn = make_cond_model_fn(model, cond_fn_)
|
81 |
+
outs = sampling.plms_sample(model_fn, x, steps, extra_args)
|
82 |
images_out = []
|
83 |
for i, out in enumerate(outs):
|
84 |
images_out.append(utils.to_pil_image(out))
|
|
|
92 |
fn=run_all,
|
93 |
inputs=[
|
94 |
gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
|
95 |
+
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1),
|
96 |
+
gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1),
|
97 |
+
gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1),
|
98 |
+
gr.inputs.Checkbox(label="CLIP Guided - improves coherence with prompt, makes it slower"),
|
|
|
|
|
|
|
|
|
|
|
99 |
],
|
100 |
outputs=gallery,
|
101 |
title="Generate images from text with V-Diffusion CC12M CFG",
|