Skip to content

Commit e980b23

Browse files
committed
2 parents c27081f + 97ed8ca commit e980b23

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

kernel_tuner/strategies/common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,14 @@ def snap_to_nearest_config(x, tune_params):
203203
"""Helper func that for each param selects the closest actual value."""
204204
params = []
205205
for i, k in enumerate(tune_params.keys()):
206-
values = np.array(tune_params[k])
207-
idx = np.abs(values - x[i]).argmin()
206+
values = tune_params[k]
207+
208+
# if `x[i]` is in `values`, use that value, otherwise find the closest match
209+
if x[i] in values:
210+
idx = values.index(x[i])
211+
else:
212+
idx = np.argmin([abs(v - x[i]) for v in values])
213+
208214
params.append(values[idx])
209215
return params
210216

test/test_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ def test_snap_to_nearest_config():
5050
tune_params['x'] = [0, 1, 2, 3, 4, 5]
5151
tune_params['y'] = [0, 1, 2, 3, 4, 5]
5252
tune_params['z'] = [0, 1, 2, 3, 4, 5]
53+
tune_params['w'] = ['a', 'b', 'c']
5354

54-
x = [-5.7, 3.14, 1e6]
55-
expected = [0, 3, 5]
55+
x = [-5.7, 3.14, 1e6, 'b']
56+
expected = [0, 3, 5, 'b']
5657

5758
answer = common.snap_to_nearest_config(x, tune_params)
5859
assert answer == expected

0 commit comments

Comments
 (0)