Skip to content

Commit 5fd4b73

Browse files
check argument list throws UserWarning instead of TypeError
1 parent 484f99c commit 5fd4b73

File tree

2 files changed

+23
-37
lines changed

2 files changed

+23
-37
lines changed

kernel_tuner/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def check_argument_list(kernel_name, kernel_string, args):
6565
# it is the right one
6666
return
6767
for errors in collected_errors:
68-
raise TypeError(errors[0])
68+
warnings.warn(errors[0], UserWarning)
69+
#raise TypeError(errors[0])
6970

7071
def check_tune_params_list(tune_params):
7172
""" raise an exception if a tune parameter has a forbidden name """

test/test_util_functions.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,22 @@ def test_get_device_interface3():
214214
lang = "blabla"
215215
core.DeviceInterface("", 0, 0, lang=lang)
216216

217+
def assert_user_warning(f, args, substring=None):
218+
with warnings.catch_warnings(record=True) as w:
219+
warnings.simplefilter("always")
220+
f(*args)
221+
assert len(w) == 1
222+
assert issubclass(w[-1].category, UserWarning)
223+
if substring:
224+
assert substring in str(w[-1].message)
225+
226+
def assert_no_user_warning(f, args):
227+
with warnings.catch_warnings(record=True) as w:
228+
warnings.simplefilter("always")
229+
f(*args)
230+
assert len(w) == 0
231+
232+
217233
def test_check_argument_list1():
218234
kernel_name = "test_kernel"
219235
kernel_string = """__kernel void test_kernel(int number, char * message, int * numbers) {
@@ -240,9 +256,7 @@ def test_check_argument_list2():
240256
}
241257
"""
242258
args = [numpy.byte(5), numpy.float64(4.6), numpy.int32([1, 2, 3]), numpy.uint64([3, 2, 111])]
243-
check_argument_list(kernel_name, kernel_string, args)
244-
#test that no exception is raised
245-
assert True
259+
assert_no_user_warning(check_argument_list, [kernel_name, kernel_string, args])
246260

247261
def test_check_argument_list3():
248262
kernel_name = "test_kernel"
@@ -251,16 +265,7 @@ def test_check_argument_list3():
251265
}
252266
"""
253267
args = [numpy.uint16(42), numpy.float16([3, 4, 6]), numpy.int32([300])]
254-
try:
255-
check_argument_list(kernel_name, kernel_string, args)
256-
print("Expected a TypeError to be raised")
257-
assert False
258-
except TypeError as expected_error:
259-
print(str(expected_error))
260-
assert "at position 2" in str(expected_error)
261-
except Exception:
262-
print("Expected a TypeError to be raised")
263-
assert False
268+
assert_user_warning(check_argument_list, [kernel_name, kernel_string, args], "at position 2")
264269

265270
def test_check_argument_list4():
266271
kernel_name = "test_kernel"
@@ -269,16 +274,7 @@ def test_check_argument_list4():
269274
}
270275
"""
271276
args = [numpy.uint16(42), numpy.float16([3, 4, 6]), numpy.int64([300]), numpy.ubyte(32)]
272-
try:
273-
check_argument_list(kernel_name, kernel_string, args)
274-
print("Expected a TypeError to be raised")
275-
assert False
276-
except TypeError as expected_error:
277-
print(str(expected_error))
278-
assert "do not match in size" in str(expected_error)
279-
except Exception:
280-
print("Expected a TypeError to be raised")
281-
assert False
277+
assert_user_warning(check_argument_list, [kernel_name, kernel_string, args], "do not match in size")
282278

283279
def test_check_argument_list5():
284280
kernel_name = "my_test_kernel"
@@ -298,13 +294,7 @@ def test_check_argument_list5():
298294
args = [numpy.array([1,2,3]).astype(numpy.float64),
299295
numpy.array([1,2,3]).astype(numpy.float32),
300296
numpy.int32(6), numpy.int32(7)]
301-
302-
try:
303-
check_argument_list(kernel_name, kernel_string, args)
304-
305-
except TypeError:
306-
print("Expected no TypeError to be raised")
307-
assert False
297+
assert_no_user_warning(check_argument_list, [kernel_name, kernel_string, args])
308298

309299
def test_check_argument_list6():
310300
kernel_name = "test_kernel"
@@ -333,12 +323,7 @@ def test_check_argument_list7():
333323
// /test_kernel
334324
"""
335325
args = [numpy.byte(5), numpy.float64(4.6), numpy.int32([1, 2, 3]), numpy.uint64([3, 2, 111])]
336-
try:
337-
check_argument_list(kernel_name, kernel_string, args)
338-
print("Expected a TypeError to be raised.")
339-
assert False
340-
except TypeError:
341-
assert True
326+
assert_user_warning(check_argument_list, [kernel_name, kernel_string, args])
342327

343328
def test_check_tune_params_list():
344329
tune_params = dict(zip(["one_thing", "led_to_another", "and_before_you_know_it",

0 commit comments

Comments
 (0)