22abstract type AbstractLRPRule end
33
44# Generic LRP rule. Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
5- function lrp! (rule:: R , layer:: L , Rₖ , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule ,L}
6- lrp_autodiff! (rule, layer, Rₖ , aₖ, Rₖ₊₁)
5+ function lrp! (Rₖ, rule:: R , layer:: L , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule ,L}
6+ lrp_autodiff! (Rₖ, rule, layer , aₖ, Rₖ₊₁)
77 return nothing
88end
99
1010function lrp_autodiff! (
11- rule :: R , layer :: L , Rₖ :: T1 , aₖ:: T1 , Rₖ₊₁:: T2
11+ Rₖ :: T1 , rule :: R , layer :: L , aₖ:: T1 , Rₖ₊₁:: T2
1212) where {R<: AbstractLRPRule ,L,T1,T2}
1313 layerᵨ = modify_layer (rule, layer)
1414 c:: T1 = only (
@@ -23,21 +23,21 @@ function lrp_autodiff!(
2323end
2424
2525# For linear layer types such as Dense layers, using autodiff is overkill.
26- function lrp! (rule:: R , layer:: Dense , Rₖ , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
27- lrp_dense! (rule, layer, Rₖ , aₖ, Rₖ₊₁)
26+ function lrp! (Rₖ, rule:: R , layer:: Dense , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
27+ lrp_dense! (Rₖ, rule, layer , aₖ, Rₖ₊₁)
2828 return nothing
2929end
3030
31- function lrp_dense! (rule:: R , l, Rₖ , aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
31+ function lrp_dense! (Rₖ, rule:: R , l, aₖ, Rₖ₊₁) where {R<: AbstractLRPRule }
3232 ρW, ρb = modify_params (rule, get_params (l)... )
3333 ãₖ₊₁ = modify_denominator (rule, ρW * aₖ .+ ρb)
3434 @tullio Rₖ[j, b] = aₖ[j, b] * ρW[k, j] / ãₖ₊₁[k, b] * Rₖ₊₁[k, b]
3535 return nothing
3636end
3737
3838# Other special cases that are dispatched on layer type:
39- lrp! (:: AbstractLRPRule , :: DropoutLayer , Rₖ , aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
40- lrp! (:: AbstractLRPRule , :: ReshapingLayer , Rₖ , aₖ, Rₖ₊₁) = (Rₖ .= reshape (Rₖ₊₁, size (aₖ)))
39+ lrp! (Rₖ, :: AbstractLRPRule , :: DropoutLayer , aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
40+ lrp! (Rₖ, :: AbstractLRPRule , :: ReshapingLayer , aₖ, Rₖ₊₁) = (Rₖ .= reshape (Rₖ₊₁, size (aₖ)))
4141
4242# To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
4343# If this isn't done, the following fallbacks are used by default:
@@ -75,7 +75,7 @@ Constructor for LRP-0 rule. Commonly used on upper layers.
7575struct ZeroRule <: AbstractLRPRule end
7676
7777"""
78- GammaRule(; γ=0.25)
78+ GammaRule([ γ=0.25] )
7979
8080Constructor for LRP-``γ`` rule. Commonly used on lower layers.
8181
@@ -84,16 +84,17 @@ Arguments:
8484"""
8585struct GammaRule{T} <: AbstractLRPRule
8686 γ:: T
87- GammaRule (; γ= 0.25 ) = new {Float32} (γ)
87+ GammaRule (γ= 0.25f0 ) = new {Float32} (γ)
8888end
8989function modify_params (r:: GammaRule , W, b)
90- ρW = W + r. γ * relu .(W)
91- ρb = b + r. γ * relu .(b)
90+ T = eltype (W)
91+ ρW = W + convert (T, r. γ) * relu .(W)
92+ ρb = b + convert (T, r. γ) * relu .(b)
9293 return ρW, ρb
9394end
9495
9596"""
96- EpsilonRule(; ϵ=1f-6 )
97+ EpsilonRule([ϵ=1.0f-6] )
9798
9899Constructor for LRP-``ϵ`` rule. Commonly used on middle layers.
99100
@@ -102,7 +103,7 @@ Arguments:
102103"""
103104struct EpsilonRule{T} <: AbstractLRPRule
104105 ϵ:: T
105- EpsilonRule (; ϵ= 1.0f-6 ) = new {Float32} (ϵ)
106+ EpsilonRule (ϵ= 1.0f-6 ) = new {Float32} (ϵ)
106107end
107108modify_denominator (r:: EpsilonRule , d) = stabilize_denom (d, r. ϵ)
108109
@@ -122,8 +123,8 @@ struct ZBoxRule{T} <: AbstractLRPRule
122123end
123124
124125# The ZBoxRule requires its own implementation of relevance propagation.
125- lrp! (r:: ZBoxRule , layer:: Dense , Rₖ, aₖ, Rₖ₊₁) = lrp_zbox! (r, layer, Rₖ , aₖ, Rₖ₊₁)
126- lrp! (r:: ZBoxRule , layer:: Conv , Rₖ, aₖ, Rₖ₊₁) = lrp_zbox! (r, layer, Rₖ , aₖ, Rₖ₊₁)
126+ lrp! (Rₖ, r:: ZBoxRule , layer:: Dense , aₖ, Rₖ₊₁) = lrp_zbox! (Rₖ, r, layer , aₖ, Rₖ₊₁)
127+ lrp! (Rₖ, r:: ZBoxRule , layer:: Conv , aₖ, Rₖ₊₁) = lrp_zbox! (Rₖ, r, layer , aₖ, Rₖ₊₁)
127128
128129_zbox_bound (T, c:: Real , in_size) = fill (convert (T, c), in_size)
129130function _zbox_bound (T, A:: AbstractArray , in_size)
@@ -135,7 +136,7 @@ function _zbox_bound(T, A::AbstractArray, in_size)
135136 return convert .(T, A)
136137end
137138
138- function lrp_zbox! (r :: ZBoxRule , layer :: L , Rₖ :: T1 , aₖ:: T1 , Rₖ₊₁:: T2 ) where {L,T1,T2}
139+ function lrp_zbox! (Rₖ :: T1 , r :: ZBoxRule , layer :: L , aₖ:: T1 , Rₖ₊₁:: T2 ) where {L,T1,T2}
139140 T = eltype (aₖ)
140141 in_size = size (aₖ)
141142 l = _zbox_bound (T, r. low, in_size)
0 commit comments