Skip to content

Commit 31fe668

Browse files
add user input checks for run_kernel
1 parent 1789186 commit 31fe668

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

kernel_tuner/interface.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -297,22 +297,11 @@ def tune_kernel(kernel_name, kernel_string, problem_size, arguments,
297297
if log:
298298
logging.basicConfig(filename=kernel_name + datetime.now().strftime('%Y%m%d-%H:%M:%S') + '.log', level=log)
299299

300-
# see if the kernel arguments have correct type
301-
if not callable(kernel_string):
302-
if isinstance(kernel_string, list):
303-
for file in kernel_string:
304-
util.check_argument_list(kernel_name, util.get_kernel_string(file), arguments)
305-
else:
306-
util.check_argument_list(kernel_name, util.get_kernel_string(kernel_string), arguments)
307-
else:
308-
logging.debug("Checking of arguments list not supported yet for code generators.")
300+
_check_user_input(kernel_name, kernel_string, arguments, block_size_names)
309301

310302
# check for forbidden names in tune parameters
311303
util.check_tune_params_list(tune_params)
312304

313-
# check for types and length of block_size_names
314-
util.check_block_size_names(block_size_names)
315-
316305
# check whether block_size_names are used as expected
317306
util.check_block_size_params_names_list(block_size_names, tune_params)
318307

@@ -435,6 +424,8 @@ def run_kernel(kernel_name, kernel_string, problem_size, arguments,
435424
lang=None, device=0, platform=0, cmem_args=None, compiler=None, compiler_options=None,
436425
block_size_names=None, quiet=False):
437426

427+
_check_user_input(kernel_name, kernel_string, arguments, block_size_names)
428+
438429
#sort options into separate dicts
439430
opts = locals()
440431
kernel_options = Options([(k, opts[k]) for k in _kernel_options.keys()])
@@ -487,3 +478,18 @@ def run_kernel(kernel_name, kernel_string, problem_size, arguments,
487478

488479

489480
run_kernel.__doc__ = _run_kernel_docstring
481+
482+
def _check_user_input(kernel_name, kernel_string, arguments, block_size_names):
483+
# see if the kernel arguments have correct type
484+
if not callable(kernel_string):
485+
if isinstance(kernel_string, list):
486+
for file in kernel_string:
487+
util.check_argument_list(kernel_name, util.get_kernel_string(file), arguments)
488+
else:
489+
util.check_argument_list(kernel_name, util.get_kernel_string(kernel_string), arguments)
490+
else:
491+
logging.debug("Checking of arguments list not supported yet for code generators.")
492+
493+
# check for types and length of block_size_names
494+
util.check_block_size_names(block_size_names)
495+

0 commit comments

Comments
 (0)