uno-final / app.py
Manireddy1508's picture
Update app.py
970da07 verified
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# app.py
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
from pathlib import Path
import gradio as gr
import torch
import openai
import os
from uno.flux.pipeline import UNOPipeline
from uno.utils.prompt_enhancer import enhance_prompt_with_chatgpt
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
openai.api_key = os.getenv("OPENAI_API_KEY")
from huggingface_hub import login
login(token=os.getenv("HUGGINGFACE_TOKEN"))
def get_examples(examples_dir: str = "assets/examples") -> list:
examples = Path(examples_dir)
ans = []
for example in examples.iterdir():
if not example.is_dir():
continue
with open(example / "config.json") as f:
example_dict = json.load(f)
example_list = [example_dict["useage"], example_dict["prompt"]]
for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]:
example_list.append(str(example / example_dict[key]) if key in example_dict else None)
example_list.append(example_dict["seed"])
ans.append(example_list)
return ans
def create_demo(model_type: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
with gr.Blocks() as demo:
gr.Markdown("# UNO by UNO team")
gr.Markdown(
"""
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://github.com/bytedance/UNO"><img alt="Build" src="https://img.shields.io/github/stars/bytedance/UNO"></a>
<a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
<a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
<a href="https://huggingface.co./bytedance-research/UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
<a href="https://huggingface.co./spaces/bytedance-research/UNO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a>
</div>
"""
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
with gr.Row():
image_prompt1 = gr.Image(label="Ref Img1", type="pil")
image_prompt2 = gr.Image(label="Ref Img2", type="pil")
image_prompt3 = gr.Image(label="Ref Img3", type="pil")
image_prompt4 = gr.Image(label="Ref Img4", type="pil")
with gr.Row():
with gr.Column():
width = gr.Slider(512, 2048, 512, step=16, label="Generation Width")
height = gr.Slider(512, 2048, 512, step=16, label="Generation Height")
with gr.Column():
gr.Markdown("πŸ“Œ Trained on 512x512. Larger size = better quality, but less stable.")
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance")
seed = gr.Number(-1, label="Seed (-1 for random)")
num_outputs = gr.Slider(1, 5, 5, step=1, label="Number of Enhanced Prompts / Images")
generate_btn = gr.Button("Generate Enhanced Images")
with gr.Column():
outputs = []
for i in range(5):
outputs.append(gr.Image(label=f"Image {i+1}"))
outputs.append(gr.Textbox(label=f"Enhanced Prompt {i+1}"))
def run_generation(prompt, width, height, guidance, num_steps, seed,
img1, img2, img3, img4, num_outputs):
uploaded_images = [img for img in [img1, img2, img3, img4] if img is not None]
print(f"\nπŸ“₯ [DEBUG] User prompt: {prompt}")
prompts = enhance_prompt_with_chatgpt(
user_prompt=prompt,
num_prompts=num_outputs,
reference_images=uploaded_images
)
print(f"\n🧠 [DEBUG] Final Prompt List (len={len(prompts)}):")
for idx, p in enumerate(prompts):
print(f" [{idx+1}] {p}")
while len(prompts) < num_outputs:
prompts.append(prompt)
results = []
for i in range(num_outputs):
try:
seed_val = int(seed) if seed != -1 else torch.randint(0, 10**8, (1,)).item()
print(f"πŸ§ͺ [DEBUG] Using seed: {seed_val} for image {i+1}")
gen_image, _ = pipeline.gradio_generate(
prompt=prompts[i],
width=width,
height=height,
guidance=guidance,
num_steps=num_steps,
seed=seed_val,
image_prompt1=img1,
image_prompt2=img2,
image_prompt3=img3,
image_prompt4=img4,
)
print(f"βœ… [DEBUG] Image {i+1} generated using prompt: {prompts[i]}")
results.append(gen_image)
results.append(prompts[i])
except Exception as e:
print(f"❌ [ERROR] Failed to generate image {i+1}: {e}")
results.append(None)
results.append(f"⚠️ Failed to generate: {e}")
# Pad to 10 outputs: 5 image + prompt pairs
while len(results) < 10:
results.append(None if len(results) % 2 == 0 else "")
return results
generate_btn.click(
fn=run_generation,
inputs=[
prompt, width, height, guidance, num_steps,
seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4, num_outputs
],
outputs=outputs
)
example_text = gr.Text("", visible=False, label="Case For:")
examples = get_examples("./assets/examples")
gr.Examples(
examples=examples,
inputs=[
example_text, prompt,
image_prompt1, image_prompt2, image_prompt3, image_prompt4,
seed, outputs[0]
],
)
return demo
if __name__ == "__main__":
from typing import Literal
from transformers import HfArgumentParser
@dataclasses.dataclass
class AppArgs:
name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
offload: bool = dataclasses.field(
default=False,
metadata={"help": "If True, sequentially offload unused models to CPU"}
)
port: int = 7860
parser = HfArgumentParser([AppArgs])
args = parser.parse_args_into_dataclasses()[0]
demo = create_demo(args.name, args.device, args.offload)
demo.launch(server_port=args.port)