Spaces:
Sleeping
Sleeping
File size: 2,679 Bytes
5a48378 5e7f992 f824bf1 c323f44 5a48378 e2fd487 5a48378 aa83627 5a48378 123497a aa83627 5a48378 c323f44 5a48378 6757b56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import spaces
import os
import json
import time
import torch
from PIL import Image
from tqdm import tqdm
import gradio as gr
from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"
lora_base_path = "./models"
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.to("cuda")
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
# Define the Gradio interface
@spaces.GPU()
def single_condition_generate_image(spatial_img):
"""
Convert an image into a Studio Ghibli style image
"""
prompt = "Ghibli Studio style, Charming hand-drawn anime-style illustration"
use_zero_init = False
zero_steps = 1
control_type = "Ghibli"
height = 768
width = 768
seed = 42
if control_type == "Ghibli":
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
# Process the image
spatial_imgs = [spatial_img] if spatial_img else []
image = pipe(
prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=25,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=[],
spatial_images=spatial_imgs,
cond_size=512,
use_zero_init=use_zero_init,
zero_steps=int(zero_steps)
).images[0]
clear_cache(pipe.transformer)
return image
# Define the Gradio interface components
control_types = ["Ghibli"]
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
with gr.Tab("Ghibli Condition Generation"):
with gr.Row():
with gr.Column():
spatial_img = gr.Image(label="Ghibli Image", type="pil") # 上传图像文件
single_generate_btn = gr.Button("Generate Image")
with gr.Column():
single_output_image = gr.Image(label="Generated Image")
# Link the buttons to the functions
single_generate_btn.click(
single_condition_generate_image,
inputs=[spatial_img],
outputs=single_output_image
)
# Launch the Gradio app
demo.launch(mcp_server=True) |