1414 float32 ,
1515 complex64 ,
1616)
17- from ._array_object import Array , CPU_DEVICE
17+ from ._array_object import Array , ALL_DEVICES
1818from ._data_type_functions import astype
1919from ._flags import requires_extension
2020
@@ -36,7 +36,7 @@ def fft(
3636 """
3737 if x .dtype not in _complex_floating_dtypes :
3838 raise TypeError ("Only complex floating-point dtypes are allowed in fft" )
39- res = Array ._new (np .fft .fft (x ._array , n = n , axis = axis , norm = norm ))
39+ res = Array ._new (np .fft .fft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
4040 # Note: np.fft functions improperly upcast float32 and complex64 to
4141 # complex128
4242 if x .dtype == complex64 :
@@ -59,7 +59,7 @@ def ifft(
5959 """
6060 if x .dtype not in _complex_floating_dtypes :
6161 raise TypeError ("Only complex floating-point dtypes are allowed in ifft" )
62- res = Array ._new (np .fft .ifft (x ._array , n = n , axis = axis , norm = norm ))
62+ res = Array ._new (np .fft .ifft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
6363 # Note: np.fft functions improperly upcast float32 and complex64 to
6464 # complex128
6565 if x .dtype == complex64 :
@@ -82,7 +82,7 @@ def fftn(
8282 """
8383 if x .dtype not in _complex_floating_dtypes :
8484 raise TypeError ("Only complex floating-point dtypes are allowed in fftn" )
85- res = Array ._new (np .fft .fftn (x ._array , s = s , axes = axes , norm = norm ))
85+ res = Array ._new (np .fft .fftn (x ._array , s = s , axes = axes , norm = norm ), device = x . device )
8686 # Note: np.fft functions improperly upcast float32 and complex64 to
8787 # complex128
8888 if x .dtype == complex64 :
@@ -105,7 +105,7 @@ def ifftn(
105105 """
106106 if x .dtype not in _complex_floating_dtypes :
107107 raise TypeError ("Only complex floating-point dtypes are allowed in ifftn" )
108- res = Array ._new (np .fft .ifftn (x ._array , s = s , axes = axes , norm = norm ))
108+ res = Array ._new (np .fft .ifftn (x ._array , s = s , axes = axes , norm = norm ), device = x . device )
109109 # Note: np.fft functions improperly upcast float32 and complex64 to
110110 # complex128
111111 if x .dtype == complex64 :
@@ -128,7 +128,7 @@ def rfft(
128128 """
129129 if x .dtype not in _real_floating_dtypes :
130130 raise TypeError ("Only real floating-point dtypes are allowed in rfft" )
131- res = Array ._new (np .fft .rfft (x ._array , n = n , axis = axis , norm = norm ))
131+ res = Array ._new (np .fft .rfft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
132132 # Note: np.fft functions improperly upcast float32 and complex64 to
133133 # complex128
134134 if x .dtype == float32 :
@@ -151,7 +151,7 @@ def irfft(
151151 """
152152 if x .dtype not in _complex_floating_dtypes :
153153 raise TypeError ("Only complex floating-point dtypes are allowed in irfft" )
154- res = Array ._new (np .fft .irfft (x ._array , n = n , axis = axis , norm = norm ))
154+ res = Array ._new (np .fft .irfft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
155155 # Note: np.fft functions improperly upcast float32 and complex64 to
156156 # complex128
157157 if x .dtype == complex64 :
@@ -174,7 +174,7 @@ def rfftn(
174174 """
175175 if x .dtype not in _real_floating_dtypes :
176176 raise TypeError ("Only real floating-point dtypes are allowed in rfftn" )
177- res = Array ._new (np .fft .rfftn (x ._array , s = s , axes = axes , norm = norm ))
177+ res = Array ._new (np .fft .rfftn (x ._array , s = s , axes = axes , norm = norm ), device = x . device )
178178 # Note: np.fft functions improperly upcast float32 and complex64 to
179179 # complex128
180180 if x .dtype == float32 :
@@ -197,7 +197,7 @@ def irfftn(
197197 """
198198 if x .dtype not in _complex_floating_dtypes :
199199 raise TypeError ("Only complex floating-point dtypes are allowed in irfftn" )
200- res = Array ._new (np .fft .irfftn (x ._array , s = s , axes = axes , norm = norm ))
200+ res = Array ._new (np .fft .irfftn (x ._array , s = s , axes = axes , norm = norm ), device = x . device )
201201 # Note: np.fft functions improperly upcast float32 and complex64 to
202202 # complex128
203203 if x .dtype == complex64 :
@@ -220,7 +220,7 @@ def hfft(
220220 """
221221 if x .dtype not in _complex_floating_dtypes :
222222 raise TypeError ("Only complex floating-point dtypes are allowed in hfft" )
223- res = Array ._new (np .fft .hfft (x ._array , n = n , axis = axis , norm = norm ))
223+ res = Array ._new (np .fft .hfft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
224224 # Note: np.fft functions improperly upcast float32 and complex64 to
225225 # complex128
226226 if x .dtype == complex64 :
@@ -243,7 +243,7 @@ def ihfft(
243243 """
244244 if x .dtype not in _real_floating_dtypes :
245245 raise TypeError ("Only real floating-point dtypes are allowed in ihfft" )
246- res = Array ._new (np .fft .ihfft (x ._array , n = n , axis = axis , norm = norm ))
246+ res = Array ._new (np .fft .ihfft (x ._array , n = n , axis = axis , norm = norm ), device = x . device )
247247 # Note: np.fft functions improperly upcast float32 and complex64 to
248248 # complex128
249249 if x .dtype == float32 :
@@ -257,9 +257,9 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar
257257
258258 See its docstring for more information.
259259 """
260- if device not in [ CPU_DEVICE , None ] :
260+ if device not in ALL_DEVICES :
261261 raise ValueError (f"Unsupported device { device !r} " )
262- return Array ._new (np .fft .fftfreq (n , d = d ))
262+ return Array ._new (np .fft .fftfreq (n , d = d ), device = device )
263263
264264@requires_extension ('fft' )
265265def rfftfreq (n : int , / , * , d : float = 1.0 , device : Optional [Device ] = None ) -> Array :
@@ -268,9 +268,9 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A
268268
269269 See its docstring for more information.
270270 """
271- if device not in [ CPU_DEVICE , None ] :
271+ if device not in ALL_DEVICES :
272272 raise ValueError (f"Unsupported device { device !r} " )
273- return Array ._new (np .fft .rfftfreq (n , d = d ))
273+ return Array ._new (np .fft .rfftfreq (n , d = d ), device = device )
274274
275275@requires_extension ('fft' )
276276def fftshift (x : Array , / , * , axes : Union [int , Sequence [int ]] = None ) -> Array :
@@ -281,7 +281,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
281281 """
282282 if x .dtype not in _floating_dtypes :
283283 raise TypeError ("Only floating-point dtypes are allowed in fftshift" )
284- return Array ._new (np .fft .fftshift (x ._array , axes = axes ))
284+ return Array ._new (np .fft .fftshift (x ._array , axes = axes ), device = x . device )
285285
286286@requires_extension ('fft' )
287287def ifftshift (x : Array , / , * , axes : Union [int , Sequence [int ]] = None ) -> Array :
@@ -292,7 +292,7 @@ def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
292292 """
293293 if x .dtype not in _floating_dtypes :
294294 raise TypeError ("Only floating-point dtypes are allowed in ifftshift" )
295- return Array ._new (np .fft .ifftshift (x ._array , axes = axes ))
295+ return Array ._new (np .fft .ifftshift (x ._array , axes = axes ), device = x . device )
296296
297297__all__ = [
298298 "fft" ,
0 commit comments