33import array_api_strict
44
55
6- @pytest .mark .parametrize ("func_name" , ("fft" , "ifft" , "fftn" , "ifftn" , "irfft" ,
7- "irfftn" , "hfft" , "fftshift" , "ifftshift" ))
6+ @pytest .mark .parametrize (
7+ "func_name" ,
8+ (
9+ "fft" ,
10+ "ifft" ,
11+ "fftn" ,
12+ "ifftn" ,
13+ "irfft" ,
14+ "irfftn" ,
15+ "hfft" ,
16+ "fftshift" ,
17+ "ifftshift" ,
18+ ),
19+ )
820def test_fft_device_support_complex (func_name ):
921 func = getattr (array_api_strict .fft , func_name )
10- x = array_api_strict .asarray ([1 , 2. ],
11- dtype = array_api_strict .complex64 ,
12- device = array_api_strict .Device ("device1" ))
22+ x = array_api_strict .asarray (
23+ [1 , 2.0 ],
24+ dtype = array_api_strict .complex64 ,
25+ device = array_api_strict .Device ("device1" ),
26+ )
1327 y = func (x )
1428
1529 assert x .device == y .device
@@ -18,8 +32,7 @@ def test_fft_device_support_complex(func_name):
1832@pytest .mark .parametrize ("func_name" , ("rfft" , "rfftn" , "ihfft" ))
1933def test_fft_device_support_real (func_name ):
2034 func = getattr (array_api_strict .fft , func_name )
21- x = array_api_strict .asarray ([1 , 2. ],
22- device = array_api_strict .Device ("device1" ))
35+ x = array_api_strict .asarray ([1 , 2.0 ], device = array_api_strict .Device ("device1" ))
2336 y = func (x )
2437
25- assert x .device == y .device
38+ assert x .device == y .device
0 commit comments