File size: 2,679 Bytes
5a48378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e7f992
f824bf1
 
 
c323f44
 
 
 
 
 
 
5a48378
 
 
 
 
 
 
 
 
 
 
 
 
 
e2fd487
5a48378
 
aa83627
 
5a48378
 
 
 
 
 
 
 
 
 
 
 
 
 
123497a
aa83627
5a48378
 
 
 
 
 
c323f44
5a48378
 
 
 
6757b56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr

from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora

# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"    
lora_base_path = "./models"


pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")

def clear_cache(transformer):
    for name, attn_processor in transformer.attn_processors.items():
        attn_processor.bank_kv.clear()

# Define the Gradio interface
@spaces.GPU()
def single_condition_generate_image(spatial_img):
    """
    Convert an image into a Studio Ghibli style image
    """
    prompt = "Ghibli Studio style, Charming hand-drawn anime-style illustration"
    use_zero_init = False
    zero_steps = 1
    control_type = "Ghibli"
    height = 768
    width = 768
    seed = 42
    if control_type == "Ghibli":
        lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
    set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
    
    # Process the image
    spatial_imgs = [spatial_img] if spatial_img else []
    image = pipe(
        prompt,
        height=int(height),
        width=int(width),
        guidance_scale=3.5,
        num_inference_steps=25,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(seed), 
        subject_images=[],
        spatial_images=spatial_imgs,
        cond_size=512,
        use_zero_init=use_zero_init,
        zero_steps=int(zero_steps)
    ).images[0]
    clear_cache(pipe.transformer)
    return image

# Define the Gradio interface components
control_types = ["Ghibli"]


# Create the Gradio Blocks interface
with gr.Blocks() as demo:
    with gr.Tab("Ghibli Condition Generation"):
        with gr.Row():
            with gr.Column():
                spatial_img = gr.Image(label="Ghibli Image", type="pil")  # 上传图像文件
                single_generate_btn = gr.Button("Generate Image")
                
            with gr.Column():
                single_output_image = gr.Image(label="Generated Image")

    # Link the buttons to the functions
    single_generate_btn.click(
        single_condition_generate_image,
        inputs=[spatial_img],
        outputs=single_output_image
    )

# Launch the Gradio app
demo.launch(mcp_server=True)