# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from dataclasses import dataclass import torch import json import numpy as np from huggingface_hub import hf_hub_download from safetensors import safe_open from safetensors.torch import load_file as load_sft from .model import Flux, FluxParams from .modules.autoencoder import AutoEncoder, AutoEncoderParams from .modules.conditioner import HFEmbedder import re from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor def load_model(ckpt, device='cpu'): if ckpt.endswith('safetensors'): from safetensors import safe_open pl_sd = {} with safe_open(ckpt, framework="pt", device=device) as f: for k in f.keys(): pl_sd[k] = f.get_tensor(k) else: pl_sd = torch.load(ckpt, map_location=device) return pl_sd def load_safetensors(path): tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors def get_lora_rank(checkpoint): for k in checkpoint.keys(): if k.endswith(".down.weight"): return checkpoint[k].shape[0] def load_checkpoint(local_path, repo_id, name): if local_path is not None: if '.safetensors' in local_path: print(f"Loading .safetensors checkpoint from {local_path}") checkpoint = load_safetensors(local_path) else: print(f"Loading checkpoint from {local_path}") checkpoint = torch.load(local_path, map_location='cpu') elif repo_id is not None and name is not None: print(f"Loading checkpoint {name} from repo id {repo_id}") checkpoint = load_from_repo_id(repo_id, name) else: raise ValueError( "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" ) return checkpoint def c_crop(image): width, height = image.size new_size = min(width, height) left = (width - new_size) / 2 top = (height - new_size) / 2 right = (width + new_size) / 2 bottom = (height + new_size) / 2 return image.crop((left, top, right, bottom)) def pad64(x): return int(np.ceil(float(x) / 64.0) * 64 - x) def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: x = x[:, :, None] assert x.ndim == 3 H, W, C = x.shape assert C == 1 or C == 3 or C == 4 if C == 3: return x if C == 1: return np.concatenate([x, x, x], axis=2) if C == 4: color = x[:, :, 0:3].astype(np.float32) alpha = x[:, :, 3:4].astype(np.float32) / 255.0 y = color * alpha + 255.0 * (1.0 - alpha) y = y.clip(0, 255).astype(np.uint8) return y @dataclass class ModelSpec: params: FluxParams ae_params: AutoEncoderParams ckpt_path: str | None ae_path: str | None repo_id: str | None repo_flow: str | None repo_ae: str | None repo_id_ae: str | None configs = { "flux-dev": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_DEV"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-fp8": ModelSpec( repo_id="XLabs-AI/flux-dev-fp8", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux-dev-fp8.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_DEV_FP8"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-schnell": ModelSpec( repo_id="black-forest-labs/FLUX.1-schnell", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux1-schnell.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_SCHNELL"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=False, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), } def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) print("\n" + "-" * 79 + "\n") print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) def load_from_repo_id(repo_id, checkpoint_name): ckpt_path = hf_hub_download(repo_id, checkpoint_name) sd = load_sft(ckpt_path, device='cpu') return sd def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) with torch.device("meta" if ckpt_path is not None else device): model = Flux(configs[name].params).to(torch.bfloat16) if ckpt_path is not None: print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_model(ckpt_path, device=str(device)) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return model def load_flow_model_only_lora( name: str, device: str | torch.device = "cuda", hf_download: bool = True, lora_rank: int = 16 ): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) if hf_download: # lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors") try: lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors") except: lora_ckpt_path = os.environ.get("LORA", None) else: lora_ckpt_path = os.environ.get("LORA", None) with torch.device("meta" if ckpt_path is not None else device): model = Flux(configs[name].params) model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device) if ckpt_path is not None: print("Loading lora") lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\ else torch.load(lora_ckpt_path, map_location='cpu') print("Loading main checkpoint") # load_sft doesn't support torch.device if ckpt_path.endswith('safetensors'): sd = load_sft(ckpt_path, device=str(device)) sd.update(lora_sd) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) else: dit_state = torch.load(ckpt_path, map_location='cpu') sd = {} for k in dit_state.keys(): sd[k.replace('module.','')] = dit_state[k] sd.update(lora_sd) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) model.to(str(device)) print_load_warning(missing, unexpected) return model def set_lora( model: Flux, lora_rank: int, double_blocks_indices: list[int] | None = None, single_blocks_indices: list[int] | None = None, device: str | torch.device = "cpu", ) -> Flux: double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \ else single_blocks_indices lora_attn_procs = {} with torch.device(device): for name, attn_processor in model.attn_processors.items(): match = re.search(r'\.(\d+)\.', name) if match: layer_index = int(match.group(1)) if name.startswith("double_blocks") and layer_index in double_blocks_indices: lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) elif name.startswith("single_blocks") and layer_index in single_blocks_indices: lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) else: lora_attn_procs[name] = attn_processor model.set_attn_processor(lora_attn_procs) return model def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux from optimum.quanto import requantize print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') model = Flux(configs[name].params).to(torch.bfloat16) print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device='cpu') with open(json_path, "r") as f: quantization_map = json.load(f) print("Start a quantization process...") requantize(model, sd, quantization_map, device=device) print("Model is quantized!") return model def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: # max length 64, 128, 256 and 512 should work (if your sequence is short enough) version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders") return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device) def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: version = os.environ.get("CLIP", "openai/clip-vit-large-patch14") return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device) def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: ckpt_path = configs[name].ae_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_ae is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) # Loading the autoencoder print("Init AE") with torch.device("meta" if ckpt_path is not None else device): ae = AutoEncoder(configs[name].ae_params) if ckpt_path is not None: sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return ae