Skip to content

Commit d0cace0

Browse files
giordanoavik-pal
andauthored
Add support for FFT plans (#1931)
* Add support for FFT plans * Fix in-place plan Co-authored-by: Avik Pal <avikpal@mit.edu> * Use `copyto!` --------- Co-authored-by: Avik Pal <avikpal@mit.edu>
1 parent ccb28c2 commit d0cace0

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

ext/ReactantAbstractFFTsExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ for op in (:rfft, :fft, :ifft)
4040
invperm(perm),
4141
)
4242
end
43+
44+
# Out-of-place plan
45+
plan_name = Symbol("Reactant", uppercase(string(op)), "Plan")
46+
plan_f = Symbol("plan_", op)
47+
@eval struct $(plan_name){T} <: AbstractFFTs.Plan{T} end
48+
@eval AbstractFFTs.$(plan_f)(::Reactant.TracedRArray{T}) where {T} = $(plan_name){T}()
49+
@eval Base.:*(::$(plan_name){T}, x::Reactant.TracedRArray{T}) where {T} =
50+
AbstractFFTs.$(op)(x)
51+
52+
# In-place plan
53+
if op !== :rfft
54+
plan_name! = Symbol("Reactant", uppercase(string(op)), "InPlacePlan")
55+
plan_f! = Symbol("plan_", op, "!")
56+
@eval struct $(plan_name!){T} <: AbstractFFTs.Plan{T} end
57+
@eval AbstractFFTs.$(plan_f!)(::Reactant.TracedRArray{T}) where {T} =
58+
$(plan_name!){T}()
59+
@eval Base.:*(::$(plan_name!){T}, x::Reactant.TracedRArray{T}) where {T} =
60+
copyto!(x, AbstractFFTs.$(op)(x))
61+
end
4362
end
4463

4564
for op in (:irfft,)

test/integration/fft.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,45 @@ end
6666
@test @jit(rfft(x_ra, (1, 2, 3))) rfft(x, (1, 2, 3))
6767
end
6868
end
69+
70+
@testset "Planned FFTs" begin
71+
@testset "Out-of-place [$(fft), size $(size)]" for size in ((16,), (16, 16)),
72+
(plan, fft) in (
73+
(FFTW.plan_fft, FFTW.fft),
74+
(FFTW.plan_ifft, FFTW.ifft),
75+
(FFTW.plan_rfft, FFTW.rfft),
76+
)
77+
78+
x = randn(fft === FFTW.rfft ? Float32 : ComplexF32, size)
79+
x_r = Reactant.to_rarray(x)
80+
# We make a copy of the original array to make sure the operation does
81+
# not modify the input.
82+
copied_x_r = copy(x_r)
83+
84+
planned_fft(x) = plan(x) * x
85+
compiled_planned_fft = @compile planned_fft(x_r)
86+
# Make sure the result is correct
87+
@test compiled_planned_fft(x_r) fft(x)
88+
# Make sure the operation is not in-place
89+
@test x_r == copied_x_r
90+
end
91+
92+
@testset "In-place [$(fft!), size $(size)]" for size in ((16,), (16, 16)),
93+
(plan!, fft!) in ((FFTW.plan_fft!, FFTW.fft!), (FFTW.plan_ifft!, FFTW.ifft!))
94+
95+
x = randn(ComplexF32, size)
96+
x_r = Reactant.to_rarray(x)
97+
# We make a copy of the original array to make sure the operation
98+
# modifies the input.
99+
copied_x_r = copy(x_r)
100+
101+
planned_fft!(x) = plan!(x) * x
102+
compiled_planned_fft! = @compile planned_fft!(x_r)
103+
planned_y_r = compiled_planned_fft!(x_r)
104+
# Make sure the result is correct
105+
@test planned_y_r fft!(x)
106+
# Make sure the operation is in-place
107+
@test planned_y_r x_r
108+
@test x_r copied_x_r
109+
end
110+
end

0 commit comments

Comments
 (0)