256
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ from diffusion.flow_matching import ODEEulerFlowMatchingSolver
|
|
27 |
import utils
|
28 |
import libs.autoencoder
|
29 |
from libs.clip import FrozenCLIPEmbedder
|
30 |
-
from configs import t2i_512px_clip_dimr
|
31 |
|
32 |
|
33 |
def unpreprocess(x: torch.Tensor) -> torch.Tensor:
|
@@ -93,7 +93,8 @@ def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
|
|
93 |
return context, token_mask, tokens, captions
|
94 |
|
95 |
# Load configuration and initialize models.
|
96 |
-
config_dict = t2i_512px_clip_dimr.get_config()
|
|
|
97 |
config = ml_collections.ConfigDict(config_dict)
|
98 |
|
99 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -108,7 +109,8 @@ MAX_IMAGE_SIZE = 1024 # Currently not used.
|
|
108 |
|
109 |
# Load the main diffusion model.
|
110 |
repo_id = "QHL067/CrossFlow"
|
111 |
-
filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
|
|
112 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
113 |
nnet = utils.get_nnet(**config.nnet)
|
114 |
nnet = nnet.to(device)
|
|
|
27 |
import utils
|
28 |
import libs.autoencoder
|
29 |
from libs.clip import FrozenCLIPEmbedder
|
30 |
+
from configs import t2i_512px_clip_dimr, t2i_256px_clip_dimr
|
31 |
|
32 |
|
33 |
def unpreprocess(x: torch.Tensor) -> torch.Tensor:
|
|
|
93 |
return context, token_mask, tokens, captions
|
94 |
|
95 |
# Load configuration and initialize models.
|
96 |
+
# config_dict = t2i_512px_clip_dimr.get_config()
|
97 |
+
config_dict = t2i_256px_clip_dimr.get_config()
|
98 |
config = ml_collections.ConfigDict(config_dict)
|
99 |
|
100 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
109 |
|
110 |
# Load the main diffusion model.
|
111 |
repo_id = "QHL067/CrossFlow"
|
112 |
+
# filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
113 |
+
filename = "pretrained_models/t2i_256px_clip_dimr.pth"
|
114 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
115 |
nnet = utils.get_nnet(**config.nnet)
|
116 |
nnet = nnet.to(device)
|