@@ -11,6 +11,18 @@ class TestOperations:
1111 def default_input_image_size (self ):
1212 return (3 , 224 , 224 )
1313
14+ @pytest .fixture
15+ def simple_model_mm (self ):
16+ class CustomModel (nn .Module ):
17+ def __init__ (self ):
18+ super ().__init__ ()
19+
20+ def forward (self , x ):
21+ return x .matmul (x .t ())
22+
23+ return CustomModel ()
24+
25+
1426 @pytest .mark .parametrize ("backend" , [FLOPS_BACKEND .PYTORCH , FLOPS_BACKEND .ATEN ])
1527 def test_conv (self , default_input_image_size , backend : FLOPS_BACKEND ):
1628 net = nn .Sequential (nn .Conv2d (3 , 2 , 3 , bias = True ))
@@ -107,54 +119,33 @@ def forward(self, x):
107119 assert params == 0
108120 assert macs > 0
109121
110- def test_ten_matmul (self ):
111- class CustomModel (nn .Module ):
112- def __init__ (self ):
113- super ().__init__ ()
114-
115- def forward (self , x ):
116- return x .matmul (x .t ())
117-
122+ def test_ten_matmul (self , simple_model_mm ):
118123 macs , params = \
119- get_model_complexity_info (CustomModel () , (10 , ),
124+ get_model_complexity_info (simple_model_mm , (10 , ),
120125 as_strings = False ,
121126 print_per_layer_stat = False ,
122127 backend = FLOPS_BACKEND .PYTORCH )
123128
124129 assert params == 0
125130 assert macs > 0
126131
127- def test_aten_ignore (self ):
128- class CustomModel (nn .Module ):
129- def __init__ (self ):
130- super ().__init__ ()
131-
132- def forward (self , x ):
133- return x .matmul (x .t ())
134-
132+ def test_aten_ignore (self , simple_model_mm ):
135133 ignored_list = [torch .ops .aten .matmul , torch .ops .aten .mm ]
136134 macs , params = \
137- get_model_complexity_info (CustomModel () , (10 , ), backend = FLOPS_BACKEND .ATEN ,
135+ get_model_complexity_info (simple_model_mm , (10 , ), backend = FLOPS_BACKEND .ATEN ,
138136 as_strings = False ,
139137 print_per_layer_stat = False ,
140138 ignore_modules = ignored_list )
141139
142140 assert params == 0
143141 assert macs == 0
144142
145- def test_aten_custom (self ):
146- class CustomModel (nn .Module ):
147- def __init__ (self ):
148- super ().__init__ ()
149-
150- def forward (self , x ):
151- return x .matmul (x .t ())
152-
143+ def test_aten_custom (self , simple_model_mm ):
153144 reference = 42
154145 custom_hooks = {torch .ops .aten .mm : lambda inputs , outputs : reference }
155146
156147 macs , params = \
157- get_model_complexity_info (CustomModel () , (10 , ), backend = FLOPS_BACKEND .ATEN ,
148+ get_model_complexity_info (simple_model_mm , (10 , ), backend = FLOPS_BACKEND .ATEN ,
158149 as_strings = False ,
159150 print_per_layer_stat = False ,
160151 custom_modules_hooks = custom_hooks )
0 commit comments