88from typing import Optional
99
1010import torch
11- import torch . nn as nn
11+ from torch import nn
1212
1313from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
14- from timm .layers import trunc_normal_ , DropPath , LayerNorm , LayerScale
14+ from timm .layers import trunc_normal_ , DropPath , LayerNorm , LayerScale , ClNormMlpClassifierHead
1515from ._builder import build_model_with_cfg
1616from ._manipulate import checkpoint_seq
1717from ._registry import register_model
@@ -122,6 +122,7 @@ def __init__(
122122 self ,
123123 dim ,
124124 num_classes = 1000 ,
125+ pool_type = 'avg' ,
125126 act_layer = nn .GELU ,
126127 mlp_ratio = 4 ,
127128 norm_layer = LayerNorm ,
@@ -130,17 +131,25 @@ def __init__(
130131 ):
131132 super ().__init__ ()
132133 hidden_features = int (mlp_ratio * dim )
134+ self .pool_type = pool_type
135+
136+ self .norm1 = norm_layer (dim )
133137 self .fc1 = nn .Linear (dim , hidden_features , bias = bias )
134138 self .act = act_layer ()
135- self .norm = norm_layer (hidden_features )
139+ self .norm2 = norm_layer (hidden_features )
136140 self .fc2 = nn .Linear (hidden_features , num_classes , bias = bias )
137141 self .head_dropout = nn .Dropout (drop_rate )
138142
139- def forward (self , x ):
143+ def forward (self , x , pre_logits : bool = False ):
144+ if self .pool_type == 'avg' :
145+ x = x .mean ((1 , 2 ))
146+ x = self .norm1 (x )
140147 x = self .fc1 (x )
141148 x = self .act (x )
142- x = self .norm (x )
149+ x = self .norm2 (x )
143150 x = self .head_dropout (x )
151+ if pre_logits :
152+ return x
144153 x = self .fc2 (x )
145154 return x
146155
@@ -208,7 +217,7 @@ def __init__(
208217 expansion_ratio = 8 / 3 ,
209218 kernel_size = 7 ,
210219 conv_ratio = 1.0 ,
211- downsample : bool = False ,
220+ downsample : str = '' ,
212221 ls_init_value : Optional [float ] = None ,
213222 norm_layer = LayerNorm ,
214223 act_layer = nn .GELU ,
@@ -218,8 +227,10 @@ def __init__(
218227 dim_out = dim_out or dim
219228 self .grad_checkpointing = False
220229
221- if downsample :
230+ if downsample == 'conv' :
222231 self .downsample = Downsample (dim , dim_out , norm_layer = norm_layer )
232+ elif downsample == 'conv_nf' :
233+ self .downsample = DownsampleNormFirst (dim , dim_out , norm_layer = norm_layer )
223234 else :
224235 assert dim == dim_out
225236 self .downsample = nn .Identity ()
@@ -276,10 +287,10 @@ def __init__(
276287 kernel_size = 7 ,
277288 stem_mid_norm = True ,
278289 ls_init_value = None ,
290+ downsample = 'conv' ,
279291 drop_path_rate = 0. ,
280292 drop_rate = 0. ,
281- output_norm = LayerNorm ,
282- head_fn = MlpHead ,
293+ head_fn = 'default' ,
283294 ** kwargs ,
284295 ):
285296 super ().__init__ ()
@@ -312,7 +323,7 @@ def __init__(
312323 depth = depths [i ],
313324 kernel_size = kernel_size ,
314325 conv_ratio = conv_ratio ,
315- downsample = i > 0 ,
326+ downsample = downsample if i > 0 else '' ,
316327 ls_init_value = ls_init_value ,
317328 norm_layer = norm_layer ,
318329 act_layer = act_layer ,
@@ -322,9 +333,25 @@ def __init__(
322333 prev_dim = dim
323334 cur += depths [i ]
324335
325- self .norm = output_norm (prev_dim )
326-
327- self .head = head_fn (prev_dim , num_classes , drop_rate = drop_rate )
336+ if head_fn == 'default' :
337+ # specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo
338+ self .head = MlpHead (
339+ prev_dim ,
340+ num_classes ,
341+ pool_type = 'avg' ,
342+ drop_rate = drop_rate ,
343+ norm_layer = norm_layer ,
344+ )
345+ else :
346+ # more typical norm -> pool -> fc -> act -> fc
347+ self .head = ClNormMlpClassifierHead (
348+ prev_dim ,
349+ num_classes ,
350+ hidden_size = int (prev_dim * 4 ),
351+ pool_type = 'avg' ,
352+ norm_layer = norm_layer ,
353+ drop_rate = drop_rate ,
354+ )
328355
329356 self .apply (self ._init_weights )
330357
@@ -336,7 +363,7 @@ def _init_weights(self, m):
336363
337364 @torch .jit .ignore
338365 def no_weight_decay (self ):
339- return {'norm' }
366+ return {}
340367
341368 def forward_features (self , x ):
342369 x = self .stem (x )
@@ -345,9 +372,7 @@ def forward_features(self, x):
345372 return x
346373
347374 def forward_head (self , x , pre_logits : bool = False ):
348- x = x .mean ((1 , 2 ))
349- x = self .norm (x )
350- x = self .head (x )
375+ x = self .head (x , pre_logits = pre_logits ) if pre_logits else self .head (x )
351376 return x
352377
353378 def forward (self , x ):
@@ -366,6 +391,10 @@ def checkpoint_filter_fn(state_dict, model):
366391 k = k .replace ('downsample_layers.0.' , 'stem.' )
367392 k = re .sub (r'stages.([0-9]+).([0-9]+)' , r'stages.\1.blocks.\2' , k )
368393 k = re .sub (r'downsample_layers.([0-9]+)' , r'stages.\1.downsample' , k )
394+ if k .startswith ('norm.' ):
395+ k = k .replace ('norm.' , 'head.norm1.' )
396+ elif k .startswith ('head.norm.' ):
397+ k = k .replace ('head.norm.' , 'head.norm2.' )
369398 out_dict [k ] = v
370399
371400 return out_dict
@@ -443,7 +472,9 @@ def mambaout_small_rw(pretrained=False, **kwargs):
443472 depths = [3 , 4 , 27 , 3 ],
444473 dims = [96 , 192 , 384 , 576 ],
445474 stem_mid_norm = False ,
475+ downsample = 'conv_nf' ,
446476 ls_init_value = 1e-6 ,
477+ head_fn = 'norm_mlp' ,
447478 )
448479 return _create_mambaout ('mambaout_small_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
449480
@@ -455,5 +486,6 @@ def mambaout_base_rw(pretrained=False, **kwargs):
455486 dims = (128 , 256 , 512 , 768 ),
456487 stem_mid_norm = False ,
457488 ls_init_value = 1e-6 ,
489+ head_fn = 'norm_mlp' ,
458490 )
459491 return _create_mambaout ('mambaout_base_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments