Skip to content

Commit 64b6018

Browse files
bugfix for issue #68
1 parent a430331 commit 64b6018

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

kernel_tuner/c.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" This module contains the functionality for running and compiling C functions """
22

3+
from collections import namedtuple
34
import subprocess
45
import platform
56
import errno
@@ -12,6 +13,15 @@
1213

1314
from kernel_tuner.util import get_temp_filename, delete_temp_file, write_file
1415

16+
dtype_map = {"int8": C.c_int8,
17+
"int16": C.c_int16,
18+
"int32": C.c_int32,
19+
"int64": C.c_int64,
20+
"float32": C.c_float,
21+
"float64": C.c_double}
22+
23+
Argument = namedtuple("Argument", ["type", "shape"])
24+
1525

1626
class CFunctions(object):
1727
"""Class that groups the code for running and compiling C functions"""
@@ -67,33 +77,23 @@ def ready_argument_list(self, arguments):
6777
:returns: A list of arguments that can be passed to the C function.
6878
:rtype: list()
6979
"""
70-
ctype_args = []
71-
72-
dtype_map = {"int8": C.c_char,
73-
"int16": C.c_short,
74-
"int32": C.c_int32,
75-
"int64": C.c_int64,
76-
"float32": C.c_float,
77-
"float64": C.c_double}
78-
np_to_c_type_map = {numpy.int32: C.c_int32,
79-
numpy.int64: C.c_int64,
80-
numpy.float32: C.c_float,
81-
numpy.float64: C.c_double}
82-
83-
for arg in arguments:
80+
ctype_args = [None for _ in arguments]
81+
self.arg_mapping = dict()
82+
83+
for i, arg in enumerate(arguments):
84+
if not isinstance(arg, (numpy.ndarray, numpy.generic)):
85+
raise TypeError("Argument is not numpy ndarray or numpy scalar %s" % type(arg))
86+
dtype_str = str(arg.dtype)
87+
arg_info = Argument(dtype_str, arg.shape)
8488
if isinstance(arg, numpy.ndarray):
85-
dtype_str = str(arg.dtype)
8689
if dtype_str in dtype_map.keys():
87-
ctype_args.append(arg.ctypes.data_as(C.POINTER(dtype_map[dtype_str])))
90+
ctype_args[i] = arg.ctypes.data_as(C.POINTER(dtype_map[dtype_str]))
8891
else:
8992
raise TypeError("unknown dtype for ndarray")
90-
self.arg_mapping[str(ctype_args[-1])] = arg.shape
91-
elif isinstance(arg, tuple(np_to_c_type_map.keys())):
92-
ctype_args.append(np_to_c_type_map[type(arg)](arg))
93-
self.arg_mapping[str(ctype_args[-1])] = ()
94-
else:
95-
raise TypeError("Argument is not numpy ndarray or numpy scalar %s" % type(arg))
96-
93+
self.arg_mapping[str(ctype_args[i])] = arg_info
94+
elif isinstance(arg, numpy.generic):
95+
ctype_args[i] = dtype_map[dtype_str](arg)
96+
self.arg_mapping[str(i)] = arg_info
9797
return ctype_args
9898

9999

@@ -169,9 +169,9 @@ def compile(self, kernel_name, kernel_string):
169169
delete_temp_file(filename+".so")
170170
delete_temp_file(filename+".dylib")
171171

172-
173172
return func
174173

174+
175175
def benchmark(self, func, c_args, threads, grid, times):
176176
"""runs the kernel repeatedly, returns averaged returned value
177177
@@ -289,7 +289,8 @@ def memcpy_dtoh(self, dest, src):
289289
:param src: A ctypes pointer to some memory allocation
290290
:type src: ctypes.pointer
291291
"""
292-
dest[:] = numpy.ctypeslib.as_array(src, shape=self.arg_mapping[str(src)])
292+
arginfo = self.arg_mapping[str(src)]
293+
dest[:] = numpy.ctypeslib.as_array(src, shape=arginfo.shape)
293294

294295

295296
def cleanup_lib(self):

test/test_c_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
except ImportError:
1010
from unittest.mock import patch, Mock
1111

12-
from kernel_tuner.c import CFunctions
12+
from kernel_tuner.c import CFunctions, Argument
1313

1414

1515
def test_ready_argument_list1():
@@ -120,7 +120,7 @@ def test_memcpy_dtoh():
120120
output = numpy.zeros_like(x)
121121

122122
cfunc = CFunctions()
123-
cfunc.arg_mapping = { str(x_c) : (4,) }
123+
cfunc.arg_mapping = { str(x_c) : Argument(str(x.dtype), (4,)) }
124124
cfunc.memcpy_dtoh(output, x_c)
125125

126126
print(a)

0 commit comments

Comments
 (0)