Skip to content

Commit a2f539f

Browse files
committed
Add a few more test model defs in prep for weight upload
1 parent 6ab2af6 commit a2f539f

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

timm/models/_efficientnet_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ._efficientnet_blocks import *
2121
from ._manipulate import named_modules
2222

23-
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
23+
__all__ = ["EfficientNetBuilder", "BlockArgs", "decode_arch_def", "efficientnet_init_weights",
2424
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
2525

2626
_logger = logging.getLogger(__name__)

timm/models/efficientnet.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
from torch.utils.checkpoint import checkpoint
4545

4646
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
47-
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct, LayerType
47+
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
48+
GroupNormAct, LayerNormAct2d, EvoNorm2dS0
4849
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
4950
from ._efficientnet_blocks import SqueezeExcite
5051
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
@@ -1808,6 +1809,14 @@ def _cfg(url='', **kwargs):
18081809
hf_hub_id='timm/',
18091810
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
18101811
input_size=(3, 160, 160), pool_size=(5, 5)),
1812+
"test_efficientnet_ln.r160_in1k": _cfg(
1813+
#hf_hub_id='timm/',
1814+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
1815+
input_size=(3, 160, 160), pool_size=(5, 5)),
1816+
"test_efficientnet_evos.r160_in1k": _cfg(
1817+
#hf_hub_id='timm/',
1818+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
1819+
input_size=(3, 160, 160), pool_size=(5, 5)),
18111820
})
18121821

18131822

@@ -2802,6 +2811,21 @@ def test_efficientnet_gn(pretrained=False, **kwargs) -> EfficientNet:
28022811
'test_efficientnet_gn', pretrained=pretrained, norm_layer=partial(GroupNormAct, group_size=8), **kwargs)
28032812
return model
28042813

2814+
2815+
@register_model
2816+
def test_efficientnet_ln(pretrained=False, **kwargs) -> EfficientNet:
2817+
model = _gen_test_efficientnet(
2818+
'test_efficientnet_ln', pretrained=pretrained, norm_layer=LayerNormAct2d, **kwargs)
2819+
return model
2820+
2821+
2822+
@register_model
2823+
def test_efficientnet_evos(pretrained=False, **kwargs) -> EfficientNet:
2824+
model = _gen_test_efficientnet(
2825+
'test_efficientnet_evos', pretrained=pretrained, norm_layer=partial(EvoNorm2dS0, group_size=8), **kwargs)
2826+
return model
2827+
2828+
28052829
register_model_deprecations(__name__, {
28062830
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
28072831
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',

timm/models/vision_transformer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,6 +2015,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
20152015
'test_vit.r160_in1k': _cfg(
20162016
hf_hub_id='timm/',
20172017
input_size=(3, 160, 160), crop_pct=0.875),
2018+
'test_vit2.r160_in1k': _cfg(
2019+
#hf_hub_id='timm/',
2020+
input_size=(3, 160, 160), crop_pct=0.875),
2021+
'test_vit3.r160_in1k': _cfg(
2022+
#hf_hub_id='timm/',
2023+
input_size=(3, 160, 160), crop_pct=0.875),
20182024
}
20192025

20202026
_quick_gelu_cfgs = [
@@ -3216,6 +3222,26 @@ def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
32163222
return model
32173223

32183224

3225+
def test_vit2(pretrained: bool = False, **kwargs) -> VisionTransformer:
3226+
""" ViT Test
3227+
"""
3228+
model_args = dict(
3229+
patch_size=16, embed_dim=64, depth=8, num_heads=2, mlp_ratio=3,
3230+
class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5)
3231+
model = _create_vision_transformer('test_vit2', pretrained=pretrained, **dict(model_args, **kwargs))
3232+
return model
3233+
3234+
3235+
def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
3236+
""" ViT Test
3237+
"""
3238+
model_args = dict(
3239+
patch_size=16, embed_dim=96, depth=10, num_heads=3, mlp_ratio=2,
3240+
class_token=False, reg_tokens=1, global_pool='map', init_values=1e-5)
3241+
model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs))
3242+
return model
3243+
3244+
32193245
register_model_deprecations(__name__, {
32203246
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
32213247
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',

0 commit comments

Comments
 (0)