File size: 626 Bytes
717b269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from .vit import vit

def create_backbone(cfg):
    if cfg.MODEL.BACKBONE.TYPE == 'vit':
        return vit(cfg)
    elif cfg.MODEL.BACKBONE.TYPE == 'fast_vit':
        import torch 
        import sys 
        from timm.models import create_model
        #from models.modules.mobileone import reparameterize_model
        fast_vit = create_model("fastvit_ma36", drop_path_rate=0.2)
        checkpoint = torch.load('./pretrained_models/fastvit_ma36.pt')
        fast_vit.load_state_dict(checkpoint['state_dict'])
        return fast_vit
        
    else:
        raise NotImplementedError('Backbone type is not implemented')