44# can be implemented by dispatching on the functions `modify_params` & `modify_denominator`,
55# which make use of the generalized LRP implementation shown in [1].
66#
7- # If the relevance propagation falls outside of this scheme, a custom function
7+ # If the relevance propagation falls outside of this scheme, custom functions
88# ```julia
99# (::MyLRPRule)(layer, aₖ, Rₖ₊₁) = ...
10+ # (::MyLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) = ...
11+ # (::AbstractLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) = ...
1012# ```
11- # can be implemented. This is used for the ZBoxRule.
13+ # that return `Rₖ` can be implemented.
14+ # This is used for the ZBoxRule and for faster computations on common layers.
1215#
1316# References:
1417# [1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
15- # [2] W. Samek et al., Explaining Deep Neural Networks and Beyond:
16- # A Review of Methods and Applications
18+ # [2] W. Samek et al., Explaining Deep Neural Networks and Beyond: A Review of Methods and Applications
1719
1820abstract type AbstractLRPRule end
1921
2022# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
2123# It can be extended for new rules via `modify_denominator` and `modify_params`.
2224# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
23- function (rule:: AbstractLRPRule )(layer, aₖ, Rₖ₊₁)
25+ (rule:: AbstractLRPRule )(layer, aₖ, Rₖ₊₁) = lrp_autodiff (rule, layer, aₖ, Rₖ₊₁)
26+
27+ function lrp_autodiff (rule, layer, aₖ, Rₖ₊₁)
2428 layerᵨ = _modify_layer (rule, layer)
2529 function fwpass (a)
2630 z = layerᵨ (a)
@@ -30,7 +34,16 @@ function (rule::AbstractLRPRule)(layer, aₖ, Rₖ₊₁)
3034 return aₖ .* gradient (fwpass, aₖ)[1 ] # Rₖ
3135end
3236
33- # Special cases are dispatched on layer type:
37+ # For linear layer types such as Dense layers, using autodiff is overkill.
38+ (rule:: AbstractLRPRule )(layer:: Dense , aₖ, Rₖ₊₁) = lrp_dense (rule, layer, aₖ, Rₖ₊₁)
39+
40+ function lrp_dense (rule, l, aₖ, Rₖ₊₁)
41+ ρW, ρb = modify_params (rule, get_params (l)... )
42+ ãₖ₊₁ = modify_denominator (rule, ρW * aₖ + ρb)
43+ return @tullio Rₖ[j] := aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
44+ end
45+
46+ # Other special cases that are dispatched on layer type:
3447(:: AbstractLRPRule )(:: DropoutLayer , aₖ, Rₖ₊₁) = Rₖ₊₁
3548(:: AbstractLRPRule )(:: ReshapingLayer , aₖ, Rₖ₊₁) = reshape (Rₖ₊₁, size (aₖ))
3649
@@ -104,7 +117,10 @@ Commonly used on the first layer for pixel input.
104117struct ZBoxRule <: AbstractLRPRule end
105118
106119# The ZBoxRule requires its own implementation of relevance propagation.
107- function (rule:: ZBoxRule )(layer:: Union{Dense,Conv} , aₖ, Rₖ₊₁)
120+ (rule:: ZBoxRule )(layer:: Dense , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
121+ (rule:: ZBoxRule )(layer:: Conv , aₖ, Rₖ₊₁) = lrp_zbox (layer, aₖ, Rₖ₊₁)
122+
123+ function lrp_zbox (layer, aₖ, Rₖ₊₁)
108124 W, b = get_params (layer)
109125 l, h = fill .(extrema (aₖ), (size (aₖ),))
110126
0 commit comments