Skip to content

Commit c74f50d

Browse files
authored
Add support for batches (#46)
1 parent 312c060 commit c74f50d

24 files changed

+295
-104
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ aₖ = randn(Float32, insize)
5555
layers = Dict(
5656
"MaxPool" => (MaxPool((3, 3); pad=0), aₖ),
5757
"Conv" => (Conv((3, 3), 3 => 2), aₖ),
58-
"Dense" => (Dense(in_dense, out_dense, relu), randn(Float32, in_dense)),
58+
"Dense" => (Dense(in_dense, out_dense, relu), randn(Float32, in_dense, 1)),
5959
"WrappedDense" =>
60-
(TestWrapper(Dense(in_dense, out_dense, relu)), randn(Float32, in_dense)),
60+
(TestWrapper(Dense(in_dense, out_dense, relu)), randn(Float32, in_dense, 1)),
6161
)
6262
rules = Dict(
6363
"ZeroRule" => ZeroRule(),

src/ExplainableAI.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module ExplainableAI
22

33
using Base.Iterators
4-
using LinearAlgebra
54
using Flux
65
using Zygote
76
using Tullio
@@ -14,10 +13,10 @@ using ColorSchemes
1413
using Markdown
1514
using PrettyTables
1615

16+
include("neuron_selection.jl")
1717
include("analyze_api.jl")
1818
include("flux.jl")
1919
include("utils.jl")
20-
include("neuron_selection.jl")
2120
include("gradient.jl")
2221
include("lrp_checks.jl")
2322
include("lrp_rules.jl")

src/analyze_api.jl

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
abstract type AbstractXAIMethod end
2-
# All analyzers are implemented such that they return an explanation and the model output:
3-
# (method::AbstractXAIMethod)(input, ns::AbstractNeuronSelector) -> (expl, output)
2+
# All analyzers are implemented such that they return an array of explanations:
3+
# (method::AbstractXAIMethod)(input, ns::AbstractNeuronSelector)::Vector{Explanation}
4+
5+
const BATCHDIM_MISSING = ArgumentError(
6+
"""The input is a 1D vector and therefore missing the required batch dimension.
7+
Call analyze with the keyword argument add_batch_dim=false."""
8+
)
49

510
"""
611
analyze(input, method)
@@ -9,29 +14,57 @@ abstract type AbstractXAIMethod end
914
Return raw classifier output and explanation.
1015
If `neuron_selection` is specified, the explanation will be calculated for that neuron.
1116
Otherwise, the output neuron with the highest activation is automatically chosen.
17+
18+
## Keyword arguments
19+
- `add_batch_dim`: add batch dimension to the input without allocating. Default is `false`.
1220
"""
1321
function analyze(
1422
input::AbstractArray{<:Real},
1523
method::AbstractXAIMethod,
16-
neuron_selection::Integer;
24+
neuron_selection::Union{Integer,Tuple{<:Integer}};
1725
kwargs...,
1826
)
19-
return method(input, IndexNS(neuron_selection); kwargs...)
27+
return _analyze(input, method, IndexSelector(neuron_selection); kwargs...)
2028
end
2129

2230
function analyze(input::AbstractArray{<:Real}, method::AbstractXAIMethod; kwargs...)
23-
return method(input, MaxActivationNS(); kwargs...)
31+
return _analyze(input, method, MaxActivationSelector(); kwargs...)
32+
end
33+
34+
function (method::AbstractXAIMethod)(
35+
input::AbstractArray{<:Real},
36+
neuron_selection::Union{Integer,Tuple{<:Integer}};
37+
kwargs...,
38+
)
39+
return _analyze(input, method, IndexSelector(neuron_selection); kwargs...)
2440
end
2541
function (method::AbstractXAIMethod)(input::AbstractArray{<:Real}; kwargs...)
26-
return method(input, MaxActivationNS(); kwargs...)
42+
return _analyze(input, method, MaxActivationSelector(); kwargs...)
43+
end
44+
45+
# lower-level call to method
46+
function _analyze(
47+
input::AbstractArray{T,N},
48+
method::AbstractXAIMethod,
49+
sel::AbstractNeuronSelector;
50+
add_batch_dim::Bool=false,
51+
kwargs...,
52+
) where {T<:Real,N}
53+
if add_batch_dim
54+
return method(batch_dim_view(input), sel; kwargs...)
55+
end
56+
N < 2 && throw(BATCHDIM_MISSING)
57+
return method(input, sel; kwargs...)
2758
end
2859

60+
# for convenience, the anaylyzer can be called directly
61+
2962
# Explanations and outputs are returned in a wrapper.
3063
# Metadata such as the analyzer allows dispatching on functions like `heatmap`.
31-
struct Explanation{A,O,L}
64+
struct Explanation{A,O,I,L}
3265
attribution::A
3366
output::O
34-
neuron_selection::Int
67+
neuron_selection::I
3568
analyzer::Symbol
3669
layerwise_relevances::L
3770
end

src/flux.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ has_output_softmax(x) = is_softmax(x)
3535
has_output_softmax(model::Chain) = has_output_softmax(model[end])
3636

3737
"""
38-
check_ouput_softmax(model)
38+
check_output_softmax(model)
3939
4040
Check whether model has softmax activation on output.
4141
Return the model if it doesn't, throw error otherwise.
4242
"""
43-
function check_ouput_softmax(model::Chain)
43+
function check_output_softmax(model::Chain)
4444
if has_output_softmax(model)
4545
throw(ArgumentError("""Model contains softmax activation on output.
4646
Call `strip_softmax` on your model first."""))

src/gradient.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1-
function gradient_wrt_input(model, input::T, output_neuron)::T where {T}
2-
return only(gradient((in) -> model(in)[output_neuron], input))
1+
function gradient_wrt_input(model, input::T, output_indices) where {T}
2+
return only(gradient((in) -> model(in)[output_indices], input))
3+
end
4+
5+
function gradients_wrt_batch(model, input::AbstractArray{T,N}, output_indices) where {T,N}
6+
# To avoid computing a sparse jacobian, we compute individual gradients
7+
# by mapping `gradient_wrt_input` on slices of the input along the batch dimension.
8+
return mapreduce(
9+
(gs...) -> cat(gs...; dims=N), zip(eachslice(input; dims=N), output_indices)
10+
) do (in, idx)
11+
gradient_wrt_input(model, batch_dim_view(in), drop_batch_dim(idx))
12+
end
313
end
414

515
"""
@@ -9,13 +19,13 @@ Analyze model by calculating the gradient of a neuron activation with respect to
919
"""
1020
struct Gradient{C<:Chain} <: AbstractXAIMethod
1121
model::C
12-
Gradient(model::Chain) = new{typeof(model)}(Flux.testmode!(check_ouput_softmax(model)))
22+
Gradient(model::Chain) = new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
1323
end
1424
function (analyzer::Gradient)(input, ns::AbstractNeuronSelector)
1525
output = analyzer.model(input)
16-
output_neuron = ns(output)
17-
grad = gradient_wrt_input(analyzer.model, input, output_neuron)
18-
return Explanation(grad, output, output_neuron, :Gradient, Nothing)
26+
output_indices = ns(output)
27+
grad = gradients_wrt_batch(analyzer.model, input, output_indices)
28+
return Explanation(grad, output, output_indices, :Gradient, Nothing)
1929
end
2030

2131
"""
@@ -27,12 +37,12 @@ This gradient is then multiplied element-wise with the input.
2737
struct InputTimesGradient{C<:Chain} <: AbstractXAIMethod
2838
model::C
2939
function InputTimesGradient(model::Chain)
30-
return new{typeof(model)}(Flux.testmode!(check_ouput_softmax(model)))
40+
return new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
3141
end
3242
end
3343
function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
3444
output = analyzer.model(input)
35-
output_neuron = ns(output)
36-
attr = input .* gradient_wrt_input(analyzer.model, input, output_neuron)
37-
return Explanation(attr, output, output_neuron, :InputTimesGradient, Nothing)
45+
output_indices = ns(output)
46+
attr = input .* gradients_wrt_batch(analyzer.model, input, output_indices)
47+
return Explanation(attr, output, output_indices, :InputTimesGradient, Nothing)
3848
end

src/heatmap.jl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
const HEATMAPPING_PRESETS = Dict{Symbol,Tuple{ColorScheme,Symbol,Symbol}}(
44
# Analyzer => (colorscheme, reduce, normalize)
5-
:LRP => (ColorSchemes.bwr, :sum, :centered),
6-
:InputTimesGradient => (ColorSchemes.bwr, :sum, :centered), # same as LRP
7-
:Gradient => (ColorSchemes.grays, :norm, :extrema),
5+
:LRP => (ColorSchemes.bwr, :sum, :centered), # attribution
6+
:InputTimesGradient => (ColorSchemes.bwr, :sum, :centered), # attribution
7+
:Gradient => (ColorSchemes.grays, :norm, :extrema), # gradient
88
)
99

1010
"""
@@ -34,35 +34,32 @@ Assumes Flux's WHCN convention (width, height, color channels, batch size).
3434
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
3535
When calling `heatmap` with an array, the default for use with the `bwr` colorscheme is `:centered`.
3636
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
37+
- `unpack_singleton::Bool`: When heatmapping a batch with a single sample, setting `unpack_singleton=true`
38+
will return an image instead of an Vector containing a single image.
3739
3840
**Note:** these keyword arguments can't be used when calling `heatmap` with an analyzer.
3941
"""
4042
function heatmap(
41-
attr::AbstractArray;
43+
attr::AbstractArray{T,N};
4244
cs::ColorScheme=ColorSchemes.bwr,
4345
reduce::Symbol=:sum,
4446
normalize::Symbol=:centered,
4547
permute::Bool=true,
46-
)
47-
_size = size(attr)
48-
length(_size) != 4 && throw(
48+
unpack_singleton::Bool=true,
49+
) where {T,N}
50+
N != 4 && throw(
4951
DomainError(
50-
_size,
52+
N,
5153
"""heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
5254
Please reshape your attribution to match this format if your model doesn't adhere to this convention.""",
5355
),
5456
)
55-
_size[end] != 1 && throw(
56-
DomainError(
57-
_size[end],
58-
"""heatmap is only applicable to a single attribution, got a batch dimension of $(_size[end]).""",
59-
),
60-
)
61-
62-
img = _normalize(dropdims(_reduce(dropdims(attr; dims=4), reduce); dims=3), normalize)
63-
permute && (img = permutedims(img))
64-
return ColorSchemes.get(cs, img)
57+
if unpack_singleton && size(attr, 4) == 1
58+
return _heatmap(attr[:, :, :, 1], cs, reduce, normalize, permute)
59+
end
60+
return map(a -> _heatmap(a, cs, reduce, normalize, permute), eachslice(attr; dims=4))
6561
end
62+
6663
# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
6764
function heatmap(expl::Explanation; permute::Bool=true, kwargs...)
6865
_cs, _reduce, _normalize = HEATMAPPING_PRESETS[expl.analyzer]
@@ -79,6 +76,18 @@ function heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
7976
return heatmap(analyze(input, analyzer, args...; kwargs...))
8077
end
8178

