[MIA'25] MambaMIM: Pre-training Mamba with State Space Token Interpolation and its Application to Medical Image Segmentation

MambaMIM




​ arXiv github License: Apache2.0

News

  • MambaMIM accepted by Medical Image Analyses (MIA'25) ! πŸ₯°
  • Weights released ! 😎
  • Code released ! 😘
  • Code and weights will be released soon ! 😘
  • [2024/08/16] Paper released !

TODOs

  • Paper released
  • Code released
  • Weight released

Getting Started

Download weights

Name Resolution Intensities Spacing Weights
MambaMIM 96 x 96 x 96 [-175, - 250] 1.5 x 1.5 x 1.5 mm Google Drive (87MB)

Prepare Environments

conda create -n mambamim python=3.9
conda activate mambamim
pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
pip install packaging timm==0.5.4
pip install transformers==4.34.1 typed-argument-parser
pip install numpy==1.21.2 opencv-python==4.5.5.64 opencv-python-headless==4.5.5.64
pip install 'monai[all]'
pip install monai==1.2.0
pip install causal_conv1d-1.2.0.post2+cu118torch1.13cxx11abiTRUE-cp38-cp38-linux_x86_64.whl
pip install mamba_ssm-1.2.0.post1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl

Prepare Datasets

We recommend that you convert the dataset into the nnUNet format.

└── MambaMIM
    β”œβ”€β”€ data
        β”œβ”€β”€ Dataset060_TotalSegmentator
            └── imagesTr
                β”œβ”€β”€ xxx_0000.nii.gz
                β”œβ”€β”€ ...
        β”œβ”€β”€ Dataset006_FLARE2022
            └── imagesTr
                β”œβ”€β”€ xxx_0000.nii.gz
                β”œβ”€β”€ ...
        └── Other_dataset
            └── imagesTr
                β”œβ”€β”€ xxx_0000.nii.gz
                β”œβ”€β”€ ...

An example dataset.json will be generated in ./data

The content should be like below:

{
    "training": [
        {
            "image": "./Dataset060_TotalSegmentator/imagesTr/xxx_0000.nii.gz"
        },
        {
            "image": "./Dataset006_FLARE2022/imagesTr/xxx_0000.nii.gz"
        },
    ]
}

Start Training

MambaMIM

Run training on multi-GPU :

# An example of training on 4 GPUs with DDP
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12351 main.py --exp_name=debug --data_path=./data  --model=mambamim --bs=16  --exp_dir=debug_mambamim_ddp_4

Run training on the single-GPU :

# An example of training on the single GPU
python main.py --exp_name=debug --data_path=./data --model=mambamim --bs=4 --exp_dir=debug_mambamim

Fine-tuning

Load pre-training weights :

# An example of Fine-tuning on BTCV (num_classes=14)
from models.network.hymamba import build_hybird

model = build_hybird(in_channel=1, n_classes=14, img_size=96).cuda()

model_dict = torch.load("mambamim_mask75.pth")   

if model.load_state_dict(model_dict, strict=False):
    print("MambaMIM use pretrained weights successfully !")

Downstream pipeline can be referred to [UNETR](research-contributions/UNETR/BTCV at main Β· Project-MONAI/research-contributions (github.com)).

Acknowledgements:

This code uses helper functions from SparK and HySparK.

Citation

If the code, paper and weights help your research, please cite:

@article{tang2024mambamim,
  title={MambaMIM: Pre-training Mamba with State Space Token-interpolation},
  author={Tang, Fenghe and Nian, Bingkun and Li, Yingtai and Yang, Jie and Wei, Liu and Zhou, S Kevin},
  journal={arXiv preprint arXiv:2408.08070},
  year={2024}
}

License

This project is released under the Apache 2.0 license. Please see the LICENSE file for more information.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support