11using ExplainabilityMethods
22using ExplainabilityMethods: modify_params
3+ import ExplainabilityMethods: _modify_layer
34using Flux
45using LinearAlgebra
56using ReferenceTests
67using Random
78
89const RULES = Dict (
9- " ZeroRule" => ZeroRule,
10- " EpsilonRule" => EpsilonRule,
11- " GammaRule" => GammaRule,
12- " ZBoxRule" => ZBoxRule,
10+ " ZeroRule" => ZeroRule () ,
11+ " EpsilonRule" => EpsilonRule () ,
12+ " GammaRule" => GammaRule () ,
13+ " ZBoxRule" => ZBoxRule () ,
1314)
1415
1516# # Hand-written tests
5455
5556# # Test Dense layer
5657# Define Dense test input
57- ins = 20 # input dimension
58- outs = 10 # output dimension
59- aₖ = pseudorandn (ins )
58+ ins_dense = 20 # input dimension
59+ outs_dense = 10 # output dimension
60+ aₖ = pseudorandn (ins_dense )
6061
6162layers = Dict (
62- " Dense_relu" => Dense (ins, outs , relu; init= pseudorandn),
63- " Dense_identity" => Dense (Matrix (I, outs, ins ), false , identity),
63+ " Dense_relu" => Dense (ins_dense, outs_dense , relu; init= pseudorandn),
64+ " Dense_identity" => Dense (Matrix (I, outs_dense, ins_dense ), false , identity),
6465)
6566@testset " Dense" begin
66- for (rulename, ruletype) in RULES
67- rule = ruletype ()
67+ for (rulename, rule) in RULES
6868 @testset " $rulename " begin
6969 for (layername, layer) in layers
7070 @testset " $layername " begin
@@ -76,10 +76,10 @@ layers = Dict(
7676
7777 # println(Rₖ)
7878 if rulename == " Dense_identity"
79- # First `outs ` dimensions should propagate
79+ # First `outs_dense ` dimensions should propagate
8080 # activations as relevances, rest should be ≈ 0.
81- @test Rₖ[1 : outs ] ≈ aₖ[1 : outs ]
82- @test all (Rₖ[outs : end ] .< 1e-8 )
81+ @test Rₖ[1 : outs_dense ] ≈ aₖ[1 : outs_dense ]
82+ @test all (Rₖ[outs_dense : end ] .< 1e-8 )
8383 end
8484
8585 @test_reference " references/rules/$rulename /$layername .jld2" Dict (
@@ -103,8 +103,7 @@ equalpairs = Dict( # these pairs of layers are all equal
103103)
104104
105105@testset " PoolingLayers" begin
106- for (rulename, ruletype) in RULES
107- rule = ruletype ()
106+ for (rulename, rule) in RULES
108107 @testset " $rulename " begin
109108 for (layername, layers) in equalpairs
110109 @testset " $layername " begin
@@ -139,8 +138,7 @@ layers = Dict(
139138 " AlphaDropout" => AlphaDropout (0.2 ),
140139)
141140@testset " Other Layers" begin
142- for (rulename, ruletype) in RULES
143- rule = ruletype ()
141+ for (rulename, rule) in RULES
144142 @testset " $rulename " begin
145143 for (layername, layer) in layers
146144 @testset " $layername " begin
@@ -164,26 +162,33 @@ end
164162struct TestWrapper{T}
165163 layer:: T
166164end
167- (l:: TestWrapper )(x) = l. layer (x)
165+ (w:: TestWrapper )(x) = w. layer (x)
166+ _modify_layer (r:: AbstractLRPRule , w:: TestWrapper ) = _modify_layer (r, w. layer)
167+ (rule:: ZBoxRule )(w:: TestWrapper , aₖ, Rₖ₊₁) = rule (w. layer, aₖ, Rₖ₊₁)
168168
169169layers = Dict (
170170 " Conv" => (Conv ((3 , 3 ), 2 => 4 ; init= pseudorandn), aₖ),
171+ " Dense_relu" =>
172+ (Dense (ins_dense, outs_dense, relu; init= pseudorandn), pseudorandn (ins_dense)),
171173 " flatten" => (flatten, aₖ),
172- " Dense" => (Dense (20 , 10 , relu; init= pseudorandn), pseudorandn (20 )),
173174)
174175@testset " Custom layers" begin
175- for (layername, (layer, aₖ)) in layers
176- @testset " $layername " begin
177- rule = ZeroRule ()
178- wrapped_layer = TestWrapper (layer)
179- Rₖ₊₁ = wrapped_layer (aₖ)
180- Rₖ = rule (wrapped_layer, aₖ, Rₖ₊₁)
181-
182- @test typeof (Rₖ) == typeof (aₖ)
183- @test size (Rₖ) == size (aₖ)
184-
185- @test_reference " references/rules/ZeroRule/$layername .jld2" Dict (" R" => Rₖ) by =
186- (r, a) -> isapprox (r[" R" ], a[" R" ]; rtol= 0.02 )
176+ for (rulename, rule) in RULES
177+ @testset " $rulename " begin
178+ for (layername, (layer, aₖ)) in layers
179+ @testset " $layername " begin
180+ wrapped_layer = TestWrapper (layer)
181+ Rₖ₊₁ = wrapped_layer (aₖ)
182+ Rₖ = rule (wrapped_layer, aₖ, Rₖ₊₁)
183+
184+ @test typeof (Rₖ) == typeof (aₖ)
185+ @test size (Rₖ) == size (aₖ)
186+
187+ @test_reference " references/rules/$rulename /$layername .jld2" Dict (
188+ " R" => Rₖ
189+ ) by = (r, a) -> isapprox (r[" R" ], a[" R" ]; rtol= 0.02 )
190+ end
191+ end
187192 end
188193 end
189194end
0 commit comments