Spaces:
Runtime error
Runtime error
Commit
·
d6f9b71
1
Parent(s):
7af4a09
Add spherical dist loss
Browse files
app.py
CHANGED
@@ -46,6 +46,11 @@ class MakeCutouts(nn.Module):
|
|
46 |
cutouts.append(cutout)
|
47 |
return torch.cat(cutouts)
|
48 |
|
|
|
|
|
|
|
|
|
|
|
49 |
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
|
50 |
model = get_model('cc12m_1_cfg')()
|
51 |
_, side_y, side_x = model.shape
|
|
|
46 |
cutouts.append(cutout)
|
47 |
return torch.cat(cutouts)
|
48 |
|
49 |
+
def spherical_dist_loss(x, y):
|
50 |
+
x = F.normalize(x, dim=-1)
|
51 |
+
y = F.normalize(y, dim=-1)
|
52 |
+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
53 |
+
|
54 |
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
|
55 |
model = get_model('cc12m_1_cfg')()
|
56 |
_, side_y, side_x = model.shape
|