import re import torch import torch.nn as nn from dataclasses import dataclass from diffusers.models.attention_processor import Attention from diffusers import ( DDPMScheduler, UNet2DConditionModel, AutoencoderKL ) from diffusers.loaders import AttnProcsLayers class LoRALinearLayerwBias(nn.Module): r""" A linear layer that is used with LoRA, can be used with bias. Parameters: in_features (`int`): Number of input features. out_features (`int`): Number of output features. rank (`int`, `optional`, defaults to 4): The rank of the LoRA layer. network_alpha (`float`, `optional`, defaults to `None`): The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning device (`torch.device`, `optional`, defaults to `None`): The device to use for the layer's weights. dtype (`torch.dtype`, `optional`, defaults to `None`): The dtype to use for the layer's weights. """ def __init__( self, in_features: int, out_features: int, rank: int = 4, network_alpha=None, device=None, dtype=None, with_bias: bool = False ): super().__init__() self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) if with_bias: self.bias = nn.Parameter(torch.zeros([1, 1, out_features], device=device, dtype=dtype)) self.with_bias = with_bias # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha self.rank = rank self.out_features = out_features self.in_features = in_features nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) if self.with_bias: up_hidden_states = up_hidden_states + self.bias if self.network_alpha is not None: up_hidden_states *= self.network_alpha / self.rank return up_hidden_states.to(orig_dtype) class TriplaneLoRAConv2dLayer(nn.Module): r""" A convolutional layer that is used with LoRA. Parameters: in_features (`int`): Number of input features. out_features (`int`): Number of output features. rank (`int`, `optional`, defaults to 4): The rank of the LoRA layer. kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1): The kernel size of the convolution. stride (`int` or `tuple` of two `int`, `optional`, defaults to 1): The stride of the convolution. padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0): The padding of the convolution. network_alpha (`float`, `optional`, defaults to `None`): The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning """ def __init__( self, in_features: int, out_features: int, rank: int = 4, kernel_size = (1, 1), stride = (1, 1), padding = 0, network_alpha = None, with_bias: bool = False, locon_type: str = "hexa_v1", #hexa_v2, vanilla_v1, vanilla_v2 ): super().__init__() assert locon_type in ["hexa_v1", "hexa_v2", "vanilla_v1", "vanilla_v2"], "The LoCON type is not supported." if locon_type == "hexa_v1": self.down_xy_geo = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.down_xz_geo = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.down_yz_geo = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.down_xy_tex = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.down_xz_tex = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.down_yz_tex = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) # according to the official kohya_ss trainer kernel_size are always fixed for the up layer # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 self.up_xy_geo = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias) self.up_xz_geo = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias) self.up_yz_geo = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias) self.up_xy_tex = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias) self.up_xz_tex = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias) self.up_yz_tex = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning elif locon_type == "hexa_v2": self.down_xy_geo = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False) self.down_xz_geo = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False) self.down_yz_geo = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False) self.down_xy_tex = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False) self.down_xz_tex = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False) self.down_yz_tex = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1),padding=padding, bias=False) self.up_xy_geo = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias) self.up_xz_geo = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias) self.up_yz_geo = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias) self.up_xy_tex = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias) self.up_xz_tex = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias) self.up_yz_tex = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias) elif locon_type == "vanilla_v1": self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=with_bias) elif locon_type == "vanilla_v2": self.down = nn.Conv2d(in_features, rank, kernel_size=(1, 1), stride=(1, 1), padding=padding, bias=False) self.up = nn.Conv2d(rank, out_features, kernel_size=kernel_size, stride=stride, bias=with_bias) self.network_alpha = network_alpha self.rank = rank self.locon_type = locon_type self._init_weights() def _init_weights(self): for layer in [ "down_xy_geo", "down_xz_geo", "down_yz_geo", "down_xy_tex", "down_xz_tex", "down_yz_tex", # in case of hexa_vX "up_xy", "up_xz", "up_yz", "up_xy_tex", "up_xz_tex", "up_yz_tex", # in case of hexa_vX "down", "up" # in case of vanilla ]: if hasattr(self, layer): # initialize the weights if "down" in layer: nn.init.normal_(getattr(self, layer).weight, std=1 / self.rank) elif "up" in layer: nn.init.zeros_(getattr(self, layer).weight) # initialize the bias if getattr(self, layer).bias is not None: nn.init.zeros_(getattr(self, layer).bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype = hidden_states.dtype dtype = self.down_xy_geo.weight.dtype if "hexa" in self.locon_type else self.down.weight.dtype if "hexa" in self.locon_type: # xy plane hidden_states_xy_geo = self.up_xy_geo(self.down_xy_geo(hidden_states[0::6].to(dtype))) hidden_states_xy_tex = self.up_xy_tex(self.down_xy_tex(hidden_states[3::6].to(dtype))) lora_hidden_states = torch.concat( [torch.zeros_like(hidden_states_xy_tex)] * 6, dim=0 ) lora_hidden_states[0::6] = hidden_states_xy_geo lora_hidden_states[3::6] = hidden_states_xy_tex # xz plane lora_hidden_states[1::6] = self.up_xz_geo(self.down_xz_geo(hidden_states[1::6].to(dtype))) lora_hidden_states[4::6] = self.up_xz_tex(self.down_xz_tex(hidden_states[4::6].to(dtype))) # yz plane lora_hidden_states[2::6] = self.up_yz_geo(self.down_yz_geo(hidden_states[2::6].to(dtype))) lora_hidden_states[5::6] = self.up_yz_tex(self.down_yz_tex(hidden_states[5::6].to(dtype))) elif "vanilla" in self.locon_type: lora_hidden_states = self.up(self.down(hidden_states.to(dtype))) if self.network_alpha is not None: lora_hidden_states *= self.network_alpha / self.rank return lora_hidden_states.to(orig_dtype) class TriplaneSelfAttentionLoRAAttnProcessor(nn.Module): """ Perform for implementing the Triplane Self-Attention LoRA Attention Processor. """ def __init__( self, hidden_size: int, rank: int = 4, network_alpha=None, with_bias: bool = False, lora_type: str = "hexa_v1", # vanilla, ): super().__init__() assert lora_type in ["hexa_v1", "vanilla", "none", "basic"], "The LoRA type is not supported." self.hidden_size = hidden_size self.rank = rank self.lora_type = lora_type if lora_type in ["hexa_v1"]: # lora for 1st plane geometry self.to_q_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 1st plane texture self.to_q_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 2nd plane geometry self.to_q_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 2nd plane texture self.to_q_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 3nd plane geometry self.to_q_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 3nd plane texture self.to_q_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) elif lora_type in ["vanilla", "basic"]: self.to_q_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None ): assert encoder_hidden_states is None, "The encoder_hidden_states should be None." residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) ############################################################################################################ # query if self.lora_type in ["hexa_v1",]: query = attn.to_q(hidden_states) _query_new = torch.zeros_like(query) # lora for xy plane geometry _query_new[0::6] = self.to_q_xy_lora_geo(hidden_states[0::6]) # lora for xy plane texture _query_new[3::6] = self.to_q_xy_lora_tex(hidden_states[3::6]) # lora for xz plane geometry _query_new[1::6] = self.to_q_xz_lora_geo(hidden_states[1::6]) # lora for xz plane texture _query_new[4::6] = self.to_q_xz_lora_tex(hidden_states[4::6]) # lora for yz plane geometry _query_new[2::6] = self.to_q_yz_lora_geo(hidden_states[2::6]) # lora for yz plane texture _query_new[5::6] = self.to_q_yz_lora_tex(hidden_states[5::6]) query = query + scale * _query_new # # speed up inference # query[0::6] += self.to_q_xy_lora_geo(hidden_states[0::6]) * scale # query[3::6] += self.to_q_xy_lora_tex(hidden_states[3::6]) * scale # query[1::6] += self.to_q_xz_lora_geo(hidden_states[1::6]) * scale # query[4::6] += self.to_q_xz_lora_tex(hidden_states[4::6]) * scale # query[2::6] += self.to_q_yz_lora_geo(hidden_states[2::6]) * scale # query[5::6] += self.to_q_yz_lora_tex(hidden_states[5::6]) * scale elif self.lora_type in ["vanilla", "basic"]: query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) elif self.lora_type in ["none"]: query = attn.to_q(hidden_states) else: raise NotImplementedError("The LoRA type is not supported for the query in HplaneSelfAttentionLoRAAttnProcessor.") ############################################################################################################ if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) ############################################################################################################ # key and value if self.lora_type in ["hexa_v1",]: key = attn.to_k(encoder_hidden_states) _key_new = torch.zeros_like(key) # lora for xy plane geometry _key_new[0::6] = self.to_k_xy_lora_geo(encoder_hidden_states[0::6]) # lora for xy plane texture _key_new[3::6] = self.to_k_xy_lora_tex(encoder_hidden_states[3::6]) # lora for xz plane geometry _key_new[1::6] = self.to_k_xz_lora_geo(encoder_hidden_states[1::6]) # lora for xz plane texture _key_new[4::6] = self.to_k_xz_lora_tex(encoder_hidden_states[4::6]) # lora for yz plane geometry _key_new[2::6] = self.to_k_yz_lora_geo(encoder_hidden_states[2::6]) # lora for yz plane texture _key_new[5::6] = self.to_k_yz_lora_tex(encoder_hidden_states[5::6]) key = key + scale * _key_new # # speed up inference # key[0::6] += self.to_k_xy_lora_geo(encoder_hidden_states[0::6]) * scale # key[3::6] += self.to_k_xy_lora_tex(encoder_hidden_states[3::6]) * scale # key[1::6] += self.to_k_xz_lora_geo(encoder_hidden_states[1::6]) * scale # key[4::6] += self.to_k_xz_lora_tex(encoder_hidden_states[4::6]) * scale # key[2::6] += self.to_k_yz_lora_geo(encoder_hidden_states[2::6]) * scale # key[5::6] += self.to_k_yz_lora_tex(encoder_hidden_states[5::6]) * scale value = attn.to_v(encoder_hidden_states) _value_new = torch.zeros_like(value) # lora for xy plane geometry _value_new[0::6] = self.to_v_xy_lora_geo(encoder_hidden_states[0::6]) # lora for xy plane texture _value_new[3::6] = self.to_v_xy_lora_tex(encoder_hidden_states[3::6]) # lora for xz plane geometry _value_new[1::6] = self.to_v_xz_lora_geo(encoder_hidden_states[1::6]) # lora for xz plane texture _value_new[4::6] = self.to_v_xz_lora_tex(encoder_hidden_states[4::6]) # lora for yz plane geometry _value_new[2::6] = self.to_v_yz_lora_geo(encoder_hidden_states[2::6]) # lora for yz plane texture _value_new[5::6] = self.to_v_yz_lora_tex(encoder_hidden_states[5::6]) value = value + scale * _value_new # # speed up inference # value[0::6] += self.to_v_xy_lora_geo(encoder_hidden_states[0::6]) * scale # value[3::6] += self.to_v_xy_lora_tex(encoder_hidden_states[3::6]) * scale # value[1::6] += self.to_v_xz_lora_geo(encoder_hidden_states[1::6]) * scale # value[4::6] += self.to_v_xz_lora_tex(encoder_hidden_states[4::6]) * scale # value[2::6] += self.to_v_yz_lora_geo(encoder_hidden_states[2::6]) * scale # value[5::6] += self.to_v_yz_lora_tex(encoder_hidden_states[5::6]) * scale elif self.lora_type in ["vanilla", "basic"]: key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) elif self.lora_type in ["none", ]: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) else: raise NotImplementedError("The LoRA type is not supported for the key and value in HplaneSelfAttentionLoRAAttnProcessor.") ############################################################################################################ # attention scores # in self-attention, query of each plane should be used to calculate the attention scores of all planes if self.lora_type in ["hexa_v1", "vanilla",]: query = attn.head_to_batch_dim( query.view(batch_size // 6, sequence_length * 6, self.hidden_size) ) key = attn.head_to_batch_dim( key.view(batch_size // 6, sequence_length * 6, self.hidden_size) ) value = attn.head_to_batch_dim( value.view(batch_size // 6, sequence_length * 6, self.hidden_size) ) # calculate the attention scores attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # split the hidden states into 6 planes hidden_states = hidden_states.view(batch_size, sequence_length, self.hidden_size) elif self.lora_type in ["none", "basic"]: query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) # calculate the attention scores attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) else: raise NotImplementedError("The LoRA type is not supported for attention scores calculation in HplaneSelfAttentionLoRAAttnProcessor.") ############################################################################################################ # linear proj if self.lora_type in ["hexa_v1", ]: hidden_states = attn.to_out[0](hidden_states) _hidden_states_new = torch.zeros_like(hidden_states) # lora for xy plane geometry _hidden_states_new[0::6] = self.to_out_xy_lora_geo(hidden_states[0::6]) # lora for xy plane texture _hidden_states_new[3::6] = self.to_out_xy_lora_tex(hidden_states[3::6]) # lora for xz plane geometry _hidden_states_new[1::6] = self.to_out_xz_lora_geo(hidden_states[1::6]) # lora for xz plane texture _hidden_states_new[4::6] = self.to_out_xz_lora_tex(hidden_states[4::6]) # lora for yz plane geometry _hidden_states_new[2::6] = self.to_out_yz_lora_geo(hidden_states[2::6]) # lora for yz plane texture _hidden_states_new[5::6] = self.to_out_yz_lora_tex(hidden_states[5::6]) hidden_states = hidden_states + scale * _hidden_states_new # # speed up inference # hidden_states[0::6] += self.to_out_xy_lora_geo(hidden_states[0::6]) * scale # hidden_states[3::6] += self.to_out_xy_lora_tex(hidden_states[3::6]) * scale # hidden_states[1::6] += self.to_out_xz_lora_geo(hidden_states[1::6]) * scale # hidden_states[4::6] += self.to_out_xz_lora_tex(hidden_states[4::6]) * scale # hidden_states[2::6] += self.to_out_yz_lora_geo(hidden_states[2::6]) * scale # hidden_states[5::6] += self.to_out_yz_lora_tex(hidden_states[5::6]) * scale elif self.lora_type in ["vanilla", "basic"]: hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) elif self.lora_type in ["none",]: hidden_states = attn.to_out[0](hidden_states) else: raise NotImplementedError("The LoRA type is not supported for the to_out layer in HplaneSelfAttentionLoRAAttnProcessor.") # dropout hidden_states = attn.to_out[1](hidden_states) ############################################################################################################ if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class TriplaneCrossAttentionLoRAAttnProcessor(nn.Module): """ Perform for implementing the Triplane Cross-Attention LoRA Attention Processor. """ def __init__( self, hidden_size: int, cross_attention_dim: int, rank: int = 4, network_alpha = None, with_bias: bool = False, lora_type: str = "hexa_v1", # vanilla, ): super().__init__() assert lora_type in ["hexa_v1", "vanilla", "none"], "The LoRA type is not supported." self.hidden_size = hidden_size self.rank = rank self.lora_type = lora_type if lora_type in ["hexa_v1"]: # lora for 1st plane geometry self.to_q_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_xy_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_xy_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_xy_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 1st plane texture self.to_q_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_xy_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_xy_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_xy_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 2nd plane geometry self.to_q_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_xz_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_xz_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_xz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 2nd plane texture self.to_q_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_xz_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_xz_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_xz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 3nd plane geometry self.to_q_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_yz_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_yz_lora_geo = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_yz_lora_geo = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) # lora for 3nd plane texture self.to_q_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_yz_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_yz_lora_tex = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_yz_lora_tex = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) elif lora_type in ["vanilla"]: # lora for all planes self.to_q_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_k_lora = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_v_lora = LoRALinearLayerwBias(cross_attention_dim, hidden_size, rank, network_alpha, with_bias=with_bias) self.to_out_lora = LoRALinearLayerwBias(hidden_size, hidden_size, rank, network_alpha, with_bias=with_bias) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None ): assert encoder_hidden_states is not None, "The encoder_hidden_states should not be None." residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) ############################################################################################################ # query if self.lora_type in ["hexa_v1",]: query = attn.to_q(hidden_states) _query_new = torch.zeros_like(query) # lora for xy plane geometry _query_new[0::6] = self.to_q_xy_lora_geo(hidden_states[0::6]) # lora for xy plane texture _query_new[3::6] = self.to_q_xy_lora_tex(hidden_states[3::6]) # lora for xz plane geometry _query_new[1::6] = self.to_q_xz_lora_geo(hidden_states[1::6]) # lora for xz plane texture _query_new[4::6] = self.to_q_xz_lora_tex(hidden_states[4::6]) # lora for yz plane geometry _query_new[2::6] = self.to_q_yz_lora_geo(hidden_states[2::6]) # lora for yz plane texture _query_new[5::6] = self.to_q_yz_lora_tex(hidden_states[5::6]) query = query + scale * _query_new elif self.lora_type == "vanilla": query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) elif self.lora_type == "none": query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) ############################################################################################################ if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) ############################################################################################################ # key and value if self.lora_type in ["hexa_v1",]: key = attn.to_k(encoder_hidden_states) _key_new = torch.zeros_like(key) # lora for xy plane geometry _key_new[0::6] = self.to_k_xy_lora_geo(encoder_hidden_states[0::6]) # lora for xy plane texture _key_new[3::6] = self.to_k_xy_lora_tex(encoder_hidden_states[3::6]) # lora for xz plane geometry _key_new[1::6] = self.to_k_xz_lora_geo(encoder_hidden_states[1::6]) # lora for xz plane texture _key_new[4::6] = self.to_k_xz_lora_tex(encoder_hidden_states[4::6]) # lora for yz plane geometry _key_new[2::6] = self.to_k_yz_lora_geo(encoder_hidden_states[2::6]) # lora for yz plane texture _key_new[5::6] = self.to_k_yz_lora_tex(encoder_hidden_states[5::6]) key = key + scale * _key_new value = attn.to_v(encoder_hidden_states) _value_new = torch.zeros_like(value) # lora for xy plane geometry _value_new[0::6] = self.to_v_xy_lora_geo(encoder_hidden_states[0::6]) # lora for xy plane texture _value_new[3::6] = self.to_v_xy_lora_tex(encoder_hidden_states[3::6]) # lora for xz plane geometry _value_new[1::6] = self.to_v_xz_lora_geo(encoder_hidden_states[1::6]) # lora for xz plane texture _value_new[4::6] = self.to_v_xz_lora_tex(encoder_hidden_states[4::6]) # lora for yz plane geometry _value_new[2::6] = self.to_v_yz_lora_geo(encoder_hidden_states[2::6]) # lora for yz plane texture _value_new[5::6] = self.to_v_yz_lora_tex(encoder_hidden_states[5::6]) value = value + scale * _value_new elif self.lora_type in ["vanilla",]: key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) elif self.lora_type in ["none",]: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) ############################################################################################################ # calculate the attention scores attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) ############################################################################################################ # linear proj if self.lora_type in ["hexa_v1", ]: hidden_states = attn.to_out[0](hidden_states) _hidden_states_new = torch.zeros_like(hidden_states) # lora for xy plane geometry _hidden_states_new[0::6] = self.to_out_xy_lora_geo(hidden_states[0::6]) # lora for xy plane texture _hidden_states_new[3::6] = self.to_out_xy_lora_tex(hidden_states[3::6]) # lora for xz plane geometry _hidden_states_new[1::6] = self.to_out_xz_lora_geo(hidden_states[1::6]) # lora for xz plane texture _hidden_states_new[4::6] = self.to_out_xz_lora_tex(hidden_states[4::6]) # lora for yz plane geometry _hidden_states_new[2::6] = self.to_out_yz_lora_geo(hidden_states[2::6]) # lora for yz plane texture _hidden_states_new[5::6] = self.to_out_yz_lora_tex(hidden_states[5::6]) hidden_states = hidden_states + scale * _hidden_states_new elif self.lora_type in ["vanilla",]: hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) elif self.lora_type in ["none",]: hidden_states = attn.to_out[0](hidden_states) else: raise NotImplementedError("The LoRA type is not supported for the to_out layer in HplaneCrossAttentionLoRAAttnProcessor.") # dropout hidden_states = attn.to_out[1](hidden_states) ############################################################################################################ if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states @dataclass class GeneratorConfig: training_type: str = "self_lora_rank_16-cross_lora_rank_16-locon_rank_16" output_dim: int = 32 self_lora_type: str = "hexa_v1" cross_lora_type: str = "hexa_v1" locon_type: str = "vanilla_v1" vae_attn_type: str = "basic" prompt_bias: bool = False class OneStepTriplaneDualStableDiffusion(nn.Module): """ One-step Triplane Stable Diffusion module. """ def __init__( self, config, vae: AutoencoderKL, unet: UNet2DConditionModel, ): super().__init__() # Convert dict to GeneratorConfig if needed self.cfg = GeneratorConfig(**config) if isinstance(config, dict) else config self.output_dim = self.cfg.output_dim # Load models self.unet = unet self.vae = vae # Get device from one of the models self.device = next(self.unet.parameters()).device # Remove unused components del vae.encoder del vae.quant_conv # Get training type from config training_type = self.cfg.training_type # save trainable parameters if not "full" in training_type: # then paramter-efficient training trainable_params = {} assert "lora" in training_type or "locon" in training_type, "The training type is not supported." @dataclass class SubModules: unet: UNet2DConditionModel vae: AutoencoderKL self.submodules = SubModules( unet=unet.to(self.device), vae=vae.to(self.device), ) # free all the parameters for param in self.unet.parameters(): param.requires_grad_(False) for param in self.vae.parameters(): param.requires_grad_(False) ############################################################ # overwrite the unet and vae with the customized processors if "lora" in training_type: # parse the rank from the training type, with the template "lora_rank_{}" assert "self_lora_rank" in training_type, "The self_lora_rank is not specified." rank = re.search(r"self_lora_rank_(\d+)", training_type).group(1) self.self_lora_rank = int(rank) assert "cross_lora_rank" in training_type, "The cross_lora_rank is not specified." rank = re.search(r"cross_lora_rank_(\d+)", training_type).group(1) self.cross_lora_rank = int(rank) # if the finetuning is with bias self.w_lora_bias = False if "with_bias" in training_type: self.w_lora_bias = True # specify the attn_processor for unet lora_attn_procs = self._set_attn_processor( self.unet, self_attn_name="attn1.processor", self_lora_type=self.cfg.self_lora_type, cross_lora_type=self.cfg.cross_lora_type ) self.unet.set_attn_processor(lora_attn_procs) # update the trainable parameters trainable_params.update(self.unet.attn_processors) # specify the attn_processor for vae lora_attn_procs = self._set_attn_processor( self.vae, self_attn_name="processor", self_lora_type=self.cfg.vae_attn_type, # hard-coded for vae cross_lora_type="vanilla" ) self.vae.set_attn_processor(lora_attn_procs) # update the trainable parameters trainable_params.update(self.vae.attn_processors) else: raise NotImplementedError("The training type is not supported.") if "locon" in training_type: # parse the rank from the training type, with the template "locon_rank_{}" rank = re.search(r"locon_rank_(\d+)", training_type).group(1) self.locon_rank = int(rank) # if the finetuning is with bias self.w_locon_bias = False if "with_bias" in training_type: self.w_locon_bias = True # specify the conv_processor for unet locon_procs = self._set_conv_processor( self.unet, locon_type=self.cfg.locon_type ) # update the trainable parameters trainable_params.update(locon_procs) # specify the conv_processor for vae locon_procs = self._set_conv_processor( self.vae, locon_type="vanilla_v1", # hard-coded for vae decoder ) # update the trainable parameters trainable_params.update(locon_procs) else: raise NotImplementedError("The training type is not supported.") # overwrite the outconv # conv_out_orig = self.vae.decoder.conv_out conv_out_new = nn.Conv2d( in_channels=128, # conv_out_orig.in_channels, hard-coded out_channels=self.cfg.output_dim, kernel_size=3, padding=1 ) # update the trainable parameters self.vae.decoder.conv_out = conv_out_new trainable_params["vae.decoder.conv_out"] = conv_out_new # save the trainable parameters self.peft_layers = AttnProcsLayers(trainable_params).to(self.device) self.peft_layers._load_state_dict_pre_hooks.clear() self.peft_layers._state_dict_hooks.clear() # hard-coded for now self.num_planes = 6 if self.cfg.prompt_bias: self.prompt_bias = nn.Parameter(torch.zeros(self.num_planes, 77, 1024)) @property def unet(self): return self.submodules.unet @property def vae(self): return self.submodules.vae def _set_conv_processor( self, module, conv_name: str = "LoRACompatibleConv", locon_type: str = "vanilla_v1", ): locon_procs = {} for _name, _module in module.named_modules(): if _module.__class__.__name__ == conv_name: # append the locon processor to the module locon_proc = TriplaneLoRAConv2dLayer( in_features=_module.in_channels, out_features=_module.out_channels, rank=self.locon_rank, kernel_size=_module.kernel_size, stride=_module.stride, padding=_module.padding, with_bias = self.w_locon_bias, locon_type= locon_type, ) # add the locon processor to the module _module.lora_layer = locon_proc # update the trainable parameters key_name = f"{_name}.lora_layer" locon_procs[key_name] = locon_proc return locon_procs def _set_attn_processor( self, module, self_attn_name: str = "attn1.processor", self_attn_procs = TriplaneSelfAttentionLoRAAttnProcessor, self_lora_type: str = "hexa_v1", cross_attn_procs = TriplaneCrossAttentionLoRAAttnProcessor, cross_lora_type: str = "hexa_v1", ): lora_attn_procs = {} for name in module.attn_processors.keys(): if name.startswith("mid_block"): hidden_size = module.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(module.config.block_out_channels))[ block_id ] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = module.config.block_out_channels[block_id] elif name.startswith("decoder"): # special case for decoder in SD hidden_size = 512 if name.endswith(self_attn_name): # it is self-attention cross_attention_dim = None lora_attn_procs[name] = self_attn_procs( hidden_size, self.self_lora_rank, with_bias = self.w_lora_bias, lora_type = self_lora_type ) else: # it is cross-attention cross_attention_dim = module.config.cross_attention_dim lora_attn_procs[name] = cross_attn_procs( hidden_size, cross_attention_dim, self.cross_lora_rank, with_bias = self.w_lora_bias, lora_type = cross_lora_type ) return lora_attn_procs def forward( self, text_embed, styles, ): return None def forward_denoise( self, text_embed, noisy_input, t, ): batch_size = text_embed.size(0) noise_shape = noisy_input.size(-2) if text_embed.ndim == 3: # same text_embed for all planes # text_embed = text_embed.repeat(self.num_planes, 1, 1) # wrong!!! text_embed = text_embed.repeat_interleave(self.num_planes, dim=0) elif text_embed.ndim == 4: # different text_embed for each plane text_embed = text_embed.view(batch_size * self.num_planes, *text_embed.shape[-2:]) else: raise ValueError("The text_embed should be either 3D or 4D.") if hasattr(self, "prompt_bias"): text_embed = text_embed + self.prompt_bias.repeat(batch_size, 1, 1) * self.cfg.prompt_bias_lr_multiplier noisy_input = noisy_input.view(-1, 4, noise_shape, noise_shape) noise_pred = self.unet( noisy_input, t, encoder_hidden_states=text_embed ).sample return noise_pred def forward_decode( self, latents, ): latents = latents.view(-1, 4, *latents.shape[-2:]) triplane = self.vae.decode(latents).sample triplane = triplane.view(-1, self.num_planes, self.cfg.output_dim, *triplane.shape[-2:]) return triplane