Skip to content

Commit cf63356

Browse files
authored
Support multiplication by im (#1929)
* multiplication by im * test * fix * Float64 --> Float32
1 parent 19a94cf commit cf63356

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/TracedRNumber.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,15 @@ for (jlop, hloop) in (
242242
end
243243
end
244244

245+
function Base.:*(x::TracedRNumber{T}, z::Complex{Bool}) where {T<:Real}
246+
# this is to support multiplication by im (Complex{Bool}(false, true))
247+
z_re, z_im = real(z), imag(z)
248+
res_re = z_re ? x : zero(x)
249+
res_im = z_im ? x : zero(x)
250+
return Complex(res_re, res_im)
251+
end
252+
Base.:*(z::Complex{Bool}, x::TracedRNumber{T}) where {T<:Real} = x * z
253+
245254
# Based on https://github.com/JuliaLang/julia/blob/39255d47db7657950ff1c82137ecec5a70bae622/base/float.jl#L608-L617
246255
function Base.mod(
247256
@nospecialize(x::Reactant.TracedRNumber{T}), @nospecialize(y::Reactant.TracedRNumber{T})

test/complex.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ end
6060
@test isapprox(f(y), ComplexF32(2.0 - 1.0im))
6161
end
6262

63+
@testset "multiplication by im" begin
64+
x = ConcreteRNumber(42.0f0)
65+
66+
mul_im(x) = x * im
67+
mul_re(x) = x * Complex(true, false)
68+
mul_re_im(x) = x * Complex(true, true)
69+
mul_nothing(x) = x * Complex(false, false)
70+
71+
@test (@jit mul_im(x)) == Complex(0.0f0, 42.0f0)
72+
@test (@jit mul_re(x)) == Complex(42.0f0, 0.0f0)
73+
@test (@jit mul_re_im(x)) == Complex(42.0f0, 42.0f0)
74+
@test (@jit mul_nothing(x)) == Complex(0.0f0, 0.0f0)
75+
end
76+
6377
@testset "complex reduction" begin
6478
x = Reactant.TestUtils.construct_test_array(ComplexF32, 10, 10)
6579
x_ra = Reactant.to_rarray(x)

0 commit comments

Comments
 (0)