Skip to content

Commit 96f367e

Browse files
improve code quality
1 parent ea608b8 commit 96f367e

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

kernel_tuner/util.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,35 @@
55
import errno
66

77
def get_instance_string(params):
8+
""" combine the parameters to a string mostly used for debug output
9+
use of OrderedDict is advised
10+
"""
811
return "_".join([str(i) for i in params.values()])
912

1013
def get_config_string(params):
14+
""" return a compact string representation of a dictionary """
1115
return "".join([k + "=" + str(v) + ", " for k,v in params.items()])
1216

1317
def get_kernel_string(original_kernel):
18+
""" retrieves kernel string from a file if the string passed looks like filename
19+
if the string does look like a filename, but the file does not exist, it
20+
is assumed that the string is not a filename after all.
21+
"""
1422
kernel_string = original_kernel
1523
if looks_like_a_filename(original_kernel):
1624
kernel_string = read_file(original_kernel) or original_kernel
1725
return kernel_string
1826

1927
def delete_temp_file(filename):
28+
""" delete a temporary file, don't complain if is no longer exists """
2029
try:
2130
os.remove(filename)
2231
except OSError as e:
2332
if e.errno != errno.ENOENT:
2433
raise e
2534

2635
def get_temp_filename():
36+
""" return a string in the form of temp_X, where X is a large integer """
2737
random_large_int = numpy.random.randint(low=100, high=100000000000)
2838
return 'temp_' + str(random_large_int)
2939

@@ -33,7 +43,7 @@ def looks_like_a_filename(original_kernel):
3343
if isinstance(original_kernel, str):
3444
result = True
3545
#test if not too long
36-
if len(original_kernel) > 100:
46+
if len(original_kernel) > 250:
3747
result = False
3848
#test if not contains special characters
3949
for c in "();{}\\":
@@ -48,13 +58,15 @@ def looks_like_a_filename(original_kernel):
4858
return result
4959

5060
def read_file(filename):
61+
""" return the contents of the file named filename or None if file not found """
5162
if os.path.isfile(filename):
5263
with open(filename, 'r') as f:
5364
return f.read()
5465

5566
def write_file(filename, string):
56-
#ugly fix, hopefully we can find a better one
67+
""" dump the contents of string to a file called filename """
5768
import sys
69+
#ugly fix, hopefully we can find a better one
5870
if sys.version_info[0] >= 3:
5971
with open(filename, 'w', encoding="utf-8") as f:
6072
f.write(string)
@@ -176,6 +188,7 @@ def replace_param_occurrences(string, params):
176188
return string
177189

178190
def check_restrictions(restrictions, element, keys, verbose):
191+
""" check whether a specific instance meets the search space restrictions """
179192
params = dict(zip(keys, element))
180193
for restrict in restrictions:
181194
if not eval(replace_param_occurrences(restrict, params)):
@@ -186,25 +199,26 @@ def check_restrictions(restrictions, element, keys, verbose):
186199
return True
187200

188201
def check_argument_list(args):
202+
""" raise an exception if a kernel argument is of unsupported type """
189203
for (i, arg) in enumerate(args):
190204
if not isinstance(arg, (numpy.ndarray, numpy.generic)):
191205
raise TypeError("Argument at position " + str(i) + " of type: " + str(type(arg)) + " should be of type numpy.ndarray or numpy scalar")
192206

193207
def setup_block_and_grid(dev, problem_size, grid_div_z, grid_div_y, grid_div_x, params, instance_string, verbose):
194-
"""compute problem size, thread block and grid dimensions for this kernel"""
195-
threads = get_thread_block_dimensions(params)
196-
if numpy.prod(threads) > dev.max_threads:
197-
if verbose:
198-
print("skipping config", instance_string, "reason: too many threads per block")
199-
return None, None
200-
current_problem_size = get_problem_size(problem_size, params)
201-
grid = get_grid_dimensions(current_problem_size, params, grid_div_z, grid_div_y, grid_div_x)
202-
return threads, grid
208+
"""compute problem size, thread block and grid dimensions for this kernel"""
209+
threads = get_thread_block_dimensions(params)
210+
if numpy.prod(threads) > dev.max_threads:
211+
if verbose:
212+
print("skipping config", instance_string, "reason: too many threads per block")
213+
return None, None
214+
current_problem_size = get_problem_size(problem_size, params)
215+
grid = get_grid_dimensions(current_problem_size, params, grid_div_z, grid_div_y, grid_div_x)
216+
return threads, grid
203217

204218
def setup_kernel_strings(kernel_name, original_kernel, params, grid, instance_string):
205-
"""create configuration specific kernel string"""
206-
kernel_string = prepare_kernel_string(original_kernel, params, grid)
207-
name = kernel_name + "_" + instance_string
208-
kernel_string = kernel_string.replace(kernel_name, name)
209-
return name, kernel_string
219+
"""create configuration specific kernel string"""
220+
kernel_string = prepare_kernel_string(original_kernel, params, grid)
221+
name = kernel_name + "_" + instance_string
222+
kernel_string = kernel_string.replace(kernel_name, name)
223+
return name, kernel_string
210224

0 commit comments

Comments
 (0)