multimodalart HF Staff commited on
Commit
3a72088
·
1 Parent(s): 1473645

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -29,11 +29,11 @@ model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
29
  model = model.half().cuda().eval().requires_grad_(False)
30
  clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0]
31
 
32
- def run_all(prompt, steps, n_images, weight):
33
  import random
34
  seed = int(random.randint(0, 2147483647))
35
  target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda()
36
-
37
  def cfg_model_fn(x, t):
38
  """The CFG wrapper function."""
39
  n = x.shape[0]
@@ -44,14 +44,41 @@ def run_all(prompt, steps, n_images, weight):
44
  v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
45
  v = v_uncond + (v_cond - v_uncond) * weight
46
  return v
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  gc.collect()
49
  torch.cuda.empty_cache()
50
  torch.manual_seed(seed)
51
  x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
52
  t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
53
  step_list = utils.get_spliced_ddpm_cosine_schedule(t)
54
- outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
 
 
 
 
 
 
55
  images_out = []
56
  for i, out in enumerate(outs):
57
  images_out.append(utils.to_pil_image(out))
@@ -65,15 +92,10 @@ iface = gr.Interface(
65
  fn=run_all,
66
  inputs=[
67
  gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
68
- gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=50,maximum=250,minimum=1,step=1),
69
- gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1,step=1),
70
- gr.inputs.Slider(label="Weight", default=5, maximum=15, minimum=0, step=1),
71
- #gr.inputs.Checkbox(label="CLIP Guided"),
72
- #gr.inputs.Dropdown(label="Flavor",choices=["ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu", "custom"]),
73
- #markdown,
74
- #gr.inputs.Dropdown(label="Style",choices=["Default","Balanced","Detailed","Consistent Creativity","Realistic","Smooth","Subtle MSE","Hyper Fast Results"],default="Hyper Fast Results"),
75
- #gr.inputs.Radio(label="Width", choices=[32,64,128,256,512],default=512),
76
- #gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=512),
77
  ],
78
  outputs=gallery,
79
  title="Generate images from text with V-Diffusion CC12M CFG",
 
29
  model = model.half().cuda().eval().requires_grad_(False)
30
  clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0]
31
 
32
+ def run_all(prompt, steps, n_images, weight, clip_guided):
33
  import random
34
  seed = int(random.randint(0, 2147483647))
35
  target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda()
36
+ clip_embed = target_embed.repeat([n, 1])
37
  def cfg_model_fn(x, t):
38
  """The CFG wrapper function."""
39
  n = x.shape[0]
 
44
  v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
45
  v = v_uncond + (v_cond - v_uncond) * weight
46
  return v
47
+
48
+ def make_cond_model_fn(model, cond_fn):
49
+ def cond_model_fn(x, t, **extra_args):
50
+ with torch.enable_grad():
51
+ x = x.detach().requires_grad_()
52
+ v = model(x, t, **extra_args)
53
+ alphas, sigmas = utils.t_to_alpha_sigma(t)
54
+ pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
55
+ cond_grad = cond_fn(x, t, pred, **extra_args).detach()
56
+ v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
57
+ return v
58
+ return cond_model_fn
59
+ def cond_fn(x, t, pred, clip_embed):
60
+ if min(pred.shape[2:4]) < 256:
61
+ pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
62
+ clip_in = normalize(make_cutouts((pred + 1) / 2))
63
+ image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
64
+ losses = spherical_dist_loss(image_embeds, clip_embed[None])
65
+ loss = losses.mean(0).sum() * args.clip_guidance_scale
66
+ grad = -torch.autograd.grad(loss, x)[0]
67
+ return grad
68
+
69
  gc.collect()
70
  torch.cuda.empty_cache()
71
  torch.manual_seed(seed)
72
  x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
73
  t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
74
  step_list = utils.get_spliced_ddpm_cosine_schedule(t)
75
+ if(not clip_guided):
76
+ outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
77
+ else:
78
+ extra_args = {'clip_embed': clip_embed}
79
+ cond_fn_ = cond_fn
80
+ model_fn = make_cond_model_fn(model, cond_fn_)
81
+ outs = sampling.plms_sample(model_fn, x, steps, extra_args)
82
  images_out = []
83
  for i, out in enumerate(outs):
84
  images_out.append(utils.to_pil_image(out))
 
92
  fn=run_all,
93
  inputs=[
94
  gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
95
+ gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1),
96
+ gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1),
97
+ gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1),
98
+ gr.inputs.Checkbox(label="CLIP Guided - improves coherence with prompt, makes it slower"),
 
 
 
 
 
99
  ],
100
  outputs=gallery,
101
  title="Generate images from text with V-Diffusion CC12M CFG",