|
66 | 66 | @test @jit(rfft(x_ra, (1, 2, 3))) ≈ rfft(x, (1, 2, 3)) |
67 | 67 | end |
68 | 68 | end |
| 69 | + |
| 70 | +@testset "Planned FFTs" begin |
| 71 | + @testset "Out-of-place [$(fft), size $(size)]" for size in ((16,), (16, 16)), |
| 72 | + (plan, fft) in ( |
| 73 | + (FFTW.plan_fft, FFTW.fft), |
| 74 | + (FFTW.plan_ifft, FFTW.ifft), |
| 75 | + (FFTW.plan_rfft, FFTW.rfft), |
| 76 | + ) |
| 77 | + |
| 78 | + x = randn(fft === FFTW.rfft ? Float32 : ComplexF32, size) |
| 79 | + x_r = Reactant.to_rarray(x) |
| 80 | + # We make a copy of the original array to make sure the operation does |
| 81 | + # not modify the input. |
| 82 | + copied_x_r = copy(x_r) |
| 83 | + |
| 84 | + planned_fft(x) = plan(x) * x |
| 85 | + compiled_planned_fft = @compile planned_fft(x_r) |
| 86 | + # Make sure the result is correct |
| 87 | + @test compiled_planned_fft(x_r) ≈ fft(x) |
| 88 | + # Make sure the operation is not in-place |
| 89 | + @test x_r == copied_x_r |
| 90 | + end |
| 91 | + |
| 92 | + @testset "In-place [$(fft!), size $(size)]" for size in ((16,), (16, 16)), |
| 93 | + (plan!, fft!) in ((FFTW.plan_fft!, FFTW.fft!), (FFTW.plan_ifft!, FFTW.ifft!)) |
| 94 | + |
| 95 | + x = randn(ComplexF32, size) |
| 96 | + x_r = Reactant.to_rarray(x) |
| 97 | + # We make a copy of the original array to make sure the operation |
| 98 | + # modifies the input. |
| 99 | + copied_x_r = copy(x_r) |
| 100 | + |
| 101 | + planned_fft!(x) = plan!(x) * x |
| 102 | + compiled_planned_fft! = @compile planned_fft!(x_r) |
| 103 | + planned_y_r = compiled_planned_fft!(x_r) |
| 104 | + # Make sure the result is correct |
| 105 | + @test planned_y_r ≈ fft!(x) |
| 106 | + # Make sure the operation is in-place |
| 107 | + @test planned_y_r ≈ x_r |
| 108 | + @test x_r ≉ copied_x_r |
| 109 | + end |
| 110 | +end |
0 commit comments