linoyts HF Staff commited on
Commit
ca154e9
·
verified ·
1 Parent(s): d13fe65

Update IP_Composer/perform_swap.py

Browse files
Files changed (1) hide show
  1. IP_Composer/perform_swap.py +40 -14
IP_Composer/perform_swap.py CHANGED
@@ -1,36 +1,62 @@
1
  import torch
2
  import numpy as np
 
 
 
3
 
4
- def compute_dataset_embeds_svd(all_embeds, rank):
5
-
 
 
6
  # Perform SVD on the combined matrix
7
- u, s, vh = np.linalg.svd(all_embeds, full_matrices=False)
8
 
9
  # Select the top `rank` singular vectors to construct the projection matrix
10
- vh = vh[:rank] # Top `rank` right singular vectors
11
- projection_matrix = vh.T @ vh # Shape: (feature_dim, feature_dim)
12
 
13
  return projection_matrix
14
 
15
- def get_embedding_composition(embed, projections_data):
16
- # Initialize the combined embedding with the input embed
 
 
 
 
 
 
 
 
17
  combined_embeds = embed.copy()
18
 
19
  for proj_data in projections_data:
20
-
21
- # Add the combined projection to the result
22
- combined_embeds -= embed @ proj_data["projection_matrix"]
23
- combined_embeds += proj_data["embed"] @ proj_data["projection_matrix"]
24
 
25
  return combined_embeds
26
 
27
 
28
- def get_modified_images_embeds_composition(embed, projections_data, ip_model, prompt=None, scale=1.0, num_samples=3, seed=420, num_inference_steps=50):
29
-
 
 
 
 
 
 
 
30
  final_embeds = get_embedding_composition(embed, projections_data)
31
  clip_embeds = torch.from_numpy(final_embeds)
32
 
33
- images = ip_model.generate(clip_image_embeds=clip_embeds, prompt=prompt, num_samples=num_samples, num_inference_steps=num_inference_steps, seed=seed, guidance_scale=7.5, scale=scale)
 
 
 
 
 
 
 
 
34
  return images
35
 
36
 
 
1
  import torch
2
  import numpy as np
3
+ from typing import List, Dict, Optional
4
+ from PIL.Image import Image as PILImage
5
+ from IP_Adapter import IPAdapterXL
6
 
7
+ def compute_dataset_embeds_svd(
8
+ all_embeds: np.ndarray,
9
+ rank: int
10
+ ) -> np.ndarray:
11
  # Perform SVD on the combined matrix
12
+ _, _, v = np.linalg.svd(all_embeds, full_matrices=False)
13
 
14
  # Select the top `rank` singular vectors to construct the projection matrix
15
+ v = v[:rank]
16
+ projection_matrix = v.T @ v
17
 
18
  return projection_matrix
19
 
20
+ def get_projected_embedding(
21
+ embed: np.ndarray,
22
+ projection_matrix: np.ndarray
23
+ ) -> np.ndarray:
24
+ return embed @ projection_matrix
25
+
26
+ def get_embedding_composition(
27
+ embed: np.ndarray,
28
+ projections_data: List[Dict[str, np.ndarray]]
29
+ ) -> np.ndarray:
30
  combined_embeds = embed.copy()
31
 
32
  for proj_data in projections_data:
33
+ combined_embeds -= get_projected_embedding(embed, proj_data["projection_matrix"])
34
+ combined_embeds += get_projected_embedding(proj_data["embed"], proj_data["projection_matrix"])
 
 
35
 
36
  return combined_embeds
37
 
38
 
39
+ def get_modified_images_embeds_composition(
40
+ embed: np.ndarray,
41
+ projections_data: List[Dict[str, np.ndarray]],
42
+ ip_model: IPAdapterXL,
43
+ prompt: Optional[str] = None,
44
+ scale: float = 1.0,
45
+ num_samples: int = 3,
46
+ seed: int = 420
47
+ ) -> List[PILImage]:
48
  final_embeds = get_embedding_composition(embed, projections_data)
49
  clip_embeds = torch.from_numpy(final_embeds)
50
 
51
+ images: List[PILImage] = ip_model.generate(
52
+ clip_image_embeds=clip_embeds,
53
+ prompt=prompt,
54
+ num_samples=num_samples,
55
+ num_inference_steps=50,
56
+ seed=seed,
57
+ guidance_scale=7.5,
58
+ scale=scale
59
+ )
60
  return images
61
 
62