1212from torch import nn
1313
1414from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
15- from timm .layers import trunc_normal_ , DropPath , LayerNorm , LayerScale , ClNormMlpClassifierHead
15+ from timm .layers import trunc_normal_ , DropPath , LayerNorm , LayerScale , ClNormMlpClassifierHead , get_act_layer
1616from ._builder import build_model_with_cfg
1717from ._manipulate import checkpoint_seq
1818from ._registry import register_model
@@ -318,10 +318,12 @@ def __init__(
318318 super ().__init__ ()
319319 self .num_classes = num_classes
320320 self .drop_rate = drop_rate
321+ self .output_fmt = 'NHWC'
321322 if not isinstance (depths , (list , tuple )):
322323 depths = [depths ] # it means the model has only one stage
323324 if not isinstance (dims , (list , tuple )):
324325 dims = [dims ]
326+ act_layer = get_act_layer (act_layer )
325327
326328 num_stage = len (depths )
327329 self .num_stage = num_stage
@@ -456,7 +458,7 @@ def checkpoint_filter_fn(state_dict, model):
456458def _cfg (url = '' , ** kwargs ):
457459 return {
458460 'url' : url ,
459- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
461+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : ( 7 , 7 ) ,
460462 'crop_pct' : 1.0 , 'interpolation' : 'bicubic' ,
461463 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD , 'classifier' : 'head.fc' ,
462464 ** kwargs
@@ -477,6 +479,7 @@ def _cfg(url='', **kwargs):
477479 'mambaout_small_rw' : _cfg (),
478480 'mambaout_base_slim_rw' : _cfg (),
479481 'mambaout_base_plus_rw' : _cfg (),
482+ 'test_mambaout' : _cfg (input_size = (3 , 160 , 160 ), pool_size = (5 , 5 )),
480483}
481484
482485
@@ -554,9 +557,26 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs):
554557 depths = (3 , 4 , 27 , 3 ),
555558 dims = (128 , 256 , 512 , 768 ),
556559 expansion_ratio = 3.0 ,
560+ conv_ratio = 1.5 ,
557561 stem_mid_norm = False ,
558562 downsample = 'conv_nf' ,
559563 ls_init_value = 1e-6 ,
564+ act_layer = 'silu' ,
560565 head_fn = 'norm_mlp' ,
561566 )
562567 return _create_mambaout ('mambaout_base_plus_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
568+
569+
570+ @register_model
571+ def test_mambaout (pretrained = False , ** kwargs ):
572+ model_args = dict (
573+ depths = (1 , 1 , 3 , 1 ),
574+ dims = (16 , 32 , 48 , 64 ),
575+ expansion_ratio = 3 ,
576+ stem_mid_norm = False ,
577+ downsample = 'conv_nf' ,
578+ ls_init_value = 1e-4 ,
579+ act_layer = 'silu' ,
580+ head_fn = 'norm_mlp' ,
581+ )
582+ return _create_mambaout ('test_mambaout' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments