55MetaFormer (https://github.com/sail-sg/metaformer),
66InceptionNeXt (https://github.com/sail-sg/inceptionnext)
77"""
8+ from collections import OrderedDict
89from typing import Optional
910
1011import torch
@@ -120,7 +121,7 @@ class MlpHead(nn.Module):
120121
121122 def __init__ (
122123 self ,
123- dim ,
124+ in_features ,
124125 num_classes = 1000 ,
125126 pool_type = 'avg' ,
126127 act_layer = nn .GELU ,
@@ -130,27 +131,47 @@ def __init__(
130131 bias = True ,
131132 ):
132133 super ().__init__ ()
133- hidden_features = int (mlp_ratio * dim )
134+ if mlp_ratio is not None :
135+ hidden_size = int (mlp_ratio * in_features )
136+ else :
137+ hidden_size = None
134138 self .pool_type = pool_type
139+ self .in_features = in_features
140+ self .hidden_size = hidden_size or in_features
141+
142+ self .norm = norm_layer (in_features )
143+ if hidden_size :
144+ self .pre_logits = nn .Sequential (OrderedDict ([
145+ ('fc' , nn .Linear (in_features , hidden_size )),
146+ ('act' , act_layer ()),
147+ ('norm' , norm_layer (hidden_size ))
148+ ]))
149+ self .num_features = hidden_size
150+ else :
151+ self .num_features = in_features
152+ self .pre_logits = nn .Identity ()
135153
136- self .norm1 = norm_layer (dim )
137- self .fc1 = nn .Linear (dim , hidden_features , bias = bias )
138- self .act = act_layer ()
139- self .norm2 = norm_layer (hidden_features )
140- self .fc2 = nn .Linear (hidden_features , num_classes , bias = bias )
154+ self .fc = nn .Linear (hidden_size , num_classes , bias = bias )
141155 self .head_dropout = nn .Dropout (drop_rate )
142156
157+ def reset (self , num_classes : int , pool_type : Optional [str ] = None , reset_other : bool = False ):
158+ if pool_type is not None :
159+ self .pool_type = pool_type
160+ if reset_other :
161+ self .norm = nn .Identity ()
162+ self .pre_logits = nn .Identity ()
163+ self .num_features = self .in_features
164+ self .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
165+
143166 def forward (self , x , pre_logits : bool = False ):
144167 if self .pool_type == 'avg' :
145168 x = x .mean ((1 , 2 ))
146- x = self .norm1 (x )
147- x = self .fc1 (x )
148- x = self .act (x )
149- x = self .norm2 (x )
169+ x = self .norm (x )
170+ x = self .pre_logits (x )
150171 x = self .head_dropout (x )
151172 if pre_logits :
152173 return x
153- x = self .fc2 (x )
174+ x = self .fc (x )
154175 return x
155176
156177
@@ -284,6 +305,7 @@ def __init__(
284305 norm_layer = LayerNorm ,
285306 act_layer = nn .GELU ,
286307 conv_ratio = 1.0 ,
308+ expansion_ratio = 8 / 3 ,
287309 kernel_size = 7 ,
288310 stem_mid_norm = True ,
289311 ls_init_value = None ,
@@ -303,6 +325,7 @@ def __init__(
303325
304326 num_stage = len (depths )
305327 self .num_stage = num_stage
328+ self .feature_info = []
306329
307330 self .stem = Stem (
308331 in_chans ,
@@ -313,16 +336,20 @@ def __init__(
313336 )
314337 prev_dim = dims [0 ]
315338 dp_rates = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
316- self .stages = nn .ModuleList ()
317339 cur = 0
340+ curr_stride = 4
341+ self .stages = nn .Sequential ()
318342 for i in range (num_stage ):
319343 dim = dims [i ]
344+ stride = 2 if curr_stride == 2 or i > 0 else 1
345+ curr_stride *= stride
320346 stage = MambaOutStage (
321347 dim = prev_dim ,
322348 dim_out = dim ,
323349 depth = depths [i ],
324350 kernel_size = kernel_size ,
325351 conv_ratio = conv_ratio ,
352+ expansion_ratio = expansion_ratio ,
326353 downsample = downsample if i > 0 else '' ,
327354 ls_init_value = ls_init_value ,
328355 norm_layer = norm_layer ,
@@ -331,6 +358,8 @@ def __init__(
331358 )
332359 self .stages .append (stage )
333360 prev_dim = dim
361+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
362+ self .feature_info += [dict (num_chs = prev_dim , reduction = curr_stride , module = f'stages.{ i } ' )]
334363 cur += depths [i ]
335364
336365 if head_fn == 'default' :
@@ -352,6 +381,8 @@ def __init__(
352381 norm_layer = norm_layer ,
353382 drop_rate = drop_rate ,
354383 )
384+ self .num_features = prev_dim
385+ self .hidden_size = self .head .num_features
355386
356387 self .apply (self ._init_weights )
357388
@@ -362,13 +393,31 @@ def _init_weights(self, m):
362393 nn .init .constant_ (m .bias , 0 )
363394
364395 @torch .jit .ignore
365- def no_weight_decay (self ):
366- return {}
396+ def group_matcher (self , coarse = False ):
397+ return dict (
398+ stem = r'^stem' ,
399+ blocks = r'^stages\.(\d+)' if coarse else [
400+ (r'^stages\.(\d+)\.downsample' , (0 ,)), # blocks
401+ (r'^stages\.(\d+)\.blocks\.(\d+)' , None ),
402+ ]
403+ )
404+
405+ @torch .jit .ignore
406+ def set_grad_checkpointing (self , enable = True ):
407+ for s in self .stages :
408+ s .grad_checkpointing = enable
409+
410+ @torch .jit .ignore
411+ def get_classifier (self ) -> nn .Module :
412+ return self .head .fc
413+
414+ def reset_classifier (self , num_classes : int , global_pool : Optional [str ] = None ):
415+ self .num_classes = num_classes
416+ self .head .reset (num_classes , global_pool )
367417
368418 def forward_features (self , x ):
369419 x = self .stem (x )
370- for s in self .stages :
371- x = s (x )
420+ x = self .stages (x )
372421 return x
373422
374423 def forward_head (self , x , pre_logits : bool = False ):
@@ -391,10 +440,14 @@ def checkpoint_filter_fn(state_dict, model):
391440 k = k .replace ('downsample_layers.0.' , 'stem.' )
392441 k = re .sub (r'stages.([0-9]+).([0-9]+)' , r'stages.\1.blocks.\2' , k )
393442 k = re .sub (r'downsample_layers.([0-9]+)' , r'stages.\1.downsample' , k )
443+ # remap head names
394444 if k .startswith ('norm.' ):
395- k = k .replace ('norm.' , 'head.norm1.' )
396- elif k .startswith ('head.norm.' ):
397- k = k .replace ('head.norm.' , 'head.norm2.' )
445+ # this is moving to head since it's after the pooling
446+ k = k .replace ('norm.' , 'head.norm.' )
447+ elif k .startswith ('head.' ):
448+ k = k .replace ('head.fc1.' , 'head.pre_logits.fc.' )
449+ k = k .replace ('head.norm.' , 'head.pre_logits.norm.' )
450+ k = k .replace ('head.fc2.' , 'head.fc.' )
398451 out_dict [k ] = v
399452
400453 return out_dict
@@ -405,7 +458,7 @@ def _cfg(url='', **kwargs):
405458 'url' : url ,
406459 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
407460 'crop_pct' : 1.0 , 'interpolation' : 'bicubic' ,
408- 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD , 'classifier' : 'head' ,
461+ 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD , 'classifier' : 'head.fc ' ,
409462 ** kwargs
410463 }
411464
@@ -422,7 +475,8 @@ def _cfg(url='', **kwargs):
422475 'mambaout_base' : _cfg (
423476 url = 'https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth' ),
424477 'mambaout_small_rw' : _cfg (),
425- 'mambaout_base_rw' : _cfg (),
478+ 'mambaout_base_slim_rw' : _cfg (),
479+ 'mambaout_base_plus_rw' : _cfg (),
426480}
427481
428482
@@ -480,12 +534,29 @@ def mambaout_small_rw(pretrained=False, **kwargs):
480534
481535
482536@register_model
483- def mambaout_base_rw (pretrained = False , ** kwargs ):
537+ def mambaout_base_slim_rw (pretrained = False , ** kwargs ):
484538 model_args = dict (
485539 depths = (3 , 4 , 27 , 3 ),
486540 dims = (128 , 256 , 512 , 768 ),
541+ expansion_ratio = 2.5 ,
542+ conv_ratio = 1.25 ,
487543 stem_mid_norm = False ,
544+ downsample = 'conv_nf' ,
545+ ls_init_value = 1e-6 ,
546+ head_fn = 'norm_mlp' ,
547+ )
548+ return _create_mambaout ('mambaout_base_slim_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
549+
550+
551+ @register_model
552+ def mambaout_base_plus_rw (pretrained = False , ** kwargs ):
553+ model_args = dict (
554+ depths = (3 , 4 , 27 , 3 ),
555+ dims = (128 , 256 , 512 , 768 ),
556+ expansion_ratio = 3.0 ,
557+ stem_mid_norm = False ,
558+ downsample = 'conv_nf' ,
488559 ls_init_value = 1e-6 ,
489560 head_fn = 'norm_mlp' ,
490561 )
491- return _create_mambaout ('mambaout_base_rw ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
562+ return _create_mambaout ('mambaout_base_plus_rw ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments