@@ -214,6 +214,22 @@ def test_get_device_interface3():
214214 lang = "blabla"
215215 core .DeviceInterface ("" , 0 , 0 , lang = lang )
216216
217+ def assert_user_warning (f , args , substring = None ):
218+ with warnings .catch_warnings (record = True ) as w :
219+ warnings .simplefilter ("always" )
220+ f (* args )
221+ assert len (w ) == 1
222+ assert issubclass (w [- 1 ].category , UserWarning )
223+ if substring :
224+ assert substring in str (w [- 1 ].message )
225+
226+ def assert_no_user_warning (f , args ):
227+ with warnings .catch_warnings (record = True ) as w :
228+ warnings .simplefilter ("always" )
229+ f (* args )
230+ assert len (w ) == 0
231+
232+
217233def test_check_argument_list1 ():
218234 kernel_name = "test_kernel"
219235 kernel_string = """__kernel void test_kernel(int number, char * message, int * numbers) {
@@ -240,9 +256,7 @@ def test_check_argument_list2():
240256 }
241257 """
242258 args = [numpy .byte (5 ), numpy .float64 (4.6 ), numpy .int32 ([1 , 2 , 3 ]), numpy .uint64 ([3 , 2 , 111 ])]
243- check_argument_list (kernel_name , kernel_string , args )
244- #test that no exception is raised
245- assert True
259+ assert_no_user_warning (check_argument_list , [kernel_name , kernel_string , args ])
246260
247261def test_check_argument_list3 ():
248262 kernel_name = "test_kernel"
@@ -251,16 +265,7 @@ def test_check_argument_list3():
251265 }
252266 """
253267 args = [numpy .uint16 (42 ), numpy .float16 ([3 , 4 , 6 ]), numpy .int32 ([300 ])]
254- try :
255- check_argument_list (kernel_name , kernel_string , args )
256- print ("Expected a TypeError to be raised" )
257- assert False
258- except TypeError as expected_error :
259- print (str (expected_error ))
260- assert "at position 2" in str (expected_error )
261- except Exception :
262- print ("Expected a TypeError to be raised" )
263- assert False
268+ assert_user_warning (check_argument_list , [kernel_name , kernel_string , args ], "at position 2" )
264269
265270def test_check_argument_list4 ():
266271 kernel_name = "test_kernel"
@@ -269,16 +274,7 @@ def test_check_argument_list4():
269274 }
270275 """
271276 args = [numpy .uint16 (42 ), numpy .float16 ([3 , 4 , 6 ]), numpy .int64 ([300 ]), numpy .ubyte (32 )]
272- try :
273- check_argument_list (kernel_name , kernel_string , args )
274- print ("Expected a TypeError to be raised" )
275- assert False
276- except TypeError as expected_error :
277- print (str (expected_error ))
278- assert "do not match in size" in str (expected_error )
279- except Exception :
280- print ("Expected a TypeError to be raised" )
281- assert False
277+ assert_user_warning (check_argument_list , [kernel_name , kernel_string , args ], "do not match in size" )
282278
283279def test_check_argument_list5 ():
284280 kernel_name = "my_test_kernel"
@@ -298,13 +294,7 @@ def test_check_argument_list5():
298294 args = [numpy .array ([1 ,2 ,3 ]).astype (numpy .float64 ),
299295 numpy .array ([1 ,2 ,3 ]).astype (numpy .float32 ),
300296 numpy .int32 (6 ), numpy .int32 (7 )]
301-
302- try :
303- check_argument_list (kernel_name , kernel_string , args )
304-
305- except TypeError :
306- print ("Expected no TypeError to be raised" )
307- assert False
297+ assert_no_user_warning (check_argument_list , [kernel_name , kernel_string , args ])
308298
309299def test_check_argument_list6 ():
310300 kernel_name = "test_kernel"
@@ -333,12 +323,7 @@ def test_check_argument_list7():
333323 // /test_kernel
334324 """
335325 args = [numpy .byte (5 ), numpy .float64 (4.6 ), numpy .int32 ([1 , 2 , 3 ]), numpy .uint64 ([3 , 2 , 111 ])]
336- try :
337- check_argument_list (kernel_name , kernel_string , args )
338- print ("Expected a TypeError to be raised." )
339- assert False
340- except TypeError :
341- assert True
326+ assert_user_warning (check_argument_list , [kernel_name , kernel_string , args ])
342327
343328def test_check_tune_params_list ():
344329 tune_params = dict (zip (["one_thing" , "led_to_another" , "and_before_you_know_it" ,
0 commit comments