Skip to content

Commit c10c885

Browse files
authored
fix and improve fft methods (#804)
1 parent 0d389af commit c10c885

File tree

3 files changed

+47
-26
lines changed

3 files changed

+47
-26
lines changed

src/Extras/fftBigFloat.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
# This is a Cooley-Tukey FFT algorithm for any number type.
2-
function fft_pow2(x::Vector{F}) where F
2+
function fft_pow2(x::StridedVector)
33
n = length(x)
4-
T = mapreduce(eltype,promote_type,x)
4+
T = mapreduce(eltype(x) isa Fun ? cfstype : eltype, promote_type,x)
55
@assert ispow2(n)
66
if n==1
7-
return x
7+
return convert(Vector, x)
88
elseif n==2
9-
return F[x[1]+x[2];x[1]-x[2]]
9+
return eltype(x)[x[1]+x[2], x[1]-x[2]]
1010
end
11-
even,odd = fft_pow2(x[1:2:end-1]),fft_pow2(x[2:2:end])
12-
twiddle = exp(-2im*convert(T,π)/n*collect(0:n-1))
13-
half1 = even + odd.*twiddle[1:div(n,2)]
14-
half2 = even + odd.*twiddle[div(n,2)+1:n]
15-
return vcat(half1,half2)
11+
even,odd = fft_pow2(@view x[1:2:end-1]), fft_pow2(@view x[2:2:end])
12+
ret = similar(x, n)
13+
for (halfind, ind) in enumerate(1:div(n,2))
14+
ret[ind] = even[halfind] + odd[halfind] * cis(-2convert(T,π)/n * (ind - 1))
15+
end
16+
for (halfind, ind) in enumerate(div(n,2)+1:n)
17+
ret[ind] = even[halfind] + odd[halfind] * cis(-2convert(T,π)/n * (ind - 1))
18+
end
19+
return ret
1620
end
17-
ifft_pow2(x::Vector) = conj(fft_pow2(conj(x)))/length(x)
21+
ifft_pow2(x::StridedVector) = conj(fft_pow2(conj(x)))/length(x)

src/Extras/fftGeneric.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,29 @@
1-
const BigFloats = Union{BigFloat,Complex{BigFloat}}
2-
3-
function fft(x::AbstractVector{F}) where F<:Fun
4-
n,T = length(x),mapreduce(eltype,promote_type,x)
1+
function fft(x::AbstractVector{<:Fun})
2+
n,T = length(x),mapreduce(cfstype,promote_type,x)
53
if ispow2(n) return fft_pow2(x) end
64
ks = range(zero(real(T)),stop=n-one(real(T)),length=n)
7-
Wks = exp(-im*convert(T,π)*ks.^2/n)
8-
xq,wq = x.*Wks,conj([exp(-im*convert(T,π)*n);reverse(Wks);Wks[2:end]])
9-
return Wks.*conv(xq,wq)[n+1:2n]
5+
Wks = cis.(-convert(T,π)/n .* ks.^2)
6+
xq,wq = x.*Wks,conj([cis(-convert(T,π)*n);reverse(Wks); @view Wks[2:end]])
7+
return Wks.* @view conv(xq,wq)[n+1:2n]
108
end
119

12-
ifft(x::AbstractVector{F}) where {F<:Fun} = conj(fft(conj(x)))/length(x)
13-
function ifft!(x::AbstractVector{F}) where F<:Fun
10+
ifft(x::AbstractVector{<:Fun}) = conj(fft(conj(x)))/length(x)
11+
function ifft!(x::AbstractVector{<:Fun})
1412
y = conj(fft(conj(x)))/length(x)
15-
x[:] = y
13+
x .= y
1614
return x
1715
end
1816

19-
function conv(u::StridedVector{F}, v::StridedVector) where F<:Fun
17+
nextpow2(n) = 2^ceil(Int, Base.log2(n))
18+
function conv(u::StridedVector{<:Fun}, v::StridedVector)
2019
nu,nv = length(u),length(v)
2120
n = nu + nv - 1
2221
np2 = nextpow2(n)
2322
pad!(u,np2),pad!(v,np2)
2423
y = ifft_pow2(fft_pow2(u).*fft_pow2(v))
2524
#TODO This would not handle Dual/ComplexDual numbers correctly
2625
T = promote_type(mapreduce(eltype,promote_type,u),mapreduce(eltype,promote_type,v))
27-
y = T<:Real ? real(y[1:n]) : y[1:n]
26+
y = T<:Real ? real(@view y[1:n]) : y[1:n]
2827
end
2928

3029
######################################################################
@@ -33,9 +32,9 @@ end
3332

3433
# plan_fft for BigFloats (covers Laurent svfft)
3534

36-
plan_fft(x::Vector{F}) where {F<:Fun} = fft
37-
plan_ifft(x::Vector{F}) where {F<:Fun} = ifft
38-
plan_ifft!(x::Vector{F}) where {F<:Fun} = ifft
35+
plan_fft(x::Vector{<:Fun}) = fft
36+
plan_ifft(x::Vector{<:Fun}) = ifft
37+
plan_ifft!(x::Vector{<:Fun}) = ifft
3938

4039

4140
# Fourier space plans for BigFloat

test/NumberTypeTest.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ApproxFun, ApproxFunOrthogonalPolynomials, Test
1+
using ApproxFun, ApproxFunOrthogonalPolynomials, Test, FFTW
22

33
@testset "BigFloat" begin
44
@testset "BigFloat constructor" begin
@@ -80,4 +80,22 @@ using ApproxFun, ApproxFunOrthogonalPolynomials, Test
8080
w = (1-x^2)^b
8181
@test w(BigFloat(1)/10) (1-(BigFloat(1)/10)^2)^b
8282
end
83+
84+
@testset "fft" begin
85+
@testset "Fun" begin
86+
f = Fun(x->1, Fourier())
87+
g = ApproxFun.fft(fill(f, 3))
88+
@test g[1] Fun(x->3, Fourier())
89+
@test g[2] Fun(x->0, Fourier()) atol=1e-10
90+
@test g[3] Fun(x->0, Fourier()) atol=1e-10
91+
end
92+
93+
@testset "BigFloat" begin
94+
for n in 1:10
95+
v = BigFloat[i for i in 1:n]
96+
fv = ApproxFunBase.fft(v)
97+
@test fv FFTW.fft(Float64.(v)) rtol=1e-8
98+
end
99+
end
100+
end
83101
end

0 commit comments

Comments
 (0)