wuwenxu.01 commited on
Commit
c62efeb
·
1 Parent(s): 9d31e57

fix: remove unused parameters

Browse files
app.py CHANGED
@@ -44,7 +44,6 @@ def get_examples(examples_dir: str = "assets/examples") -> list:
44
  example_list.append(None)
45
 
46
  example_list.append(example_dict["seed"])
47
- example_list.append(example_dict["ref_long_side"])
48
 
49
  ans.append(example_list)
50
  return ans
@@ -58,23 +57,27 @@ def create_demo(
58
  pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
59
  pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
60
 
 
 
 
 
 
 
 
 
 
 
61
  with gr.Blocks() as demo:
62
  gr.Markdown(f"# UNO by UNO team")
 
63
  with gr.Row():
64
  with gr.Column():
65
  prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
66
  with gr.Row():
67
- image_prompt1 = gr.Image(label="ref img1", visible=True, interactive=True, type="pil")
68
- image_prompt2 = gr.Image(label="ref img2", visible=True, interactive=True, type="pil")
69
- image_prompt3 = gr.Image(label="ref img3", visible=True, interactive=True, type="pil")
70
- image_prompt4 = gr.Image(label="ref img4", visible=True, interactive=True, type="pil")
71
-
72
- with gr.Row():
73
- with gr.Column():
74
- ref_long_side = gr.Slider(128, 512, 512, step=16, label="Long side of Ref Images")
75
- with gr.Column():
76
- gr.Markdown("📌 **The recommended ref scale** is related to the ref img number.\n")
77
- gr.Markdown(" 1->512 / 2,3,4->320")
78
 
79
  with gr.Row():
80
  with gr.Column():
@@ -87,7 +90,7 @@ def create_demo(
87
  " and the higher size gives a better visual effect but is less stable"
88
  )
89
 
90
- with gr.Accordion("Generation Options", open=False):
91
  with gr.Row():
92
  num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
93
  guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
@@ -102,7 +105,7 @@ def create_demo(
102
 
103
  inputs = [
104
  prompt, width, height, guidance, num_steps,
105
- seed, ref_long_side, image_prompt1, image_prompt2, image_prompt3, image_prompt4
106
  ]
107
  generate_btn.click(
108
  fn=pipeline.gradio_generate,
@@ -118,11 +121,10 @@ def create_demo(
118
  inputs=[
119
  example_text, prompt,
120
  image_prompt1, image_prompt2, image_prompt3, image_prompt4,
121
- seed, ref_long_side
122
  ],
123
  )
124
 
125
-
126
  return demo
127
 
128
  if __name__ == "__main__":
@@ -145,4 +147,4 @@ if __name__ == "__main__":
145
  args = args_tuple[0]
146
 
147
  demo = create_demo(args.name, args.device, args.offload)
148
- demo.launch(server_port=args.port)
 
44
  example_list.append(None)
45
 
46
  example_list.append(example_dict["seed"])
 
47
 
48
  ans.append(example_list)
49
  return ans
 
57
  pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
58
  pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
59
 
60
+
61
+ badges_text = r"""
62
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
63
+ <a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
64
+ <a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
65
+ <a href="https://huggingface.co/bytedance-research/UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
66
+ <a href="https://huggingface.co/spaces/bytedance-research/UNO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a>
67
+ </div>
68
+ """.strip()
69
+
70
  with gr.Blocks() as demo:
71
  gr.Markdown(f"# UNO by UNO team")
72
+ gr.Markdown(badges_text)
73
  with gr.Row():
74
  with gr.Column():
75
  prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
76
  with gr.Row():
77
+ image_prompt1 = gr.Image(label="Ref Img1", visible=True, interactive=True, type="pil")
78
+ image_prompt2 = gr.Image(label="Ref Img2", visible=True, interactive=True, type="pil")
79
+ image_prompt3 = gr.Image(label="Ref Img3", visible=True, interactive=True, type="pil")
80
+ image_prompt4 = gr.Image(label="Ref img4", visible=True, interactive=True, type="pil")
 
 
 
 
 
 
 
81
 
82
  with gr.Row():
83
  with gr.Column():
 
90
  " and the higher size gives a better visual effect but is less stable"
91
  )
92
 
93
+ with gr.Accordion("Advanced Options", open=False):
94
  with gr.Row():
95
  num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
96
  guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
 
105
 
106
  inputs = [
107
  prompt, width, height, guidance, num_steps,
108
+ seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4
109
  ]
110
  generate_btn.click(
111
  fn=pipeline.gradio_generate,
 
121
  inputs=[
122
  example_text, prompt,
123
  image_prompt1, image_prompt2, image_prompt3, image_prompt4,
124
+ seed, output_image
125
  ],
126
  )
127
 
 
128
  return demo
129
 
130
  if __name__ == "__main__":
 
147
  args = args_tuple[0]
148
 
149
  demo = create_demo(args.name, args.device, args.offload)
150
+ demo.launch(server_port=args.port)
assets/examples/3one2one/config.json DELETED
@@ -1,8 +0,0 @@
1
- {
2
- "prompt": "3d cartoon style, a woman.",
3
- "seed": 2,
4
- "ref_long_side": 512,
5
- "useage": "one2one",
6
- "image_ref1": "./ref1.png",
7
- "image_result": "./result.png"
8
- }
 
 
 
 
 
 
 
 
 
assets/examples/3one2one/ref1.png DELETED

Git LFS Details

  • SHA256: 434929ca5eeb1daf036bfff7c0d4297ccb7017967bd60141e0287c409203e0ae
  • Pointer size: 131 Bytes
  • Size of remote file: 574 kB
assets/examples/3one2one/result.png DELETED

Git LFS Details

  • SHA256: dc87fa4fa14fb69cb37abb65775525cec3bfb90f9d9c072ee5cbe5adaf4dd146
  • Pointer size: 131 Bytes
  • Size of remote file: 303 kB
assets/examples/{5two2one → 3two2one}/config.json RENAMED
@@ -1,6 +1,6 @@
1
  {
2
  "prompt": "The figurine is in the crystal ball",
3
- "seed": 1,
4
  "ref_long_side": 320,
5
  "useage": "two2one",
6
  "image_ref1": "./ref1.png",
 
1
  {
2
  "prompt": "The figurine is in the crystal ball",
3
+ "seed": 0,
4
  "ref_long_side": 320,
5
  "useage": "two2one",
6
  "image_ref1": "./ref1.png",
assets/examples/{5two2one → 3two2one}/ref1.png RENAMED
File without changes
assets/examples/{5two2one → 3two2one}/ref2.png RENAMED
File without changes
assets/examples/{5two2one → 3two2one}/result.png RENAMED
File without changes
assets/examples/4two2one/config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "prompt": "The logo is printed on the cup",
3
- "seed": 0,
4
  "ref_long_side": 320,
5
  "useage": "two2one",
6
  "image_ref1": "./ref1.png",
 
1
  {
2
  "prompt": "The logo is printed on the cup",
3
+ "seed": 61733557,
4
  "ref_long_side": 320,
5
  "useage": "two2one",
6
  "image_ref1": "./ref1.png",
assets/examples/{6many2one → 5many2one}/config.json RENAMED
File without changes
assets/examples/{6many2one → 5many2one}/ref1.png RENAMED
File without changes
assets/examples/{6many2one → 5many2one}/ref2.png RENAMED
File without changes
assets/examples/{6many2one → 5many2one}/ref3.png RENAMED
File without changes
assets/examples/{6many2one → 5many2one}/result.png RENAMED
File without changes
assets/examples/{7t2i → 6t2i}/config.json RENAMED
File without changes
assets/examples/{7t2i → 6t2i}/result.png RENAMED
File without changes
uno/flux/pipeline.py CHANGED
@@ -27,7 +27,7 @@ from uno.flux.modules.layers import (
27
  SingleStreamBlockLoraProcessor,
28
  SingleStreamBlockProcessor,
29
  )
30
- from uno.flux.sampling import denoise, get_noise, get_schedule, prepare, prepare_multi_ip, unpack
31
  from uno.flux.util import (
32
  get_lora_rank,
33
  load_ae,
@@ -185,10 +185,6 @@ class UNOPipeline:
185
  guidance: float = 4,
186
  num_steps: int = 50,
187
  seed: int = 123456789,
188
- true_gs: float = 3,
189
- neg_prompt: str = '',
190
- neg_image_prompt: Image = None,
191
- timestep_to_start_cfg: int = 0,
192
  **kwargs
193
  ):
194
  width = 16 * (width // 16)
@@ -201,9 +197,6 @@ class UNOPipeline:
201
  guidance,
202
  num_steps,
203
  seed,
204
- timestep_to_start_cfg=timestep_to_start_cfg,
205
- true_gs=true_gs,
206
- neg_prompt=neg_prompt,
207
  **kwargs
208
  )
209
 
@@ -216,7 +209,6 @@ class UNOPipeline:
216
  guidance: float,
217
  num_steps: int,
218
  seed: int,
219
- ref_long_side: int,
220
  image_prompt1: Image.Image,
221
  image_prompt2: Image.Image,
222
  image_prompt3: Image.Image,
@@ -224,6 +216,7 @@ class UNOPipeline:
224
  ):
225
  ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
226
  ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
 
227
  ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
228
 
229
  seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
@@ -250,9 +243,6 @@ class UNOPipeline:
250
  guidance: float,
251
  num_steps: int,
252
  seed: int,
253
- timestep_to_start_cfg: int = 1e5, # TODO 没用,删除
254
- true_gs: float = 3.5,
255
- neg_prompt: str = "",
256
  ref_imgs: list[Image.Image] | None = None,
257
  pe: Literal['d', 'h', 'w', 'o'] = 'd',
258
  ):
@@ -283,11 +273,6 @@ class UNOPipeline:
283
  img=x,
284
  prompt=prompt, ref_imgs=x_1_refs, pe=pe
285
  )
286
- neg_inp_cond = prepare_multi_ip(
287
- t5=self.t5, clip=self.clip,
288
- img=x,
289
- prompt=neg_prompt, ref_imgs=x_1_refs, pe=pe
290
- )
291
 
292
  if self.offload:
293
  self.offload_model_to_cpu(self.t5, self.clip)
@@ -298,11 +283,6 @@ class UNOPipeline:
298
  **inp_cond,
299
  timesteps=timesteps,
300
  guidance=guidance,
301
- timestep_to_start_cfg=timestep_to_start_cfg,
302
- neg_txt=neg_inp_cond['txt'],
303
- neg_txt_ids=neg_inp_cond['txt_ids'],
304
- neg_vec=neg_inp_cond['vec'],
305
- true_gs=true_gs,
306
  )
307
 
308
  if self.offload:
 
27
  SingleStreamBlockLoraProcessor,
28
  SingleStreamBlockProcessor,
29
  )
