Skip to content

Commit a4a284b

Browse files
committed
Fixed a bug relating to constraints checking, and added and updated tests to detect this in the future
1 parent 30f8568 commit a4a284b

File tree

4 files changed

+33
-21
lines changed

4 files changed

+33
-21
lines changed

kernel_tuner/runners/simulation.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" The simulation runner for sequentially tuning the parameter space based on cached data """
1+
"""The simulation runner for sequentially tuning the parameter space based on cached data."""
22
import logging
33
from collections import namedtuple
44
from time import perf_counter
@@ -10,7 +10,7 @@
1010

1111

1212
class SimulationDevice(_SimulationDevice):
13-
""" Simulated device used by simulation runner """
13+
"""Simulated device used by simulation runner."""
1414

1515
@property
1616
def name(self):
@@ -27,10 +27,10 @@ def get_environment(self):
2727

2828

2929
class SimulationRunner(Runner):
30-
""" SimulationRunner is used for tuning with a single process/thread """
30+
"""SimulationRunner is used for tuning with a single process/thread."""
3131

3232
def __init__(self, kernel_source, kernel_options, device_options, iterations, observers):
33-
""" Instantiate the SimulationRunner
33+
"""Instantiate the SimulationRunner.
3434
3535
:param kernel_source: The kernel source
3636
:type kernel_source: kernel_tuner.core.KernelSource
@@ -46,7 +46,6 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob
4646
each kernel instance.
4747
:type iterations: int
4848
"""
49-
5049
self.quiet = device_options.quiet
5150
self.dev = SimulationDevice(1024, dict(device_name="Simulation"), self.quiet)
5251

@@ -66,7 +65,7 @@ def get_environment(self, tuning_options):
6665
return env
6766

6867
def run(self, parameter_space, tuning_options):
69-
""" Iterate through the entire parameter space using a single Python process
68+
"""Iterate through the entire parameter space using a single Python process.
7069
7170
:param parameter_space: The parameter space as an iterable.
7271
:type parameter_space: iterable
@@ -78,7 +77,6 @@ def run(self, parameter_space, tuning_options):
7877
:returns: A list of dictionaries for executed kernel configurations and their
7978
execution times.
8079
:rtype: dict()
81-
8280
"""
8381
logging.debug('simulation runner started for ' + self.kernel_options.kernel_name)
8482

kernel_tuner/util.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def check_block_size_params_names_list(block_size_names, tune_params):
249249
)
250250

251251

252-
def check_restrictions(restrictions, params: dict, verbose: bool):
252+
def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
253253
"""Check whether a specific instance meets the search space restrictions."""
254254
valid = True
255255
if callable(restrictions):
@@ -263,14 +263,19 @@ def check_restrictions(restrictions, params: dict, verbose: bool):
263263
if not restrict(params.values()):
264264
valid = False
265265
break
266+
continue
266267
# if it's a string, fill in the parameters and evaluate
267-
elif isinstance(restrict, str) and not eval(replace_param_occurrences(restrict, params)):
268-
valid = False
269-
break
268+
elif isinstance(restrict, str):
269+
if not eval(replace_param_occurrences(restrict, params)):
270+
valid = False
271+
break
272+
continue
270273
# if it's a function, call it
271-
elif callable(restrict) and not restrict(params):
272-
valid = False
273-
break
274+
elif callable(restrict):
275+
if not restrict(**params):
276+
valid = False
277+
break
278+
continue
274279
# if it's a tuple, use only the parameters in the second argument to call the restriction
275280
elif (isinstance(restrict, tuple) and len(restrict) == 2
276281
and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))):
@@ -282,6 +287,7 @@ def check_restrictions(restrictions, params: dict, verbose: bool):
282287
if not restrict(**selected_params):
283288
valid = False
284289
break
290+
continue
285291
# otherwise, raise an error
286292
else:
287293
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")

test/test_searchspace.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from unittest.mock import patch
1010

1111
import numpy as np
12-
from constraint import ExactSumConstraint, FunctionConstraint
12+
from constraint import ExactSumConstraint
1313

1414
from kernel_tuner.interface import Options
1515
from kernel_tuner.searchspace import Searchspace
@@ -37,13 +37,13 @@
3737

3838
# each GPU must have at least one layer and the sum of all layers must not exceed the total number of layers
3939

40-
4140
def _min_func(gpu1, gpu2, gpu3, gpu4):
4241
return min([gpu1, gpu2, gpu3, gpu4]) >= 1
4342

4443

45-
# test three different types of restrictions: python-constraint, a function and a string
46-
restrict = [ExactSumConstraint(num_layers), FunctionConstraint(_min_func)]
44+
# test two different types of restrictions: a constraint and a callable
45+
assert callable(_min_func)
46+
restrict = [ExactSumConstraint(num_layers), _min_func]
4747

4848
# create the searchspace object
4949
searchspace = Searchspace(tune_params, restrict, max_threads)
@@ -79,6 +79,17 @@ def test_internal_representation():
7979
for index, dict_config in enumerate(searchspace.get_list_dict().keys()):
8080
assert dict_config == searchspace.list[index]
8181

82+
def test_check_restrictions():
83+
"""Test whether the outcome of restrictions is as expected when using check_restrictions."""
84+
from kernel_tuner.util import check_restrictions
85+
86+
param_config_false = {'x': 1, 'y': 4, 'z': "string_1" }
87+
param_config_true = {'x': 3, 'y': 4, 'z': "string_1" }
88+
89+
assert check_restrictions(simple_searchspace.restrictions, param_config_false, verbose=False) is False
90+
assert check_restrictions(simple_searchspace.restrictions, param_config_true, verbose=False) is True
91+
92+
8293
def test_against_bruteforce():
8394
"""Tests the default Searchspace framework against bruteforcing the searchspace."""
8495
compare_two_searchspace_objects(simple_searchspace, simple_searchspace_bruteforce)

test/test_util_functions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,6 @@ def test_replace_param_occurrences():
226226

227227
def test_check_restrictions():
228228
params = {"a": 7, "b": 4, "c": 3}
229-
print(params.values())
230-
print(params.keys())
231229
restrictions = [
232230
["a==b+c"],
233231
["a==b+c", "b==b", "a-b==c"],
@@ -238,7 +236,6 @@ def test_check_restrictions():
238236
# test the call returns expected
239237
for r, e in zip(restrictions, expected):
240238
answer = check_restrictions(r, dict(zip(params.keys(), params.values())), False)
241-
print(answer)
242239
assert answer == e
243240

244241

0 commit comments

Comments
 (0)