4545
4646from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
4747from timm .layers import trunc_normal_ , AvgPool2dSame , DropPath , Mlp , GlobalResponseNormMlp , \
48- LayerNorm2d , LayerNorm , create_conv2d , get_act_layer , make_divisible , to_ntuple
48+ LayerNorm2d , LayerNorm , RmsNorm2d , RmsNorm , create_conv2d , get_act_layer , get_norm_layer , make_divisible , to_ntuple
4949from timm .layers import NormMlpClassifierHead , ClassifierHead
5050from ._builder import build_model_with_cfg
5151from ._features import feature_take_indices
@@ -289,24 +289,27 @@ def __init__(
289289 super ().__init__ ()
290290 assert output_stride in (8 , 16 , 32 )
291291 kernel_sizes = to_ntuple (4 )(kernel_sizes )
292- if norm_layer is None :
293- norm_layer = LayerNorm2d
294- norm_layer_cl = norm_layer if conv_mlp else LayerNorm
292+ use_rms = isinstance (norm_layer , str ) and norm_layer .startswith ('rmsnorm' )
293+ if norm_layer is None or use_rms :
294+ norm_layer = RmsNorm2d if use_rms else LayerNorm2d
295+ norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm )
295296 if norm_eps is not None :
296297 norm_layer = partial (norm_layer , eps = norm_eps )
297298 norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
298299 else :
299300 assert conv_mlp ,\
300301 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
302+ norm_layer = get_norm_layer (norm_layer )
301303 norm_layer_cl = norm_layer
302304 if norm_eps is not None :
303305 norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
306+ act_layer = get_act_layer (act_layer )
304307
305308 self .num_classes = num_classes
306309 self .drop_rate = drop_rate
307310 self .feature_info = []
308311
309- assert stem_type in ('patch' , 'overlap' , 'overlap_tiered' )
312+ assert stem_type in ('patch' , 'overlap' , 'overlap_tiered' , 'overlap_act' )
310313 if stem_type == 'patch' :
311314 # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312315 self .stem = nn .Sequential (
@@ -316,11 +319,12 @@ def __init__(
316319 stem_stride = patch_size
317320 else :
318321 mid_chs = make_divisible (dims [0 ] // 2 ) if 'tiered' in stem_type else dims [0 ]
319- self .stem = nn .Sequential (
322+ self .stem = nn .Sequential (* filter ( None , [
320323 nn .Conv2d (in_chans , mid_chs , kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
324+ act_layer () if 'act' in stem_type else None ,
321325 nn .Conv2d (mid_chs , dims [0 ], kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
322326 norm_layer (dims [0 ]),
323- )
327+ ]) )
324328 stem_stride = 4
325329
326330 self .stages = nn .Sequential ()
@@ -592,6 +596,14 @@ def _cfgv2(url='', **kwargs):
592596 hf_hub_id = 'timm/' ,
593597 crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
594598
599+ 'convnext_zepto_rms.untrained' : _cfg (
600+ #hf_hub_id='timm/',
601+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
602+ test_input_size = (3 , 256 , 256 ), test_crop_pct = 0.95 ),
603+ 'convnext_zepto_rms_ols.untrained' : _cfg (
604+ # hf_hub_id='timm/',
605+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
606+ test_input_size = (3 , 256 , 256 ), test_crop_pct = 0.95 ),
595607 'convnext_atto.d2_in1k' : _cfg (
596608 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth' ,
597609 hf_hub_id = 'timm/' ,
@@ -600,6 +612,9 @@ def _cfgv2(url='', **kwargs):
600612 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth' ,
601613 hf_hub_id = 'timm/' ,
602614 test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
615+ 'convnext_atto_rms.untrained' : _cfg (
616+ #hf_hub_id='timm/',
617+ test_input_size = (3 , 256 , 256 ), test_crop_pct = 0.95 ),
603618 'convnext_femto.d1_in1k' : _cfg (
604619 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth' ,
605620 hf_hub_id = 'timm/' ,
@@ -968,6 +983,23 @@ def _cfgv2(url='', **kwargs):
968983})
969984
970985
986+ @register_model
987+ def convnext_zepto_rms (pretrained = False , ** kwargs ) -> ConvNeXt :
988+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
989+ model_args = dict (depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'rmsnorm2d' )
990+ model = _create_convnext ('convnext_zepto_rms' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
991+ return model
992+
993+
994+ @register_model
995+ def convnext_zepto_rms_ols (pretrained = False , ** kwargs ) -> ConvNeXt :
996+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
997+ model_args = dict (
998+ depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'rmsnorm2d' , stem_type = 'overlap_act' )
999+ model = _create_convnext ('convnext_zepto_rms_oas' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1000+ return model
1001+
1002+
9711003@register_model
9721004def convnext_atto (pretrained = False , ** kwargs ) -> ConvNeXt :
9731005 # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
@@ -984,6 +1016,14 @@ def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
9841016 return model
9851017
9861018
1019+ @register_model
1020+ def convnext_atto_rms (pretrained = False , ** kwargs ) -> ConvNeXt :
1021+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
1022+ model_args = dict (depths = (2 , 2 , 6 , 2 ), dims = (40 , 80 , 160 , 320 ), conv_mlp = True , norm_layer = 'rmsnorm2d' )
1023+ model = _create_convnext ('convnext_atto_rms' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1024+ return model
1025+
1026+
9871027@register_model
9881028def convnext_femto (pretrained = False , ** kwargs ) -> ConvNeXt :
9891029 # timm femto variant
0 commit comments