Stable-Flow / app.py
linoyts's picture
linoyts HF Staff
Update app.py
7d1539c verified
raw
history blame
12.8 kB
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
import os
from huggingface_hub import hf_hub_download
import torch
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
#from gradio_imageslider import ImageSlider
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
import numpy as np
MULTIMODAL_VITAL_LAYERS = [0, 1, 17, 18]
SINGLE_MODAL_VITAL_LAYERS = list(np.array([28, 53, 54, 56, 25]) - 19)
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
#pipe.enable_lora()
pipe.to(DEVICE, dtype=torch.float16)
def get_examples():
case = [
[Image.open("metal.png"),"dragon.png", "a dragon, in 3d melting gold metal",0.9, 0.5, 0, 5, 28, 28, 0, False,False, 2, False, "text/image guided stylzation" ],
[Image.open("doll.png"),"anime.png", "anime illustration",0.9, 0.5, 0, 6, 28, 28, 0, False, False, 2, False,"text/image guided stylzation" ],
[Image.open("doll.png"), "raccoon.png", "raccoon, made of yarn",0.9, 0.5, 0, 4, 28, 28, 0, False, False, 2, False, "local subject edits" ],
[Image.open("cat.jpg"),"parrot.png", "a parrot", 0.9 ,0.5,2, 8,28, 28,0, False , False, 1, False, "local subject edits"],
[Image.open("cat.jpg"),"tiger.png", "a tiger", 0.9 ,0.5,0, 4,8, 8,789385745, False , False, 1, True, "local subject edits"],
[Image.open("metal.png"), "dragon.png","a dragon, in 3d melting gold metal",0.9, 0.5, 0, 4, 8, 8, 789385745, False,True, 2, True , "text/image guided stylzation"],
]
return case
def reset_image_input():
return True
def reset_do_inversion(image_input):
if image_input:
return True
else:
return False
def resize_img(image, max_size=1024):
width, height = image.size
scaling_factor = min(max_size / width, max_size / height)
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return image.resize((new_width, new_height), Image.LANCZOS)
@torch.no_grad()
@spaces.GPU(duration=85)
def image2latent(image, latent_nudging_scalar = 1.15):
image = pipe.image_processor.preprocess(image, height=1024, width=1024,).type(pipe.vae.dtype).to("cuda")
latents = pipe.vae.encode(image)["latent_dist"].mean
latents = (latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
latents = latents * latent_nudging_scalar
height = pipe.default_sample_size * pipe.vae_scale_factor
width = pipe.default_sample_size * pipe.vae_scale_factor
num_channels_latents = pipe.transformer.config.in_channels // 4
height = 2 * (height // (pipe.vae_scale_factor * 2))
width = 2 * (width // (pipe.vae_scale_factor * 2))
latents = pipe._pack_latents(
latents=latents,
batch_size=1,
num_channels_latents=num_channels_latents,
height=height,
width=width
)
return latents
def check_hyper_flux_lora(enable_hyper_flux):
if enable_hyper_flux:
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), lora_scale=0.125)
pipe.fuse_lora(lora_scale=0.125)
return 8, 8
else:
pipe.unfuse_lora()
return 28, 28
def convert_string_to_list(s):
return [int(x) for x in s.split(',') if x]
@spaces.GPU(duration=150)
def invert_and_edit(image,
source_prompt,
edit_prompt,
multimodal_layers,
single_layers,
num_inversion_steps,
num_inference_steps,
seed,
randomize_seed,
latent_nudging_scalar,
guidance_scale,
width = 1024,
height = 1024,
inverted_latent_list = None,
do_inversion = True,
image_input = False,
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if image_input and (image is not None):
if do_inversion:
inverted_latent_list = pipe(
source_prompt,
height=1024,
width=1024,
guidance_scale=1,
output_type="pil",
num_inference_steps=num_inversion_steps,
max_sequence_length=512,
latents=image2latent(image, latent_nudging_scalar),
invert_image=True
)
do_inversion = False
else:
# move to gpu because of zero and gr.states
inverted_latent_list = [tensor.to(DEVICE) for tensor in inverted_latent_list]
num_inference_steps = num_inversion_steps
latents = inverted_latent_list[-1].tile(2, 1, 1)
guidance_scale = [1,3]
image_input = True
else:
latents = torch.randn(
(4096, 64),
generator=torch.Generator(0).manual_seed(0),
dtype=torch.float16,
device=DEVICE,
).tile(2, 1, 1)
guidance_scale = guidance_scale
image_input = False
try:
multimodal_layers = convert_string_to_list(multimodal_layers)
single_layers = convert_string_to_list(single_layers)
except:
multimodal_layers = MULTIMODAL_VITAL_LAYERS
single_layers = SINGLE_MODAL_VITAL_LAYERS
output = pipe(
[source_prompt, edit_prompt],
height=1024,
width=1024,
guidance_scale=guidance_scale,
output_type="pil",
num_inference_steps=num_inference_steps,
max_sequence_length=512,
latents=latents,
inverted_latent_list=inverted_latent_list,
mm_copy_blocks=multimodal_layers,
single_copy_blocks=single_layers,
).images
# move back to cpu because of zero and gr.states
if inverted_latent_list is not None:
inverted_latent_list = [tensor.cpu() for tensor in inverted_latent_list]
if image is None:
image = output[0]
return image, output[1], inverted_latent_list, do_inversion, image_input, seed
# UI CSS
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
# Create the Gradio interface
with gr.Blocks(css=css) as demo:
inverted_latents = gr.State()
do_inversion = gr.State(False)
image_input = gr.State(False)
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# Stable Flow 🌊🖌️
### Edit real images with FLUX.1 [dev]
following the algorithm proposed in [*Stable Flow: Vital Layers for Training-Free Image Editing* by Avrahami et al.](https://arxiv.org/pdf/2411.14430)
[[non-commercial license](https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[project page](https://omriavrahami.com/stable-flow/) [[arxiv](https://arxiv.org/pdf/2411.14430)]
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="pil"
)
source_prompt = gr.Text(
label="Source Prompt",
max_lines=1,
placeholder="describe the edited output",
)
edit_prompt = gr.Text(
label="Edit Prompt",
max_lines=1,
placeholder="describe the edited output",
)
with gr.Row():
multimodal_layers = gr.Text(
info = "MMDiT attention layers used for editing",
label="vital multimodal layers",
max_lines=1,
value="0, 1, 17, 18",
)
single_layers = gr.Text(
info = "DiT attention layers used editing",
label="vital single layers",
max_lines=1,
value="9, 34, 35, 37, 6",
)
with gr.Row():
enable_hyper_flux = gr.Checkbox(label="8-step LoRA", value=False, info="may reduce edit quality", visible=False)
run_button = gr.Button("Edit", variant="primary")
with gr.Column():
result = gr.Image(label="Result")
# with gr.Column():
# with gr.Group():
# result = ImageSlider(position=0.5)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
num_inference_steps = gr.Slider(
label="num inference steps",
minimum=1,
maximum=50,
step=1,
value=8,
)
guidance_scale = gr.Slider(
label="guidance scale",
minimum=1,
maximum=25,
step=1,
value=3.5,
)
with gr.Row():
num_inversion_steps = gr.Slider(
label="num inversion steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
latent_nudging_scalar= gr.Slider(
label="latent nudging scalar",
minimum=1,
maximum=5,
step=0.01,
value=1.15,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
run_button.click(
fn=invert_and_edit,
inputs=[
input_image,
source_prompt,
edit_prompt,
multimodal_layers,
single_layers,
num_inversion_steps,
num_inference_steps,
seed,
randomize_seed,
latent_nudging_scalar,
guidance_scale,
width,
height,
inverted_latents,
do_inversion,
image_input
],
outputs=[input_image, result, inverted_latents, do_inversion, image_input, seed],
)
# gr.Examples(
# examples=get_examples(),
# inputs=[input_image,result, prompt, num_inversion_steps, num_inference_steps, seed, randomize_seed, enable_hyper_flux ],
# outputs=[result],
# )
input_image.input(fn=reset_image_input,
outputs=[image_input]).then(
fn=reset_do_inversion,
inputs = [image_input],
outputs=[do_inversion]
)
source_prompt.change(
fn=reset_do_inversion,
inputs = [image_input],
outputs=[do_inversion]
)
num_inversion_steps.change(
fn=reset_do_inversion,
inputs = [image_input],
outputs=[do_inversion]
)
seed.change(
fn=reset_do_inversion,
inputs = [image_input],
outputs=[do_inversion]
)
enable_hyper_flux.change(
fn=check_hyper_flux_lora,
inputs=[enable_hyper_flux],
outputs=[num_inversion_steps, num_inference_steps]
)
if __name__ == "__main__":
demo.launch()