|
import torch
|
|
import argparse
|
|
from diffusers.utils import load_image, check_min_version
|
|
from controlnet_flux import FluxControlNetModel
|
|
from transformer_flux import FluxTransformer2DModel
|
|
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
|
|
|
|
|
|
def main(image, mask, prompt):
|
|
check_min_version("0.30.2")
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.cuda.empty_cache()
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
import os
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
|
|
|
|
|
controlnet = FluxControlNetModel.from_pretrained(
|
|
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
|
torch_dtype=torch.bfloat16,
|
|
).to("cuda")
|
|
|
|
transformer = FluxTransformer2DModel.from_pretrained(
|
|
"black-forest-labs/FLUX.1-dev",
|
|
subfolder="transformer",
|
|
torch_dtype=torch.bfloat16,
|
|
).to("cuda")
|
|
|
|
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
|
|
"black-forest-labs/FLUX.1-dev",
|
|
controlnet=controlnet,
|
|
transformer=transformer,
|
|
torch_dtype=torch.bfloat16,
|
|
).to("cuda")
|
|
|
|
|
|
pipe.enable_attention_slicing(1)
|
|
|
|
|
|
size = (384, 384)
|
|
image = image.convert("RGB").resize(size)
|
|
mask = mask.convert("RGB").resize(size)
|
|
|
|
|
|
generator = torch.Generator(device="cuda").manual_seed(24)
|
|
|
|
|
|
with torch.cuda.amp.autocast():
|
|
result = pipe(
|
|
prompt=prompt,
|
|
height=size[1],
|
|
width=size[0],
|
|
control_image=image,
|
|
control_mask=mask,
|
|
num_inference_steps=28,
|
|
generator=generator,
|
|
controlnet_conditioning_scale=0.9,
|
|
guidance_scale=3.5,
|
|
negative_prompt="",
|
|
true_guidance_scale=1.0,
|
|
).images[0]
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
print("Successfully inpaint image")
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Inpaint an image using FluxControlNetInpaintingPipeline."
|
|
)
|
|
parser.add_argument(
|
|
"--image_path", type=str, required=True, help="Path to the input image."
|
|
)
|
|
parser.add_argument(
|
|
"--mask_path", type=str, required=True, help="Path to the mask image."
|
|
)
|
|
parser.add_argument(
|
|
"--prompt", type=str, required=True, help="Prompt for the inpainting process."
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
result = main(args.image_path, args.mask_path, args.prompt)
|
|
result.save("output.png")
|
|
|