|
import torch
|
|
import torch.nn as nn
|
|
|
|
from modules.basic_layers import (
|
|
SinusoidalPositionalEmbedding,
|
|
ResGatedBlock,
|
|
MaxViTBlock,
|
|
Downsample,
|
|
Upsample
|
|
)
|
|
|
|
class UnetDownBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
temb_channels: int = 128,
|
|
heads: int = 1,
|
|
window_size: int = 7,
|
|
window_attn: bool = True,
|
|
grid_attn: bool = True,
|
|
expansion_rate: int = 4,
|
|
num_conv_blocks: int = 2,
|
|
dropout: float = 0.0
|
|
):
|
|
super(UnetDownBlock, self).__init__()
|
|
self.pool = Downsample(
|
|
in_channels = in_channels,
|
|
out_channels = in_channels,
|
|
use_conv = True
|
|
)
|
|
in_channels = 3 * in_channels + 2
|
|
self.conv = nn.ModuleList([
|
|
ResGatedBlock(
|
|
in_channels = in_channels if i == 0 else out_channels,
|
|
out_channels = out_channels,
|
|
emb_channels = temb_channels,
|
|
gated_conv = True
|
|
) for i in range(num_conv_blocks)
|
|
])
|
|
self.maxvit = MaxViTBlock(
|
|
channels = out_channels,
|
|
|
|
heads = heads,
|
|
window_size = window_size,
|
|
window_attn = window_attn,
|
|
grid_attn = grid_attn,
|
|
expansion_rate = expansion_rate,
|
|
dropout = dropout,
|
|
emb_channels = temb_channels
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
warp0: torch.Tensor,
|
|
warp1: torch.Tensor,
|
|
temb: torch.Tensor
|
|
):
|
|
x = self.pool(x)
|
|
x = torch.cat([x, warp0, warp1], dim=1)
|
|
for conv in self.conv:
|
|
x = conv(x, temb)
|
|
x = self.maxvit(x, temb)
|
|
return x
|
|
|
|
class UnetMiddleBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
mid_channels: int,
|
|
out_channels: int,
|
|
temb_channels: int = 128,
|
|
heads: int = 1,
|
|
window_size: int = 7,
|
|
window_attn: bool = True,
|
|
grid_attn: bool = True,
|
|
expansion_rate: int = 4,
|
|
dropout: float = 0.0
|
|
):
|
|
super(UnetMiddleBlock, self).__init__()
|
|
|
|
self.middle_blocks = nn.ModuleList([
|
|
ResGatedBlock(
|
|
in_channels = in_channels,
|
|
out_channels = mid_channels,
|
|
emb_channels = temb_channels,
|
|
gated_conv = True
|
|
),
|
|
MaxViTBlock(
|
|
channels = mid_channels,
|
|
|
|
heads = heads,
|
|
window_size = window_size,
|
|
window_attn = window_attn,
|
|
grid_attn = grid_attn,
|
|
expansion_rate = expansion_rate,
|
|
dropout = dropout,
|
|
emb_channels = temb_channels
|
|
),
|
|
ResGatedBlock(
|
|
in_channels = mid_channels,
|
|
out_channels = out_channels,
|
|
emb_channels = temb_channels,
|
|
gated_conv = True
|
|
)
|
|
])
|
|
|
|
def forward(self, x, temb):
|
|
for block in self.middle_blocks:
|
|
x = block(x, temb)
|
|
return x
|
|
|
|
class UnetUpBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
temb_channels: int = 128,
|
|
heads: int = 1,
|
|
window_size: int = 7,
|
|
window_attn: bool = True,
|
|
grid_attn: bool = True,
|
|
expansion_rate: int = 4,
|
|
num_conv_blocks: int = 2,
|
|
dropout: float = 0.0
|
|
):
|
|
super(UnetUpBlock, self).__init__()
|
|
in_channels = 2 * in_channels
|
|
self.maxvit = MaxViTBlock(
|
|
channels = in_channels,
|
|
|
|
heads = heads,
|
|
window_size = window_size,
|
|
window_attn = window_attn,
|
|
grid_attn = grid_attn,
|
|
expansion_rate = expansion_rate,
|
|
dropout = dropout,
|
|
emb_channels = temb_channels
|
|
)
|
|
self.upsample = Upsample(
|
|
in_channels = in_channels,
|
|
out_channels = in_channels,
|
|
use_conv = True
|
|
)
|
|
self.conv = nn.ModuleList([
|
|
ResGatedBlock(
|
|
in_channels if i == 0 else out_channels,
|
|
out_channels,
|
|
emb_channels = temb_channels,
|
|
gated_conv = True
|
|
) for i in range(num_conv_blocks)
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
skip_connection: torch.Tensor,
|
|
temb: torch.Tensor
|
|
):
|
|
x = torch.cat([x, skip_connection], dim=1)
|
|
x = self.maxvit(x, temb)
|
|
x = self.upsample(x)
|
|
for conv in self.conv:
|
|
x = conv(x, temb)
|
|
return x
|
|
|
|
class Synthesis(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
channels: list[int],
|
|
temb_channels: int,
|
|
heads: int = 1,
|
|
window_size: int = 7,
|
|
window_attn: bool = True,
|
|
grid_attn: bool = True,
|
|
expansion_rate: int = 4,
|
|
num_conv_blocks: int = 2,
|
|
dropout: float = 0.0
|
|
):
|
|
super(Synthesis, self).__init__()
|
|
|
|
|
|
self.t_pos_encoding = SinusoidalPositionalEmbedding(temb_channels)
|
|
|
|
self.input_blocks = nn.ModuleList([
|
|
nn.Conv2d(3*in_channels + 4, channels[0], kernel_size=3, padding=1),
|
|
ResGatedBlock(
|
|
in_channels = channels[0],
|
|
out_channels = channels[0],
|
|
emb_channels = temb_channels,
|
|
gated_conv = True
|
|
)
|
|
])
|
|
|
|
self.down_blocks = nn.ModuleList([
|
|
UnetDownBlock(
|
|
|
|
channels[i],
|
|
channels[i + 1],
|
|
temb_channels,
|
|
heads = heads,
|
|
window_size = window_size,
|
|
window_attn = window_attn,
|
|
grid_attn = grid_attn,
|
|
expansion_rate = expansion_rate,
|
|
num_conv_blocks = num_conv_blocks,
|
|
dropout = dropout,
|
|
) for i in range(len(channels) - 1)
|
|
])
|
|
|
|
self.middle_block = UnetMiddleBlock(
|
|
in_channels = channels[-1],
|
|
mid_channels = channels[-1],
|
|
out_channels = channels[-1],
|
|
temb_channels = temb_channels,
|
|
heads = heads,
|
|
window_size = window_size,
|
|
window_attn = window_attn,
|
|
grid_attn = grid_attn,
|
|
expansion_rate = expansion_rate,
|
|
dropout = dropout,
|
|
)
|
|
|
|
self.up_blocks = nn.ModuleList([
|
|
UnetUpBlock(
|
|
channels[i + 1],
|
|
channels[i],
|
|
temb_channels,
|
|
heads = heads,
|
|
window_size = window_size,
|
|
window_attn = window_attn,
|
|
grid_attn = grid_attn,
|
|
expansion_rate = expansion_rate,
|
|
num_conv_blocks = num_conv_blocks,
|
|
dropout = dropout,
|
|
) for i in reversed(range(len(channels) - 1))
|
|
])
|
|
|
|
self.output_blocks = nn.ModuleList([
|
|
ResGatedBlock(
|
|
in_channels = channels[0],
|
|
out_channels = channels[0],
|
|
emb_channels = temb_channels,
|
|
gated_conv = True
|
|
),
|
|
nn.Conv2d(channels[0], in_channels, kernel_size=3, padding=1)
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
warp0: list[torch.Tensor],
|
|
warp1: list[torch.Tensor],
|
|
temb: torch.Tensor
|
|
):
|
|
temb = temb.unsqueeze(-1).type(torch.float)
|
|
temb = self.t_pos_encoding(temb)
|
|
|
|
x = self.input_blocks[0](torch.cat([x, warp0[0], warp1[0]], dim=1))
|
|
x = self.input_blocks[1](x, temb)
|
|
|
|
features = []
|
|
for i, down_block in enumerate(self.down_blocks):
|
|
x = down_block(x, warp0[i + 1], warp1[i + 1], temb)
|
|
features.append(x)
|
|
|
|
x = self.middle_block(x, temb)
|
|
|
|
for i, up_block in enumerate(self.up_blocks):
|
|
x = up_block(x, features[-(i + 1)], temb)
|
|
|
|
x = self.output_blocks[0](x, temb)
|
|
x = self.output_blocks[1](x)
|
|
|
|
return x |