EasyGhibli / app.py
abidlabs's picture
abidlabs HF Staff
Update app.py
f824bf1 verified
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr
from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"
lora_base_path = "./models"
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
# Define the Gradio interface
@spaces.GPU()
def single_condition_generate_image(spatial_img):
"""
Convert an image into a Studio Ghibli style image
"""
prompt = "Ghibli Studio style, Charming hand-drawn anime-style illustration"
use_zero_init = False
zero_steps = 1
control_type = "Ghibli"
height = 768
width = 768
seed = 42
if control_type == "Ghibli":
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
# Process the image
spatial_imgs = [spatial_img] if spatial_img else []
image = pipe(
prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=[],
spatial_images=spatial_imgs,
cond_size=512,
use_zero_init=use_zero_init,
zero_steps=int(zero_steps)
).images[0]
clear_cache(pipe.transformer)
return image
# Define the Gradio interface components
control_types = ["Ghibli"]
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
with gr.Tab("Ghibli Condition Generation"):
with gr.Row():
with gr.Column():
spatial_img = gr.Image(label="Ghibli Image", type="pil") # 上传图像文件
single_generate_btn = gr.Button("Generate Image")
with gr.Column():
single_output_image = gr.Image(label="Generated Image")
# Link the buttons to the functions
single_generate_btn.click(
single_condition_generate_image,
inputs=[spatial_img],
outputs=single_output_image
)
# Launch the Gradio app
demo.launch(mcp_server=True)