Skip to content

Commit d764a5e

Browse files
committed
Update tests
1 parent 6d7445c commit d764a5e

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

tests/common_test.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,24 +90,26 @@ def input_constructor(input_res):
9090

9191
assert (macs, params) == (8, 8)
9292

93-
def test_func_interpolate_args(self):
93+
@pytest.mark.parametrize("out_size", [(20, 20), 20])
94+
def test_func_interpolate_args(self, out_size):
9495
class CustomModel(nn.Module):
9596
def __init__(self):
9697
super().__init__()
9798

9899
def forward(self, x):
99-
return nn.functional.interpolate(input=x, size=(20, 20),
100+
return nn.functional.interpolate(input=x, size=out_size,
100101
mode='bilinear', align_corners=False)
101102

102103
macs, params = \
103104
get_model_complexity_info(CustomModel(), (3, 10, 10),
104105
as_strings=False,
105106
print_per_layer_stat=False,
106107
backend=FLOPS_BACKEND.PYTORCH)
108+
107109
assert params == 0
108-
assert macs > 0
110+
assert macs == 1200
109111

110-
CustomModel.forward = lambda self, x: nn.functional.interpolate(x, size=(20, 20),
112+
CustomModel.forward = lambda self, x: nn.functional.interpolate(x, out_size,
111113
mode='bilinear')
112114

113115
macs, params = \
@@ -116,7 +118,18 @@ def forward(self, x):
116118
print_per_layer_stat=False,
117119
backend=FLOPS_BACKEND.PYTORCH)
118120
assert params == 0
119-
assert macs > 0
121+
assert macs == 1200
122+
123+
CustomModel.forward = lambda self, x: nn.functional.interpolate(x, scale_factor=2,
124+
mode='bilinear')
125+
126+
macs, params = \
127+
get_model_complexity_info(CustomModel(), (3, 10, 10),
128+
as_strings=False,
129+
print_per_layer_stat=False,
130+
backend=FLOPS_BACKEND.PYTORCH)
131+
assert params == 0
132+
assert macs == 1200
120133

121134
def test_ten_matmul(self, simple_model_mm):
122135
macs, params = \

0 commit comments

Comments
 (0)