Stable-Flow / app.py
linoyts's picture
linoyts HF Staff
Update app.py
030cf22 verified
raw
history blame
10.7 kB
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
import os
from huggingface_hub import hf_hub_download
import torch
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
import numpy as np
MULTIMODAL_VITAL_LAYERS = [0, 1, 17, 18]
SINGLE_MODAL_VITAL_LAYERS = list(np.array([28, 53, 54, 56, 25]) - 19)
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16)
# pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), lora_scale=0.125)
# pipe.fuse_lora(lora_scale=0.125)
#pipe.enable_lora()
pipe.to("cuda")
def get_examples():
case = [
[Image.open("metal.png"),"dragon.png", "a dragon, in 3d melting gold metal",0.9, 0.5, 0, 5, 28, 28, 0, False,False, 2, False, "text/image guided stylzation" ],
[Image.open("doll.png"),"anime.png", "anime illustration",0.9, 0.5, 0, 6, 28, 28, 0, False, False, 2, False,"text/image guided stylzation" ],
[Image.open("doll.png"), "raccoon.png", "raccoon, made of yarn",0.9, 0.5, 0, 4, 28, 28, 0, False, False, 2, False, "local subject edits" ],
[Image.open("cat.jpg"),"parrot.png", "a parrot", 0.9 ,0.5,2, 8,28, 28,0, False , False, 1, False, "local subject edits"],
[Image.open("cat.jpg"),"tiger.png", "a tiger", 0.9 ,0.5,0, 4,8, 8,789385745, False , False, 1, True, "local subject edits"],
[Image.open("metal.png"), "dragon.png","a dragon, in 3d melting gold metal",0.9, 0.5, 0, 4, 8, 8, 789385745, False,True, 2, True , "text/image guided stylzation"],
]
return case
def reset_do_inversion():
return True
def resize_img(image, max_size=1024):
width, height = image.size
scaling_factor = min(max_size / width, max_size / height)
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return image.resize((new_width, new_height), Image.LANCZOS)
@torch.no_grad()
@spaces.GPU(duration=85)
def image2latent(image, latent_nudging_scalar = 1.15):
image = pipe.image_processor.preprocess(image, height=1024, width=1024,).type(pipe.vae.dtype).to("cuda")
latents = pipe.vae.encode(image)["latent_dist"].mean
latents = (latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
latents = latents * latent_nudging_scalar
height = pipe.default_sample_size * pipe.vae_scale_factor
width = pipe.default_sample_size * pipe.vae_scale_factor
num_channels_latents = pipe.transformer.config.in_channels // 4
height = 2 * (height // (pipe.vae_scale_factor * 2))
width = 2 * (width // (pipe.vae_scale_factor * 2))
latents = pipe._pack_latents(
latents=latents,
batch_size=1,
num_channels_latents=num_channels_latents,
height=height,
width=width
)
return latents
def check_hyper_flux_lora(enable_hyper_flux):
if enable_hyper_flux:
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), lora_scale=0.125)
pipe.fuse_lora(lora_scale=0.125)
return 8, 8
else:
pipe.unfuse_lora()
return 28, 28
def convert_string_to_list(s):
return [int(x) for x in s.split(',') if x]
@spaces.GPU(duration=150)
def invert_and_edit(image,
source_prompt,
edit_prompt,
multimodal_layers,
single_layers,
num_inversion_steps,
num_inference_steps,
seed,
randomize_seed,
width = 1024,
height = 1024,
inverted_latent_list = None,
do_inversion = True,
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if do_inversion:
inverted_latent_list = pipe(
source_prompt,
height=1024,
width=1024,
guidance_scale=1,
output_type="pil",
num_inference_steps=num_inversion_steps,
max_sequence_length=512,
latents=image2latent(image),
invert_image=True
)
do_inversion = False
else:
# move to gpu because of zero and gr.states
inverted_latent_list = [tensor.to(DEVICE) for tensor in inverted_latent_list]
try:
multimodal_layers = convert_string_to_list(multimodal_layers)
single_layers = convert_string_to_list(single_layers)
except:
multimodal_layers = MULTIMODAL_VITAL_LAYERS
single_layers = SINGLE_MODAL_VITAL_LAYERS
output = pipe(
[source_prompt, edit_prompt],
height=1024,
width=1024,
guidance_scale=[1,3],
output_type="pil",
num_inference_steps=num_inference_steps,
max_sequence_length=512,
latents=inverted_latent_list[-1].tile(2, 1, 1),
inverted_latent_list=inverted_latent_list,
mm_copy_blocks=multimodal_layers,
single_copy_blocks=single_layers,
).images[1]
# move back to cpu because of zero and gr.states
inverted_latent_list = [tensor.cpu() for tensor in inverted_latent_list]
return output, inverted_latent_list, do_inversion, seed
# UI CSS
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
# Create the Gradio interface
with gr.Blocks(css=css) as demo:
inverted_latents = gr.State()
do_inversion = gr.State(True)
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# Stable Flow 🖌️🏞️
### Edit real images with FLUX.1 [dev]
following the algorithm proposed in [*Stable Flow: Vital Layers for Training-Free Image Editing* by Avrahami et al.](https://arxiv.org/pdf/2411.14430)
[[non-commercial license](https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[project page](https://omriavrahami.com/stable-flow/) [[arxiv](https://arxiv.org/pdf/2411.14430)]
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="pil"
)
source_prompt = gr.Text(
label="Source Prompt",
max_lines=1,
placeholder="describe the edited output",
)
edit_prompt = gr.Text(
label="Edit Prompt",
max_lines=1,
placeholder="describe the edited output",
)
with gr.Row():
multimodal_layers = gr.Text(
info = "the attention layers used for injection",
label="vital multimodal layers",
max_lines=1,
placeholder="0, 1, 17, 18",
)
single_layers = gr.Text(
info = "the attention layers used for injection",
label="vital single layers",
max_lines=1,
placeholder="9, 34, 35, 37, 6",
)
with gr.Row():
enable_hyper_flux = gr.Checkbox(label="8-step LoRA", value=False, info="may reduce edit quality", visible=False)
run_button = gr.Button("Edit", variant="primary")
with gr.Column():
result = gr.Image(label="Result")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
num_inference_steps = gr.Slider(
label="num inference steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
with gr.Row():
num_inversion_steps = gr.Slider(
label="num inversion steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
run_button.click(
fn=invert_and_edit,
inputs=[
input_image,
source_prompt,
edit_prompt,
multimodal_layers,
single_layers,
num_inversion_steps,
num_inference_steps,
seed,
randomize_seed,
width,
height,
inverted_latents,
do_inversion
],
outputs=[result, inverted_latents, do_inversion, seed],
)
# gr.Examples(
# examples=get_examples(),
# inputs=[input_image,result, prompt, num_inversion_steps, num_inference_steps, seed, randomize_seed, enable_hyper_flux ],
# outputs=[result],
# )
input_image.change(
fn=reset_do_inversion,
outputs=[do_inversion]
)
num_inversion_steps.change(
fn=reset_do_inversion,
outputs=[do_inversion]
)
seed.change(
fn=reset_do_inversion,
outputs=[do_inversion]
)
enable_hyper_flux.change(
fn=check_hyper_flux_lora,
inputs=[enable_hyper_flux],
outputs=[num_inversion_steps, num_inference_steps]
)
if __name__ == "__main__":
demo.launch()