@@ -90,24 +90,26 @@ def input_constructor(input_res):
9090
9191 assert (macs , params ) == (8 , 8 )
9292
93- def test_func_interpolate_args (self ):
93+ @pytest .mark .parametrize ("out_size" , [(20 , 20 ), 20 ])
94+ def test_func_interpolate_args (self , out_size ):
9495 class CustomModel (nn .Module ):
9596 def __init__ (self ):
9697 super ().__init__ ()
9798
9899 def forward (self , x ):
99- return nn .functional .interpolate (input = x , size = ( 20 , 20 ) ,
100+ return nn .functional .interpolate (input = x , size = out_size ,
100101 mode = 'bilinear' , align_corners = False )
101102
102103 macs , params = \
103104 get_model_complexity_info (CustomModel (), (3 , 10 , 10 ),
104105 as_strings = False ,
105106 print_per_layer_stat = False ,
106107 backend = FLOPS_BACKEND .PYTORCH )
108+
107109 assert params == 0
108- assert macs > 0
110+ assert macs == 1200
109111
110- CustomModel .forward = lambda self , x : nn .functional .interpolate (x , size = ( 20 , 20 ) ,
112+ CustomModel .forward = lambda self , x : nn .functional .interpolate (x , out_size ,
111113 mode = 'bilinear' )
112114
113115 macs , params = \
@@ -116,7 +118,18 @@ def forward(self, x):
116118 print_per_layer_stat = False ,
117119 backend = FLOPS_BACKEND .PYTORCH )
118120 assert params == 0
119- assert macs > 0
121+ assert macs == 1200
122+
123+ CustomModel .forward = lambda self , x : nn .functional .interpolate (x , scale_factor = 2 ,
124+ mode = 'bilinear' )
125+
126+ macs , params = \
127+ get_model_complexity_info (CustomModel (), (3 , 10 , 10 ),
128+ as_strings = False ,
129+ print_per_layer_stat = False ,
130+ backend = FLOPS_BACKEND .PYTORCH )
131+ assert params == 0
132+ assert macs == 1200
120133
121134 def test_ten_matmul (self , simple_model_mm ):
122135 macs , params = \
0 commit comments