multimodalart HF Staff commited on
Commit
449a298
·
1 Parent(s): ab9e9c4

Cutouts function

Browse files
Files changed (1) hide show
  1. app.py +21 -0
app.py CHANGED
@@ -23,6 +23,27 @@ from huggingface_hub import hf_hub_download
23
  from CLIP import clip
24
  from diffusion import get_model, sampling, utils
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
27
  model = get_model('cc12m_1_cfg')()
28
  _, side_y, side_x = model.shape
 
23
  from CLIP import clip
24
  from diffusion import get_model, sampling, utils
25
 
26
+ class MakeCutouts(nn.Module):
27
+ def __init__(self, cut_size, cutn, cut_pow=1.):
28
+ super().__init__()
29
+ self.cut_size = cut_size
30
+ self.cutn = cutn
31
+ self.cut_pow = cut_pow
32
+
33
+ def forward(self, input):
34
+ sideY, sideX = input.shape[2:4]
35
+ max_size = min(sideX, sideY)
36
+ min_size = min(sideX, sideY, self.cut_size)
37
+ cutouts = []
38
+ for _ in range(self.cutn):
39
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
40
+ offsetx = torch.randint(0, sideX - size + 1, ())
41
+ offsety = torch.randint(0, sideY - size + 1, ())
42
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
43
+ cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)
44
+ cutouts.append(cutout)
45
+ return torch.cat(cutouts)
46
+
47
  cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
48
  model = get_model('cc12m_1_cfg')()
49
  _, side_y, side_x = model.shape