Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from .tiny_block import TinyBlock | |
from transformers import MambaConfig, MambaModel | |
# from .conmamba import ConMamba | |
class CSPTinyLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, num_blocks, ssm=False): | |
super(CSPTinyLayer, self).__init__() | |
self.ssm = ssm | |
# Split channels | |
self.split_channels = in_channels // 2 | |
if self.ssm: | |
# Mamba Blocks | |
configuration = MambaConfig(vocab_size=0, hidden_size=self.split_channels, num_hidden_layers=num_blocks) | |
self.mamba_blocks = MambaModel(configuration) | |
# mamba_config = { | |
# 'd_state': self.split_channels, | |
# 'expand': 2, | |
# 'd_conv': 4, | |
# 'bidirectional': True | |
# } | |
# self.mamba_blocks = ConMamba( | |
# num_blocks=num_blocks, | |
# channels=self.split_channels, | |
# height=8, | |
# width=8, | |
# mamba_config=mamba_config | |
# ) | |
else: | |
# TinyBlocks | |
self.tiny_blocks = nn.Sequential( | |
*[TinyBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)] | |
) | |
# Transition layer to adjust channel dimensions | |
self.transition = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
# Split input into two parts | |
p1 = x[:, :self.split_channels, :, :] | |
p2 = x[:, self.split_channels:, :, :] | |
if self.ssm: | |
# Reshape to fit Mamba | |
B, C, H, W = p2.shape | |
p2 = p2.permute(0, 2, 3, 1) # [B, H, W, C] | |
p2 = p2.reshape(B, H * W, C) # [B, L, C], L = H * W | |
# Process p2 through MambaBlocks | |
p2_out = self.mamba_blocks(inputs_embeds=p2).last_hidden_state | |
# p2_out = self.mamba_blocks(p2) | |
# Reshape back to original dimension | |
p2_out = p2_out.reshape(B, H, W, -1) | |
p2_out = p2_out.permute(0, 3, 1, 2) # [B, C, H, W] | |
else: | |
# Process p2 through TinyBlocks | |
p2_out = self.tiny_blocks(p2) | |
# Concatenate p1 and processed p2 | |
concatenated = torch.cat((p1, p2_out), dim=1) | |
# Apply transition layer | |
out = self.transition(concatenated) | |
return out | |
if __name__ == "__main__": | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
model = CSPTinyLayer(32, 32, 2, True).to(device) | |
print(model) | |
dummy_input = torch.randn(256, 32, 8, 8).to(device) | |
output = model(dummy_input) | |
print(output.shape) | |