@@ -15,26 +15,28 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
1515 y1, bk1 = rrule (CFG, copy∘ broadcasted, BS1, > , rand (3 ), rand (3 ))
1616 @test y1 isa AbstractArray{Bool}
1717 @test all (d -> d isa AbstractZero, bk1 (99 ))
18-
18+
1919 y2, bk2 = rrule (CFG, copy∘ broadcasted, BT1, isinteger, Tuple (rand (3 )))
2020 @test y2 isa Tuple{Bool,Bool,Bool}
2121 @test all (d -> d isa AbstractZero, bk2 (99 ))
2222 end
2323
2424 @testset " split 2: derivatives" begin
2525 test_rrule (copy∘ broadcasted, BS1, log, rand (3 ) .+ 1 )
26- test_rrule (copy∘ broadcasted, BT1, log, Tuple (rand (3 ) .+ 1 ))
26+ # `check_inferred` doesn't accept the `Union` returned from ProjectTo as of
27+ # ChainRuleCore 1.15.4 https://github.com/JuliaDiff/ChainRulesCore.jl/issues/586
28+ test_rrule (copy∘ broadcasted, BT1, log, Tuple (rand (3 ) .+ 1 ); check_inferred= false )
2729
2830 # Two args uses StructArrays
2931 test_rrule (copy∘ broadcasted, BS1, atan, rand (3 ), rand (3 ))
3032 test_rrule (copy∘ broadcasted, BS2, atan, rand (3 ), rand (4 )' )
3133 test_rrule (copy∘ broadcasted, BS1, atan, rand (3 ), rand ())
3234 test_rrule (copy∘ broadcasted, BT1, atan, rand (3 ), Tuple (rand (1 )))
3335 test_rrule (copy∘ broadcasted, BT1, atan, Tuple (rand (3 )), Tuple (rand (3 )), check_inferred = VERSION > v " 1.7" )
34-
36+
3537 # test_rrule(copy∘broadcasted, *, BS1, rand(3), Ref(rand())) # don't know what I was testing
3638 end
37-
39+
3840 @testset " split 3: forwards" begin
3941 # In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else.
4042 test_rrule (copy∘ broadcasted, BS1, flog, rand (3 ))
@@ -57,14 +59,14 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
5759 test_rrule (copy∘ broadcasted, BS2, Multiplier (rand ()), rand (3 ), rand (4 )' , check_inferred= false ) # Union{ZeroTangent, Tangent{Multiplier{...
5860 @test_skip test_rrule (copy∘ broadcasted, BS1, Multiplier (rand ()), rand (3 ), 5.0im , check_inferred= false ) # ProjectTo(f) fails to remove the imaginary part of Multiplier's gradient
5961 test_rrule (copy∘ broadcasted, BS1, make_two_vec, rand (3 ), check_inferred= false )
60-
62+
6163 # Non-diff components -- note that with BroadcastStyle, Ref is from e.g. Broadcast.broadcastable(nothing)
6264 test_rrule (copy∘ broadcasted, BS2, first∘ tuple, rand (3 ), Ref (:sym ), rand (4 )' , check_inferred= false )
6365 test_rrule (copy∘ broadcasted, BS2, last∘ tuple, rand (3 ), Ref (nothing ), rand (4 )' , check_inferred= false )
6466 test_rrule (copy∘ broadcasted, BS1, |> , rand (3 ), Ref (sin), check_inferred= false )
6567 _call (f, x... ) = f (x... )
6668 test_rrule (copy∘ broadcasted, BS2, _call, Ref (atan), rand (3 ), rand (4 )' , check_inferred= false )
67-
69+
6870 test_rrule (copy∘ broadcasted, BS1, getindex, [rand (3 ) for _ in 1 : 2 ], [3 ,1 ], check_inferred= false )
6971 test_rrule (copy∘ broadcasted, BS1, getindex, [rand (3 ) for _ in 1 : 2 ], (3 ,1 ), check_inferred= false )
7072 test_rrule (copy∘ broadcasted, BS1, getindex, [rand (3 ) for _ in 1 : 2 ], Ref (CartesianIndex (2 )), check_inferred= false )
@@ -86,20 +88,20 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
8688 @gpu test_rrule (copy∘ broadcasted, + , rand (3 ), 1.0 * im)
8789 @gpu test_rrule (copy∘ broadcasted, + , rand (3 ), true )
8890 @gpu_broken test_rrule (copy∘ broadcasted, + , rand (3 ), Tuple (rand (3 )))
89-
91+
9092 @gpu test_rrule (copy∘ broadcasted, - , rand (3 ), rand (3 ))
9193 @gpu test_rrule (copy∘ broadcasted, - , rand (3 ), rand (4 )' )
9294 @gpu test_rrule (copy∘ broadcasted, - , rand (3 ))
9395 test_rrule (copy∘ broadcasted, - , Tuple (rand (3 )))
94-
96+
9597 @gpu test_rrule (copy∘ broadcasted, * , rand (3 ), rand (3 ))
9698 @gpu test_rrule (copy∘ broadcasted, * , rand (3 ), rand ())
9799 @gpu test_rrule (copy∘ broadcasted, * , rand (), rand (3 ))
98100
99101 test_rrule (copy∘ broadcasted, * , rand (3 ) .+ im, rand (3 ) .+ 2im )
100102 test_rrule (copy∘ broadcasted, * , rand (3 ) .+ im, rand () + 3im )
101103 test_rrule (copy∘ broadcasted, * , rand () + im, rand (3 ) .+ 4im )
102-
104+
103105 @test_skip test_rrule (copy∘ broadcasted, * , im, rand (3 )) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}})
104106 @test_skip test_rrule (copy∘ broadcasted, * , rand (3 ), im) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}})
105107 y4, bk4 = rrule (CFG, copy∘ broadcasted, * , im, [1 ,2 ,3.0 ])
@@ -113,16 +115,16 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
113115
114116 @gpu test_rrule (copy∘ broadcasted, Base. literal_pow, ^ , rand (3 ), Val (2 ))
115117 @gpu test_rrule (copy∘ broadcasted, Base. literal_pow, ^ , rand (3 ) .+ im, Val (2 ))
116-
118+
117119 @gpu test_rrule (copy∘ broadcasted, / , rand (3 ), rand ())
118120 @gpu test_rrule (copy∘ broadcasted, / , rand (3 ) .+ im, rand () + 3im )
119121 end
120122 @testset " identity etc" begin
121123 test_rrule (copy∘ broadcasted, identity, rand (3 ))
122-
124+
123125 test_rrule (copy∘ broadcasted, Float32, rand (3 ), rtol= 1e-4 )
124126 test_rrule (copy∘ broadcasted, ComplexF32, rand (3 ), rtol= 1e-4 )
125-
127+
126128 test_rrule (copy∘ broadcasted, float, rand (3 ))
127129 end
128130 @testset " complex" begin
@@ -136,7 +138,7 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
136138
137139 test_rrule (copy∘ broadcasted, imag, rand (3 ))
138140 test_rrule (copy∘ broadcasted, imag, rand (3 ) .+ im .* rand .())
139-
141+
140142 test_rrule (copy∘ broadcasted, complex, rand (3 ))
141143 end
142144 end
@@ -173,9 +175,9 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
173175 test_rrule (copy∘ broadcasted, complex, rand ())
174176 end
175177 end
176-
178+
177179 @testset " bugs" begin
178180 @test ChainRules. unbroadcast ((1 , 2 , [3 ]), [4 , 5 , [6 ]]) isa Tangent # earlier, NTuple demanded same type
179181 @test ChainRules. unbroadcast (broadcasted (- , (1 , 2 ), 3 ), (4 , 5 )) == (4 , 5 ) # earlier, called ndims(::Tuple)
180182 end
181- end
183+ end
0 commit comments