import torch import torch.nn as nn from modules.basic_layers import ( SinusoidalPositionalEmbedding, ResGatedBlock, MaxViTBlock, Downsample, Upsample ) class UnetDownBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int = 128, heads: int = 1, window_size: int = 7, window_attn: bool = True, grid_attn: bool = True, expansion_rate: int = 4, num_conv_blocks: int = 2, dropout: float = 0.0 ): super(UnetDownBlock, self).__init__() self.pool = Downsample( in_channels = in_channels, out_channels = in_channels, use_conv = True ) in_channels = 3 * in_channels + 2 self.conv = nn.ModuleList([ ResGatedBlock( in_channels = in_channels if i == 0 else out_channels, out_channels = out_channels, emb_channels = temb_channels, gated_conv = True ) for i in range(num_conv_blocks) ]) self.maxvit = MaxViTBlock( channels = out_channels, #latent_dim = out_channels // 6, heads = heads, window_size = window_size, window_attn = window_attn, grid_attn = grid_attn, expansion_rate = expansion_rate, dropout = dropout, emb_channels = temb_channels ) def forward( self, x: torch.Tensor, warp0: torch.Tensor, warp1: torch.Tensor, temb: torch.Tensor ): x = self.pool(x) x = torch.cat([x, warp0, warp1], dim=1) for conv in self.conv: x = conv(x, temb) x = self.maxvit(x, temb) return x class UnetMiddleBlock(nn.Module): def __init__( self, in_channels: int, mid_channels: int, out_channels: int, temb_channels: int = 128, heads: int = 1, window_size: int = 7, window_attn: bool = True, grid_attn: bool = True, expansion_rate: int = 4, dropout: float = 0.0 ): super(UnetMiddleBlock, self).__init__() self.middle_blocks = nn.ModuleList([ ResGatedBlock( in_channels = in_channels, out_channels = mid_channels, emb_channels = temb_channels, gated_conv = True ), MaxViTBlock( channels = mid_channels, #latent_dim = mid_channels // 6, heads = heads, window_size = window_size, window_attn = window_attn, grid_attn = grid_attn, expansion_rate = expansion_rate, dropout = dropout, emb_channels = temb_channels ), ResGatedBlock( in_channels = mid_channels, out_channels = out_channels, emb_channels = temb_channels, gated_conv = True ) ]) def forward(self, x, temb): for block in self.middle_blocks: x = block(x, temb) return x class UnetUpBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int = 128, heads: int = 1, window_size: int = 7, window_attn: bool = True, grid_attn: bool = True, expansion_rate: int = 4, num_conv_blocks: int = 2, dropout: float = 0.0 ): super(UnetUpBlock, self).__init__() in_channels = 2 * in_channels self.maxvit = MaxViTBlock( channels = in_channels, #latent_dim = in_channels // 6, heads = heads, window_size = window_size, window_attn = window_attn, grid_attn = grid_attn, expansion_rate = expansion_rate, dropout = dropout, emb_channels = temb_channels ) self.upsample = Upsample( in_channels = in_channels, out_channels = in_channels, use_conv = True ) self.conv = nn.ModuleList([ ResGatedBlock( in_channels if i == 0 else out_channels, out_channels, emb_channels = temb_channels, gated_conv = True ) for i in range(num_conv_blocks) ]) def forward( self, x: torch.Tensor, skip_connection: torch.Tensor, temb: torch.Tensor ): x = torch.cat([x, skip_connection], dim=1) x = self.maxvit(x, temb) x = self.upsample(x) for conv in self.conv: x = conv(x, temb) return x class Synthesis(nn.Module): def __init__( self, in_channels: int, channels: list[int], temb_channels: int, heads: int = 1, window_size: int = 7, window_attn: bool = True, grid_attn: bool = True, expansion_rate: int = 4, num_conv_blocks: int = 2, dropout: float = 0.0 ): super(Synthesis, self).__init__() self.t_pos_encoding = SinusoidalPositionalEmbedding(temb_channels) self.input_blocks = nn.ModuleList([ nn.Conv2d(3*in_channels + 4, channels[0], kernel_size=3, padding=1), ResGatedBlock( in_channels = channels[0], out_channels = channels[0], emb_channels = temb_channels, gated_conv = True ) ]) self.down_blocks = nn.ModuleList([ UnetDownBlock( #3 * channels[i] + 2, channels[i], channels[i + 1], temb_channels, heads = heads, window_size = window_size, window_attn = window_attn, grid_attn = grid_attn, expansion_rate = expansion_rate, num_conv_blocks = num_conv_blocks, dropout = dropout, ) for i in range(len(channels) - 1) ]) self.middle_block = UnetMiddleBlock( in_channels = channels[-1], mid_channels = channels[-1], out_channels = channels[-1], temb_channels = temb_channels, heads = heads, window_size = window_size, window_attn = window_attn, grid_attn = grid_attn, expansion_rate = expansion_rate, dropout = dropout, ) self.up_blocks = nn.ModuleList([ UnetUpBlock( channels[i + 1], channels[i], temb_channels, heads = heads, window_size = window_size, window_attn = window_attn, grid_attn = grid_attn, expansion_rate = expansion_rate, num_conv_blocks = num_conv_blocks, dropout = dropout, ) for i in reversed(range(len(channels) - 1)) ]) self.output_blocks = nn.ModuleList([ ResGatedBlock( in_channels = channels[0], out_channels = channels[0], emb_channels = temb_channels, gated_conv = True ), nn.Conv2d(channels[0], in_channels, kernel_size=3, padding=1) ]) def forward( self, x: torch.Tensor, warp0: list[torch.Tensor], warp1: list[torch.Tensor], temb: torch.Tensor ): temb = temb.unsqueeze(-1).type(torch.float) temb = self.t_pos_encoding(temb) x = self.input_blocks[0](torch.cat([x, warp0[0], warp1[0]], dim=1)) x = self.input_blocks[1](x, temb) features = [] for i, down_block in enumerate(self.down_blocks): x = down_block(x, warp0[i + 1], warp1[i + 1], temb) features.append(x) x = self.middle_block(x, temb) for i, up_block in enumerate(self.up_blocks): x = up_block(x, features[-(i + 1)], temb) x = self.output_blocks[0](x, temb) x = self.output_blocks[1](x) return x