55MetaFormer (https://github.com/sail-sg/metaformer),
66InceptionNeXt (https://github.com/sail-sg/inceptionnext)
77"""
8- from functools import partial
98from typing import Optional
109
1110import torch
1211import torch .nn as nn
13- import torch .nn .functional as F
14- from timm .models .layers import trunc_normal_ , DropPath , LayerNorm
15- from .vision_transformer import LayerScale
16- from ._manipulate import checkpoint_seq
17- from timm .models .registry import register_model
12+
1813from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
14+ from timm .layers import trunc_normal_ , DropPath , LayerNorm , LayerScale
15+ from ._builder import build_model_with_cfg
16+ from ._manipulate import checkpoint_seq
17+ from ._registry import register_model
1918
2019
2120class Stem (nn .Module ):
@@ -275,6 +274,7 @@ def __init__(
275274 act_layer = nn .GELU ,
276275 conv_ratio = 1.0 ,
277276 kernel_size = 7 ,
277+ stem_mid_norm = True ,
278278 ls_init_value = None ,
279279 drop_path_rate = 0. ,
280280 drop_rate = 0. ,
@@ -293,7 +293,13 @@ def __init__(
293293 num_stage = len (depths )
294294 self .num_stage = num_stage
295295
296- self .stem = Stem (in_chans , dims [0 ], act_layer = act_layer , norm_layer = norm_layer )
296+ self .stem = Stem (
297+ in_chans ,
298+ dims [0 ],
299+ mid_norm = stem_mid_norm ,
300+ act_layer = act_layer ,
301+ norm_layer = norm_layer ,
302+ )
297303 prev_dim = dims [0 ]
298304 dp_rates = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
299305 self .stages = nn .ModuleList ()
@@ -338,7 +344,7 @@ def forward_features(self, x):
338344 x = s (x )
339345 return x
340346
341- def forward_head (self , x ):
347+ def forward_head (self , x , pre_logits : bool = False ):
342348 x = x .mean ((1 , 2 ))
343349 x = self .norm (x )
344350 x = self .head (x )
@@ -350,6 +356,21 @@ def forward(self, x):
350356 return x
351357
352358
359+ def checkpoint_filter_fn (state_dict , model ):
360+ if 'model' in state_dict :
361+ state_dict = state_dict ['model' ]
362+
363+ import re
364+ out_dict = {}
365+ for k , v in state_dict .items ():
366+ k = k .replace ('downsample_layers.0.' , 'stem.' )
367+ k = re .sub (r'stages.([0-9]+).([0-9]+)' , r'stages.\1.blocks.\2' , k )
368+ k = re .sub (r'downsample_layers.([0-9]+)' , r'stages.\1.downsample' , k )
369+ out_dict [k ] = v
370+
371+ return out_dict
372+
373+
353374def _cfg (url = '' , ** kwargs ):
354375 return {
355376 'url' : url ,
@@ -376,105 +397,63 @@ def _cfg(url='', **kwargs):
376397}
377398
378399
400+ def _create_mambaout (variant , pretrained = False , ** kwargs ):
401+ model = build_model_with_cfg (
402+ MambaOut , variant , pretrained ,
403+ pretrained_filter_fn = checkpoint_filter_fn ,
404+ feature_cfg = dict (out_indices = (0 , 1 , 2 , 3 ), flatten_sequential = True ),
405+ ** kwargs ,
406+ )
407+ return model
408+
409+
379410# a series of MambaOut models
380411@register_model
381412def mambaout_femto (pretrained = False , ** kwargs ):
382- model = MambaOut (
383- depths = [3 , 3 , 9 , 3 ],
384- dims = [48 , 96 , 192 , 288 ],
385- ** kwargs )
386- model .default_cfg = default_cfgs ['mambaout_femto' ]
387- if pretrained :
388- state_dict = torch .hub .load_state_dict_from_url (
389- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
390- model .load_state_dict (state_dict )
391- return model
392-
413+ model_args = dict (depths = (3 , 3 , 9 , 3 ), dims = (48 , 96 , 192 , 288 ))
414+ return _create_mambaout ('mambaout_femto' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
393415
394416# Kobe Memorial Version with 24 Gated CNN blocks
395417@register_model
396418def mambaout_kobe (pretrained = False , ** kwargs ):
397- model = MambaOut (
398- depths = [3 , 3 , 15 , 3 ],
399- dims = [48 , 96 , 192 , 288 ],
400- ** kwargs )
401- model .default_cfg = default_cfgs ['mambaout_kobe' ]
402- if pretrained :
403- state_dict = torch .hub .load_state_dict_from_url (
404- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
405- model .load_state_dict (state_dict )
406- return model
407-
419+ model_args = dict (depths = [3 , 3 , 15 , 3 ], dims = [48 , 96 , 192 , 288 ])
420+ return _create_mambaout ('mambaout_kobe' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
408421
409422@register_model
410423def mambaout_tiny (pretrained = False , ** kwargs ):
411- model = MambaOut (
412- depths = [3 , 3 , 9 , 3 ],
413- dims = [96 , 192 , 384 , 576 ],
414- ** kwargs )
415- model .default_cfg = default_cfgs ['mambaout_tiny' ]
416- if pretrained :
417- state_dict = torch .hub .load_state_dict_from_url (
418- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
419- model .load_state_dict (state_dict )
420- return model
424+ model_args = dict (depths = [3 , 3 , 9 , 3 ], dims = [96 , 192 , 384 , 576 ])
425+ return _create_mambaout ('mambaout_tiny' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
421426
422427
423428@register_model
424429def mambaout_small (pretrained = False , ** kwargs ):
425- model = MambaOut (
426- depths = [3 , 4 , 27 , 3 ],
427- dims = [96 , 192 , 384 , 576 ],
428- ** kwargs )
429- model .default_cfg = default_cfgs ['mambaout_small' ]
430- if pretrained :
431- state_dict = torch .hub .load_state_dict_from_url (
432- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
433- model .load_state_dict (state_dict )
434- return model
430+ model_args = dict (depths = [3 , 4 , 27 , 3 ], dims = [96 , 192 , 384 , 576 ])
431+ return _create_mambaout ('mambaout_small' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
435432
436433
437434@register_model
438435def mambaout_base (pretrained = False , ** kwargs ):
439- model = MambaOut (
440- depths = [3 , 4 , 27 , 3 ],
441- dims = [128 , 256 , 512 , 768 ],
442- ** kwargs )
443- model .default_cfg = default_cfgs ['mambaout_base' ]
444- if pretrained :
445- state_dict = torch .hub .load_state_dict_from_url (
446- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
447- model .load_state_dict (state_dict )
448- return model
436+ model_args = dict (depths = [3 , 4 , 27 , 3 ], dims = [128 , 256 , 512 , 768 ])
437+ return _create_mambaout ('mambaout_base' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
449438
450439
451440@register_model
452441def mambaout_small_rw (pretrained = False , ** kwargs ):
453- model = MambaOut (
442+ model_args = dict (
454443 depths = [3 , 4 , 27 , 3 ],
455444 dims = [96 , 192 , 384 , 576 ],
445+ stem_mid_norm = False ,
456446 ls_init_value = 1e-6 ,
457- ** kwargs ,
458447 )
459- model .default_cfg = default_cfgs ['mambaout_small' ]
460- if pretrained :
461- state_dict = torch .hub .load_state_dict_from_url (
462- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
463- model .load_state_dict (state_dict )
464- return model
448+ return _create_mambaout ('mambaout_small_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
465449
466450
467451@register_model
468452def mambaout_base_rw (pretrained = False , ** kwargs ):
469- model = MambaOut (
453+ model_args = dict (
470454 depths = (3 , 4 , 27 , 3 ),
471455 dims = (128 , 256 , 512 , 768 ),
456+ stem_mid_norm = False ,
472457 ls_init_value = 1e-6 ,
473- ** kwargs
474458 )
475- model .default_cfg = default_cfgs ['mambaout_base' ]
476- if pretrained :
477- state_dict = torch .hub .load_state_dict_from_url (
478- url = model .default_cfg ['url' ], map_location = "cpu" , check_hash = True )
479- model .load_state_dict (state_dict )
480- return model
459+ return _create_mambaout ('mambaout_base_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments