|
| 1 | +# # Getting started |
| 2 | +# ## Preparing the model |
| 3 | +# ExplainabilityMethods.jl can be used on any classifier. |
| 4 | +# In this tutorial we will be using a pretrained VGG-19 model from |
| 5 | +# [Metalhead.jl](https://github.com/FluxML/Metalhead.jl) |
| 6 | +using ExplainabilityMethods |
| 7 | +using Flux |
| 8 | +using Metalhead |
| 9 | +using Metalhead: weights |
| 10 | + |
| 11 | +vgg = VGG19() |
| 12 | +Flux.loadparams!(vgg, Metalhead.weights("vgg19")) |
| 13 | + |
| 14 | +#md # !!! note "Pretrained weights" |
| 15 | +#md # This doc page was generated using Metalhead `v0.6.0`. |
| 16 | +#md # At the time you read this, Metalhead might already have implemented weight loading |
| 17 | +#md # via `VGG19(; pretrain=true)`, in which case `loadparams!` is not necessary. |
| 18 | + |
| 19 | +# In case they exist, we need to strip softmax activations from the output before analyzing: |
| 20 | +model = strip_softmax(vgg.layers) |
| 21 | + |
| 22 | +# We also need to load an image |
| 23 | +using Images |
| 24 | +using TestImages |
| 25 | + |
| 26 | +img_raw = testimage("chelsea") |
| 27 | + |
| 28 | +# which we preprocess for VGG-19 |
| 29 | +include("../utils/preprocessing.jl") |
| 30 | +img = preprocess(img_raw) |
| 31 | +size(img) |
| 32 | + |
| 33 | +# ## Calling the analyzer |
| 34 | +# We can now select an analyzer of our choice |
| 35 | +# and call `analyze` to get an explaination `expl`: |
| 36 | +analyzer = LRPZero(model) |
| 37 | +expl, out = analyze(img, analyzer); |
| 38 | + |
| 39 | +#md # !!! note "Neuron selection" |
| 40 | +#md # To get an explaination with respect to a specific output neuron (e.g. class 42) call |
| 41 | +#md # ```julia |
| 42 | +#md # expl, out = analyze(img, analyzer, 42) |
| 43 | +#md # ``` |
| 44 | +# |
| 45 | +# Finally, we can visualize the explaination through heatmapping: |
| 46 | +heatmap(expl) |
| 47 | + |
| 48 | +# Currently, the following analyzers are implemented: |
| 49 | +# |
| 50 | +# ``` |
| 51 | +# ├── Gradient |
| 52 | +# ├── InputTimesGradient |
| 53 | +# └── LRP |
| 54 | +# ├── LRPZero |
| 55 | +# ├── LRPEpsilon |
| 56 | +# └── LRPGamma |
| 57 | +# ``` |
| 58 | + |
| 59 | +# ## Custom rules composites |
| 60 | +# If our model is a "flat" chain of Flux layers, we can assign LRP rules |
| 61 | +# to each layer individually. For this purpose, |
| 62 | +# ExplainabilityMethods exports the method `flatten_chain`: |
| 63 | +model = flatten_chain(model) |
| 64 | + |
| 65 | +#md # !!! note "Flattening models" |
| 66 | +#md # Not all models can be flattened, e.g. those using |
| 67 | +#md # `Parallel` and `SkipConnection` layers. |
| 68 | +# |
| 69 | +# Now we set a rule for each layer |
| 70 | +rules = [ |
| 71 | + ZBoxRule(), repeat([GammaRule()], 15)..., repeat([ZeroRule()], length(model) - 16)... |
| 72 | +] |
| 73 | +# to define a custom LRP analyzer: |
| 74 | +analyzer = LRP(model, rules) |
| 75 | +expl, out = analyze(img, analyzer) |
| 76 | +heatmap(expl) |
| 77 | + |
| 78 | +# ## Custom rules |
| 79 | +# Let's define a rule that modifies the weights and biases of our layer on the forward pass. |
| 80 | +# The rule has to be of type `AbstractLRPRule`. |
| 81 | +struct MyCustomLRPRule <: AbstractLRPRule end |
| 82 | + |
| 83 | +# It is then possible to dispatch on the utility functions `modify_layer`, `modify_params` |
| 84 | +# and `modify_denominator` with our rule type `MyCustomLRPRule` |
| 85 | +# to define custom rules without writing boilerplate code. |
| 86 | +function modify_params(::MyCustomLRPRule, W, b) |
| 87 | + ρW = W + 0.1 * relu.(W) |
| 88 | + return ρW, b |
| 89 | +end |
| 90 | + |
| 91 | +# We can directly use this rule to make an analyzer! |
| 92 | +analyzer = LRP(model, MyCustomLRPRule()) |
| 93 | +expl, out = analyze(img, analyzer) |
| 94 | +heatmap(expl) |
| 95 | + |
| 96 | +#md # !!! note "PRs welcome" |
| 97 | +#md # If you implement a rule that's not included in ExplainabilityMethods, please make a PR to |
| 98 | +#md # [`src/lrp_rules.jl`](https://github.com/adrhill/ExplainabilityMethods.jl/blob/master/src/lrp_rules.jl)! |
0 commit comments