|
1 | 1 | """ This module contains the functionality for running and compiling C functions """ |
2 | 2 |
|
| 3 | +from collections import namedtuple |
3 | 4 | import subprocess |
4 | 5 | import platform |
5 | 6 | import errno |
|
12 | 13 |
|
13 | 14 | from kernel_tuner.util import get_temp_filename, delete_temp_file, write_file |
14 | 15 |
|
| 16 | +dtype_map = {"int8": C.c_int8, |
| 17 | + "int16": C.c_int16, |
| 18 | + "int32": C.c_int32, |
| 19 | + "int64": C.c_int64, |
| 20 | + "float32": C.c_float, |
| 21 | + "float64": C.c_double} |
| 22 | + |
| 23 | +Argument = namedtuple("Argument", ["type", "shape"]) |
| 24 | + |
15 | 25 |
|
16 | 26 | class CFunctions(object): |
17 | 27 | """Class that groups the code for running and compiling C functions""" |
@@ -67,33 +77,23 @@ def ready_argument_list(self, arguments): |
67 | 77 | :returns: A list of arguments that can be passed to the C function. |
68 | 78 | :rtype: list() |
69 | 79 | """ |
70 | | - ctype_args = [] |
71 | | - |
72 | | - dtype_map = {"int8": C.c_char, |
73 | | - "int16": C.c_short, |
74 | | - "int32": C.c_int32, |
75 | | - "int64": C.c_int64, |
76 | | - "float32": C.c_float, |
77 | | - "float64": C.c_double} |
78 | | - np_to_c_type_map = {numpy.int32: C.c_int32, |
79 | | - numpy.int64: C.c_int64, |
80 | | - numpy.float32: C.c_float, |
81 | | - numpy.float64: C.c_double} |
82 | | - |
83 | | - for arg in arguments: |
| 80 | + ctype_args = [None for _ in arguments] |
| 81 | + self.arg_mapping = dict() |
| 82 | + |
| 83 | + for i, arg in enumerate(arguments): |
| 84 | + if not isinstance(arg, (numpy.ndarray, numpy.generic)): |
| 85 | + raise TypeError("Argument is not numpy ndarray or numpy scalar %s" % type(arg)) |
| 86 | + dtype_str = str(arg.dtype) |
| 87 | + arg_info = Argument(dtype_str, arg.shape) |
84 | 88 | if isinstance(arg, numpy.ndarray): |
85 | | - dtype_str = str(arg.dtype) |
86 | 89 | if dtype_str in dtype_map.keys(): |
87 | | - ctype_args.append(arg.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))) |
| 90 | + ctype_args[i] = arg.ctypes.data_as(C.POINTER(dtype_map[dtype_str])) |
88 | 91 | else: |
89 | 92 | raise TypeError("unknown dtype for ndarray") |
90 | | - self.arg_mapping[str(ctype_args[-1])] = arg.shape |
91 | | - elif isinstance(arg, tuple(np_to_c_type_map.keys())): |
92 | | - ctype_args.append(np_to_c_type_map[type(arg)](arg)) |
93 | | - self.arg_mapping[str(ctype_args[-1])] = () |
94 | | - else: |
95 | | - raise TypeError("Argument is not numpy ndarray or numpy scalar %s" % type(arg)) |
96 | | - |
| 93 | + self.arg_mapping[str(ctype_args[i])] = arg_info |
| 94 | + elif isinstance(arg, numpy.generic): |
| 95 | + ctype_args[i] = dtype_map[dtype_str](arg) |
| 96 | + self.arg_mapping[str(i)] = arg_info |
97 | 97 | return ctype_args |
98 | 98 |
|
99 | 99 |
|
@@ -169,9 +169,9 @@ def compile(self, kernel_name, kernel_string): |
169 | 169 | delete_temp_file(filename+".so") |
170 | 170 | delete_temp_file(filename+".dylib") |
171 | 171 |
|
172 | | - |
173 | 172 | return func |
174 | 173 |
|
| 174 | + |
175 | 175 | def benchmark(self, func, c_args, threads, grid, times): |
176 | 176 | """runs the kernel repeatedly, returns averaged returned value |
177 | 177 |
|
@@ -289,7 +289,8 @@ def memcpy_dtoh(self, dest, src): |
289 | 289 | :param src: A ctypes pointer to some memory allocation |
290 | 290 | :type src: ctypes.pointer |
291 | 291 | """ |
292 | | - dest[:] = numpy.ctypeslib.as_array(src, shape=self.arg_mapping[str(src)]) |
| 292 | + arginfo = self.arg_mapping[str(src)] |
| 293 | + dest[:] = numpy.ctypeslib.as_array(src, shape=arginfo.shape) |
293 | 294 |
|
294 | 295 |
|
295 | 296 | def cleanup_lib(self): |
|
0 commit comments