Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import os | |
import random | |
import glob | |
import pickle | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from torch.optim import Adam | |
from torchvision.utils import make_grid | |
from PIL import Image | |
from transformers import ( | |
DistilBertModel, | |
DistilBertTokenizer, | |
CLIPTokenizer, | |
CLIPTextModel, | |
) | |
dataset_params = { | |
"image_path": "data/CelebAMask-HQ", | |
"image_channels": 3, | |
"image_size": 256, | |
"name": "celebhq", | |
} | |
diffusion_params = { | |
"num_timesteps": 1000, | |
"beta_start": 0.00085, | |
"beta_end": 0.012, | |
} | |
ldm_params = { | |
"down_channels": [256, 384, 512, 768], | |
"mid_channels": [768, 512], | |
"down_sample": [True, True, True], | |
"attn_down": [True, True, True], # Attention in the DownBlock and UpBlock of VQ-VAE | |
"time_emb_dim": 512, | |
"norm_channels": 32, | |
"num_heads": 16, | |
"conv_out_channels": 128, | |
"num_down_layers": 2, | |
"num_mid_layers": 2, | |
"num_up_layers": 2, | |
"condition_config": { | |
"condition_types": ["text", "image"], | |
"text_condition_config": { | |
"text_embed_model": "clip", | |
"train_text_embed_model": False, | |
"text_embed_dim": 512, # Each token should map to text_embed_dim sized vector | |
"cond_drop_prob": 0.1, # Probability of dropping conditioning during training to allow the model to generate images without conditioning as well | |
}, | |
"image_condition_config": { | |
"image_condition_input_channels": 18, # CelebA has 18 classes excluding background | |
"image_condition_output_channels": 3, | |
"image_condition_h": 512, # Mask height | |
"image_condition_w": 512, # Mask width | |
"cond_drop_prob": 0.1, # Probability of dropping conditioning during training to allow the model to generate images without conditioning as well | |
}, | |
}, | |
} | |
autoencoder_params = { | |
"z_channels": 4, | |
"codebook_size": 8192, | |
"down_channels": [64, 128, 256, 256], | |
"mid_channels": [256, 256], | |
"down_sample": [True, True, True], | |
"attn_down": [ | |
False, | |
False, | |
False, | |
], # No attention in the DownBlock and UpBlock of VQ-VAE | |
"norm_channels": 32, | |
"num_heads": 4, | |
"num_down_layers": 2, | |
"num_mid_layers": 2, | |
"num_up_layers": 2, | |
} | |
train_params = { | |
"seed": 1111, | |
"task_name": "celebhq", # Folder to save models and images to | |
"ldm_batch_size": 16, | |
"autoencoder_batch_size": 4, | |
"disc_start": 15000, | |
"disc_weight": 0.5, | |
"codebook_weight": 1, | |
"commitment_beta": 0.2, | |
"perceptual_weight": 1, | |
"kl_weight": 0.000005, | |
"ldm_epochs": 100, | |
"autoencoder_epochs": 20, | |
"num_samples": 1, | |
"num_grid_rows": 1, | |
"ldm_lr": 0.000005, | |
"autoencoder_lr": 0.00001, | |
"autoencoder_acc_steps": 4, | |
"autoencoder_img_save_steps": 64, | |
"save_latents": True, | |
"cf_guidance_scale": 1.0, | |
"vqvae_latent_dir_name": "vqvae_latents", | |
"ldm_ckpt_name": "ddpm_ckpt_class_cond.pth", | |
"vqvae_autoencoder_ckpt_name": "vqvae_autoencoder_ckpt.pth", | |
} | |
def get_config_value(config, key, default_value): | |
return config[key] if key in config else default_value | |
def spatial_average(in_tens, keepdim=True): | |
return in_tens.mean([2, 3], keepdim=keepdim) | |
class LinearNoiseScheduler: | |
def __init__(self, num_timesteps, beta_start, beta_end): | |
self.num_timesteps = num_timesteps | |
self.beta_start = beta_start | |
self.beta_end = beta_end | |
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2 | |
self.alphas = 1.0 - self.betas | |
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) | |
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) | |
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) | |
def add_noise(self, original, noise, t): | |
# original: (batch_size, c, h, w), t: tensor of timesteps (batch_size,) | |
batch_size = original.shape[0] | |
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].view( | |
batch_size, 1, 1, 1 | |
) | |
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to( | |
original.device | |
)[t].view(batch_size, 1, 1, 1) | |
return sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise | |
def sample_prev_timestep(self, xt, noise_pred, t): | |
batch_size = xt.shape[0] | |
alpha_cum_prod_t = self.alpha_cum_prod.to(xt.device)[t].view( | |
batch_size, 1, 1, 1 | |
) | |
sqrt_one_minus_alpha_cum_prod_t = self.sqrt_one_minus_alpha_cum_prod.to( | |
xt.device | |
)[t].view(batch_size, 1, 1, 1) | |
x0 = (xt - sqrt_one_minus_alpha_cum_prod_t * noise_pred) / torch.sqrt( | |
alpha_cum_prod_t | |
) | |
x0 = torch.clamp(x0, -1.0, 1.0) | |
betas_t = self.betas.to(xt.device)[t].view(batch_size, 1, 1, 1) | |
mean = ( | |
xt - betas_t / sqrt_one_minus_alpha_cum_prod_t * noise_pred | |
) / torch.sqrt(self.alphas.to(xt.device)[t].view(batch_size, 1, 1, 1)) | |
if t[0] == 0: | |
return mean, x0 | |
else: | |
prev_alpha_cum_prod = self.alpha_cum_prod.to(xt.device)[ | |
(t - 1).clamp(min=0) | |
].view(batch_size, 1, 1, 1) | |
variance = (1 - prev_alpha_cum_prod) / (1 - alpha_cum_prod_t) * betas_t | |
sigma = variance.sqrt() | |
z = torch.randn_like(xt) | |
return mean + sigma * z, x0 | |
def get_tokenizer_and_model(model_type, device, eval_mode=True): | |
assert model_type in ( | |
"bert", | |
"clip", | |
), "Text model can only be one of 'clip' or 'bert'" | |
if model_type == "bert": | |
text_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
text_model = DistilBertModel.from_pretrained("distilbert-base-uncased").to( | |
device | |
) | |
else: | |
text_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16") | |
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch16").to( | |
device | |
) | |
if eval_mode: | |
text_model.eval() | |
return text_tokenizer, text_model | |
def get_text_representation(text, text_tokenizer, text_model, device, max_length=77): | |
token_output = text_tokenizer( | |
text, | |
truncation=True, | |
padding="max_length", | |
return_attention_mask=True, | |
max_length=max_length, | |
) | |
tokens_tensor = torch.tensor(token_output["input_ids"]).to(device) | |
mask_tensor = torch.tensor(token_output["attention_mask"]).to(device) | |
text_embed = text_model(tokens_tensor, attention_mask=mask_tensor).last_hidden_state | |
return text_embed | |
def get_time_embedding(time_steps, temb_dim): | |
""" | |
Convert time steps tensor into an embedding using the sinusoidal time embedding formula | |
time_steps: 1D tensor of length batch size | |
temb_dim: Dimension of the embedding | |
""" | |
assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" | |
# factor = 10000^(2i/d_model) | |
factor = 10000 ** ( | |
( | |
torch.arange( | |
start=0, | |
end=temb_dim // 2, | |
dtype=torch.float32, | |
device=time_steps.device, | |
) | |
/ (temb_dim // 2) | |
) | |
) | |
t_emb = time_steps.unsqueeze(dim=-1).repeat(1, temb_dim // 2) / factor | |
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) | |
return t_emb # (batch_size, temb_dim) | |
class DownBlock(nn.Module): | |
""" | |
Down conv block with attention. | |
1. Resnet block with time embedding | |
2. Attention block | |
3. Downsample | |
in_channels: Number of channels in the input feature map. | |
out_channels: Number of channels produced by this block. | |
t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. | |
down_sample: Whether to apply downsampling at the end. | |
num_heads: Number of attention heads (used if attention is enabled). | |
num_layers: How many sub-blocks to apply in sequence. | |
attn: Whether to apply self-attention | |
norm_channels: Number of groups for GroupNorm. | |
cross_attn: Whether to apply cross-attention. | |
context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
t_emb_dim, | |
down_sample, | |
num_heads, | |
num_layers, | |
attn, | |
norm_channels, | |
cross_attn=False, | |
context_dim=None, | |
): | |
super().__init__() | |
self.num_layers = num_layers | |
self.down_sample = down_sample | |
self.attn = attn | |
self.context_dim = context_dim | |
self.cross_attn = cross_attn | |
self.t_emb_dim = t_emb_dim | |
self.resnet_conv_first = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.GroupNorm( | |
norm_channels, in_channels if i == 0 else out_channels | |
), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels | |
nn.SiLU(), | |
nn.Conv2d( | |
in_channels=(in_channels if i == 0 else out_channels), | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w) | |
) | |
for i in range(num_layers) | |
] | |
) | |
# Only add the time embedding for diffusion and not AutoEncoder | |
if self.t_emb_dim is not None: | |
self.t_emb_layers = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
in_features=self.t_emb_dim, out_features=out_channels | |
), # (batch_size, t_emb_dim) -> (batch_size, out_channels) | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.resnet_conv_second = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.GroupNorm(norm_channels, out_channels), | |
nn.SiLU(), | |
nn.Conv2d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.residual_input_conv = nn.ModuleList( | |
[ | |
nn.Conv2d( | |
in_channels=(in_channels if i == 0 else out_channels), | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w) | |
for i in range(num_layers) | |
] | |
) | |
if self.attn: | |
self.attention_norms = nn.ModuleList( | |
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] | |
) | |
self.attentions = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=out_channels, num_heads=num_heads, batch_first=True | |
) | |
for i in range(num_layers) | |
] | |
) | |
# Cross attention for text conditioning | |
if self.cross_attn: | |
assert ( | |
context_dim is not None | |
), "Context Dimension must be passed for cross attention" | |
self.cross_attention_norms = nn.ModuleList( | |
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] | |
) | |
self.cross_attentions = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=out_channels, num_heads=num_heads, batch_first=True | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.context_proj = nn.ModuleList( | |
[ | |
nn.Linear(in_features=context_dim, out_features=out_channels) | |
for i in range(num_layers) | |
] | |
) | |
# Down sample by a factor of 2 | |
self.down_sample_conv = ( | |
nn.Conv2d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
) | |
if self.down_sample | |
else nn.Identity() | |
) # (batch_size, out_channels, h / 2, w / 2) | |
def forward(self, x, t_emb=None, context=None): | |
out = x | |
for i in range(self.num_layers): | |
# Resnet block of UNET | |
resnet_input = out # (batch_size, c, h, w) | |
out = self.resnet_conv_first[i](out) # (batch_size, out_channels, h, w) | |
# Only add the time embedding for diffusion and not AutoEncoder | |
if self.t_emb_dim is not None: | |
# Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) | |
out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) # (batch_size, out_channels, h, w) | |
out = self.resnet_conv_second[i]( | |
out | |
) # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) | |
# Residual Connection | |
out = out + self.residual_input_conv[i]( | |
resnet_input | |
) # (batch_size, out_channels, h, w) | |
# Only do for Diffusion and not for AutoEncoder | |
if self.attn: | |
# Attention block of UNET | |
batch_size, channels, h, w = ( | |
out.shape | |
) # (batch_size, out_channels, h, w) | |
in_attn = out.reshape( | |
batch_size, channels, h * w | |
) # (batch_size, out_channels, h * w) | |
in_attn = self.attention_norms[i](in_attn) | |
in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels) | |
# Self-Attention | |
out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) | |
out_attn = out_attn.transpose(1, 2).reshape( | |
batch_size, channels, h, w | |
) # (batch_size, out_channels h, w) | |
# Skip connection | |
out = out + out_attn # (batch_size, out_channels h, w) | |
if self.cross_attn: | |
assert ( | |
context is not None | |
), "context cannot be None if cross attention layers are used" | |
batch_size, channels, h, w = ( | |
out.shape | |
) # (batch_size, out_channels, h, w) | |
in_attn = out.reshape( | |
batch_size, channels, h * w | |
) # (batch_size, out_channels, h * w) | |
in_attn = self.cross_attention_norms[i](in_attn) | |
in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels) | |
assert ( | |
context.shape[0] == x.shape[0] | |
and context.shape[-1] == self.context_dim | |
) # Make sure the batch_size and context_dim match with the model's parameters | |
context_proj = self.context_proj[i]( | |
context | |
) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, out_channels) | |
# Cross-Attention | |
out_attn, attn_weights = self.cross_attentions[i]( | |
in_attn, context_proj, context_proj | |
) # (batch_size, h * w, out_channels) | |
out_attn = out_attn.transpose(1, 2).reshape( | |
batch_size, channels, h, w | |
) # (batch_size, out_channels, h, w) | |
# Skip Connection | |
out = out + out_attn # (batch_size, out_channels, h, w) | |
# Downsampling | |
out = self.down_sample_conv(out) # (batch_size, out_channels, h / 2, w / 2) | |
return out | |
class MidBlock(nn.Module): | |
""" | |
Mid conv block with attention. | |
1. Resnet block with time embedding | |
2. Attention block | |
3. Resnet block with time embedding | |
in_channels: Number of channels in the input feature map. | |
out_channels: Number of channels produced by this block. | |
t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. | |
num_heads: Number of attention heads (used if attention is enabled). | |
num_layers: How many sub-blocks to apply in sequence. | |
norm_channels: Number of groups for GroupNorm. | |
cross_attn: Whether to apply cross-attention. | |
context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
t_emb_dim, | |
num_heads, | |
num_layers, | |
norm_channels, | |
cross_attn=None, | |
context_dim=None, | |
): | |
super().__init__() | |
self.num_layers = num_layers | |
self.t_emb_dim = t_emb_dim | |
self.context_dim = context_dim | |
self.cross_attn = cross_attn | |
self.resnet_conv_first = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.GroupNorm( | |
norm_channels, in_channels if i == 0 else out_channels | |
), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels | |
nn.SiLU(), | |
nn.Conv2d( | |
in_channels=(in_channels if i == 0 else out_channels), | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w) | |
) | |
for i in range(num_layers + 1) | |
] | |
) | |
# Only add the time embedding for diffusion and not AutoEncoder | |
if self.t_emb_dim is not None: | |
self.t_emb_layers = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
in_features=self.t_emb_dim, out_features=out_channels | |
), # (batch_size, t_emb_dim) -> (batch_size, out_channels) | |
) | |
for i in range(num_layers + 1) | |
] | |
) | |
self.resnet_conv_second = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.GroupNorm(norm_channels, out_channels), | |
nn.SiLU(), | |
nn.Conv2d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) | |
) | |
for i in range(num_layers + 1) | |
] | |
) | |
self.residual_input_conv = nn.ModuleList( | |
[ | |
nn.Conv2d( | |
in_channels=(in_channels if i == 0 else out_channels), | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w) | |
for i in range(num_layers + 1) | |
] | |
) | |
self.attention_norms = nn.ModuleList( | |
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] | |
) | |
self.attentions = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=out_channels, num_heads=num_heads, batch_first=True | |
) | |
for i in range(num_layers) | |
] | |
) | |
# Cross attention for text conditioning | |
if self.cross_attn: | |
assert ( | |
context_dim is not None | |
), "Context Dimension must be passed for cross attention" | |
self.cross_attention_norms = nn.ModuleList( | |
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] | |
) | |
self.cross_attentions = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=out_channels, num_heads=num_heads, batch_first=True | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.context_proj = nn.ModuleList( | |
[ | |
nn.Linear(in_features=context_dim, out_features=out_channels) | |
for i in range(num_layers) | |
] | |
) | |
def forward(self, x, t_emb=None, context=None): | |
out = x | |
# First ResNet block | |
resnet_input = out # (batch_size, c, h, w) | |
out = self.resnet_conv_first[0](out) # (batch_size, out_channels, h, w) | |
# Only add the time embedding for diffusion and not AutoEncoder | |
if self.t_emb_dim is not None: | |
# Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) | |
out = out + self.t_emb_layers[0](t_emb).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) # (batch_size, out_channels, h, w) | |
out = self.resnet_conv_second[0]( | |
out | |
) # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) | |
# Residual Connection | |
out = out + self.residual_input_conv[0]( | |
resnet_input | |
) # (batch_size, out_channels, h, w) | |
for i in range(self.num_layers): | |
# Attention Block | |
batch_size, channels, h, w = out.shape # (batch_size, out_channels, h, w) | |
# Do for both Diffusion and AutoEncoder | |
in_attn = out.reshape( | |
batch_size, channels, h * w | |
) # (batch_size, out_channels, h * w) | |
in_attn = self.attention_norms[i](in_attn) | |
in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels) | |
# Self-Attention | |
out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) | |
out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) | |
# Skip connection | |
out = out + out_attn # (batch_size, out_channels h, w) | |
if self.cross_attn: | |
assert ( | |
context is not None | |
), "context cannot be None if cross attention layers are used" | |
batch_size, channels, h, w = out.shape | |
in_attn = out.reshape( | |
batch_size, channels, h * w | |
) # (batch_size, out_channels, h * w) | |
in_attn = self.cross_attention_norms[i](in_attn) | |
in_attn = in_attn.transpose(1, 2) # (batch_size, h * w, out_channels) | |
assert ( | |
context.shape[0] == x.shape[0] | |
and context.shape[-1] == self.context_dim | |
) # Make sure the batch_size and context_dim match with the model's parameters | |
context_proj = self.context_proj[i]( | |
context | |
) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, context_dim) | |
# Cross-Attention | |
out_attn, attn_weights = self.cross_attentions[i]( | |
in_attn, context_proj, context_proj | |
) | |
out_attn = out_attn.transpose(1, 2).reshape( | |
batch_size, channels, h, w | |
) # (batch_size, out_channels, h, w) | |
# Skip Connection | |
out = out + out_attn # (batch_size, out_channels h, w) | |
# Resnet Block | |
resnet_input = out | |
out = self.resnet_conv_first[i + 1]( | |
out | |
) # (batch_size, out_channels h, w) -> (batch_size, out_channels h, w) | |
# Only add the time embedding for diffusion and not AutoEncoder | |
if self.t_emb_dim is not None: | |
# Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) | |
out = out + self.t_emb_layers[i + 1](t_emb).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) # (batch_size, out_channels h, w) | |
out = self.resnet_conv_second[i + 1]( | |
out | |
) # (batch_size, out_channels h, w) -> (batch_size, out_channels h, w) | |
# Residual Connection | |
out = out + self.residual_input_conv[i + 1]( | |
resnet_input | |
) # (batch_size, out_channels, h, w) | |
return out | |
class UpBlock(nn.Module): | |
""" | |
Up conv block with attention. | |
1. Upsample | |
1. Concatenate Down block output | |
2. Resnet block with time embedding | |
3. Attention Block | |
in_channels: Number of channels in the input feature map. | |
out_channels: Number of channels produced by this block. | |
t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. | |
up_sample: Whether to apply upsampling at the end. | |
num_heads: Number of attention heads (used if attention is enabled). | |
num_layers: How many sub-blocks to apply in sequence. | |
attn: Whether to apply self-attention | |
norm_channels: Number of groups for GroupNorm. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
t_emb_dim, | |
up_sample, | |
num_heads, | |
num_layers, | |
attn, | |
norm_channels, | |
): | |
super().__init__() | |
self.num_layers = num_layers | |
self.up_sample = up_sample | |
self.t_emb_dim = t_emb_dim | |
self.attn = attn | |
# Upsample by a factor of 2 | |
self.up_sample_conv = ( | |
nn.ConvTranspose2d( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
) | |
if self.up_sample | |
else nn.Identity() | |
) # (batch_size, c, h * 2, w * 2) | |
self.resnet_conv_first = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.GroupNorm( | |
norm_channels, in_channels if i == 0 else out_channels | |
), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels | |
nn.SiLU(), | |
nn.Conv2d( | |
in_channels=(in_channels if i == 0 else out_channels), | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), # (batch_size, c, h, w) -> (batch_size, out_channels, h, w) | |
) | |
for i in range(num_layers) | |
] | |
) | |
# Only add the time embedding for diffusion and not AutoEncoder | |
if self.t_emb_dim is not None: | |
self.t_emb_layers = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
in_features=self.t_emb_dim, out_features=out_channels | |
), # (batch_size, t_emb_dim) -> (batch_size, out_channels) | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.resnet_conv_second = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.GroupNorm(norm_channels, out_channels), | |
nn.SiLU(), | |
nn.Conv2d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), # (batch_size, out_channels, h, w) -> (batch_size, out_channels, h, w) | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.residual_input_conv = nn.ModuleList( | |
[ | |
nn.Conv2d( | |
in_channels=(in_channels if i == 0 else out_channels), | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) # (batch_size, in_channels, h, w) -> (batch_size, out_channels, h, w) | |
for i in range(num_layers) | |
] | |
) | |
if self.attn: | |
self.attention_norms = nn.ModuleList( | |
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] | |
) | |
self.attentions = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=out_channels, num_heads=num_heads, batch_first=True | |
) | |
for i in range(num_layers) | |
] | |
) | |
def forward(self, x, out_down=None, t_emb=None): | |
# x shape: (batch_size, c, h, w) | |
# Upsample | |
x = self.up_sample_conv( | |
x | |
) # (batch_size, c, h, w) -> (batch_size, c, h * 2, w * 2) | |
# *Only do for diffusion | |
# Concatenate with the output of respective DownBlock | |
if out_down is not None: | |
x = torch.cat( | |
[x, out_down], dim=1 | |
) # (batch_size, c, h * 2, w * 2) -> (batch_size, c * 2, h * 2, w * 2) | |
out = x # (batch_size, c, h * 2, w * 2) | |
for i in range(self.num_layers): | |
# Resnet block | |
resnet_input = out | |
out = self.resnet_conv_first[i]( | |
out | |
) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) | |
# Only add the time embedding for diffusion and not AutoEncoder | |
if self.t_emb_dim is not None: | |
# Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) | |
out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) # (batch_size, out_channels, h * 2, w * 2) | |
out = self.resnet_conv_second[i]( | |
out | |
) # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) | |
# Residual Connection | |
out = out + self.residual_input_conv[i]( | |
resnet_input | |
) # (batch_size, out_channels, h * 2, w * 2) | |
# Only do for Diffusion and not for AutoEncoder | |
if self.attn: | |
# Attention block of UNET | |
batch_size, channels, h, w = out.shape | |
in_attn = out.reshape( | |
batch_size, channels, h * w | |
) # (batch_size, out_channels, h * w * 4) | |
in_attn = self.attention_norms[i](in_attn) | |
in_attn = in_attn.transpose( | |
1, 2 | |
) # (batch_size, h * w * 4, out_channels) | |
# Self-Attention | |
out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) | |
out_attn = out_attn.transpose(1, 2).reshape( | |
batch_size, channels, h, w | |
) # (batch_size, out_channels h * 2, w * 2) | |
# Skip connection | |
out = out + out_attn # (batch_size, out_channels h * 2, w * 2) | |
return out # (batch_size, out_channels h * 2, w * 2) | |
class UpBlockUNet(nn.Module): | |
""" | |
Up conv block with attention. | |
1. Upsample | |
1. Concatenate Down block output | |
2. Resnet block with time embedding | |
3. Attention Block | |
in_channels: Number of channels in the input feature map. (It is passed in multiplied by 2 for concatenation with DownBlock output) | |
out_channels: Number of channels produced by this block. | |
t_emb_dim: Dimension of the time embedding. Only use for UNet for Diffusion. In an AutoEncoder, set it to None. | |
up_sample: Whether to apply upsampling at the end. | |
num_heads: Number of attention heads (used if attention is enabled). | |
num_layers: How many sub-blocks to apply in sequence. | |
norm_channels: Number of groups for GroupNorm. | |
cross_attn: Whether to apply cross-attention. | |
context_dim: If performing cross-attention, provide a context_dim for extra conditioning context. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
t_emb_dim, | |
up_sample, | |
num_heads, | |
num_layers, | |
norm_channels, | |
cross_attn=False, | |
context_dim=None, | |
): | |
super().__init__() | |
self.num_layers = num_layers | |
self.up_sample = up_sample | |
self.t_emb_dim = t_emb_dim | |
self.cross_attn = cross_attn | |
self.context_dim = context_dim | |
self.up_sample_conv = ( | |
nn.ConvTranspose2d( | |
in_channels=(in_channels // 2), | |
out_channels=(in_channels // 2), | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
) | |
if self.up_sample | |
else nn.Identity() | |
) # (batch_size, in_channels // 2, h * 2, w * 2) | |
self.resnet_conv_first = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.GroupNorm( | |
norm_channels, in_channels if i == 0 else out_channels | |
), # Normalizes over channels. For the first sub-block, the in_channels=in_channels, else out_channels | |
nn.SiLU(), | |
nn.Conv2d( | |
in_channels=(in_channels if i == 0 else out_channels), | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), # (batch_size, in_channels, h * 2, w. * 2) -> (batch_size, out_channels, h * 2, w * 2) - Starts at in_channels and not in_channels // 2 because of concatenation | |
) | |
for i in range(num_layers) | |
] | |
) | |
# Only add the time embedding if needed for UNET in diffusion | |
# Do not add the time embedding in the AutoEncoder | |
if self.t_emb_dim is not None: | |
self.t_emb_layers = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
in_features=self.t_emb_dim, out_features=out_channels | |
), # (batch_size, t_emb_dim) -> (batch_size, out_channels) | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.resnet_conv_second = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.GroupNorm(norm_channels, out_channels), | |
nn.SiLU(), | |
nn.Conv2d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.residual_input_conv = nn.ModuleList( | |
[ | |
nn.Conv2d( | |
in_channels=(in_channels if i == 0 else out_channels), | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
for i in range( | |
num_layers | |
) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) | |
] | |
) | |
self.attention_norms = nn.ModuleList( | |
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] | |
) | |
self.attentions = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=out_channels, num_heads=num_heads, batch_first=True | |
) | |
for i in range(num_layers) | |
] | |
) | |
# Cross attention for text conditioning | |
if self.cross_attn: | |
assert ( | |
context_dim is not None | |
), "Context Dimension must be passed for cross attention" | |
self.cross_attention_norms = nn.ModuleList( | |
[nn.GroupNorm(norm_channels, out_channels) for i in range(num_layers)] | |
) | |
self.cross_attentions = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=out_channels, num_heads=num_heads, batch_first=True | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.context_proj = nn.ModuleList( | |
[ | |
nn.Linear(in_features=context_dim, out_features=out_channels) | |
for i in range(num_layers) | |
] | |
) | |
def forward(self, x, out_down=None, t_emb=None, context=None): | |
# x shape: (batch_size, in_channels // 2, h, w) | |
# Upsample | |
x = self.up_sample_conv( | |
x | |
) # (batch_size, in_channels // 2, h, w) -> (batch_size, in_channels // 2, h * 2, w * 2) | |
# Concatenate with the output of respective DownBlock | |
if out_down is not None: | |
x = torch.cat( | |
[x, out_down], dim=1 | |
) # (batch_size, in_channels // 2, h * 2, w * 2) -> (batch_size, in_channels, h * 2, w * 2) | |
out = x # (batch_size, in_channels, h * 2, w * 2) | |
for i in range(self.num_layers): | |
# Resnet block | |
resnet_input = out | |
out = self.resnet_conv_first[i]( | |
out | |
) # (batch_size, in_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) | |
if self.t_emb_dim is not None: | |
# Add the embeddings for timesteps - (batch_size, t_emb_dim) -> (batch_size, out_channels, 1, 1) | |
out = out + self.t_emb_layers[i](t_emb).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) # (batch_size, out_channels, h * 2, w * 2) | |
out = self.resnet_conv_second[i]( | |
out | |
) # (batch_size, out_channels, h * 2, w * 2) -> (batch_size, out_channels, h * 2, w * 2) | |
# Residual Connection | |
out = out + self.residual_input_conv[i]( | |
resnet_input | |
) # (batch_size, out_channels, h * 2, w * 2) | |
# Attention block of UNET | |
batch_size, channels, h, w = ( | |
out.shape | |
) # (batch_size, out_channels, h * 2, w * 2) | |
in_attn = out.reshape( | |
batch_size, channels, h * w | |
) # (batch_size, out_channels, h * w * 4) | |
in_attn = self.attention_norms[i](in_attn) | |
in_attn = in_attn.transpose(1, 2) # (batch_size, h * w * 4, out_channels) | |
# Self-Attention | |
out_attn, attn_weights = self.attentions[i](in_attn, in_attn, in_attn) | |
out_attn = out_attn.transpose(1, 2).reshape( | |
batch_size, channels, h, w | |
) # (batch_size, out_channels h * 2, w * 2) | |
# Skip connection | |
out = out + out_attn # (batch_size, out_channels h * 2, w * 2) | |
if self.cross_attn: | |
assert ( | |
context is not None | |
), "context cannot be None if cross attention layers are used" | |
batch_size, channels, h, w = out.shape | |
in_attn = out.reshape( | |
batch_size, channels, h * w | |
) # (batch_size, out_channels, h * w * 4) | |
in_attn = self.cross_attention_norms[i](in_attn) | |
in_attn = in_attn.transpose( | |
1, 2 | |
) # (batch_size, h * w * 4, out_channels) | |
assert ( | |
len(context.shape) == 3 | |
), "Context shape does not match batch_size, _, context_dim" | |
assert ( | |
context.shape[0] == x.shape[0] | |
and context.shape[-1] == self.context_dim | |
), "Context shape does not match batch_size, _, context_dim" # Make sure the batch_size and context_dim match with the model's parameters | |
context_proj = self.context_proj[i]( | |
context | |
) # (batch_size, seq_len, context_dim) -> (batch_size, seq_len, context_dim) | |
# Cross-Attention | |
out_attn, attn_weights = self.cross_attentions[i]( | |
in_attn, context_proj, context_proj | |
) | |
out_attn = out_attn.transpose(1, 2).reshape( | |
batch_size, channels, h, w | |
) # (batch_size, out_channels, h * 2, w * 2) | |
# Skip Connection | |
out = out + out_attn # (batch_size, out_channels h * 2, w * 2) | |
return out # (batch_size, out_channels h * 2, w * 2) | |
class VQVAE(nn.Module): | |
def __init__(self, image_channels, model_config): | |
super().__init__() | |
self.down_channels = model_config["down_channels"] | |
self.mid_channels = model_config["mid_channels"] | |
self.down_sample = model_config["down_sample"] | |
self.num_down_layers = model_config["num_down_layers"] | |
self.num_mid_layers = model_config["num_mid_layers"] | |
self.num_up_layers = model_config["num_up_layers"] | |
# To disable attention in Downblock of Encoder and Upblock of Decoder | |
self.attns = model_config["attn_down"] | |
# Latent Dimension | |
self.z_channels = model_config[ | |
"z_channels" | |
] # number of channels in the latent representation | |
self.codebook_size = model_config[ | |
"codebook_size" | |
] # number of discrete code vectors available | |
self.norm_channels = model_config["norm_channels"] | |
self.num_heads = model_config["num_heads"] | |
assert self.mid_channels[0] == self.down_channels[-1] | |
assert self.mid_channels[-1] == self.down_channels[-1] | |
assert len(self.down_sample) == len(self.down_channels) - 1 | |
assert len(self.attns) == len(self.down_channels) - 1 | |
# Wherever we downsample in the encoder, use upsampling in the decoder at the corresponding location | |
self.up_sample = list(reversed(self.down_sample)) | |
# Encoder | |
self.encoder_conv_in = nn.Conv2d( | |
in_channels=image_channels, | |
out_channels=self.down_channels[0], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) # (batch_size, 3, h, w) -> (batch_size, c, h, w) | |
# Downblock + Midblock | |
self.encoder_layers = nn.ModuleList([]) | |
for i in range(len(self.down_channels) - 1): | |
self.encoder_layers.append( | |
DownBlock( | |
in_channels=self.down_channels[i], | |
out_channels=self.down_channels[i + 1], | |
t_emb_dim=None, | |
down_sample=self.down_sample[i], | |
num_heads=self.num_heads, | |
num_layers=self.num_down_layers, | |
attn=self.attns[i], | |
norm_channels=self.norm_channels, | |
) | |
) | |
self.encoder_mids = nn.ModuleList([]) | |
for i in range(len(self.mid_channels) - 1): | |
self.encoder_mids.append( | |
MidBlock( | |
in_channels=self.mid_channels[i], | |
out_channels=self.mid_channels[i + 1], | |
t_emb_dim=None, | |
num_heads=self.num_heads, | |
num_layers=self.num_mid_layers, | |
norm_channels=self.norm_channels, | |
) | |
) | |
self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) | |
self.encoder_conv_out = nn.Conv2d( | |
in_channels=self.down_channels[-1], | |
out_channels=self.z_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) # (batch_size, z_channels, h', w') | |
# Pre Quantization Convolution | |
self.pre_quant_conv = nn.Conv2d( | |
in_channels=self.z_channels, | |
out_channels=self.z_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) # (batch_size, z_channels, h', w') | |
# Codebook Vectors | |
self.embedding = nn.Embedding( | |
self.codebook_size, self.z_channels | |
) # (codebook_size, z_channels) | |
# Decoder | |
# Post Quantization Convolution | |
self.post_quant_conv = nn.Conv2d( | |
in_channels=self.z_channels, | |
out_channels=self.z_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) # (batch_size, z_channels, h', w') | |
self.decoder_conv_in = nn.Conv2d( | |
in_channels=self.z_channels, | |
out_channels=self.mid_channels[-1], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) # (batch_size, c, h', w') | |
# Midblock + Upblock | |
self.decoder_mids = nn.ModuleList([]) | |
for i in reversed(range(1, len(self.mid_channels))): | |
self.decoder_mids.append( | |
MidBlock( | |
in_channels=self.mid_channels[i], | |
out_channels=self.mid_channels[i - 1], | |
t_emb_dim=None, | |
num_heads=self.num_heads, | |
num_layers=self.num_mid_layers, | |
norm_channels=self.norm_channels, | |
) | |
) | |
self.decoder_layers = nn.ModuleList([]) | |
for i in reversed(range(1, len(self.down_channels))): | |
self.decoder_layers.append( | |
UpBlock( | |
in_channels=self.down_channels[i], | |
out_channels=self.down_channels[i - 1], | |
t_emb_dim=None, | |
up_sample=self.down_sample[i - 1], | |
num_heads=self.num_heads, | |
num_layers=self.num_up_layers, | |
attn=self.attns[i - 1], | |
norm_channels=self.norm_channels, | |
) | |
) | |
self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) | |
self.decoder_conv_out = nn.Conv2d( | |
in_channels=self.down_channels[0], | |
out_channels=image_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) # (batch_size, c, h, w) | |
def quantize(self, x): | |
batch_size, c, h, w = x.shape # (batch_size, z_channels, h, w) | |
x = x.permute( | |
0, 2, 3, 1 | |
) # (batch_size, z_channels, h, w) -> (batch_size, h, w, z_channels) | |
x = x.reshape( | |
batch_size, -1, c | |
) # (batch_size, h, w, z_channels) -> (batch_size, h * w, z_channels) | |
# Find the nearest codebook vector with distance between (batch_size, h * w, z_channels) and (batch_size, code_book_size, z_channels) -> (batch_size, h * w, code_book_size) | |
dist = torch.cdist( | |
x, self.embedding.weight.unsqueeze(dim=0).repeat((batch_size, 1, 1)) | |
) # cdist calculates the batched p-norm distance | |
# (batch_size, h * w) Get the index of the closet codebook vector | |
min_encoding_indices = torch.argmin(dist, dim=-1) | |
# Replace the encoder output with the nearest codebook | |
quant_out = torch.index_select( | |
self.embedding.weight, 0, min_encoding_indices.view(-1) | |
) # (batch_size, h * w, z_channels) | |
x = x.reshape((-1, c)) # (batch_size * h * w, z_channels) | |
# Commitment and Codebook Loss using mSE | |
commitment_loss = torch.mean((quant_out.detach() - x) ** 2) | |
codebook_loss = torch.mean((quant_out - x.detach()) ** 2) | |
quantize_losses = { | |
"codebook_loss": codebook_loss, | |
"commitment_loss": commitment_loss, | |
} | |
# Straight through estimation | |
quant_out = x + (quant_out - x).detach() | |
quant_out = quant_out.reshape(batch_size, h, w, c).permute( | |
0, 3, 1, 2 | |
) # (batch_size, z_channels, h, w) | |
min_encoding_indices = min_encoding_indices.reshape( | |
(-1, h, w) | |
) # (batch_size, h, w) | |
return quant_out, quantize_losses, min_encoding_indices | |
def encode(self, x): | |
out = self.encoder_conv_in(x) # (batch_size, self.down_channels[0], h, w) | |
# (batch_size, self.down_channels[0], h, w) -> (batch_size, self.down_channels[-1], h', w') | |
for idx, down in enumerate(self.encoder_layers): | |
out = down(out) | |
# (batch_size, self.down_channels[-1], h', w') -> (batch_size, self.mid_channels[-1], h', w') | |
for mid in self.encoder_mids: | |
out = mid(out) | |
out = self.encoder_norm_out(out) | |
out = F.silu(out) | |
out = self.encoder_conv_out( | |
out | |
) # (batch_size, self.mid_channels[-1], h', w') -> (batch_size, self.z_channels, h', w') | |
out = self.pre_quant_conv( | |
out | |
) # (batch_size, self.z_channels, h', w') -> (batch_size, self.z_channels, h', w') | |
out, quant_losses, min_encoding_indices = self.quantize( | |
out | |
) # (batch_size, self.z_channels, h', w'), (codebook_loss, commitment_loss), (batch_size, h, w) | |
return out, quant_losses | |
def decode(self, z): | |
out = z | |
out = self.post_quant_conv( | |
out | |
) # (batch_size, self.z_channels, h', w') -> (batch_size, self.z_channels, h', w') | |
out = self.decoder_conv_in( | |
out | |
) # (batch_size, self.z_channels, h', w') -> (batch_size, self.mid_channels[-1], h', w') | |
# (batch_size, self.mid_channels[-1], h', w') -> (batch_size, self.down_channels[-1], h', w') | |
for mid in self.decoder_mids: | |
out = mid(out) | |
# (batch_size, self.down_channels[-1], h', w') -> (batch_size, self.down_channels[0], h, w) | |
for idx, up in enumerate(self.decoder_layers): | |
out = up(out) | |
out = self.decoder_norm_out(out) | |
out = F.silu(out) | |
out = self.decoder_conv_out( | |
out | |
) # (batch_size, self.down_channels[0], h, w) -> (batch_size, c, h, w) | |
return out | |
def forward(self, x): | |
# x shape: (batch_size, c, h, w) | |
z, quant_losses = self.encode( | |
x | |
) # (batch_size, self.z_channels, h', w'), (codebook_loss, commitment_loss) | |
out = self.decode(z) # (batch_size, c, h, w) | |
return out, z, quant_losses | |
def validate_image_conditional_input(cond_input, x): | |
assert ( | |
"image" in cond_input | |
), "Model initialized with image conditioning but cond_input has no image information" | |
assert ( | |
cond_input["image"].shape[0] == x.shape[0] | |
), "Batch size mismatch of image condition and input" | |
assert ( | |
cond_input["image"].shape[2] % x.shape[2] == 0 | |
), "Height/Width of image condition must be divisible by latent input" | |
def validate_class_conditional_input(cond_input, x, num_classes): | |
assert ( | |
"class" in cond_input | |
), "Model initialized with class conditioning but cond_input has no class information" | |
assert cond_input["class"].shape == ( | |
x.shape[0], | |
num_classes, | |
), "Shape of class condition input must match (Batch Size, )" | |
def get_config_value(config, key, default_value): | |
return config[key] if key in config else default_value | |
class UNet(nn.Module): | |
""" | |
Unet model comprising | |
Down blocks, Midblocks and Uplocks | |
""" | |
def __init__(self, image_channels, model_config): | |
super().__init__() | |
self.down_channels = model_config["down_channels"] | |
self.mid_channels = model_config["mid_channels"] | |
self.t_emb_dim = model_config["time_emb_dim"] | |
self.down_sample = model_config["down_sample"] | |
self.num_down_layers = model_config["num_down_layers"] | |
self.num_mid_layers = model_config["num_mid_layers"] | |
self.num_up_layers = model_config["num_up_layers"] | |
self.attns = model_config["attn_down"] | |
self.norm_channels = model_config["norm_channels"] | |
self.num_heads = model_config["num_heads"] | |
self.conv_out_channels = model_config["conv_out_channels"] | |
assert self.mid_channels[0] == self.down_channels[-1] | |
assert self.mid_channels[-1] == self.down_channels[-2] | |
assert len(self.down_sample) == len(self.down_channels) - 1 | |
assert len(self.attns) == len(self.down_channels) - 1 | |
# Class, Mask, and Text Conditioning Config | |
self.class_cond = False | |
self.text_cond = False | |
self.image_cond = False | |
self.text_embed_dim = None | |
self.condition_config = get_config_value( | |
model_config, "condition_config", None | |
) # Get the dictionary containing conditional information | |
if self.condition_config is not None: | |
assert ( | |
"condition_types" in self.condition_config | |
), "Condition Type not provided in model config" | |
condition_types = self.condition_config["condition_types"] | |
# For class, text, and image, get necessary parameters | |
if "class" in condition_types: | |
self.class_cond = True | |
self.num_classes = self.condition_config["class_condition_config"][ | |
"num_classes" | |
] | |
if "text" in condition_types: | |
self.text_cond = True | |
self.text_embed_dim = self.condition_config["text_condition_config"][ | |
"text_embed_dim" | |
] | |
if "image" in condition_types: | |
self.image_cond = True | |
self.image_cond_input_channels = self.condition_config[ | |
"image_condition_config" | |
]["image_condition_input_channels"] | |
self.image_cond_output_channels = self.condition_config[ | |
"image_condition_config" | |
]["image_condition_output_channels"] | |
if self.class_cond: | |
# For class conditioning, do not add the class embedding information for unconditional generation | |
self.class_emb = nn.Embedding( | |
self.num_classes, self.t_emb_dim | |
) # (num_classes, t_emb_dim) | |
if self.image_cond: | |
# Map the mask image to a image_cond_output_channels channel image, and concat with input across the channel dimension | |
self.cond_conv_in = nn.Conv2d( | |
in_channels=self.image_cond_input_channels, | |
out_channels=self.image_cond_output_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
) | |
self.conv_in_concat = nn.Conv2d( | |
in_channels=(image_channels + self.image_cond_output_channels), | |
out_channels=self.down_channels[0], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
else: | |
self.conv_in = nn.Conv2d( | |
in_channels=image_channels, | |
out_channels=self.down_channels[0], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) # (batch_size, image_channels, h, w) -> (batch_size, self.down_channels[0], h, w) | |
self.cond = self.text_cond or self.image_cond or self.class_cond | |
# Initial projection from sinusoidal time embedding | |
self.t_proj = nn.Sequential( | |
nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim), | |
nn.SiLU(), | |
nn.Linear(in_features=self.t_emb_dim, out_features=self.t_emb_dim), | |
) # (batch_size, t_emb_dim) | |
self.up_sample = list(reversed(self.down_sample)) | |
self.downs = nn.ModuleList([]) | |
for i in range(len(self.down_channels) - 1): | |
# Cross attention and Context Dim are only used for text conditioning | |
self.downs.append( | |
DownBlock( | |
in_channels=self.down_channels[i], | |
out_channels=self.down_channels[i + 1], | |
t_emb_dim=self.t_emb_dim, | |
down_sample=self.down_sample[i], | |
num_heads=self.num_heads, | |
num_layers=self.num_down_layers, | |
attn=self.attns[i], | |
norm_channels=self.norm_channels, | |
cross_attn=self.text_cond, | |
context_dim=self.text_embed_dim, | |
) | |
) | |
self.mids = nn.ModuleList([]) | |
for i in range(len(self.mid_channels) - 1): | |
# Cross attention and Context Dim are only used for text conditioning | |
self.mids.append( | |
MidBlock( | |
in_channels=self.mid_channels[i], | |
out_channels=self.mid_channels[i + 1], | |
t_emb_dim=self.t_emb_dim, | |
num_heads=self.num_heads, | |
num_layers=self.num_mid_layers, | |
norm_channels=self.norm_channels, | |
cross_attn=self.text_cond, | |
context_dim=self.text_embed_dim, | |
) | |
) | |
self.ups = nn.ModuleList([]) | |
for i in reversed(range(len(self.down_channels) - 1)): | |
# Cross attention and Context Dim are only used for text conditioning | |
self.ups.append( | |
UpBlockUNet( | |
in_channels=(self.down_channels[i] * 2), | |
out_channels=( | |
self.down_channels[i - 1] if i != 0 else self.conv_out_channels | |
), | |
t_emb_dim=self.t_emb_dim, | |
up_sample=self.down_sample[i], | |
num_heads=self.num_heads, | |
num_layers=self.num_up_layers, | |
norm_channels=self.norm_channels, | |
cross_attn=self.text_cond, | |
context_dim=self.text_embed_dim, | |
) | |
) | |
self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) | |
self.conv_out = nn.Conv2d( | |
in_channels=self.conv_out_channels, | |
out_channels=image_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) # (batch_size, conv_out_channels, h, w) -> (batch_size, image_channels, h, w) | |
def forward(self, x, t, cond_input=None): | |
# x shape: (batch_size, c, h, w) | |
# cond_input is the conditioning vector | |
# For class conditioning, it will be a one-hot vector of size # (batch_size, num_classes) | |
if self.cond: | |
assert ( | |
cond_input is not None | |
), "Model initialized with conditioning so cond_input cannot be None" | |
if self.image_cond: | |
# Mask Conditioning | |
validate_image_conditional_input(cond_input, x) | |
image_cond = cond_input["image"] | |
image_cond = F.interpolate(image_cond, size=x.shape[-2:]) | |
image_cond = self.cond_conv_in(image_cond) | |
assert image_cond.shape[-2:] == x.shape[-2:] | |
x = torch.cat( | |
[x, image_cond], dim=1 | |
) # (batch_size, image_channels + image_cond_output_channels, h, w) | |
out = self.conv_in_concat(x) # (batch_size, down_channels[0], h, w) | |
else: | |
out = self.conv_in(x) # (batch_size, down_channels[0], h, w) | |
t_emb = get_time_embedding( | |
torch.as_tensor(t).long(), self.t_emb_dim | |
) # (batch_size, t_emb_dim) | |
t_emb = self.t_proj(t_emb) # (batch_size, t_emb_dim) | |
# Class Conditioning | |
if self.class_cond: | |
validate_class_conditional_input(cond_input, x, self.num_classes) | |
# Take the matrix for class embedding vectors and matrix multiply it with the embedding matrix to get the class embedding for all images in a batch | |
class_embed = torch.matmul( | |
cond_input["class"].float(), self.class_emb.weight | |
) # (batch_size, t_emb_dim) | |
t_emb += class_embed # Add the class embedding to the time embedding | |
context_hidden_states = None | |
# Only use context hidden states in cross-attention for text conditioning | |
if self.text_cond: | |
assert ( | |
"text" in cond_input | |
), "Model initialized with text conditioning but cond_input has no text information" | |
context_hidden_states = cond_input["text"] | |
down_outs = [] | |
for idx, down in enumerate(self.downs): | |
down_outs.append(out) | |
out = down( | |
out, t_emb, context_hidden_states | |
) # Use context_hidden_states for cross-attention | |
# out = (batch_size, c4, h / 4, w / 4) | |
for mid in self.mids: | |
out = mid(out, t_emb, context_hidden_states) | |
# out = (batch_size, c3, h / 4, w / 4) | |
for up in self.ups: | |
down_out = down_outs.pop() | |
out = up(out, down_out, t_emb, context_hidden_states) | |
# out = (batch_size, self.conv_out_channels, h, w) | |
out = F.silu(self.norm_out(out)) | |
out = self.conv_out( | |
out | |
) # (batch_size, self.conv_out_channels, h, w) -> (batch_size, image_channels, h, w) | |
return out # (batch_size, image_channels, h, w) |