1818abstract type AbstractLRPRule end
1919
2020# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
21- # It can be extended for new rules via `modify_denominator` and `modify_layer`,
22- # which in turn uses `modify_params` .
21+ # It can be extended for new rules via `modify_denominator` and `modify_params`.
22+ # Since it uses autodiff, it is used as a fallback for layer types without custom implementation .
2323function (rule:: AbstractLRPRule )(layer, aₖ, Rₖ₊₁)
24- layerᵨ = modify_layer (rule, layer)
24+ layerᵨ = _modify_layer (rule, layer)
2525 function fwpass (a)
2626 z = layerᵨ (a)
2727 s = Zygote. dropgrad (Rₖ₊₁ ./ modify_denominator (rule, z))
2828 return z ⋅ s
2929 end
30- c = gradient (fwpass, aₖ)[1 ]
31- Rₖ = aₖ .* c
32- return Rₖ
30+ return aₖ .* gradient (fwpass, aₖ)[1 ] # Rₖ
3331end
3432
3533# Special cases are dispatched on layer type:
36- (rule :: AbstractLRPRule )(:: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
37- (rule :: AbstractLRPRule )(:: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
34+ (:: AbstractLRPRule )(:: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
35+ (:: AbstractLRPRule )(:: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
3836
37+ # To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
38+ # If this isn't done, the following fallbacks are used by default:
3939"""
40- modify_layer(rule, layer)
41-
42- Applies `modify_params` to layer if it has parameters
43- """
44- modify_layer (:: AbstractLRPRule , l) = l # skip layers without params
45- function modify_layer (rule:: AbstractLRPRule , l:: Union{Dense,Conv} )
46- W, b = get_weights (l)
47- ρW, ρb = modify_params (rule, W, b)
48- return set_weights (l, ρW, ρb)
49- end
50-
51- """
52- modify_params!(rule, W, b)
40+ modify_params(rule, W, b)
5341
5442Function that modifies weights and biases before applying relevance propagation.
5543"""
5644modify_params (:: AbstractLRPRule , W, b) = (W, b) # general fallback
5745
5846"""
59- modify_denominator!(d, rule )
47+ modify_denominator(rule, d )
6048
6149Function that modifies zₖ on the forward pass, e.g. for numerical stability.
6250"""
6351modify_denominator (:: AbstractLRPRule , d) = stabilize_denom (d; eps= 1.0f-9 ) # general fallback
6452
53+ # This helper function applies `modify_params`:
54+ _modify_layer (:: AbstractLRPRule , layer) = layer # skip layers without modify_params
55+ function _modify_layer (rule:: AbstractLRPRule , layer:: Union{Dense,Conv} )
56+ return set_params (layer, modify_params (rule, get_params (layer)... )... )
57+ end
58+
6559"""
6660 ZeroRule()
6761
@@ -111,11 +105,11 @@ struct ZBoxRule <: AbstractLRPRule end
111105
112106# The ZBoxRule requires its own implementation of relevance propagation.
113107function (rule:: ZBoxRule )(layer:: Union{Dense,Conv} , aₖ, Rₖ₊₁)
114- layer, layer⁺, layer⁻ = modify_layer (rule, layer)
108+ W, b = get_params (layer)
109+ l, h = fill .(extrema (aₖ), (size (aₖ),))
115110
116- onemat = ones (eltype (aₖ), size (aₖ))
117- l = onemat * minimum (aₖ)
118- h = onemat * maximum (aₖ)
111+ layer⁺ = set_params (layer, max .(0 , W), max .(0 , b)) # W⁺, b⁺
112+ layer⁻ = set_params (layer, min .(0 , W), min .(0 , b)) # W⁻, b⁻
119113
120114 # Forward pass
121115 function fwpass (a, l, h)
@@ -128,20 +122,5 @@ function (rule::ZBoxRule)(layer::Union{Dense,Conv}, aₖ, Rₖ₊₁)
128122 return z ⋅ s
129123 end
130124 c, cₗ, cₕ = gradient (fwpass, aₖ, l, h) # w.r.t. three inputs
131-
132- # Backward pass
133- Rₖ = aₖ .* c + l .* cₗ + h .* cₕ
134- return Rₖ
135- end
136-
137- function modify_layer (:: ZBoxRule , l:: Union{Dense,Conv} )
138- W, b = get_weights (l)
139- W⁻ = min .(0 , W)
140- W⁺ = max .(0 , W)
141- b⁻ = min .(0 , b)
142- b⁺ = max .(0 , b)
143-
144- l⁺ = set_weights (l, W⁺, b⁺)
145- l⁻ = set_weights (l, W⁻, b⁻)
146- return l, l⁺, l⁻
125+ return aₖ .* c + l .* cₗ + h .* cₕ # Rₖ from backward pass
147126end
0 commit comments