Skip to content

Commit b9340a3

Browse files
committed
Refactor tests
1 parent 9b77567 commit b9340a3

File tree

1 file changed

+18
-27
lines changed

1 file changed

+18
-27
lines changed

tests/common_test.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@ class TestOperations:
1111
def default_input_image_size(self):
1212
return (3, 224, 224)
1313

14+
@pytest.fixture
15+
def simple_model_mm(self):
16+
class CustomModel(nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
def forward(self, x):
21+
return x.matmul(x.t())
22+
23+
return CustomModel()
24+
25+
1426
@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
1527
def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND):
1628
net = nn.Sequential(nn.Conv2d(3, 2, 3, bias=True))
@@ -107,54 +119,33 @@ def forward(self, x):
107119
assert params == 0
108120
assert macs > 0
109121

110-
def test_ten_matmul(self):
111-
class CustomModel(nn.Module):
112-
def __init__(self):
113-
super().__init__()
114-
115-
def forward(self, x):
116-
return x.matmul(x.t())
117-
122+
def test_ten_matmul(self, simple_model_mm):
118123
macs, params = \
119-
get_model_complexity_info(CustomModel(), (10, ),
124+
get_model_complexity_info(simple_model_mm, (10, ),
120125
as_strings=False,
121126
print_per_layer_stat=False,
122127
backend=FLOPS_BACKEND.PYTORCH)
123128

124129
assert params == 0
125130
assert macs > 0
126131

127-
def test_aten_ignore(self):
128-
class CustomModel(nn.Module):
129-
def __init__(self):
130-
super().__init__()
131-
132-
def forward(self, x):
133-
return x.matmul(x.t())
134-
132+
def test_aten_ignore(self, simple_model_mm):
135133
ignored_list = [torch.ops.aten.matmul, torch.ops.aten.mm]
136134
macs, params = \
137-
get_model_complexity_info(CustomModel(), (10, ), backend=FLOPS_BACKEND.ATEN,
135+
get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN,
138136
as_strings=False,
139137
print_per_layer_stat=False,
140138
ignore_modules=ignored_list)
141139

142140
assert params == 0
143141
assert macs == 0
144142

145-
def test_aten_custom(self):
146-
class CustomModel(nn.Module):
147-
def __init__(self):
148-
super().__init__()
149-
150-
def forward(self, x):
151-
return x.matmul(x.t())
152-
143+
def test_aten_custom(self, simple_model_mm):
153144
reference = 42
154145
custom_hooks = {torch.ops.aten.mm: lambda inputs, outputs: reference}
155146

156147
macs, params = \
157-
get_model_complexity_info(CustomModel(), (10, ), backend=FLOPS_BACKEND.ATEN,
148+
get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN,
158149
as_strings=False,
159150
print_per_layer_stat=False,
160151
custom_modules_hooks=custom_hooks)

0 commit comments

Comments
 (0)