import math import os import random import glob import pickle import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from torch.optim import Adam from torchvision.utils import make_grid from PIL import Image from transformers import ( DistilBertModel, DistilBertTokenizer, CLIPTokenizer, CLIPTextModel, ) dataset_params = { "image_path": "data/CelebAMask-HQ", "image_channels": 3, "image_size": 256, "name": "celebhq", } diffusion_params = { "num_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012, } ldm_params = { "down_channels": [256, 384, 512, 768], "mid_channels": [768, 512], "down_sample": [True, True, True], "attn_down": [True, True, True], # Attention in the DownBlock and UpBlock of VQ-VAE "time_emb_dim": 512, "norm_channels": 32, "num_heads": 16, "conv_out_channels": 128, "num_down_layers": 2, "num_mid_layers": 2, "num_up_layers": 2, "condition_config": { "condition_types": ["text", "image"], "text_condition_config": { "text_embed_model": "clip", "train_text_embed_model": False, "text_embed_dim": 512, # Each token should map to text_embed_dim sized vector "cond_drop_prob": 0.1, # Probability of dropping conditioning during training to allow the model to generate images without conditioning as well }, "image_condition_config": { "image_condition_input_channels": 18, # CelebA has 18 classes excluding background "image_condition_output_channels": 3, "image_condition_h": 512, # Mask height "image_condition_w": 512, # Mask width "cond_drop_prob": 0.1, # Probability of dropping conditioning during training to allow the model to generate images without conditioning as well }, }, } autoencoder_params = { "z_channels": 4, "codebook_size": 8192, "down_channels": [64, 128, 256, 256], "mid_channels": [256, 256], "down_sample": [True, True, True], "attn_down": [ False, False, False, ], # No attention in the DownBlock and UpBlock of VQ-VAE "norm_channels": 32, "num_heads": 4, "num_down_layers": 2, "num_mid_layers": 2, "num_up_layers": 2, } train_params = { "seed": 1111, "task_name": "celebhq", # Folder to save models and images to "ldm_batch_size": 16, "autoencoder_batch_size": 4, "disc_start": 15000, "disc_weight": 0.5, "codebook_weight": 1, "commitment_beta": 0.2, "perceptual_weight": 1, "kl_weight": 0.000005, "ldm_epochs": 100, "autoencoder_epochs": 20, "num_samples": 1, "num_grid_rows": 1, "ldm_lr": 0.000005, "autoencoder_lr": 0.00001, "autoencoder_acc_steps": 4, "autoencoder_img_save_steps": 64, "save_latents": True, "cf_guidance_scale": 1.0, "vqvae_latent_dir_name": "vqvae_latents", "ldm_ckpt_name": "ddpm_ckpt_class_cond.pth", "vqvae_autoencoder_ckpt_name": "vqvae_autoencoder_ckpt.pth", } def get_config_value(config, key, default_value): return config[key] if key in config else default_value def spatial_average(in_tens, keepdim=True): return in_tens.mean([2, 3], keepdim=keepdim) class LinearNoiseScheduler: def __init__(self, num_timesteps, beta_start, beta_end): self.num_timesteps = num_timesteps self.beta_start = beta_start self.beta_end = beta_end self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2 self.alphas = 1.0 - self.betas self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) def add_noise(self, original, noise, t): # original: (batch_size, c, h, w), t: tensor of timesteps (batch_size,) batch_size = original.shape[0] sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].view( batch_size, 1, 1, 1 ) sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to( original.device )[t].view(batch_size, 1, 1, 1) return sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise def sample_prev_timestep(self, xt, noise_pred, t): batch_size = xt.shape[0] alpha_cum_prod_t = self.alpha_cum_prod.to(xt.device)[t].view( batch_size, 1, 1, 1 ) sqrt_one_minus_alpha_cum_prod_t = self.sqrt_one_minus_alpha_cum_prod.to( xt.device )[t].view(batch_size, 1, 1, 1) x0 = (xt - sqrt_one_minus_alpha_cum_prod_t * noise_pred) / torch.sqrt( alpha_cum_prod_t ) x0 = torch.clamp(x0, -1.0, 1.0) betas_t = self.betas.to(xt.device)[t].view(batch_size, 1, 1, 1) mean = ( xt - betas_t / sqrt_one_minus_alpha_cum_prod_t * noise_pred ) / torch.sqrt(self.alphas.to(xt.device)[t].view(batch_size, 1, 1, 1)) if t[0] == 0: return mean, x0 else: prev_alpha_cum_prod = self.alpha_cum_prod.to(xt.device)[ (t - 1).clamp(min=0) ].view(batch_size, 1, 1, 1) variance = (1 - prev_alpha_cum_prod) / (1 - alpha_cum_prod_t) * betas_t sigma = variance.sqrt() z = torch.randn_like(xt) return mean + sigma * z, x0 def get_tokenizer_and_model(model_type, device, eval_mode=True): assert model_type in ( "bert", "clip", ), "Text model can only be one of 'clip' or 'bert'" if model_type == "bert": text_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") text_model = DistilBertModel.from_pretrained("distilbert-base-uncased").to( device ) else: text_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16") text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch16").to( device ) if eval_mode: text_model.eval() return text_tokenizer, text_model def get_text_representation(text, text_tokenizer, text_model, device, max_length=77): token_output = text_tokenizer( text, truncation=True, padding="max_length", return_attention_mask=True, max_length=max_length, ) tokens_tensor = torch.tensor(token_output["input_ids"]).to(device) mask_tensor = torch.tensor(token_output["attention_mask"]).to(device) text_embed = text_model(tokens_tensor, attention_mask=mask_tensor).last_hidden_state return text_embed def get_time_embedding(time_steps, temb_dim): """ Convert time steps tensor into an embedding using the sinusoidal time embedding formula time_steps: 1D tensor of length batch size temb_dim: Dimension of the embedding """ assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" # factor = 10000^(2i/d_model) factor = 10000 ** ( ( torch.arange( start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device, ) / (temb_dim // 2) ) ) t_emb = time_steps.unsqueeze(dim=-1).repeat(1, temb_dim // 2) / factor t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) return t_emb # (batch_size, temb_dim) class DownBlock(nn.Module): """ Down conv block with attention. 1. Resnet block with time embedding 2. Attention block 3. Downsample in_channels: Number of channels in the input feature map. out_channels: Number of channels produced by this block. t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. down_sample: Whether to apply downsampling at the end. num_heads: Number of attention heads (used if attention is enabled). num_layers: How many sub-blocks to apply in sequence. attn: Whether to apply self-attention norm_channels: Number of groups for GroupNorm. cross_attn: Whether to apply cross-attention. context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. """ def __init__( self, in_channels, out_channels, t_emb_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None, ): super().__init__() self.num_layers = num_layers self.down_sample = down_sample self.attn = attn self.context_dim = context_dim self.cross_attn = cross_attn self.t_emb_dim = t_emb_dim self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( norm_channels, in_channels if i == 0 else out_channels ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels nn.SiLU(), nn.Conv2d( in_channels=(in_channels if i == 0 else out_channels), out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w) ) for i in range(num_layers) ] ) # Only add the time embedding for diffusion and not AutoEncoder if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential( nn.SiLU(), nn.Linear( in_features=self.t_emb_dim, out_features=out_channels ), # (batch_size, t_emb_dim) -> (batch_size, out_channels) ) for i in range(num_layers) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) ) for i in range(num_layers) ] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d( in_channels=(in_channels if i == 0 else out_channels), out_channels=out_channels, kernel_size=1, stride=1, padding=0, ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w) for i in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention( embed_dim=out_channels, num_heads=num_heads, batch_first=True ) for i in range(num_layers) ] ) # Cross attention for text conditioning if self.cross_attn: assert ( context_dim is not None ), "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention( embed_dim=out_channels, num_heads=num_heads, batch_first=True ) for i in range(num_layers) ] ) self.context_proj = nn.ModuleList( [ nn.Linear(in_features=context_dim, out_features=out_channels) for i in range(num_layers) ] ) # Down sample by a factor of 2 self.down_sample_conv = ( nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, ) if self.down_sample else nn.Identity() ) # (batch_size, out_channels, h / 2, w / 2) def forward(self, x, t_emb=None, context=None): out = x for i in range(self.num_layers): # Resnet block of UNET resnet_input = out # (batch_size, c, h, w) out = self.resnet_conv_first[i](out) # (batch_size, out_channels, h, w) # Only add the time embedding for diffusion and not AutoEncoder if self.t_emb_dim is not None: # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( dim=-1 ) # (batch_size, out_channels, h, w) out = self.resnet_conv_second[i]( out ) # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) # Residual Connection out = out + self.residual_input_conv[i]( resnet_input ) # (batch_size, out_channels, h, w) # Only do for Diffusion and not for AutoEncoder if self.attn: # Attention block of UNET batch_size, channels, h, w = ( out.shape ) # (batch_size, out_channels, h, w) in_attn = out.reshape( batch_size, channels, h * w ) # (batch_size, out_channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels) # Self-Attention out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape( batch_size, channels, h, w ) # (batch_size, out_channels h, w) # Skip connection out = out + out_attn # (batch_size, out_channels h, w) if self.cross_attn: assert ( context is not None ), "context cannot be None if cross attention layers are used" batch_size, channels, h, w = ( out.shape ) # (batch_size, out_channels, h, w) in_attn = out.reshape( batch_size, channels, h * w ) # (batch_size, out_channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels) assert ( context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim ) # Make sure the batch_size and context_dim match with the model's parameters context_proj = self.context_proj[i]( context ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, out_channels) # Cross-Attention out_attn, attn_weights = self.cross_attentions[i]( in_attn, context_proj, context_proj ) # (batch_size, h * w, out_channels) out_attn = out_attn.transpose(1, 2).reshape( batch_size, channels, h, w ) # (batch_size, out_channels, h, w) # Skip Connection out = out + out_attn # (batch_size, out_channels, h, w) # Downsampling out = self.down_sample_conv(out) # (batch_size, out_channels, h / 2, w / 2) return out class MidBlock(nn.Module): """ Mid conv block with attention. 1. Resnet block with time embedding 2. Attention block 3. Resnet block with time embedding in_channels: Number of channels in the input feature map. out_channels: Number of channels produced by this block. t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. num_heads: Number of attention heads (used if attention is enabled). num_layers: How many sub-blocks to apply in sequence. norm_channels: Number of groups for GroupNorm. cross_attn: Whether to apply cross-attention. context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. """ def __init__( self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None, ): super().__init__() self.num_layers = num_layers self.t_emb_dim = t_emb_dim self.context_dim = context_dim self.cross_attn = cross_attn self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( norm_channels, in_channels if i == 0 else out_channels ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels nn.SiLU(), nn.Conv2d( in_channels=(in_channels if i == 0 else out_channels), out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w) ) for i in range(num_layers + 1) ] ) # Only add the time embedding for diffusion and not AutoEncoder if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential( nn.SiLU(), nn.Linear( in_features=self.t_emb_dim, out_features=out_channels ), # (batch_size, t_emb_dim) -> (batch_size, out_channels) ) for i in range(num_layers + 1) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) ) for i in range(num_layers + 1) ] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d( in_channels=(in_channels if i == 0 else out_channels), out_channels=out_channels, kernel_size=1, stride=1, padding=0, ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w) for i in range(num_layers + 1) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention( embed_dim=out_channels, num_heads=num_heads, batch_first=True ) for i in range(num_layers) ] ) # Cross attention for text conditioning if self.cross_attn: assert ( context_dim is not None ), "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention( embed_dim=out_channels, num_heads=num_heads, batch_first=True ) for i in range(num_layers) ] ) self.context_proj = nn.ModuleList( [ nn.Linear(in_features=context_dim, out_features=out_channels) for i in range(num_layers) ] ) def forward(self, x, t_emb=None, context=None): out = x # First ResNet block resnet_input = out # (batch_size, c, h, w) out = self.resnet_conv_first[0](out) # (batch_size, out_channels, h, w) # Only add the time embedding for diffusion and not AutoEncoder if self.t_emb_dim is not None: # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) out = out + self.t_emb_layers[0](t_emb).unsqueeze(dim=-1).unsqueeze( dim=-1 ) # (batch_size, out_channels, h, w) out = self.resnet_conv_second[0]( out ) # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) # Residual Connection out = out + self.residual_input_conv[0]( resnet_input ) # (batch_size, out_channels, h, w) for i in range(self.num_layers): # Attention Block batch_size, channels, h, w = out.shape # (batch_size, out_channels, h, w) # Do for both Diffusion and AutoEncoder in_attn = out.reshape( batch_size, channels, h * w ) # (batch_size, out_channels, h * w) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels) # Self-Attention out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) # Skip connection out = out + out_attn # (batch_size, out_channels h, w) if self.cross_attn: assert ( context is not None ), "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape( batch_size, channels, h * w ) # (batch_size, out_channels, h * w) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels) assert ( context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim ) # Make sure the batch_size and context_dim match with the model's parameters context_proj = self.context_proj[i]( context ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, context_dim) # Cross-Attention out_attn, attn_weights = self.cross_attentions[i]( in_attn, context_proj, context_proj ) out_attn = out_attn.transpose(1, 2).reshape( batch_size, channels, h, w ) # (batch_size, out_channels, h, w) # Skip Connection out = out + out_attn # (batch_size, out_channels h, w) # Resnet Block resnet_input = out out = self.resnet_conv_first[i + 1]( out ) # (batch_size, out_channels h, w) -> (batch_size, out_channels h, w) # Only add the time embedding for diffusion and not AutoEncoder if self.t_emb_dim is not None: # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) out = out + self.t_emb_layers[i + 1](t_emb).unsqueeze(dim=-1).unsqueeze( dim=-1 ) # (batch_size, out_channels h, w) out = self.resnet_conv_second[i + 1]( out ) # (batch_size, out_channels h, w) -> (batch_size, out_channels h, w) # Residual Connection out = out + self.residual_input_conv[i + 1]( resnet_input ) # (batch_size, out_channels, h, w) return out class UpBlock(nn.Module): """ Up conv block with attention. 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block in_channels: Number of channels in the input feature map. out_channels: Number of channels produced by this block. t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. up_sample: Whether to apply upsampling at the end. num_heads: Number of attention heads (used if attention is enabled). num_layers: How many sub-blocks to apply in sequence. attn: Whether to apply self-attention norm_channels: Number of groups for GroupNorm. """ def __init__( self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, attn, norm_channels, ): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.attn = attn # Upsample by a factor of 2 self.up_sample_conv = ( nn.ConvTranspose2d( in_channels=in_channels, out_channels=in_channels, kernel_size=4, stride=2, padding=1, ) if self.up_sample else nn.Identity() ) # (batch_size, c, h * 2, w * 2) self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( norm_channels, in_channels if i == 0 else out_channels ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels nn.SiLU(), nn.Conv2d( in_channels=(in_channels if i == 0 else out_channels), out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w) ) for i in range(num_layers) ] ) # Only add the time embedding for diffusion and not AutoEncoder if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential( nn.SiLU(), nn.Linear( in_features=self.t_emb_dim, out_features=out_channels ), # (batch_size, t_emb_dim) -> (batch_size, out_channels) ) for i in range(num_layers) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) ) for i in range(num_layers) ] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d( in_channels=(in_channels if i == 0 else out_channels), out_channels=out_channels, kernel_size=1, stride=1, padding=0, ) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w) for i in range(num_layers) ] ) if self.attn: self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention( embed_dim=out_channels, num_heads=num_heads, batch_first=True ) for i in range(num_layers) ] ) def forward(self, x, out_down=None, t_emb=None): # x shape: (batch_size, c, h, w) # Upsample x = self.up_sample_conv( x ) # (batch_size, c, h, w) -> (batch_size, c, h * 2, w * 2) # *Only do for diffusion # Concatenate with the output of respective DownBlock if out_down is not None: x = torch.cat( [x, out_down], dim=1 ) # (batch_size, c, h * 2, w * 2) -> (batch_size, c * 2, h * 2, w * 2) out = x # (batch_size, c, h * 2, w * 2) for i in range(self.num_layers): # Resnet block resnet_input = out out = self.resnet_conv_first[i]( out ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) # Only add the time embedding for diffusion and not AutoEncoder if self.t_emb_dim is not None: # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( dim=-1 ) # (batch_size, out_channels, h * 2, w * 2) out = self.resnet_conv_second[i]( out ) # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) # Residual Connection out = out + self.residual_input_conv[i]( resnet_input ) # (batch_size, out_channels, h * 2, w * 2) # Only do for Diffusion and not for AutoEncoder if self.attn: # Attention block of UNET batch_size, channels, h, w = out.shape in_attn = out.reshape( batch_size, channels, h * w ) # (batch_size, out_channels, h * w * 4) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose( 1, 2 ) # (batch_size, h * w * 4, out_channels) # Self-Attention out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape( batch_size, channels, h, w ) # (batch_size, out_channels h * 2, w * 2) # Skip connection out = out + out_attn # (batch_size, out_channels h * 2, w * 2) return out # (batch_size, out_channels h * 2, w * 2) class UpBlockUNet(nn.Module): """ Up conv block with attention. 1. Upsample 1. Concatenate Down block output 2. Resnet block with time embedding 3. Attention Block in_channels: Number of channels in the input feature map. (It is passed in multiplied by 2 for concatenation with DownBlock output) out_channels: Number of channels produced by this block. t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. up_sample: Whether to apply upsampling at the end. num_heads: Number of attention heads (used if attention is enabled). num_layers: How many sub-blocks to apply in sequence. norm_channels: Number of groups for GroupNorm. cross_attn: Whether to apply cross-attention. context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. """ def __init__( self, in_channels, out_channels, t_emb_dim, up_sample, num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None, ): super().__init__() self.num_layers = num_layers self.up_sample = up_sample self.t_emb_dim = t_emb_dim self.cross_attn = cross_attn self.context_dim = context_dim self.up_sample_conv = ( nn.ConvTranspose2d( in_channels=(in_channels // 2), out_channels=(in_channels // 2), kernel_size=4, stride=2, padding=1, ) if self.up_sample else nn.Identity() ) # (batch_size, in_channels // 2, h * 2, w * 2) self.resnet_conv_first = nn.ModuleList( [ nn.Sequential( nn.GroupNorm( norm_channels, in_channels if i == 0 else out_channels ), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels nn.SiLU(), nn.Conv2d( in_channels=(in_channels if i == 0 else out_channels), out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), # (batch_size, in_channels, h * 2, w. * 2) -> (batch_size, out_channels, h * 2, w * 2) - Starts at in_channels and not in_channels // 2 because of concatenation ) for i in range(num_layers) ] ) # Only add the time embedding if needed for UNET in diffusion # Do not add the time embedding in the AutoEncoder if self.t_emb_dim is not None: self.t_emb_layers = nn.ModuleList( [ nn.Sequential( nn.SiLU(), nn.Linear( in_features=self.t_emb_dim, out_features=out_channels ), # (batch_size, t_emb_dim) -> (batch_size, out_channels) ) for i in range(num_layers) ] ) self.resnet_conv_second = nn.ModuleList( [ nn.Sequential( nn.GroupNorm(norm_channels, out_channels), nn.SiLU(), nn.Conv2d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, ), # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) ) for i in range(num_layers) ] ) self.residual_input_conv = nn.ModuleList( [ nn.Conv2d( in_channels=(in_channels if i == 0 else out_channels), out_channels=out_channels, kernel_size=1, stride=1, padding=0, ) for i in range( num_layers ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) ] ) self.attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] ) self.attentions = nn.ModuleList( [ nn.MultiheadAttention( embed_dim=out_channels, num_heads=num_heads, batch_first=True ) for i in range(num_layers) ] ) # Cross attention for text conditioning if self.cross_attn: assert ( context_dim is not None ), "Context Dimension must be passed for cross attention" self.cross_attention_norms = nn.ModuleList( [nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] ) self.cross_attentions = nn.ModuleList( [ nn.MultiheadAttention( embed_dim=out_channels, num_heads=num_heads, batch_first=True ) for i in range(num_layers) ] ) self.context_proj = nn.ModuleList( [ nn.Linear(in_features=context_dim, out_features=out_channels) for i in range(num_layers) ] ) def forward(self, x, out_down=None, t_emb=None, context=None): # x shape: (batch_size, in_channels // 2, h, w) # Upsample x = self.up_sample_conv( x ) # (batch_size, in_channels // 2, h, w) -> (batch_size, in_channels // 2, h * 2, w * 2) # Concatenate with the output of respective DownBlock if out_down is not None: x = torch.cat( [x, out_down], dim=1 ) # (batch_size, in_channels // 2, h * 2, w * 2) -> (batch_size, in_channels, h * 2, w * 2) out = x # (batch_size, in_channels, h * 2, w * 2) for i in range(self.num_layers): # Resnet block resnet_input = out out = self.resnet_conv_first[i]( out ) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) if self.t_emb_dim is not None: # Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( dim=-1 ) # (batch_size, out_channels, h * 2, w * 2) out = self.resnet_conv_second[i]( out ) # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) # Residual Connection out = out + self.residual_input_conv[i]( resnet_input ) # (batch_size, out_channels, h * 2, w * 2) # Attention block of UNET batch_size, channels, h, w = ( out.shape ) # (batch_size, out_channels, h * 2, w * 2) in_attn = out.reshape( batch_size, channels, h * w ) # (batch_size, out_channels, h * w * 4) in_attn = self.attention_norms[i](in_attn) in_attn = in_attn.transpose(1, 2) # (batch_size, h * w * 4, out_channels) # Self-Attention out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) out_attn = out_attn.transpose(1, 2).reshape( batch_size, channels, h, w ) # (batch_size, out_channels h * 2, w * 2) # Skip connection out = out + out_attn # (batch_size, out_channels h * 2, w * 2) if self.cross_attn: assert ( context is not None ), "context cannot be None if cross attention layers are used" batch_size, channels, h, w = out.shape in_attn = out.reshape( batch_size, channels, h * w ) # (batch_size, out_channels, h * w * 4) in_attn = self.cross_attention_norms[i](in_attn) in_attn = in_attn.transpose( 1, 2 ) # (batch_size, h * w * 4, out_channels) assert ( len(context.shape) == 3 ), "Context shape does not match batch_size, _, context_dim" assert ( context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim ), "Context shape does not match batch_size, _, context_dim" # Make sure the batch_size and context_dim match with the model's parameters context_proj = self.context_proj[i]( context ) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, context_dim) # Cross-Attention out_attn, attn_weights = self.cross_attentions[i]( in_attn, context_proj, context_proj ) out_attn = out_attn.transpose(1, 2).reshape( batch_size, channels, h, w ) # (batch_size, out_channels, h * 2, w * 2) # Skip Connection out = out + out_attn # (batch_size, out_channels h * 2, w * 2) return out # (batch_size, out_channels h * 2, w * 2) class VQVAE(nn.Module): def __init__(self, image_channels, model_config): super().__init__() self.down_channels = model_config["down_channels"] self.mid_channels = model_config["mid_channels"] self.down_sample = model_config["down_sample"] self.num_down_layers = model_config["num_down_layers"] self.num_mid_layers = model_config["num_mid_layers"] self.num_up_layers = model_config["num_up_layers"] # To disable attention in Downblock of Encoder and Upblock of Decoder self.attns = model_config["attn_down"] # Latent Dimension self.z_channels = model_config[ "z_channels" ] # number of channels in the latent representation self.codebook_size = model_config[ "codebook_size" ] # number of discrete code vectors available self.norm_channels = model_config["norm_channels"] self.num_heads = model_config["num_heads"] assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-1] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Wherever we downsample in the encoder, use upsampling in the decoder at the corresponding location self.up_sample = list(reversed(self.down_sample)) # Encoder self.encoder_conv_in = nn.Conv2d( in_channels=image_channels, out_channels=self.down_channels[0], kernel_size=3, stride=1, padding=1, ) # (batch_size, 3, h, w) -> (batch_size, c, h, w) # Downblock + Midblock self.encoder_layers = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): self.encoder_layers.append( DownBlock( in_channels=self.down_channels[i], out_channels=self.down_channels[i + 1], t_emb_dim=None, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels, ) ) self.encoder_mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): self.encoder_mids.append( MidBlock( in_channels=self.mid_channels[i], out_channels=self.mid_channels[i + 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, ) ) self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) self.encoder_conv_out = nn.Conv2d( in_channels=self.down_channels[-1], out_channels=self.z_channels, kernel_size=3, stride=1, padding=1, ) # (batch_size, z_channels, h', w') # Pre Quantization Convolution self.pre_quant_conv = nn.Conv2d( in_channels=self.z_channels, out_channels=self.z_channels, kernel_size=1, stride=1, padding=0, ) # (batch_size, z_channels, h', w') # Codebook Vectors self.embedding = nn.Embedding( self.codebook_size, self.z_channels ) # (codebook_size, z_channels) # Decoder # Post Quantization Convolution self.post_quant_conv = nn.Conv2d( in_channels=self.z_channels, out_channels=self.z_channels, kernel_size=1, stride=1, padding=0, ) # (batch_size, z_channels, h', w') self.decoder_conv_in = nn.Conv2d( in_channels=self.z_channels, out_channels=self.mid_channels[-1], kernel_size=3, stride=1, padding=1, ) # (batch_size, c, h', w') # Midblock + Upblock self.decoder_mids = nn.ModuleList([]) for i in reversed(range(1, len(self.mid_channels))): self.decoder_mids.append( MidBlock( in_channels=self.mid_channels[i], out_channels=self.mid_channels[i - 1], t_emb_dim=None, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, ) ) self.decoder_layers = nn.ModuleList([]) for i in reversed(range(1, len(self.down_channels))): self.decoder_layers.append( UpBlock( in_channels=self.down_channels[i], out_channels=self.down_channels[i - 1], t_emb_dim=None, up_sample=self.down_sample[i - 1], num_heads=self.num_heads, num_layers=self.num_up_layers, attn=self.attns[i - 1], norm_channels=self.norm_channels, ) ) self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) self.decoder_conv_out = nn.Conv2d( in_channels=self.down_channels[0], out_channels=image_channels, kernel_size=3, stride=1, padding=1, ) # (batch_size, c, h, w) def quantize(self, x): batch_size, c, h, w = x.shape # (batch_size, z_channels, h, w) x = x.permute( 0, 2, 3, 1 ) # (batch_size, z_channels, h, w) -> (batch_size, h, w, z_channels) x = x.reshape( batch_size, -1, c ) # (batch_size, h, w, z_channels) -> (batch_size, h * w, z_channels) # Find the nearest codebook vector with distance between (batch_size, h * w, z_channels) and (batch_size, code_book_size, z_channels) -> (batch_size, h * w, code_book_size) dist = torch.cdist( x, self.embedding.weight.unsqueeze(dim=0).repeat((batch_size, 1, 1)) ) # cdist calculates the batched p-norm distance # (batch_size, h * w) Get the index of the closet codebook vector min_encoding_indices = torch.argmin(dist, dim=-1) # Replace the encoder output with the nearest codebook quant_out = torch.index_select( self.embedding.weight, 0, min_encoding_indices.view(-1) ) # (batch_size, h * w, z_channels) x = x.reshape((-1, c)) # (batch_size * h * w, z_channels) # Commitment and Codebook Loss using mSE commitment_loss = torch.mean((quant_out.detach() - x) ** 2) codebook_loss = torch.mean((quant_out - x.detach()) ** 2) quantize_losses = { "codebook_loss": codebook_loss, "commitment_loss": commitment_loss, } # Straight through estimation quant_out = x + (quant_out - x).detach() quant_out = quant_out.reshape(batch_size, h, w, c).permute( 0, 3, 1, 2 ) # (batch_size, z_channels, h, w) min_encoding_indices = min_encoding_indices.reshape( (-1, h, w) ) # (batch_size, h, w) return quant_out, quantize_losses, min_encoding_indices def encode(self, x): out = self.encoder_conv_in(x) # (batch_size, self.down_channels[0], h, w) # (batch_size, self.down_channels[0], h, w) -> (batch_size, self.down_channels[-1], h', w') for idx, down in enumerate(self.encoder_layers): out = down(out) # (batch_size, self.down_channels[-1], h', w') -> (batch_size, self.mid_channels[-1], h', w') for mid in self.encoder_mids: out = mid(out) out = self.encoder_norm_out(out) out = F.silu(out) out = self.encoder_conv_out( out ) # (batch_size, self.mid_channels[-1], h', w') -> (batch_size, self.z_channels, h', w') out = self.pre_quant_conv( out ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.z_channels, h', w') out, quant_losses, min_encoding_indices = self.quantize( out ) # (batch_size, self.z_channels, h', w'), (codebook_loss, commitment_loss), (batch_size, h, w) return out, quant_losses def decode(self, z): out = z out = self.post_quant_conv( out ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.z_channels, h', w') out = self.decoder_conv_in( out ) # (batch_size, self.z_channels, h', w') -> (batch_size, self.mid_channels[-1], h', w') # (batch_size, self.mid_channels[-1], h', w') -> (batch_size, self.down_channels[-1], h', w') for mid in self.decoder_mids: out = mid(out) # (batch_size, self.down_channels[-1], h', w') -> (batch_size, self.down_channels[0], h, w) for idx, up in enumerate(self.decoder_layers): out = up(out) out = self.decoder_norm_out(out) out = F.silu(out) out = self.decoder_conv_out( out ) # (batch_size, self.down_channels[0], h, w) -> (batch_size, c, h, w) return out def forward(self, x): # x shape: (batch_size, c, h, w) z, quant_losses = self.encode( x ) # (batch_size, self.z_channels, h', w'), (codebook_loss, commitment_loss) out = self.decode(z) # (batch_size, c, h, w) return out, z, quant_losses def validate_image_conditional_input(cond_input, x): assert ( "image" in cond_input ), "Model initialized with image conditioning but cond_input has no image information" assert ( cond_input["image"].shape[0] == x.shape[0] ), "Batch size mismatch of image condition and input" assert ( cond_input["image"].shape[2] % x.shape[2] == 0 ), "Height/Width of image condition must be divisible by latent input" def validate_class_conditional_input(cond_input, x, num_classes): assert ( "class" in cond_input ), "Model initialized with class conditioning but cond_input has no class information" assert cond_input["class"].shape == ( x.shape[0], num_classes, ), "Shape of class condition input must match (Batch Size, )" def get_config_value(config, key, default_value): return config[key] if key in config else default_value class UNet(nn.Module): """ Unet model comprising Down blocks, Midblocks and Uplocks """ def __init__(self, image_channels, model_config): super().__init__() self.down_channels = model_config["down_channels"] self.mid_channels = model_config["mid_channels"] self.t_emb_dim = model_config["time_emb_dim"] self.down_sample = model_config["down_sample"] self.num_down_layers = model_config["num_down_layers"] self.num_mid_layers = model_config["num_mid_layers"] self.num_up_layers = model_config["num_up_layers"] self.attns = model_config["attn_down"] self.norm_channels = model_config["norm_channels"] self.num_heads = model_config["num_heads"] self.conv_out_channels = model_config["conv_out_channels"] assert self.mid_channels[0] == self.down_channels[-1] assert self.mid_channels[-1] == self.down_channels[-2] assert len(self.down_sample) == len(self.down_channels) - 1 assert len(self.attns) == len(self.down_channels) - 1 # Class, Mask, and Text Conditioning Config self.class_cond = False self.text_cond = False self.image_cond = False self.text_embed_dim = None self.condition_config = get_config_value( model_config, "condition_config", None ) # Get the dictionary containing conditional information if self.condition_config is not None: assert ( "condition_types" in self.condition_config ), "Condition Type not provided in model config" condition_types = self.condition_config["condition_types"] # For class, text, and image, get necessary parameters if "class" in condition_types: self.class_cond = True self.num_classes = self.condition_config["class_condition_config"][ "num_classes" ] if "text" in condition_types: self.text_cond = True self.text_embed_dim = self.condition_config["text_condition_config"][ "text_embed_dim" ] if "image" in condition_types: self.image_cond = True self.image_cond_input_channels = self.condition_config[ "image_condition_config" ]["image_condition_input_channels"] self.image_cond_output_channels = self.condition_config[ "image_condition_config" ]["image_condition_output_channels"] if self.class_cond: # For class conditioning, do not add the class embedding information for unconditional generation self.class_emb = nn.Embedding( self.num_classes, self.t_emb_dim ) # (num_classes, t_emb_dim) if self.image_cond: # Map the mask image to a image_cond_output_channels channel image, and concat with input across the channel dimension self.cond_conv_in = nn.Conv2d( in_channels=self.image_cond_input_channels, out_channels=self.image_cond_output_channels, kernel_size=1, stride=1, padding=0, bias=False, ) self.conv_in_concat = nn.Conv2d( in_channels=(image_channels + self.image_cond_output_channels), out_channels=self.down_channels[0], kernel_size=3, stride=1, padding=1, ) else: self.conv_in = nn.Conv2d( in_channels=image_channels, out_channels=self.down_channels[0], kernel_size=3, stride=1, padding=1, ) # (batch_size, image_channels, h, w) -> (batch_size, self.down_channels[0], h, w) self.cond = self.text_cond or self.image_cond or self.class_cond # Initial projection from sinusoidal time embedding self.t_proj = nn.Sequential( nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim), nn.SiLU(), nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim), ) # (batch_size, t_emb_dim) self.up_sample = list(reversed(self.down_sample)) self.downs = nn.ModuleList([]) for i in range(len(self.down_channels) - 1): # Cross attention and Context Dim are only used for text conditioning self.downs.append( DownBlock( in_channels=self.down_channels[i], out_channels=self.down_channels[i + 1], t_emb_dim=self.t_emb_dim, down_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_down_layers, attn=self.attns[i], norm_channels=self.norm_channels, cross_attn=self.text_cond, context_dim=self.text_embed_dim, ) ) self.mids = nn.ModuleList([]) for i in range(len(self.mid_channels) - 1): # Cross attention and Context Dim are only used for text conditioning self.mids.append( MidBlock( in_channels=self.mid_channels[i], out_channels=self.mid_channels[i + 1], t_emb_dim=self.t_emb_dim, num_heads=self.num_heads, num_layers=self.num_mid_layers, norm_channels=self.norm_channels, cross_attn=self.text_cond, context_dim=self.text_embed_dim, ) ) self.ups = nn.ModuleList([]) for i in reversed(range(len(self.down_channels) - 1)): # Cross attention and Context Dim are only used for text conditioning self.ups.append( UpBlockUNet( in_channels=(self.down_channels[i] * 2), out_channels=( self.down_channels[i - 1] if i != 0 else self.conv_out_channels ), t_emb_dim=self.t_emb_dim, up_sample=self.down_sample[i], num_heads=self.num_heads, num_layers=self.num_up_layers, norm_channels=self.norm_channels, cross_attn=self.text_cond, context_dim=self.text_embed_dim, ) ) self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) self.conv_out = nn.Conv2d( in_channels=self.conv_out_channels, out_channels=image_channels, kernel_size=3, stride=1, padding=1, ) # (batch_size, conv_out_channels, h, w) -> (batch_size, image_channels, h, w) def forward(self, x, t, cond_input=None): # x shape: (batch_size, c, h, w) # cond_input is the conditioning vector # For class conditioning, it will be a one-hot vector of size # (batch_size, num_classes) if self.cond: assert ( cond_input is not None ), "Model initialized with conditioning so cond_input cannot be None" if self.image_cond: # Mask Conditioning validate_image_conditional_input(cond_input, x) image_cond = cond_input["image"] image_cond = F.interpolate(image_cond, size=x.shape[-2:]) image_cond = self.cond_conv_in(image_cond) assert image_cond.shape[-2:] == x.shape[-2:] x = torch.cat( [x, image_cond], dim=1 ) # (batch_size, image_channels + image_cond_output_channels, h, w) out = self.conv_in_concat(x) # (batch_size, down_channels[0], h, w) else: out = self.conv_in(x) # (batch_size, down_channels[0], h, w) t_emb = get_time_embedding( torch.as_tensor(t).long(), self.t_emb_dim ) # (batch_size, t_emb_dim) t_emb = self.t_proj(t_emb) # (batch_size, t_emb_dim) # Class Conditioning if self.class_cond: validate_class_conditional_input(cond_input, x, self.num_classes) # Take the matrix for class embedding vectors and matrix multiply it with the embedding matrix to get the class embedding for all images in a batch class_embed = torch.matmul( cond_input["class"].float(), self.class_emb.weight ) # (batch_size, t_emb_dim) t_emb += class_embed # Add the class embedding to the time embedding context_hidden_states = None # Only use context hidden states in cross-attention for text conditioning if self.text_cond: assert ( "text" in cond_input ), "Model initialized with text conditioning but cond_input has no text information" context_hidden_states = cond_input["text"] down_outs = [] for idx, down in enumerate(self.downs): down_outs.append(out) out = down( out, t_emb, context_hidden_states ) # Use context_hidden_states for cross-attention # out = (batch_size, c4, h / 4, w / 4) for mid in self.mids: out = mid(out, t_emb, context_hidden_states) # out = (batch_size, c3, h / 4, w / 4) for up in self.ups: down_out = down_outs.pop() out = up(out, down_out, t_emb, context_hidden_states) # out = (batch_size, self.conv_out_channels, h, w) out = F.silu(self.norm_out(out)) out = self.conv_out( out ) # (batch_size, self.conv_out_channels, h, w) -> (batch_size, image_channels, h, w) return out # (batch_size, image_channels, h, w)