multimodalart HF Staff commited on
Commit
bf89172
·
1 Parent(s): 9cd412c

Add more logic to clip embeds

Browse files
Files changed (1) hide show
  1. app.py +26 -2
app.py CHANGED
@@ -59,8 +59,32 @@ make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.)
59
  def run_all(prompt, steps, n_images, weight, clip_guided):
60
  import random
61
  seed = int(random.randint(0, 2147483647))
62
- target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda()
63
- clip_embed = target_embed.repeat([n_images, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def cfg_model_fn(x, t):
65
  """The CFG wrapper function."""
66
  n = x.shape[0]
 
59
  def run_all(prompt, steps, n_images, weight, clip_guided):
60
  import random
61
  seed = int(random.randint(0, 2147483647))
62
+ target_embed = clip_model.encode_text(clip.tokenize(prompt)).float()#.cuda()
63
+
64
+ if(clip_guided):
65
+ prompts = [prompt]
66
+ def parse_prompt(prompt):
67
+ if prompt.startswith('http://') or prompt.startswith('https://'):
68
+ vals = prompt.rsplit(':', 2)
69
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
70
+ else:
71
+ vals = prompt.rsplit(':', 1)
72
+ vals = vals + ['', '1'][len(vals):]
73
+ return vals[0], float(vals[1])
74
+
75
+ for prompt in prompts:
76
+ txt, weight = parse_prompt(prompt)
77
+ target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
78
+ weights.append(weight)
79
+
80
+ target_embeds = torch.cat(target_embeds)
81
+ weights = torch.tensor(weights, device=device)
82
+ if weights.sum().abs() < 1e-3:
83
+ raise RuntimeError('The weights must not sum to 0.')
84
+ weights /= weights.sum().abs()
85
+ clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
86
+ clip_embed = target_embed.repeat([n_images, 1])
87
+
88
  def cfg_model_fn(x, t):
89
  """The CFG wrapper function."""
90
  n = x.shape[0]