|
import glob |
|
from os import path |
|
from paths import get_file_name, FastStableDiffusionPaths |
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _lora_info: |
|
def __init__( |
|
self, |
|
path: str, |
|
weight: float, |
|
): |
|
self.path = path |
|
self.adapter_name = get_file_name(path) |
|
self.weight = weight |
|
|
|
def __del__(self): |
|
self.path = None |
|
self.adapter_name = None |
|
|
|
|
|
_loaded_loras = [] |
|
_current_pipeline = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_lora_weight( |
|
pipeline, |
|
lcm_diffusion_setting, |
|
): |
|
if not lcm_diffusion_setting.lora.path: |
|
raise Exception("Empty lora model path") |
|
|
|
if not path.exists(lcm_diffusion_setting.lora.path): |
|
raise Exception("Lora model path is invalid") |
|
|
|
|
|
|
|
global _loaded_loras |
|
global _current_pipeline |
|
if pipeline != _current_pipeline: |
|
for lora in _loaded_loras: |
|
del lora |
|
del _loaded_loras |
|
_loaded_loras = [] |
|
_current_pipeline = pipeline |
|
|
|
current_lora = _lora_info( |
|
lcm_diffusion_setting.lora.path, |
|
lcm_diffusion_setting.lora.weight, |
|
) |
|
_loaded_loras.append(current_lora) |
|
|
|
if lcm_diffusion_setting.lora.enabled: |
|
print(f"LoRA adapter name : {current_lora.adapter_name}") |
|
pipeline.load_lora_weights( |
|
FastStableDiffusionPaths.get_lora_models_path(), |
|
weight_name=Path(lcm_diffusion_setting.lora.path).name, |
|
local_files_only=True, |
|
adapter_name=current_lora.adapter_name, |
|
) |
|
update_lora_weights( |
|
pipeline, |
|
lcm_diffusion_setting, |
|
) |
|
|
|
if lcm_diffusion_setting.lora.fuse: |
|
pipeline.fuse_lora() |
|
|
|
|
|
def get_lora_models(root_dir: str): |
|
lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True) |
|
lora_models_map = {} |
|
for file_path in lora_models: |
|
lora_name = get_file_name(file_path) |
|
if lora_name is not None: |
|
lora_models_map[lora_name] = file_path |
|
return lora_models_map |
|
|
|
|
|
|
|
|
|
def get_active_lora_weights(): |
|
active_loras = [] |
|
for lora_info in _loaded_loras: |
|
active_loras.append( |
|
( |
|
lora_info.adapter_name, |
|
lora_info.weight, |
|
) |
|
) |
|
return active_loras |
|
|
|
|
|
|
|
|
|
def update_lora_weights( |
|
pipeline, |
|
lcm_diffusion_setting, |
|
lora_weights=None, |
|
): |
|
global _loaded_loras |
|
global _current_pipeline |
|
if pipeline != _current_pipeline: |
|
print("Wrong pipeline when trying to update LoRA weights") |
|
return |
|
if lora_weights: |
|
for idx, lora in enumerate(lora_weights): |
|
if _loaded_loras[idx].adapter_name != lora[0]: |
|
print("Wrong adapter name in LoRA enumeration!") |
|
continue |
|
_loaded_loras[idx].weight = lora[1] |
|
|
|
adapter_names = [] |
|
adapter_weights = [] |
|
if lcm_diffusion_setting.use_lcm_lora: |
|
adapter_names.append("lcm") |
|
adapter_weights.append(1.0) |
|
for lora in _loaded_loras: |
|
adapter_names.append(lora.adapter_name) |
|
adapter_weights.append(lora.weight) |
|
pipeline.set_adapters( |
|
adapter_names, |
|
adapter_weights=adapter_weights, |
|
) |
|
adapter_weights = zip(adapter_names, adapter_weights) |
|
print(f"Adapters: {list(adapter_weights)}") |
|
|