File size: 5,655 Bytes
5a48378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa83627
5a48378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2fd487
5a48378
 
aa83627
 
5a48378
 
 
 
 
 
 
 
 
3cfa238
b76b6c0
 
 
 
 
 
 
5a48378
 
 
 
 
91ba562
 
 
 
5a48378
 
 
 
6922a91
5a48378
 
 
 
 
3cfa238
aa83627
123497a
aa83627
5a48378
 
 
 
 
 
aa83627
5a48378
 
 
 
 
 
 
 
 
b76b6c0
5a48378
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr

from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora

# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"    
lora_base_path = "./models"


pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")

def clear_cache(transformer):
    for name, attn_processor in transformer.attn_processors.items():
        attn_processor.bank_kv.clear()

# Define the Gradio interface
@spaces.GPU()
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, use_zero_init, zero_steps):
    # Set the control type
    if control_type == "Ghibli":
        lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
    set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
    
    # Process the image
    spatial_imgs = [spatial_img] if spatial_img else []
    image = pipe(
        prompt,
        height=int(height),
        width=int(width),
        guidance_scale=3.5,
        num_inference_steps=25,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(seed), 
        subject_images=[],
        spatial_images=spatial_imgs,
        cond_size=512,
        use_zero_init=use_zero_init,
        zero_steps=int(zero_steps)
    ).images[0]
    clear_cache(pipe.transformer)
    return image

# Define the Gradio interface components
control_types = ["Ghibli"]

# Example data
single_examples = [
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 680, 1024, 5, "Ghibli", True, 1],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 560, 1024, 42, "Ghibli", False, 0],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 568, 1024, 1, "Ghibli", False, 0],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 768, 672, 1, "Ghibli", False, 0],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 896, 1024, 1, "Ghibli", False, 0],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 528, 800, 1, "Ghibli", False, 0],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 696, 1024, 1, "Ghibli", False, 0],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 896, 1024, 1, "Ghibli", False, 0],
]


# Create the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("# Ghibli Studio: Управление генерацией изображений на GPT-ChatBot")
    gr.Markdown("Модель обучена всего на 100 реальных лицах в паре с аналогами, сгенерированными GPT-4o в стиле Гибли, и сохраняет черты лица, применяя культовую эстетику аниме.")
   
    gr.Markdown("**Внимание:** Рекомендуемые подсказки для использования Ghibli Control LoRA должны включать триггерные слова: Ghibli Studio style, Charming hand-drawn anime-style illustration. Вы также можете добавить несколько подробных описаний для лучшего результата.")

    with gr.Tab("Ghibli Condition Generation"):
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
                spatial_img = gr.Image(label="Ghibli Image", type="pil")  # 上传图像文件
                height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
                width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
                seed = gr.Number(label="Seed", value=42)
                control_type = gr.Dropdown(choices=control_types, label="Control Type")
                use_zero_init = gr.Checkbox(label="Use CFG zero star", value=False)
                zero_steps = gr.Number(label="Zero Init Steps", value=1)
                single_generate_btn = gr.Button("Generate Image")
                
            with gr.Column():
                single_output_image = gr.Image(label="Generated Image")

        # Add examples for Single Condition Generation
        gr.Examples(
            examples=single_examples,
            inputs=[prompt, spatial_img, height, width, seed, control_type, use_zero_init, zero_steps],
            outputs=single_output_image,
            fn=single_condition_generate_image,
            cache_examples=False,  # 缓存示例结果以加快加载速度
            label="Single Condition Examples"
        )

    # Link the buttons to the functions
    single_generate_btn.click(
        single_condition_generate_image,
        inputs=[prompt, spatial_img, height, width, seed, control_type, use_zero_init, zero_steps],
        outputs=single_output_image
    )

# Launch the Gradio app
demo.queue().launch()