Skip to content

Commit c27081f

Browse files
committed
Fixed an issue where single-value parameters could cause errors in bayes_opt, extended tests to test for this
1 parent a4a284b commit c27081f

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

kernel_tuner/strategies/bayes_opt.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,13 @@ def find_param_config_unvisited_index(self, param_config: tuple) -> int:
396396
return self.unvisited_cache.index(param_config)
397397

398398
def normalize_param_config(self, param_config: tuple) -> tuple:
399-
"""Normalizes a parameter configuration."""
400-
normalized = tuple(
401-
self.normalized_dict[self.param_names[index]][param_value] for index, param_value in enumerate(param_config)
402-
)
403-
return normalized
399+
"""Normalizes a parameter configuration. Skips over pruned values."""
400+
param_config = self.unprune_param_config(param_config)
401+
normalized = list()
402+
for index, param_value in enumerate(param_config):
403+
if self.removed_tune_params[index] is None:
404+
normalized.append(self.normalized_dict[self.param_names[index]][param_value])
405+
return tuple(normalized)
404406

405407
def denormalize_param_config(self, param_config: tuple) -> tuple:
406408
"""Denormalizes a parameter configuration."""

test/strategies/test_strategies.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,30 @@ def test_strategies(vector_add, strategy):
5151

5252
assert len(results) > 0
5353

54+
# check if the number of valid unique configurations is less than or equal to max_fevals
5455
if not strategy == "brute_force":
55-
# check if the number of valid unique configurations is less then max_fevals
56-
5756
tune_params = vector_add[-1]
5857
unique_results = {}
59-
6058
for result in results:
6159
x_int = ",".join([str(v) for k, v in result.items() if k in tune_params])
6260
if not isinstance(result["time"], util.InvalidConfig):
6361
unique_results[x_int] = result["time"]
64-
6562
assert len(unique_results) <= filter_options["max_fevals"]
63+
64+
# check whether the returned dictionaries contain exactly the expected keys and the appropriate type
65+
expected_items = {
66+
'block_size_x': int,
67+
'time': (float, int),
68+
'times': list,
69+
'compile_time': (float, int),
70+
'verification_time': (float, int),
71+
'benchmark_time': (float, int),
72+
'strategy_time': (float, int),
73+
'framework_time': (float, int),
74+
'timestamp': str
75+
}
76+
for res in results:
77+
assert len(res) == len(expected_items)
78+
for expected_key, expected_type in expected_items.items():
79+
assert expected_key in res
80+
assert isinstance(res[expected_key], expected_type)

0 commit comments

Comments
 (0)