1616import torch .nn .functional as F
1717
1818from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
19- from timm .layers import DropBlock2d , DropPath , AvgPool2dSame , BlurPool2d , GroupNorm , LayerType , create_attn , \
20- get_attn , get_act_layer , get_norm_layer , create_classifier , create_aa
19+ from timm .layers import DropBlock2d , DropPath , AvgPool2dSame , BlurPool2d , LayerType , create_attn , \
20+ get_attn , get_act_layer , get_norm_layer , create_classifier , create_aa , to_ntuple
2121from ._builder import build_model_with_cfg
2222from ._features import feature_take_indices
2323from ._manipulate import checkpoint_seq
@@ -286,7 +286,7 @@ def drop_blocks(drop_prob: float = 0.):
286286
287287
288288def make_blocks (
289- block_fn : Union [BasicBlock , Bottleneck ],
289+ block_fns : Tuple [ Union [BasicBlock , Bottleneck ] ],
290290 channels : Tuple [int , ...],
291291 block_repeats : Tuple [int , ...],
292292 inplanes : int ,
@@ -304,7 +304,7 @@ def make_blocks(
304304 net_block_idx = 0
305305 net_stride = 4
306306 dilation = prev_dilation = 1
307- for stage_idx , (planes , num_blocks , db ) in enumerate (zip (channels , block_repeats , drop_blocks (drop_block_rate ))):
307+ for stage_idx , (block_fn , planes , num_blocks , db ) in enumerate (zip (block_fns , channels , block_repeats , drop_blocks (drop_block_rate ))):
308308 stage_name = f'layer{ stage_idx + 1 } ' # never liked this name, but weight compat requires it
309309 stride = 1 if stage_idx == 0 else 2
310310 if net_stride >= output_stride :
@@ -490,8 +490,9 @@ def __init__(
490490 self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
491491
492492 # Feature Blocks
493+ block_fns = to_ntuple (len (channels ))(block )
493494 stage_modules , stage_feature_info = make_blocks (
494- block ,
495+ block_fns ,
495496 channels ,
496497 layers ,
497498 inplanes ,
@@ -513,7 +514,7 @@ def __init__(
513514 self .feature_info .extend (stage_feature_info )
514515
515516 # Head (Pooling and Classifier)
516- self .num_features = self .head_hidden_size = channels [- 1 ] * block .expansion
517+ self .num_features = self .head_hidden_size = channels [- 1 ] * block_fns [ - 1 ] .expansion
517518 self .global_pool , self .fc = create_classifier (self .num_features , self .num_classes , pool_type = global_pool )
518519
519520 self .init_weights (zero_init_last = zero_init_last )
@@ -1301,6 +1302,11 @@ def _gcfg(url='', **kwargs):
13011302 hf_hub_id = 'timm/' ,
13021303 url = 'https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth' ,
13031304 first_conv = 'conv1.0' ),
1305+
1306+ 'test_resnet.r160_in1k' : _cfg (
1307+ #hf_hub_id='timm/',
1308+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
1309+ input_size = (3 , 160 , 160 ), pool_size = (5 , 5 ), first_conv = 'conv1.0' ),
13041310})
13051311
13061312
@@ -2040,6 +2046,16 @@ def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet:
20402046 return _create_resnet ('resnetrs420' , pretrained , ** dict (model_args , ** kwargs ))
20412047
20422048
2049+ @register_model
2050+ def test_resnet (pretrained : bool = False , ** kwargs ) -> ResNet :
2051+ """Constructs a tiny ResNet test model.
2052+ """
2053+ model_args = dict (
2054+ block = [BasicBlock , BasicBlock , Bottleneck , BasicBlock ], layers = (1 , 1 , 1 , 1 ),
2055+ stem_width = 16 , stem_type = 'deep' , avg_down = True , channels = (32 , 48 , 48 , 96 ))
2056+ return _create_resnet ('test_resnet' , pretrained , ** dict (model_args , ** kwargs ))
2057+
2058+
20432059register_model_deprecations (__name__ , {
20442060 'tv_resnet34' : 'resnet34.tv_in1k' ,
20452061 'tv_resnet50' : 'resnet50.tv_in1k' ,
0 commit comments