Skip to content

Commit 11ea2b7

Browse files
authored
Add LRP model checks and support unknown layers via AD fallback (#26)
* Add LRP model checks * Add registration mechanism to support custom layers * Deprecate `flatten_chain` in favor of `flatten_model` * Add more rules tests
1 parent 614bdb7 commit 11ea2b7

File tree

25 files changed

+356
-50
lines changed

25 files changed

+356
-50
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,25 @@ ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
88
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
99
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
12+
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1113
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1214
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1315

1416
[compat]
1517
ColorSchemes = "3"
1618
Flux = "0.12"
1719
ImageCore = "0.8, 0.9"
18-
JLD2 = "0.4"
20+
PrettyTables = "1"
1921
Zygote = "0.6"
2022
julia = "1.6"
2123

2224
[extras]
2325
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
2426
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2527
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
28+
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
2629
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2730

2831
[targets]
29-
test = ["JLD2", "Random", "ReferenceTests", "Test"]
32+
test = ["JLD2", "Random", "ReferenceTests", "Suppressor", "Test"]

benchmark/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
ExplainabilityMethods = "cd722a4f-8d55-446b-8550-a4aabc9151ab"
34
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
45
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"

benchmark/benchmarks.jl

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
using ExplainabilityMethods
2-
using ExplainabilityMethods: ANALYZERS
1+
using BenchmarkTools
32
using Flux
3+
using ExplainabilityMethods
44

55
on_CI = haskey(ENV, "GITHUB_ACTIONS")
66

77
include("../test/vgg19.jl")
88
vgg19 = VGG19(; pretrain=false)
9-
model = flatten_chain(strip_softmax(vgg19.layers))
9+
model = flatten_model(strip_softmax(vgg19.layers))
1010
img = rand(MersenneTwister(123), Float32, (224, 224, 3, 1))
1111

1212
# Benchmark custom LRP composite
@@ -24,10 +24,49 @@ algs = Dict(
2424

2525
# Define benchmark
2626
SUITE = BenchmarkGroup()
27+
SUITE["VGG"] = BenchmarkGroup([k for k in keys(algs)])
2728
for (name, alg) in algs
28-
SUITE[name] = BenchmarkGroup(["construct analyzer", "analyze"])
29-
SUITE[name]["construct analyzer"] = @benchmarkable alg($(model))
29+
SUITE["VGG"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
30+
SUITE["VGG"][name]["construct analyzer"] = @benchmarkable alg($(model))
3031

3132
analyzer = alg(model)
32-
SUITE[name]["analyze"] = @benchmarkable analyze($(img), $(analyzer))
33+
SUITE["VGG"][name]["analyze"] = @benchmarkable analyze($(img), $(analyzer))
34+
end
35+
36+
# Rules benchmarks – use wrapper to trigger AD fallback
37+
struct TestWrapper{T}
38+
layer::T
39+
end
40+
(l::TestWrapper)(x) = l.layer(x)
41+
42+
# generate input for conv layers
43+
insize = (128, 128, 3, 1)
44+
aₖ = randn(Float32, insize)
45+
46+
layers = Dict(
47+
"MaxPool" => (MaxPool((3, 3); pad=0), aₖ),
48+
"MeanPool" => (MeanPool((3, 3); pad=0), aₖ),
49+
"Conv" => (Conv((3, 3), 3 => 6), aₖ),
50+
"flatten" => (flatten, aₖ),
51+
"Dense" => (Dense(1000, 200, relu), randn(Float32, 1000)),
52+
)
53+
rules = Dict(
54+
"ZeroRule" => ZeroRule(),
55+
"EpsilonRule" => EpsilonRule(),
56+
"GammaRule" => GammaRule(),
57+
"ZBoxRule" => ZBoxRule(),
58+
)
59+
rulenames = [k for k in keys(rules)]
60+
61+
for (layername, (layer, aₖ)) in layers
62+
SUITE[layername] = BenchmarkGroup(rulenames)
63+
64+
for (rulename, ruletype) in rules
65+
Rₖ₊₁ = layer(aₖ)
66+
SUITE[layername][rulename] = BenchmarkGroup(["dispatch", "AD fallback"])
67+
SUITE[layername][rulename]["dispatch"] = @benchmarkable rule($layer, $aₖ, $Rₖ₊₁)
68+
SUITE[layername][rulename]["AD fallback"] = @benchmarkable rule(
69+
$TestWrapper(layer), $aₖ, $Rₖ₊₁
70+
)
71+
end
3372
end

docs/literate/example.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ heatmap(expl)
6666
# ## Custom composites
6767
# If our model is a "flat" chain of Flux layers, we can assign LRP rules
6868
# to each layer individually. For this purpose,
69-
# ExplainabilityMethods exports the method [`flatten_chain`](@ref):
70-
model = flatten_chain(model)
69+
# ExplainabilityMethods exports the method [`flatten_model`](@ref):
70+
model = flatten_model(model)
7171

7272
#md # !!! warning "Flattening models"
7373
#md # Not all models can be flattened, e.g. those using

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ modify_denominator
3333
# Utilities
3434
```@docs
3535
strip_softmax
36-
flatten_chain
36+
flatten_model
3737
```
3838

3939
# Index

src/ExplainabilityMethods.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@ using ColorSchemes
77
using ImageCore
88
using Base.Iterators
99

10+
using Markdown
11+
using PrettyTables
12+
1013
include("analyze_api.jl")
1114
include("flux.jl")
1215
include("utils.jl")
1316
include("neuron_selection.jl")
1417
include("gradient.jl")
18+
include("lrp_checks.jl")
1519
include("lrp_rules.jl")
1620
include("lrp.jl")
1721
include("heatmap.jl")
@@ -25,13 +29,15 @@ export LRP, LRPZero, LRPEpsilon, LRPGamma
2529

2630
# LRP rules
2731
export AbstractLRPRule
32+
export LRP_CONFIG
2833
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule
2934
export modify_layer, modify_params, modify_denominator
35+
export check_model
3036

3137
# heatmapping
3238
export heatmap
3339

3440
# utils
35-
export strip_softmax, flatten_chain
41+
export strip_softmax, flatten_model, flatten_chain
3642

3743
end # module

src/flux.jl

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,39 @@
11
## Group layers by type:
2-
const ConvLayers = Union{Conv,DepthwiseConv,ConvTranspose,CrossCor}
3-
const DropoutLayers = Union{Dropout,typeof(Flux.dropout),AlphaDropout}
4-
const ReshapingLayers = Union{typeof(Flux.flatten)}
2+
const ConvLayer = Union{Conv,DepthwiseConv,ConvTranspose,CrossCor}
3+
const DropoutLayer = Union{Dropout,typeof(Flux.dropout),AlphaDropout}
4+
const ReshapingLayer = Union{typeof(Flux.flatten)}
55
# Pooling layers
6-
const MaxPoolLayers = Union{MaxPool,AdaptiveMaxPool,GlobalMaxPool}
7-
const MeanPoolLayers = Union{MeanPool,AdaptiveMeanPool,GlobalMeanPool}
8-
const PoolingLayers = Union{MaxPoolLayers,MeanPoolLayers}
6+
const MaxPoolLayer = Union{MaxPool,AdaptiveMaxPool,GlobalMaxPool}
7+
const MeanPoolLayer = Union{MeanPool,AdaptiveMeanPool,GlobalMeanPool}
8+
const PoolingLayer = Union{MaxPoolLayer,MeanPoolLayer}
9+
# Activation functions that are similar to ReLU
10+
const ReluLikeActivation = Union{
11+
typeof(relu),typeof(gelu),typeof(swish),typeof(softplus),typeof(mish)
12+
}
13+
# Layers & activation functions supported by LRP
14+
const LRPSupportedLayer = Union{Dense,ConvLayer,DropoutLayer,ReshapingLayer,PoolingLayer}
15+
const LRPSupportedActivation = Union{typeof(identity),ReluLikeActivation}
916

10-
_flatten_chain(x) = x
11-
_flatten_chain(c::Chain) = [c.layers...]
17+
_flatten_model(x) = x
18+
_flatten_model(c::Chain) = [c.layers...]
1219
"""
13-
flatten_chain(c)
20+
flatten_model(c)
1421
1522
Flatten a Flux chain containing Flux chains.
1623
"""
17-
function flatten_chain(chain::Chain)
24+
function flatten_model(chain::Chain)
1825
if any(isa.(chain.layers, Chain))
19-
flatchain = Chain(vcat(_flatten_chain.(chain.layers)...)...)
20-
return flatten_chain(flatchain)
26+
flatchain = Chain(vcat(_flatten_model.(chain.layers)...)...)
27+
return flatten_model(flatchain)
2128
end
2229
return chain
2330
end
31+
@deprecate flatten_chain(c) flatten_model(c)
2432

2533
is_softmax(layer) = layer isa Union{typeof(softmax),typeof(softmax!)}
2634
has_output_softmax(x) = is_softmax(x)
2735
has_output_softmax(model::Chain) = has_output_softmax(model[end])
36+
2837
"""
2938
check_ouput_softmax(model)
3039
@@ -46,7 +55,7 @@ Remove softmax activation on model output if it exists.
4655
"""
4756
function strip_softmax(model::Chain)
4857
if has_output_softmax(model)
49-
model = flatten_chain(model)
58+
model = flatten_model(model)
5059
return Chain(model.layers[1:(end - 1)]...)
5160
end
5261
return model

src/lrp.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""
22
LRP(c::Chain, r::AbstractLRPRule)
33
LRP(c::Chain, rs::AbstractVector{<:AbstractLRPRule})
4-
LRP(layers::AbstractVector{LRPLayer})
54
65
Analyze model by applying Layer-Wise Relevance Propagation.
76
7+
# Keyword arguments
8+
- `skip_checks::Bool`: Skip checks whether model is compatible with LRP and contains output softmax. Default is `false`.
9+
- `verbose::Bool`: Select whether the model checks should print a summary on failure. Default is `true`.
10+
811
# References
912
[1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
1013
[2] W. Samek et al., Explaining Deep Neural Networks and Beyond: A Review of Methods and Applications
@@ -14,26 +17,34 @@ struct LRP{R<:AbstractVector{<:AbstractLRPRule}} <: AbstractXAIMethod
1417
rules::R
1518

1619
# Construct LRP analyzer by manually assigning a rule to each layer
17-
function LRP(model::Chain, rules::AbstractVector{<:AbstractLRPRule})
18-
check_ouput_softmax(model)
19-
model = flatten_chain(model)
20+
function LRP(
21+
model::Chain,
22+
rules::AbstractVector{<:AbstractLRPRule};
23+
skip_checks=false,
24+
verbose=true,
25+
)
26+
model = flatten_model(model)
27+
if !skip_checks
28+
check_ouput_softmax(model)
29+
check_model(Val(:LRP), model; verbose=verbose)
30+
end
2031
if length(model.layers) != length(rules)
2132
throw(ArgumentError("Length of rules doesn't match length of Flux chain."))
2233
end
2334
return new{typeof(rules)}(model, rules)
2435
end
25-
# Construct LRP analyzer by assigning a single rule to all layers
26-
function LRP(model::Chain, r::AbstractLRPRule)
27-
check_ouput_softmax(model)
28-
model = flatten_chain(model)
29-
rules = repeat([r], length(model.layers))
30-
return new{typeof(rules)}(model, rules)
31-
end
36+
end
37+
38+
# Construct LRP analyzer by assigning a single rule to all layers
39+
function LRP(model::Chain, r::AbstractLRPRule; kwargs...)
40+
model = flatten_model(model)
41+
rules = repeat([r], length(model.layers))
42+
return LRP(model, rules; kwargs...)
3243
end
3344
# Additional constructors for convenience:
34-
LRPZero(model::Chain) = LRP(model, ZeroRule())
35-
LRPEpsilon(model::Chain) = LRP(model, EpsilonRule())
36-
LRPGamma(model::Chain) = LRP(model, GammaRule())
45+
LRPZero(model::Chain; kwargs...) = LRP(model, ZeroRule(); kwargs...)
46+
LRPEpsilon(model::Chain; kwargs...) = LRP(model, EpsilonRule(); kwargs...)
47+
LRPGamma(model::Chain; kwargs...) = LRP(model, GammaRule(); kwargs...)
3748

3849
# The call to the LRP analyzer.
3950
function (analyzer::LRP)(input, ns::AbstractNeuronSelector; layerwise_relevances=false)

src/lrp_checks.jl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
module LRP_CONFIG
2+
using ExplainabilityMethods
3+
using ExplainabilityMethods: LRPSupportedLayer, LRPSupportedActivation
4+
"""
5+
LRP_CONFIG.supports_layer(layer)
6+
7+
Check whether LRP can be used on a layer or a Chain.
8+
To extend LRP to your own layers, define:
9+
```julia
10+
LRP_CONFIG.supports_layer(::MyLayer) = true
11+
```
12+
"""
13+
supports_layer(l) = false
14+
supports_layer(::LRPSupportedLayer) = true
15+
"""
16+
LRP_CONFIG.supports_activation(σ)
17+
18+
Check whether LRP can be used on a given activation function.
19+
To extend LRP to your own activation functions, define:
20+
```julia
21+
LRP_CONFIG.supports_activation(::MyActivation) = true
22+
```
23+
"""
24+
supports_activation(σ) = false
25+
supports_activation(::LRPSupportedActivation) = true
26+
end # LRP_CONFIG module
27+
28+
_check_layer(::Val{:LRP}, layer) = LRP_CONFIG.supports_layer(layer)
29+
_check_layer(::Val{:LRP}, c::Chain) = all(_check_layer(Val(:LRP), l) for l in c)
30+
31+
function _check_activation(::Val{:LRP}, layer)
32+
hasproperty(layer, ) && return LRP_CONFIG.supports_activation(layer.σ)
33+
return true
34+
end
35+
_check_activation(::Val{:LRP}, c::Chain) = all(_check_activation(Val(:LRP), l) for l in c)
36+
37+
"""
38+
check_model(method::Symbol, model; verbose=true)
39+
40+
Check whether the given method can be used on the model.
41+
Currently, model checks are only implemented for LRP, using the symbol `:LRP`.
42+
43+
# Example
44+
julia> check_model(:LRP, model)
45+
"""
46+
check_model(method::Symbol, model; kwargs...) = check_model(Val(method), model; kwargs...)
47+
function check_model(::Val{:LRP}, c::Chain; verbose=true)
48+
layer_checks = collect(_check_layer.(Val(:LRP), c.layers))
49+
activation_checks = collect(_check_activation.(Val(:LRP), c.layers))
50+
passed_layer_checks = all(layer_checks)
51+
passed_activation_checks = all(activation_checks)
52+
53+
passed_layer_checks && passed_activation_checks && return true
54+
55+
layer_names = [_print_name(l) for l in c]
56+
activation_names = [_print_activation(l) for l in c]
57+
58+
verbose && _show_check_summary(
59+
c, layer_names, layer_checks, activation_names, activation_checks
60+
)
61+
if !passed_layer_checks
62+
verbose && display(
63+
Markdown.parse(
64+
"""# Layers failed model check
65+
Found unknown layers `$(join(unique(layer_names[.!layer_checks]), ", "))`
66+
that are not supported by ExplainabilityMethods' LRP implementation yet.
67+
68+
If you think the missing layer should be supported by default, please [submit an issue](https://github.com/adrhill/ExplainabilityMethods.jl/issues).
69+
70+
These model checks can be skipped at your own risk by setting the LRP-analyzer keyword argument `skip_checks=true`.
71+
72+
## Using custom layers
73+
If you implemented custom layers, register them via
74+
```julia
75+
LRP_CONFIG.supports_layer(::MyLayer) = true # for structs
76+
LRP_CONFIG.supports_activation(::typeof(mylayer)) = true # for functions
77+
```
78+
The default fallback for this layer will use Automatic Differentiation according to "Layer-Wise Relevance Propagation: An Overview".
79+
You can also define a fully LRP-custom rule for your layer by using the interface
80+
```julia
81+
function (rule::AbstractLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁)
82+
# ...
83+
return Rₖ
84+
end
85+
```
86+
This pattern can also be used to dispatch on specific rules.
87+
""",
88+
),
89+
)
90+
throw(ArgumentError("Unknown layers found in model"))
91+
end
92+
if !passed_activation_checks
93+
verbose && display(
94+
Markdown.parse(
95+
""" # Activations failed model check
96+
Found layers with unknown or unsupported activation functions
97+
`$(join(unique(activation_names[.!activation_checks]), ", "))`.
98+
LRP assumes that the model is a "deep rectifier network" that only contains ReLU-like activation functions.
99+
100+
If you think the missing activation function should be supported by default, please [submit an issue](https://github.com/adrhill/ExplainabilityMethods.jl/issues).
101+
102+
These model checks can be skipped at your own risk by setting the LRP-analyzer keyword argument `skip_checks=true`.
103+
104+
## Using custom activation functions
105+
If you use custom ReLU-like activation functions, register them via
106+
```julia
107+
LRP_CONFIG.supports_activation(::typeof(myfunction)) = true # for functions
108+
LRP_CONFIG.supports_activation(::MyActivation) = true # for structs
109+
```
110+
""",
111+
),
112+
)
113+
throw(ArgumentError("Unknown or unsupported activation functions found in model"))
114+
end
115+
return false
116+
end

0 commit comments

Comments
 (0)