30
+ from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
31
  from uno.flux.util import (
32
  get_lora_rank,
33
  load_ae,
 
185
  guidance: float = 4,
186
  num_steps: int = 50,
187
  seed: int = 123456789,
 
 
 
 
188
  **kwargs
189
  ):
190
  width = 16 * (width // 16)
 
197
  guidance,
198
  num_steps,
199
  seed,
 
 
 
200
  **kwargs
201
  )
202
 
 
209
  guidance: float,
210
  num_steps: int,
211
  seed: int,
 
212
  image_prompt1: Image.Image,
213
  image_prompt2: Image.Image,
214
  image_prompt3: Image.Image,
 
216
  ):
217
  ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
218
  ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
219
+ ref_long_side = 512 if len(ref_imgs) <= 1 else 320
220
  ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
221
 
222
  seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
 
243
  guidance: float,
244
  num_steps: int,
245
  seed: int,
 
 
 
246
  ref_imgs: list[Image.Image] | None = None,
247
  pe: Literal['d', 'h', 'w', 'o'] = 'd',
248
  ):
 
273
  img=x,
274
  prompt=prompt, ref_imgs=x_1_refs, pe=pe
275
  )
 
 
 
 
 
276
 
277
  if self.offload:
278
  self.offload_model_to_cpu(self.t5, self.clip)
 
