QHL067 commited on
Commit
7a820f6
·
1 Parent(s): 50929d3
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -27,7 +27,7 @@ from diffusion.flow_matching import ODEEulerFlowMatchingSolver
27
  import utils
28
  import libs.autoencoder
29
  from libs.clip import FrozenCLIPEmbedder
30
- from configs import t2i_512px_clip_dimr
31
 
32
 
33
  def unpreprocess(x: torch.Tensor) -> torch.Tensor:
@@ -93,7 +93,8 @@ def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
93
  return context, token_mask, tokens, captions
94
 
95
  # Load configuration and initialize models.
96
- config_dict = t2i_512px_clip_dimr.get_config()
 
97
  config = ml_collections.ConfigDict(config_dict)
98
 
99
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -108,7 +109,8 @@ MAX_IMAGE_SIZE = 1024 # Currently not used.
108
 
109
  # Load the main diffusion model.
110
  repo_id = "QHL067/CrossFlow"
111
- filename = "pretrained_models/t2i_512px_clip_dimr.pth"
 
112
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
113
  nnet = utils.get_nnet(**config.nnet)
114
  nnet = nnet.to(device)
 
27
  import utils
28
  import libs.autoencoder
29
  from libs.clip import FrozenCLIPEmbedder
30
+ from configs import t2i_512px_clip_dimr, t2i_256px_clip_dimr
31
 
32
 
33
  def unpreprocess(x: torch.Tensor) -> torch.Tensor:
 
93
  return context, token_mask, tokens, captions
94
 
95
  # Load configuration and initialize models.
96
+ # config_dict = t2i_512px_clip_dimr.get_config()
97
+ config_dict = t2i_256px_clip_dimr.get_config()
98
  config = ml_collections.ConfigDict(config_dict)
99
 
100
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
109
 
110
  # Load the main diffusion model.
111
  repo_id = "QHL067/CrossFlow"
112
+ # filename = "pretrained_models/t2i_512px_clip_dimr.pth"
113
+ filename = "pretrained_models/t2i_256px_clip_dimr.pth"
114
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
115
  nnet = utils.get_nnet(**config.nnet)
116
  nnet = nnet.to(device)