|
|
|
|
|
import random |
|
from collections import OrderedDict |
|
|
|
import torch, numpy as np |
|
from PIL import Image |
|
from scepter.modules.model.registry import MODELS |
|
from scepter.modules.utils.config import Config |
|
from scepter.modules.utils.distribute import we |
|
from .registry import BaseInference, INFERENCES |
|
from .utils import ACEPlusImageProcessor |
|
|
|
@INFERENCES.register_class() |
|
class ACEInference(BaseInference): |
|
''' |
|
reuse the ldm code |
|
''' |
|
def __init__(self, cfg, logger=None): |
|
super().__init__(cfg, logger) |
|
self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id) |
|
self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN) |
|
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for |
|
k, v in cfg.SAMPLE_ARGS.items()} |
|
self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16")) |
|
@torch.no_grad() |
|
def __call__(self, |
|
reference_image=None, |
|
edit_image=None, |
|
edit_mask=None, |
|
prompt='', |
|
edit_type=None, |
|
output_height=1024, |
|
output_width=1024, |
|
sampler='flow_euler', |
|
sample_steps=28, |
|
guide_scale=50, |
|
lora_path=None, |
|
seed=-1, |
|
repainting_scale=0, |
|
use_change=False, |
|
keep_pixels=False, |
|
keep_pixels_rate=0.8, |
|
**kwargs): |
|
|
|
if isinstance(prompt, str): |
|
prompt = [prompt] |
|
seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1) |
|
image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask, |
|
height=output_height, width=output_width, |
|
repainting_scale=repainting_scale, |
|
keep_pixels=keep_pixels, |
|
keep_pixels_rate=keep_pixels_rate, |
|
use_change = use_change) |
|
change_image = [None] if change_image is None else [change_image.to(we.device_id)] |
|
image, mask = [image.to(we.device_id)], [mask.to(we.device_id)] |
|
|
|
(src_image_list, src_mask_list, modify_image_list, |
|
edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt] |
|
|
|
with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'): |
|
out_image = self.pipe( |
|
src_image_list=src_image_list, |
|
modify_image_list= modify_image_list, |
|
src_mask_list=src_mask_list, |
|
edit_id=edit_id, |
|
image=image, |
|
image_mask=mask, |
|
prompt=prompt, |
|
sampler='flow_euler', |
|
sample_steps=sample_steps, |
|
seed=seed, |
|
guide_scale=guide_scale, |
|
show_process=True, |
|
) |
|
imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy() |
|
for x_i in out_image |
|
] |
|
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs] |
|
edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) |
|
change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) |
|
mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8)) |
|
return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed |
|
|