import torch import torch.nn as nn from .tiny_block import TinyBlock from transformers import MambaConfig, MambaModel # from .conmamba import ConMamba class CSPTinyLayer(nn.Module): def __init__(self, in_channels, out_channels, num_blocks, ssm=False): super(CSPTinyLayer, self).__init__() self.ssm = ssm # Split channels self.split_channels = in_channels // 2 if self.ssm: # Mamba Blocks configuration = MambaConfig(vocab_size=0, hidden_size=self.split_channels, num_hidden_layers=num_blocks) self.mamba_blocks = MambaModel(configuration) # mamba_config = { # 'd_state': self.split_channels, # 'expand': 2, # 'd_conv': 4, # 'bidirectional': True # } # self.mamba_blocks = ConMamba( # num_blocks=num_blocks, # channels=self.split_channels, # height=8, # width=8, # mamba_config=mamba_config # ) else: # TinyBlocks self.tiny_blocks = nn.Sequential( *[TinyBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)] ) # Transition layer to adjust channel dimensions self.transition = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): # Split input into two parts p1 = x[:, :self.split_channels, :, :] p2 = x[:, self.split_channels:, :, :] if self.ssm: # Reshape to fit Mamba B, C, H, W = p2.shape p2 = p2.permute(0, 2, 3, 1) # [B, H, W, C] p2 = p2.reshape(B, H * W, C) # [B, L, C], L = H * W # Process p2 through MambaBlocks p2_out = self.mamba_blocks(inputs_embeds=p2).last_hidden_state # p2_out = self.mamba_blocks(p2) # Reshape back to original dimension p2_out = p2_out.reshape(B, H, W, -1) p2_out = p2_out.permute(0, 3, 1, 2) # [B, C, H, W] else: # Process p2 through TinyBlocks p2_out = self.tiny_blocks(p2) # Concatenate p1 and processed p2 concatenated = torch.cat((p1, p2_out), dim=1) # Apply transition layer out = self.transition(concatenated) return out if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = CSPTinyLayer(32, 32, 2, True).to(device) print(model) dummy_input = torch.randn(256, 32, 8, 8).to(device) output = model(dummy_input) print(output.shape)