Skip to content

Commit cd212d6

Browse files
authored
Document support for batches (#48)
1 parent d5b1736 commit cd212d6

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

README.md

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ ___
99

1010
Explainable AI in Julia using [Flux.jl](https://fluxml.ai).
1111

12-
This package implements interpretability methods and visualizations for neural networks, similar to [Captum](https://github.com/pytorch/captum) for PyTorch and [iNNvestigate](https://github.com/albermax/innvestigate) for Keras models.
12+
This package implements interpretability methods and visualizations for neural networks, similar to [Captum][captum-repo] and [Zennit][zennit-repo] for PyTorch and [iNNvestigate][innvestigate-repo] for Keras models.
1313

1414
## Installation
1515
To install this package and its dependencies, open the Julia REPL and run
@@ -20,22 +20,31 @@ julia> ]add ExplainableAI
2020
⚠️ This package is still in early development, expect breaking changes. ⚠️
2121

2222
## Example
23-
Let's use LRP to explain why an image of a cat gets classified as a cat:
23+
Let's use LRP to explain why an MNIST digit gets classified as a 9 using a small pre-trained LeNet5 model.
24+
If you want to follow along, the model can be found [here][model-bson-url].
2425
```julia
2526
using ExplainableAI
2627
using Flux
27-
using Metalhead
28+
using MLDatasets
29+
using BSON: @load
2830

2931
# Load model
30-
vgg = VGG19()
31-
model = strip_softmax(vgg.layers)
32+
@load "model.bson" model
33+
model = strip_softmax(model)
34+
35+
# Load input
36+
x, _ = MNIST.testdata(Float32, 10)
37+
input = reshape(x, 28, 28, 1, :) # reshape to WHCN format
3238

3339
# Run XAI method
3440
analyzer = LRP(model)
35-
expl = analyze(img, analyzer)
41+
expl = analyze(input, analyzer) # or: expl = analyzer(input)
3642

3743
# Show heatmap
3844
heatmap(expl)
45+
46+
# Or analyze & show heatmap directly
47+
heatmap(input, analyzer)
3948
```
4049
![][heatmap]
4150

@@ -69,7 +78,7 @@ Contributions are welcome!
6978
> Adrian Hill acknowledges support by the Federal Ministry of Education and Research (BMBF) for the Berlin Institute for the Foundations of Learning and Data (BIFOLD) (01IS18037A).
7079
7180
[banner-img]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/banner.png
72-
[heatmap]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmap.png
81+
[heatmap]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/mnist9.png
7382

7483
[docs-stab-img]: https://img.shields.io/badge/docs-stable-blue.svg
7584
[docs-stab-url]: https://adrhill.github.io/ExplainableAI.jl/stable
@@ -88,3 +97,9 @@ Contributions are welcome!
8897

8998
[doi-img]: https://zenodo.org/badge/337430397.svg
9099
[doi-url]: https://zenodo.org/badge/latestdoi/337430397
100+
101+
[model-bson-url]: https://github.com/adrhill/ExplainableAI.jl/blob/master/docs/src/model.bson
102+
103+
[captum-repo]: https://github.com/pytorch/captum
104+
[zennit-repo]: https://github.com/chr5tphr/zennit
105+
[innvestigate-repo]: https://github.com/albermax/innvestigate

docs/literate/example.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ input = reshape(x, 28, 28, 1, :);
4848
# ## Calling the analyzer
4949
# We can now select an analyzer of our choice
5050
# and call [`analyze`](@ref) to get an `Explanation`:
51-
analyzer = LRPZero(model)
51+
analyzer = LRP(model)
5252
expl = analyze(input, analyzer);
5353

5454
# This `Explanation` bundles the following data:
@@ -79,6 +79,26 @@ heatmap(input, analyzer, 5)
7979
#md # expl = analyze(img, analyzer, 5)
8080
#md # ```
8181

82+
# ## Input batches
83+
# ExplainableAI also supports input batches:
84+
batchsize = 100
85+
xs, _ = MNIST.testdata(Float32, 1:batchsize)
86+
batch = reshape(xs, 28, 28, 1, :) # reshape to WHCN format
87+
expl_batch = analyze(batch, analyzer);
88+
89+
# This will once again return a single `Explanation` `expl_batch` for the entire batch.
90+
# Calling `heatmap` on `expl_batch` will detect the batch dimension and return a vector of heatmaps.
91+
#
92+
# Let's check if the digit at `index = 10` still matches.
93+
hs = heatmap(expl_batch)
94+
hs[index]
95+
96+
# JuliaImages' `mosaic` can be used to return a tiled view of all heatmaps:
97+
mosaic(hs; nrow=10)
98+
99+
# We can also evaluate a batch with respect to a specific output neuron, e.g. for the digit zero at index `1`:
100+
mosaic(heatmap(batch, analyzer, 1); nrow=10)
101+
82102
# ## Automatic heatmap presets
83103
# Currently, the following analyzers are implemented:
84104
#
@@ -111,4 +131,7 @@ heatmap(expl; cs=ColorSchemes.jet)
111131
#
112132
heatmap(expl; reduce=:sum, normalize=:extrema, cs=ColorSchemes.inferno)
113133

134+
# This also works with batches
135+
mosaic(heatmap(expl_batch; normalize=:extrema, cs=ColorSchemes.inferno); nrow=10)
136+
114137
# For the full list of keyword arguments, refer to the [`heatmap`](@ref) documentation.

0 commit comments

Comments
 (0)