Skip to content

Commit 9b77567

Browse files
committed
Add tests for aten ignore modules
1 parent f7506d6 commit 9b77567

File tree

1 file changed

+48
-5
lines changed

1 file changed

+48
-5
lines changed

tests/common_test.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def input_constructor(input_res):
5353
macs, params = get_model_complexity_info(net, (3,),
5454
input_constructor=input_constructor,
5555
as_strings=False,
56-
print_per_layer_stat=False)
56+
print_per_layer_stat=False,
57+
backend=FLOPS_BACKEND.PYTORCH)
5758

5859
assert (macs, params) == (8, 8)
5960

@@ -73,7 +74,8 @@ def input_constructor(input_res):
7374
get_model_complexity_info(CustomLinear(), (3,),
7475
input_constructor=input_constructor,
7576
as_strings=False,
76-
print_per_layer_stat=False)
77+
print_per_layer_stat=False,
78+
backend=FLOPS_BACKEND.PYTORCH)
7779

7880
assert (macs, params) == (8, 8)
7981

@@ -89,7 +91,8 @@ def forward(self, x):
8991
macs, params = \
9092
get_model_complexity_info(CustomModel(), (3, 10, 10),
9193
as_strings=False,
92-
print_per_layer_stat=False)
94+
print_per_layer_stat=False,
95+
backend=FLOPS_BACKEND.PYTORCH)
9396
assert params == 0
9497
assert macs > 0
9598

@@ -99,7 +102,8 @@ def forward(self, x):
99102
macs, params = \
100103
get_model_complexity_info(CustomModel(), (3, 10, 10),
101104
as_strings=False,
102-
print_per_layer_stat=False)
105+
print_per_layer_stat=False,
106+
backend=FLOPS_BACKEND.PYTORCH)
103107
assert params == 0
104108
assert macs > 0
105109

@@ -114,7 +118,46 @@ def forward(self, x):
114118
macs, params = \
115119
get_model_complexity_info(CustomModel(), (10, ),
116120
as_strings=False,
117-
print_per_layer_stat=False)
121+
print_per_layer_stat=False,
122+
backend=FLOPS_BACKEND.PYTORCH)
118123

119124
assert params == 0
120125
assert macs > 0
126+
127+
def test_aten_ignore(self):
128+
class CustomModel(nn.Module):
129+
def __init__(self):
130+
super().__init__()
131+
132+
def forward(self, x):
133+
return x.matmul(x.t())
134+
135+
ignored_list = [torch.ops.aten.matmul, torch.ops.aten.mm]
136+
macs, params = \
137+
get_model_complexity_info(CustomModel(), (10, ), backend=FLOPS_BACKEND.ATEN,
138+
as_strings=False,
139+
print_per_layer_stat=False,
140+
ignore_modules=ignored_list)
141+
142+
assert params == 0
143+
assert macs == 0
144+
145+
def test_aten_custom(self):
146+
class CustomModel(nn.Module):
147+
def __init__(self):
148+
super().__init__()
149+
150+
def forward(self, x):
151+
return x.matmul(x.t())
152+
153+
reference = 42
154+
custom_hooks = {torch.ops.aten.mm: lambda inputs, outputs: reference}
155+
156+
macs, params = \
157+
get_model_complexity_info(CustomModel(), (10, ), backend=FLOPS_BACKEND.ATEN,
158+
as_strings=False,
159+
print_per_layer_stat=False,
160+
custom_modules_hooks=custom_hooks)
161+
162+
assert params == 0
163+
assert macs == reference

0 commit comments

Comments
 (0)