|
""" |
|
Wrapper class to call the stablediffusion.cpp shared library for GGUF support |
|
""" |
|
|
|
import ctypes |
|
import platform |
|
from ctypes import ( |
|
POINTER, |
|
c_bool, |
|
c_char_p, |
|
c_float, |
|
c_int, |
|
c_int64, |
|
c_void_p, |
|
) |
|
from dataclasses import dataclass |
|
from os import path |
|
from typing import List, Any |
|
|
|
import numpy as np |
|
from PIL import Image |
|
|
|
from backend.gguf.sdcpp_types import ( |
|
RngType, |
|
SampleMethod, |
|
Schedule, |
|
SDCPPLogLevel, |
|
SDImage, |
|
SdType, |
|
) |
|
|
|
|
|
@dataclass |
|
class ModelConfig: |
|
model_path: str = "" |
|
clip_l_path: str = "" |
|
t5xxl_path: str = "" |
|
diffusion_model_path: str = "" |
|
vae_path: str = "" |
|
taesd_path: str = "" |
|
control_net_path: str = "" |
|
lora_model_dir: str = "" |
|
embed_dir: str = "" |
|
stacked_id_embed_dir: str = "" |
|
vae_decode_only: bool = True |
|
vae_tiling: bool = False |
|
free_params_immediately: bool = False |
|
n_threads: int = 4 |
|
wtype: SdType = SdType.SD_TYPE_Q4_0 |
|
rng_type: RngType = RngType.CUDA_RNG |
|
schedule: Schedule = Schedule.DEFAULT |
|
keep_clip_on_cpu: bool = False |
|
keep_control_net_cpu: bool = False |
|
keep_vae_on_cpu: bool = False |
|
|
|
|
|
@dataclass |
|
class Txt2ImgConfig: |
|
prompt: str = "a man wearing sun glasses, highly detailed" |
|
negative_prompt: str = "" |
|
clip_skip: int = -1 |
|
cfg_scale: float = 2.0 |
|
guidance: float = 3.5 |
|
width: int = 512 |
|
height: int = 512 |
|
sample_method: SampleMethod = SampleMethod.EULER_A |
|
sample_steps: int = 1 |
|
seed: int = -1 |
|
batch_count: int = 2 |
|
control_cond: Image = None |
|
control_strength: float = 0.90 |
|
style_strength: float = 0.5 |
|
normalize_input: bool = False |
|
input_id_images_path: bytes = b"" |
|
|
|
|
|
class GGUFDiffusion: |
|
"""GGUF Diffusion |
|
To support GGUF diffusion model based on stablediffusion.cpp |
|
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md |
|
Implmented based on stablediffusion.h |
|
""" |
|
|
|
def __init__( |
|
self, |
|
libpath: str, |
|
config: ModelConfig, |
|
logging_enabled: bool = False, |
|
): |
|
sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath) |
|
try: |
|
self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path) |
|
except OSError as e: |
|
print(f"Failed to load library {sdcpp_shared_lib_path}") |
|
raise ValueError(f"Error: {e}") |
|
|
|
if not config.clip_l_path or not path.exists(config.clip_l_path): |
|
raise ValueError( |
|
"CLIP model file not found,please check readme.md for GGUF model usage" |
|
) |
|
|
|
if not config.t5xxl_path or not path.exists(config.t5xxl_path): |
|
raise ValueError( |
|
"T5XXL model file not found,please check readme.md for GGUF model usage" |
|
) |
|
|
|
if not config.diffusion_model_path or not path.exists( |
|
config.diffusion_model_path |
|
): |
|
raise ValueError( |
|
"Diffusion model file not found,please check readme.md for GGUF model usage" |
|
) |
|
|
|
if not config.vae_path or not path.exists(config.vae_path): |
|
raise ValueError( |
|
"VAE model file not found,please check readme.md for GGUF model usage" |
|
) |
|
|
|
self.model_config = config |
|
|
|
self.libsdcpp.new_sd_ctx.argtypes = [ |
|
c_char_p, |
|
c_char_p, |
|
c_char_p, |
|
c_char_p, |
|
c_char_p, |
|
c_char_p, |
|
c_char_p, |
|
c_char_p, |
|
c_char_p, |
|
c_char_p, |
|
c_bool, |
|
c_bool, |
|
c_bool, |
|
c_int, |
|
SdType, |
|
RngType, |
|
Schedule, |
|
c_bool, |
|
c_bool, |
|
c_bool, |
|
] |
|
|
|
self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p) |
|
|
|
self.sd_ctx = self.libsdcpp.new_sd_ctx( |
|
self._str_to_bytes(self.model_config.model_path), |
|
self._str_to_bytes(self.model_config.clip_l_path), |
|
self._str_to_bytes(self.model_config.t5xxl_path), |
|
self._str_to_bytes(self.model_config.diffusion_model_path), |
|
self._str_to_bytes(self.model_config.vae_path), |
|
self._str_to_bytes(self.model_config.taesd_path), |
|
self._str_to_bytes(self.model_config.control_net_path), |
|
self._str_to_bytes(self.model_config.lora_model_dir), |
|
self._str_to_bytes(self.model_config.embed_dir), |
|
self._str_to_bytes(self.model_config.stacked_id_embed_dir), |
|
self.model_config.vae_decode_only, |
|
self.model_config.vae_tiling, |
|
self.model_config.free_params_immediately, |
|
self.model_config.n_threads, |
|
self.model_config.wtype, |
|
self.model_config.rng_type, |
|
self.model_config.schedule, |
|
self.model_config.keep_clip_on_cpu, |
|
self.model_config.keep_control_net_cpu, |
|
self.model_config.keep_vae_on_cpu, |
|
) |
|
|
|
if logging_enabled: |
|
self._set_logcallback() |
|
|
|
def _set_logcallback(self): |
|
print("Setting logging callback") |
|
|
|
SdLogCallbackType = ctypes.CFUNCTYPE( |
|
None, |
|
SDCPPLogLevel, |
|
ctypes.c_char_p, |
|
ctypes.c_void_p, |
|
) |
|
|
|
self.libsdcpp.sd_set_log_callback.argtypes = [ |
|
SdLogCallbackType, |
|
ctypes.c_void_p, |
|
] |
|
self.libsdcpp.sd_set_log_callback.restype = None |
|
|
|
self.c_log_callback = SdLogCallbackType( |
|
self.log_callback |
|
) |
|
self.libsdcpp.sd_set_log_callback(self.c_log_callback, None) |
|
|
|
def _get_sdcpp_shared_lib_path( |
|
self, |
|
root_path: str, |
|
) -> str: |
|
system_name = platform.system() |
|
print(f"GGUF Diffusion on {system_name}") |
|
lib_name = "stable-diffusion.dll" |
|
sdcpp_lib_path = "" |
|
|
|
if system_name == "Windows": |
|
sdcpp_lib_path = path.join(root_path, lib_name) |
|
elif system_name == "Linux": |
|
lib_name = "libstable-diffusion.so" |
|
sdcpp_lib_path = path.join(root_path, lib_name) |
|
elif system_name == "Darwin": |
|
lib_name = "libstable-diffusion.dylib" |
|
sdcpp_lib_path = path.join(root_path, lib_name) |
|
else: |
|
print("Unknown platform.") |
|
|
|
return sdcpp_lib_path |
|
|
|
@staticmethod |
|
def log_callback( |
|
level, |
|
text, |
|
data, |
|
): |
|
print(f"{text.decode('utf-8')}", end="") |
|
|
|
def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes: |
|
if in_str: |
|
return in_str.encode(encoding) |
|
else: |
|
return b"" |
|
|
|
def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]: |
|
self.libsdcpp.txt2img.restype = POINTER(SDImage) |
|
self.libsdcpp.txt2img.argtypes = [ |
|
c_void_p, |
|
c_char_p, |
|
c_char_p, |
|
c_int, |
|
c_float, |
|
c_float, |
|
c_int, |
|
c_int, |
|
SampleMethod, |
|
c_int, |
|
c_int64, |
|
c_int, |
|
POINTER(SDImage), |
|
c_float, |
|
c_float, |
|
c_bool, |
|
c_char_p, |
|
] |
|
|
|
image_buffer = self.libsdcpp.txt2img( |
|
self.sd_ctx, |
|
self._str_to_bytes(txt2img_cfg.prompt), |
|
self._str_to_bytes(txt2img_cfg.negative_prompt), |
|
txt2img_cfg.clip_skip, |
|
txt2img_cfg.cfg_scale, |
|
txt2img_cfg.guidance, |
|
txt2img_cfg.width, |
|
txt2img_cfg.height, |
|
txt2img_cfg.sample_method, |
|
txt2img_cfg.sample_steps, |
|
txt2img_cfg.seed, |
|
txt2img_cfg.batch_count, |
|
txt2img_cfg.control_cond, |
|
txt2img_cfg.control_strength, |
|
txt2img_cfg.style_strength, |
|
txt2img_cfg.normalize_input, |
|
txt2img_cfg.input_id_images_path, |
|
) |
|
|
|
images = self._get_sd_images_from_buffer( |
|
image_buffer, |
|
txt2img_cfg.batch_count, |
|
) |
|
|
|
return images |
|
|
|
def _get_sd_images_from_buffer( |
|
self, |
|
image_buffer: Any, |
|
batch_count: int, |
|
) -> List[Any]: |
|
images = [] |
|
if image_buffer: |
|
for i in range(batch_count): |
|
image = image_buffer[i] |
|
print( |
|
f"Generated image: {image.width}x{image.height} with {image.channel} channels" |
|
) |
|
|
|
width = image.width |
|
height = image.height |
|
channels = image.channel |
|
pixel_data = np.ctypeslib.as_array( |
|
image.data, shape=(height, width, channels) |
|
) |
|
|
|
if channels == 1: |
|
pil_image = Image.fromarray(pixel_data.squeeze(), mode="L") |
|
elif channels == 3: |
|
pil_image = Image.fromarray(pixel_data, mode="RGB") |
|
elif channels == 4: |
|
pil_image = Image.fromarray(pixel_data, mode="RGBA") |
|
else: |
|
raise ValueError(f"Unsupported number of channels: {channels}") |
|
|
|
images.append(pil_image) |
|
return images |
|
|
|
def terminate(self): |
|
if self.libsdcpp: |
|
if self.sd_ctx: |
|
self.libsdcpp.free_sd_ctx.argtypes = [c_void_p] |
|
self.libsdcpp.free_sd_ctx.restype = None |
|
self.libsdcpp.free_sd_ctx(self.sd_ctx) |
|
del self.sd_ctx |
|
self.sd_ctx = None |
|
del self.libsdcpp |
|
self.libsdcpp = None |
|
|