@@ -500,7 +500,7 @@ def prune_intermediate_layers(
500500 take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
501501 self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
502502 if prune_head :
503- self .head .reset (0 , reset_other = True )
503+ self .head .reset (0 , reset_other = prune_norm )
504504 return take_indices
505505
506506 def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
@@ -556,6 +556,10 @@ def _cfg(url='', **kwargs):
556556 min_input_size = (3 , 256 , 256 ),
557557 input_size = (3 , 1024 , 1024 ), pool_size = (32 , 32 ),
558558 ),
559+ "hieradet_small.untrained" : _cfg (
560+ num_classes = 1000 ,
561+ input_size = (3 , 256 , 256 ), pool_size = (8 , 8 ),
562+ ),
559563})
560564
561565
@@ -604,12 +608,6 @@ def sam2_hiera_small(pretrained=False, **kwargs):
604608 return _create_hiera_det ('sam2_hiera_small' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
605609
606610
607- # @register_model
608- # def sam2_hiera_base(pretrained=False, **kwargs):
609- # model_args = dict()
610- # return _create_hiera_det('sam2_hiera_base', pretrained=pretrained, **dict(model_args, **kwargs))
611-
612-
613611@register_model
614612def sam2_hiera_base_plus (pretrained = False , ** kwargs ):
615613 model_args = dict (embed_dim = 112 , num_heads = 2 , global_pos_size = (14 , 14 ))
@@ -626,3 +624,15 @@ def sam2_hiera_large(pretrained=False, **kwargs):
626624 window_spec = (8 , 4 , 16 , 8 ),
627625 )
628626 return _create_hiera_det ('sam2_hiera_large' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
627+
628+
629+ @register_model
630+ def hieradet_small (pretrained = False , ** kwargs ):
631+ model_args = dict (stages = (1 , 2 , 11 , 2 ), global_att_blocks = (7 , 10 , 13 ), window_spec = (8 , 4 , 16 , 8 ))
632+ return _create_hiera_det ('hieradet_small' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
633+
634+
635+ # @register_model
636+ # def hieradet_base(pretrained=False, **kwargs):
637+ # model_args = dict(window_spec=(8, 4, 16, 8))
638+ # return _create_hiera_det('hieradet_base', pretrained=pretrained, **dict(model_args, **kwargs))
0 commit comments