|
| 1 | +""" Module for wrapper functions |
| 2 | +
|
| 3 | +This module contains functions that generate wrappers for functions, |
| 4 | +allowing them to be compiled and run using Kernel Tuner. |
| 5 | +
|
| 6 | +The first function in this module generates a wrapper for |
| 7 | +primitive-typed (templated) C++ functions, allowing them to be |
| 8 | +compiled and executed using Kernel Tuner. The plan is to later add |
| 9 | +functionality to also wrap device functions. |
| 10 | +
|
| 11 | +""" |
| 12 | + |
| 13 | +import numpy as np |
| 14 | + |
| 15 | +from kernel_tuner import util |
| 16 | + |
| 17 | + |
| 18 | +def cpp(function_name, kernel_source, args, convert_to_array=None): |
| 19 | + """ Generate a wrapper to call C++ functions from Python |
| 20 | +
|
| 21 | + This function allows Kernel Tuner to call templated C++ functions |
| 22 | + that use primitive data types (double, float, int, ...). |
| 23 | +
|
| 24 | + There is support to convert function arguments from plain pointers |
| 25 | + to array references. If this is needed, there should be a True value |
| 26 | + in convert_to_array in the location corresponding to the location in |
| 27 | + the args array. |
| 28 | +
|
| 29 | + For example, a Numpy array argument of type float64 and length 10 |
| 30 | + will be cast using: |
| 31 | + ``*reinterpret_cast<double(*)[10]>(arg)`` |
| 32 | + which allows it to be used to call a C++ that is defined as: |
| 33 | + ``template<typename T, int s>void my_function(T (&arg)[s], ...)`` |
| 34 | +
|
| 35 | + Arrays of size 1 will be converted to simple non-array references. |
| 36 | + False indicates that no conversion is performed. Conversion |
| 37 | + is only support for numpy array arguments. If convert_to_array is |
| 38 | + passed it should have the same length as the args array. |
| 39 | +
|
| 40 | + :param function_name: A string containing the name of the C++ function |
| 41 | + to be wrapped |
| 42 | + :type function_name: string |
| 43 | +
|
| 44 | + :param kernel_source: One of the sources for the kernel, could be a |
| 45 | + function that generates the kernel code, a string containing a filename |
| 46 | + that points to the kernel source, or just a string that contains the code. |
| 47 | + :type kernel_source: string or callable |
| 48 | +
|
| 49 | + :param args: A list of kernel arguments, use numpy arrays for |
| 50 | + arrays, use numpy.int32 or numpy.float32 for scalars. |
| 51 | + :type args: list |
| 52 | +
|
| 53 | + :param convert_to_array: A list of same length as args, containing |
| 54 | + True or False values indicating whether the corresponding argument |
| 55 | + in args should be cast to a reference to an array or not. |
| 56 | + :type convert_to_array: list (True or False) |
| 57 | +
|
| 58 | + :returns: A string containing the orignal code extended with the wrapper |
| 59 | + function. The wrapper has "extern C" binding and can be passed to |
| 60 | + other Kernel Tuner functions, for example run_kernel with lang="C". |
| 61 | + The name of the wrapper function will be the name of the function with |
| 62 | + a "_wrapper" postfix. |
| 63 | + :rtype: string |
| 64 | +
|
| 65 | + """ |
| 66 | + |
| 67 | + if convert_to_array and len(args) != len(convert_to_array): |
| 68 | + raise ValueError("convert_to_array length should be same as args") |
| 69 | + |
| 70 | + type_map = {"int8": "char", |
| 71 | + "int16": "short", |
| 72 | + "int32": "int", |
| 73 | + "float32": "float", |
| 74 | + "float64": "double"} |
| 75 | + |
| 76 | + def type_str(arg): |
| 77 | + if not str(arg.dtype) in type_map: |
| 78 | + raise Value("only primitive data types are supported by the C++ wrapper") |
| 79 | + typestring = type_map[str(arg.dtype)] |
| 80 | + if isinstance(arg, np.ndarray): |
| 81 | + typestring += " *" |
| 82 | + return typestring + " " |
| 83 | + |
| 84 | + signature = ",".join([type_str(arg) + "arg" + str(i) for i, arg in enumerate(args)]) |
| 85 | + |
| 86 | + if not convert_to_array: |
| 87 | + call_args = ",".join(["arg" + str(i) for i in range(len(args))]) |
| 88 | + else: |
| 89 | + call_args = [] |
| 90 | + for i, arg in enumerate(args): |
| 91 | + if convert_to_array[i]: |
| 92 | + if not isinstance(arg, np.ndarray): |
| 93 | + ValueError("conversion to array reference only supported for arguments that are numpy arrays, use length-1 numpy array to pass a scalar by reference") |
| 94 | + if np.prod(arg.shape) > 1: |
| 95 | + #convert pointer to a reference to an array |
| 96 | + arg_shape = "".join("[%d]" % i for i in arg.shape) |
| 97 | + arg_str = "*reinterpret_cast<" + type_map[str(arg.dtype)] + "(*)" + arg_shape + ">(arg" + str(i) + ")" |
| 98 | + else: |
| 99 | + #a reference is accepted rather than a pointer, just dereference |
| 100 | + arg_str = "*arg" + str(i) |
| 101 | + call_args.append(arg_str) |
| 102 | + #call_args = ",".join(["*reinterpret_cast<double(*)[9]>(arg" + str(i) + ")" for i in range(len(args))]) |
| 103 | + else: |
| 104 | + call_args.append("arg" + str(i)) |
| 105 | + call_args_str = ",".join(call_args) |
| 106 | + |
| 107 | + kernel_string = util.get_kernel_string(kernel_source) |
| 108 | + |
| 109 | + return """ |
| 110 | +
|
| 111 | + %s |
| 112 | +
|
| 113 | + extern "C" |
| 114 | + float %s_wrapper(%s) { |
| 115 | +
|
| 116 | + %s(%s); |
| 117 | +
|
| 118 | + return 0.0f; |
| 119 | + }""" % (kernel_string, function_name, signature, function_name, call_args_str) |
| 120 | + |
| 121 | + |
0 commit comments