2121from typing import Callable , List , Tuple , Union
2222
2323
24- from timm .models import is_model , list_models , get_pretrained_cfg
24+ from timm .models import is_model , list_models , get_pretrained_cfg , get_arch_pretrained_cfgs
2525
2626
2727parser = argparse .ArgumentParser (description = 'Per-model process launcher' )
@@ -98,23 +98,44 @@ def _get_model_cfgs(
9898 num_classes = None ,
9999 expand_train_test = False ,
100100 include_crop = True ,
101+ expand_arch = False ,
101102):
102- model_cfgs = []
103- for n in model_names :
104- pt_cfg = get_pretrained_cfg (n )
105- if num_classes is not None and getattr (pt_cfg , 'num_classes' , 0 ) != num_classes :
106- continue
107- model_cfgs .append ((n , pt_cfg .input_size [- 1 ], pt_cfg .crop_pct ))
108- if expand_train_test and pt_cfg .test_input_size is not None :
109- if pt_cfg .test_crop_pct is not None :
110- model_cfgs .append ((n , pt_cfg .test_input_size [- 1 ], pt_cfg .test_crop_pct ))
103+ model_cfgs = set ()
104+
105+ for name in model_names :
106+ if expand_arch :
107+ pt_cfgs = get_arch_pretrained_cfgs (name ).values ()
108+ else :
109+ pt_cfg = get_pretrained_cfg (name )
110+ pt_cfgs = [pt_cfg ] if pt_cfg is not None else []
111+
112+ for cfg in pt_cfgs :
113+ if cfg .input_size is None :
114+ continue
115+ if num_classes is not None and getattr (cfg , 'num_classes' , 0 ) != num_classes :
116+ continue
117+
118+ # Add main configuration
119+ size = cfg .input_size [- 1 ]
120+ if include_crop :
121+ model_cfgs .add ((name , size , cfg .crop_pct ))
111122 else :
112- model_cfgs .append ((n , pt_cfg .test_input_size [- 1 ], pt_cfg .crop_pct ))
123+ model_cfgs .add ((name , size ))
124+
125+ # Add test configuration if required
126+ if expand_train_test and cfg .test_input_size is not None :
127+ test_size = cfg .test_input_size [- 1 ]
128+ if include_crop :
129+ test_crop = cfg .test_crop_pct or cfg .crop_pct
130+ model_cfgs .add ((name , test_size , test_crop ))
131+ else :
132+ model_cfgs .add ((name , test_size ))
133+
134+ # Format the output
113135 if include_crop :
114- model_cfgs = [(n , {'img-size' : r , 'crop-pct' : cp }) for n , r , cp in sorted (model_cfgs )]
136+ return [(n , {'img-size' : r , 'crop-pct' : cp }) for n , r , cp in sorted (model_cfgs )]
115137 else :
116- model_cfgs = [(n , {'img-size' : r }) for n , r , cp in sorted (model_cfgs )]
117- return model_cfgs
138+ return [(n , {'img-size' : r }) for n , r in sorted (model_cfgs )]
118139
119140
120141def main ():
@@ -132,17 +153,16 @@ def main():
132153 model_cfgs = _get_model_cfgs (model_names , num_classes = 1000 , expand_train_test = True )
133154 elif args .model_list == 'all_res' :
134155 model_names = list_models ()
135- model_cfgs = _get_model_cfgs (model_names , expand_train_test = True , include_crop = False )
156+ model_cfgs = _get_model_cfgs (model_names , expand_train_test = True , include_crop = False , expand_arch = True )
136157 elif not is_model (args .model_list ):
137158 # model name doesn't exist, try as wildcard filter
138159 model_names = list_models (args .model_list )
139160 model_cfgs = [(n , None ) for n in model_names ]
140161
141162 if not model_cfgs and os .path .exists (args .model_list ):
142163 with open (args .model_list ) as f :
143- model_cfgs = []
144164 model_names = [line .rstrip () for line in f ]
145- _get_model_cfgs (
165+ model_cfgs = _get_model_cfgs (
146166 model_names ,
147167 #num_classes=1000,
148168 expand_train_test = True ,
0 commit comments