@@ -53,7 +53,8 @@ def input_constructor(input_res):
5353 macs , params = get_model_complexity_info (net , (3 ,),
5454 input_constructor = input_constructor ,
5555 as_strings = False ,
56- print_per_layer_stat = False )
56+ print_per_layer_stat = False ,
57+ backend = FLOPS_BACKEND .PYTORCH )
5758
5859 assert (macs , params ) == (8 , 8 )
5960
@@ -73,7 +74,8 @@ def input_constructor(input_res):
7374 get_model_complexity_info (CustomLinear (), (3 ,),
7475 input_constructor = input_constructor ,
7576 as_strings = False ,
76- print_per_layer_stat = False )
77+ print_per_layer_stat = False ,
78+ backend = FLOPS_BACKEND .PYTORCH )
7779
7880 assert (macs , params ) == (8 , 8 )
7981
@@ -89,7 +91,8 @@ def forward(self, x):
8991 macs , params = \
9092 get_model_complexity_info (CustomModel (), (3 , 10 , 10 ),
9193 as_strings = False ,
92- print_per_layer_stat = False )
94+ print_per_layer_stat = False ,
95+ backend = FLOPS_BACKEND .PYTORCH )
9396 assert params == 0
9497 assert macs > 0
9598
@@ -99,7 +102,8 @@ def forward(self, x):
99102 macs , params = \
100103 get_model_complexity_info (CustomModel (), (3 , 10 , 10 ),
101104 as_strings = False ,
102- print_per_layer_stat = False )
105+ print_per_layer_stat = False ,
106+ backend = FLOPS_BACKEND .PYTORCH )
103107 assert params == 0
104108 assert macs > 0
105109
@@ -114,7 +118,46 @@ def forward(self, x):
114118 macs , params = \
115119 get_model_complexity_info (CustomModel (), (10 , ),
116120 as_strings = False ,
117- print_per_layer_stat = False )
121+ print_per_layer_stat = False ,
122+ backend = FLOPS_BACKEND .PYTORCH )
118123
119124 assert params == 0
120125 assert macs > 0
126+
127+ def test_aten_ignore (self ):
128+ class CustomModel (nn .Module ):
129+ def __init__ (self ):
130+ super ().__init__ ()
131+
132+ def forward (self , x ):
133+ return x .matmul (x .t ())
134+
135+ ignored_list = [torch .ops .aten .matmul , torch .ops .aten .mm ]
136+ macs , params = \
137+ get_model_complexity_info (CustomModel (), (10 , ), backend = FLOPS_BACKEND .ATEN ,
138+ as_strings = False ,
139+ print_per_layer_stat = False ,
140+ ignore_modules = ignored_list )
141+
142+ assert params == 0
143+ assert macs == 0
144+
145+ def test_aten_custom (self ):
146+ class CustomModel (nn .Module ):
147+ def __init__ (self ):
148+ super ().__init__ ()
149+
150+ def forward (self , x ):
151+ return x .matmul (x .t ())
152+
153+ reference = 42
154+ custom_hooks = {torch .ops .aten .mm : lambda inputs , outputs : reference }
155+
156+ macs , params = \
157+ get_model_complexity_info (CustomModel (), (10 , ), backend = FLOPS_BACKEND .ATEN ,
158+ as_strings = False ,
159+ print_per_layer_stat = False ,
160+ custom_modules_hooks = custom_hooks )
161+
162+ assert params == 0
163+ assert macs == reference
0 commit comments