79+
# Lower level function that is mapped along batch dimension
80+
function _heatmap(
81+
attr::AbstractArray{T,3},
82+
cs::ColorScheme,
83+
reduce::Symbol,
84+
normalize::Symbol,
85+
permute::Bool,
86+
) where {T<:Real}
87+
img = _normalize(dropdims(_reduce(attr, reduce); dims=3), normalize)
88+
permute && (img = permutedims(img))
89+
return ColorSchemes.get(cs, img)
90+
end
8291

8392
# Normalize activations across pixels
8493
function _normalize(attr, method::Symbol)
@@ -97,15 +106,15 @@ function _normalize(attr, method::Symbol)
97106
end
98107

99108
# Reduce attributions across color channels into a single scalar – assumes WHCN convention
100-
function _reduce(attr::T, method::Symbol) where {T}
101-
if size(attr, 3) == 1 # nothing need to reduce
109+
function _reduce(attr::AbstractArray{T,3}, method::Symbol) where {T}
110+
if size(attr, 3) == 1 # nothing to reduce
102111
return attr
112+
elseif method == :sum
113+
return reduce(+, attr; dims=3)
103114
elseif method == :maxabs
104-
return maximum(abs, attr; dims=3)
115+
return reduce((c...) -> maximum(abs.(c)), attr; dims=3, init=zero(T))
105116
elseif method == :norm
106-
return mapslices(norm, attr; dims=3)::T
107-
elseif method == :sum
108-
return sum(attr; dims=3)
117+
return reduce((c...) -> sqrt(sum(c .^ 2)), attr; dims=3, init=zero(T))
109118
end
110119
throw(
111120
ArgumentError(

src/lrp.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct LRP{R<:AbstractVector{<:AbstractLRPRule}} <: AbstractXAIMethod
2525
)
2626
model = flatten_model(model)
2727
if !skip_checks
28-
check_ouput_softmax(model)
28+
check_output_softmax(model)
2929
check_model(Val(:LRP), model; verbose=verbose)
3030
end
3131
if length(model.layers) != length(rules)
@@ -59,9 +59,9 @@ function (analyzer::LRP)(
5959
rels = similar.(acts)
6060

6161
# Mask output neuron
62-
output_neuron = ns(acts[end])
62+
output_indices = ns(acts[end])
6363
rels[end] .= zero(T)
64-
rels[end][output_neuron] = acts[end][output_neuron]
64+
rels[end][output_indices] = acts[end][output_indices]
6565

6666
# Backward pass through layers, applying LRP rules
6767
for (i, rule) in Iterators.reverse(enumerate(analyzer.rules))
@@ -71,7 +71,7 @@ function (analyzer::LRP)(
7171
return Explanation(
7272
first(rels),
7373
last(acts),
74-
output_neuron,
74+
output_indices,
7575
:LRP,
7676
ifelse(layerwise_relevances, rels, Nothing),
7777
)

src/lrp_checks.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,4 @@ function check_model(::Val{:LRP}, c::Chain; verbose=true)
114114
)
115115
throw(ArgumentError("Unknown or unsupported activation functions found in model"))
116116
end
117-
return false
118117
end

src/lrp_rules.jl

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,7 @@
1-
# Generic implementation of LRP according to [1, 2].
2-
# LRP-rules are implemented as structs of type `AbstractLRPRule`.
3-
# Through the magic of multiple dispatch, rule modifications such as LRP-γ and -ϵ
4-
# can be implemented by dispatching on the functions `modify_params` & `modify_denominator`,
5-
# which make use of the generalized LRP implementation shown in [1].
6-
#
7-
# If the relevance propagation falls outside of this scheme, custom low-level functions
8-
# ```julia
9-
# lrp!(::MyLRPRule, layer, Rₖ, aₖ, Rₖ₊₁) = ...
10-
# lrp!(::MyLRPRule, layer::MyLayer, Rₖ, aₖ, Rₖ₊₁) = ...
11-
# lrp!(::AbstractLRPRule, layer::MyLayer, Rₖ, aₖ, Rₖ₊₁) = ...
12-
# ```
13-
# that inplace-update `Rₖ` can be implemented.
14-
# This is used for the ZBoxRule and for faster computations on common layers.
15-
#
16-
# References:
17-
# [1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
18-
# [2] W. Samek et al., Explaining Deep Neural Networks and Beyond: A Review of Methods and Applications
19-
1+
# https://adrhill.github.io/ExplainableAI.jl/stable/generated/advanced_lrp/#How-it-works-internally
202
abstract type AbstractLRPRule end
213

22-
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
23-
# It can be extended for new rules via `modify_denominator` and `modify_params`.
24-
# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
4+
# Generic LRP rule. Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
255
function lrp!(rule::R, layer::L, Rₖ, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule,L}
266
lrp_autodiff!(rule, layer, Rₖ, aₖ, Rₖ₊₁)
277
return nothing
@@ -50,8 +30,8 @@ end
5030

5131
function lrp_dense!(rule::R, l, Rₖ, aₖ, Rₖ₊₁) where {R<:AbstractLRPRule}
5232
ρW, ρb = modify_params(rule, get_params(l)...)
53-
ãₖ₊₁ = modify_denominator(rule, ρW * aₖ + ρb)
54-
@tullio Rₖ[j] = aₖ[j] * ρW[k, j] / ãₖ₊₁[k] * Rₖ₊₁[k]
33+
ãₖ₊₁ = modify_denominator(rule, ρW * aₖ .+ ρb)
34+
@tullio Rₖ[j, b] = aₖ[j, b] * ρW[k, j] / ãₖ₊₁[k, b] * Rₖ₊₁[k, b]
5535
return nothing
5636
end
5737

src/neuron_selection.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
abstract type AbstractNeuronSelector end
2-
(ns::AbstractNeuronSelector)(output::AbstractArray) = ns(drop_singleton_dims(output))
32

43
"""
5-
MaxActivationNS()
4+
MaxActivationSelector()
65
76
Neuron selector that picks the output neuron with the highest activation.
87
"""
9-
struct MaxActivationNS <: AbstractNeuronSelector end
10-
(::MaxActivationNS)(output::AbstractVector) = argmax(output)
8+
struct MaxActivationSelector <: AbstractNeuronSelector end
9+
function (::MaxActivationSelector)(out::AbstractArray{T,N}) where {T,N}
10+
N < 2 && throw(BATCHDIM_MISSING)
11+
return Vector{CartesianIndex{N}}([argmax(out; dims=1:(N - 1))...])
12+
end
1113

1214
"""
13-
IndexNS(index)
15+
IndexSelector(index)
1416
1517
Neuron selector that picks the output neuron at the given index.
1618
"""
17-
struct IndexNS{T} <: AbstractNeuronSelector
18-
index::T
19-
IndexNS(index::Integer) = new{typeof(index)}(index)
19+
struct IndexSelector{I} <: AbstractNeuronSelector
20+
index::I
21+
end
22+
function (s::IndexSelector{<:Integer})(out::AbstractArray{T,N}) where {T,N}
23+
N < 2 && throw(BATCHDIM_MISSING)
24+
return CartesianIndex{N}.(s.index, 1:size(out, N))
25+
end
26+
function (s::IndexSelector{I})(out::AbstractArray{T,N}) where {I,T,N}
27+
N < 2 && throw(BATCHDIM_MISSING)
28+
return CartesianIndex{N}.(s.index..., 1:size(out, N))
2029
end
21-
(ns::IndexNS)(output::AbstractVector) = ns.index

0 commit comments

Comments
 (0)