Spaces:
Paused
Paused
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import dataclasses | |
from typing import Literal | |
from accelerate import Accelerator | |
from transformers import HfArgumentParser | |
from PIL import Image | |
import json | |
import openai | |
from uno.flux.pipeline import UNOPipeline, preprocess_ref | |
from uno.utils.prompt_enhancer import enhance_prompt_with_chatgpt | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
def horizontal_concat(images): | |
widths, heights = zip(*(img.size for img in images)) | |
total_width = sum(widths) | |
max_height = max(heights) | |
new_im = Image.new('RGB', (total_width, max_height)) | |
x_offset = 0 | |
for img in images: | |
new_im.paste(img, (x_offset, 0)) | |
x_offset += img.size[0] | |
return new_im | |
class InferenceArgs: | |
prompt: str | None = None | |
image_paths: list[str] | None = None | |
eval_json_path: str | None = None | |
offload: bool = False | |
num_images_per_prompt: int = 1 | |
model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" | |
width: int = 512 | |
height: int = 512 | |
ref_size: int = -1 | |
num_steps: int = 25 | |
guidance: float = 4 | |
seed: int = 3407 | |
save_path: str = "output/inference" | |
only_lora: bool = True | |
concat_refs: bool = False | |
lora_rank: int = 512 | |
data_resolution: int = 512 | |
pe: Literal['d', 'h', 'w', 'o'] = 'd' | |
def main(args: InferenceArgs): | |
accelerator = Accelerator() | |
pipeline = UNOPipeline( | |
args.model_type, | |
accelerator.device, | |
args.offload, | |
only_lora=args.only_lora, | |
lora_rank=args.lora_rank | |
) | |
assert args.prompt is not None or args.eval_json_path is not None, \ | |
"Please provide either prompt or eval_json_path" | |
if args.eval_json_path: | |
with open(args.eval_json_path, "rt") as f: | |
data_dicts = json.load(f) | |
data_root = os.path.dirname(args.eval_json_path) | |
else: | |
data_root = "./" | |
data_dicts = [{"prompt": args.prompt, "image_paths": args.image_paths}] | |
for i, data_dict in enumerate(data_dicts): | |
try: | |
ref_imgs = [ | |
Image.open(os.path.join(data_root, img_path)) | |
for img_path in data_dict["image_paths"] | |
] | |
except Exception as e: | |
print(f"β [ERROR] Failed to load reference images: {e}") | |
continue | |
if args.ref_size == -1: | |
args.ref_size = 512 if len(ref_imgs) == 1 else 320 | |
ref_imgs = [preprocess_ref(img, args.ref_size) for img in ref_imgs] | |
print(f"\nπ§ [DEBUG] Enhancing prompt: '{data_dict['prompt']}'") | |
enhanced_prompts = enhance_prompt_with_chatgpt( | |
user_prompt=data_dict["prompt"], | |
num_prompts=args.num_images_per_prompt, | |
reference_images=ref_imgs | |
) | |
# Pad if needed | |
while len(enhanced_prompts) < args.num_images_per_prompt: | |
print(f"β οΈ [DEBUG] Padding prompts: returning user prompt as fallback.") | |
enhanced_prompts.append(data_dict["prompt"]) | |
for j in range(args.num_images_per_prompt): | |
if (i * args.num_images_per_prompt + j) % accelerator.num_processes != accelerator.process_index: | |
continue | |
prompt_j = enhanced_prompts[j] | |
print(f"\n--- Generating image [{i}_{j}] ---") | |
print(f"Enhanced Prompt: {prompt_j}") | |
print(f"Image paths: {data_dict['image_paths']}") | |
print(f"Seed: {args.seed + j}") | |
print(f"Resolution: {args.width}x{args.height}") | |
print("------------------------------") | |
try: | |
image_gen = pipeline( | |
prompt=prompt_j, | |
width=args.width, | |
height=args.height, | |
guidance=args.guidance, | |
num_steps=args.num_steps, | |
seed=args.seed + j, | |
ref_imgs=ref_imgs, | |
pe=args.pe, | |
) | |
if args.concat_refs: | |
image_gen = horizontal_concat([image_gen, *ref_imgs]) | |
os.makedirs(args.save_path, exist_ok=True) | |
image_gen.save(os.path.join(args.save_path, f"{i}_{j}.png")) | |
# Save generation context | |
args_dict = vars(args) | |
args_dict['prompt'] = prompt_j | |
args_dict['image_paths'] = data_dict["image_paths"] | |
with open(os.path.join(args.save_path, f"{i}_{j}.json"), 'w') as f: | |
json.dump(args_dict, f, indent=4) | |
except Exception as e: | |
print(f"β [ERROR] Failed to generate or save image {i}_{j}: {e}") | |
if __name__ == "__main__": | |
parser = HfArgumentParser([InferenceArgs]) | |
args = parser.parse_args_into_dataclasses()[0] | |
main(args) | |