283
  **inp_cond,
284
  timesteps=timesteps,
285
  guidance=guidance,
 
 
 
 
 
286
  )
287
 
288
  if self.offload:
uno/flux/sampling.py CHANGED
@@ -215,14 +215,9 @@ def denoise(
215
  txt: Tensor,
216
  txt_ids: Tensor,
217
  vec: Tensor,
218
- neg_txt: Tensor,
219
- neg_txt_ids: Tensor,
220
- neg_vec: Tensor,
221
  # sampling parameters
222
  timesteps: list[float],
223
  guidance: float = 4.0,
224
- true_gs = 1,
225
- timestep_to_start_cfg=0,
226
  ref_img: Tensor=None,
227
  ref_img_ids: Tensor=None,
228
  ):
@@ -241,20 +236,6 @@ def denoise(
241
  timesteps=t_vec,
242
  guidance=guidance_vec
243
  )
244
- if i >= timestep_to_start_cfg:
245
- # not test
246
- neg_pred = model(
247
- img=img,
248
- img_ids=img_ids,
249
- ref_img=ref_img, # TODO: neg img embedding
250
- ref_img_ids=ref_img_ids,
251
- txt=neg_txt,
252
- txt_ids=neg_txt_ids,
253
- y=neg_vec,
254
- timesteps=t_vec,
255
- guidance=guidance_vec,
256
- )
257
- pred = neg_pred + true_gs * (pred - neg_pred)
258
  img = img + (t_prev - t_curr) * pred
259
  i += 1
260
  return img
 
215
  txt: Tensor,
216
  txt_ids: Tensor,
217
  vec: Tensor,
 
 
 
218
  # sampling parameters
219
  timesteps: list[float],
220
  guidance: float = 4.0,
 
 
221
  ref_img: Tensor=None,
222
  ref_img_ids: Tensor=None,
223
  ):
 
236
  timesteps=t_vec,
237
  guidance=guidance_vec
238
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  img = img + (t_prev - t_curr) * pred
240
  i += 1
241
  return img
uno/flux/util.py CHANGED
@@ -271,7 +271,11 @@ def load_flow_model_only_lora(
271
  ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
272
 
273
  if hf_download:
274
- lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
 
 
 
 
275
  else:
276
  lora_ckpt_path = os.environ.get("LORA", None)
277
 
@@ -362,10 +366,12 @@ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf
362
 
363
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
364
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
365
- return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
 
366
 
367
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
368
- return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
 
369
 
370
 
371
  def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
 
271
  ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
272
 
273
  if hf_download:
274
+ # lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
275
+ try:
276
+ lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
277
+ except:
278
+ lora_ckpt_path = os.environ.get("LORA", None)
279
  else:
280
  lora_ckpt_path = os.environ.get("LORA", None)
281
 
 
366
 
367
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
368
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
369
+ version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
370
+ return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
371
 
372
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
373
+ version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
374
+ return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
375
 
376
 
377
  def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: