|
| 1 | +module LRP_CONFIG |
| 2 | +using ExplainabilityMethods |
| 3 | +using ExplainabilityMethods: LRPSupportedLayer, LRPSupportedActivation |
| 4 | +""" |
| 5 | + LRP_CONFIG.supports_layer(layer) |
| 6 | +
|
| 7 | +Check whether LRP can be used on a layer or a Chain. |
| 8 | +To extend LRP to your own layers, define: |
| 9 | +```julia |
| 10 | +LRP_CONFIG.supports_layer(::MyLayer) = true |
| 11 | +``` |
| 12 | +""" |
| 13 | +supports_layer(l) = false |
| 14 | +supports_layer(::LRPSupportedLayer) = true |
| 15 | +""" |
| 16 | + LRP_CONFIG.supports_activation(σ) |
| 17 | +
|
| 18 | +Check whether LRP can be used on a given activation function. |
| 19 | +To extend LRP to your own activation functions, define: |
| 20 | +```julia |
| 21 | +LRP_CONFIG.supports_activation(::MyActivation) = true |
| 22 | +``` |
| 23 | +""" |
| 24 | +supports_activation(σ) = false |
| 25 | +supports_activation(::LRPSupportedActivation) = true |
| 26 | +end # LRP_CONFIG module |
| 27 | + |
| 28 | +_check_layer(::Val{:LRP}, layer) = LRP_CONFIG.supports_layer(layer) |
| 29 | +_check_layer(::Val{:LRP}, c::Chain) = all(_check_layer(Val(:LRP), l) for l in c) |
| 30 | + |
| 31 | +function _check_activation(::Val{:LRP}, layer) |
| 32 | + hasproperty(layer, :σ) && return LRP_CONFIG.supports_activation(layer.σ) |
| 33 | + return true |
| 34 | +end |
| 35 | +_check_activation(::Val{:LRP}, c::Chain) = all(_check_activation(Val(:LRP), l) for l in c) |
| 36 | + |
| 37 | +""" |
| 38 | + check_model(method::Symbol, model; verbose=true) |
| 39 | +
|
| 40 | +Check whether the given method can be used on the model. |
| 41 | +Currently, model checks are only implemented for LRP, using the symbol `:LRP`. |
| 42 | +
|
| 43 | +# Example |
| 44 | +julia> check_model(:LRP, model) |
| 45 | +""" |
| 46 | +check_model(method::Symbol, model; kwargs...) = check_model(Val(method), model; kwargs...) |
| 47 | +function check_model(::Val{:LRP}, c::Chain; verbose=true) |
| 48 | + layer_checks = collect(_check_layer.(Val(:LRP), c.layers)) |
| 49 | + activation_checks = collect(_check_activation.(Val(:LRP), c.layers)) |
| 50 | + passed_layer_checks = all(layer_checks) |
| 51 | + passed_activation_checks = all(activation_checks) |
| 52 | + |
| 53 | + passed_layer_checks && passed_activation_checks && return true |
| 54 | + |
| 55 | + layer_names = [_print_name(l) for l in c] |
| 56 | + activation_names = [_print_activation(l) for l in c] |
| 57 | + |
| 58 | + verbose && _show_check_summary( |
| 59 | + c, layer_names, layer_checks, activation_names, activation_checks |
| 60 | + ) |
| 61 | + if !passed_layer_checks |
| 62 | + verbose && display( |
| 63 | + Markdown.parse( |
| 64 | + """# Layers failed model check |
| 65 | + Found unknown layers `$(join(unique(layer_names[.!layer_checks]), ", "))` |
| 66 | + that are not supported by ExplainabilityMethods' LRP implementation yet. |
| 67 | +
|
| 68 | + If you think the missing layer should be supported by default, please [submit an issue](https://github.com/adrhill/ExplainabilityMethods.jl/issues). |
| 69 | +
|
| 70 | + These model checks can be skipped at your own risk by setting the LRP-analyzer keyword argument `skip_checks=true`. |
| 71 | +
|
| 72 | + ## Using custom layers |
| 73 | + If you implemented custom layers, register them via |
| 74 | + ```julia |
| 75 | + LRP_CONFIG.supports_layer(::MyLayer) = true # for structs |
| 76 | + LRP_CONFIG.supports_activation(::typeof(mylayer)) = true # for functions |
| 77 | + ``` |
| 78 | + The default fallback for this layer will use Automatic Differentiation according to "Layer-Wise Relevance Propagation: An Overview". |
| 79 | + You can also define a fully LRP-custom rule for your layer by using the interface |
| 80 | + ```julia |
| 81 | + function (rule::AbstractLRPRule)(layer::MyLayer, aₖ, Rₖ₊₁) |
| 82 | + # ... |
| 83 | + return Rₖ |
| 84 | + end |
| 85 | + ``` |
| 86 | + This pattern can also be used to dispatch on specific rules. |
| 87 | + """, |
| 88 | + ), |
| 89 | + ) |
| 90 | + throw(ArgumentError("Unknown layers found in model")) |
| 91 | + end |
| 92 | + if !passed_activation_checks |
| 93 | + verbose && display( |
| 94 | + Markdown.parse( |
| 95 | + """ # Activations failed model check |
| 96 | + Found layers with unknown or unsupported activation functions |
| 97 | + `$(join(unique(activation_names[.!activation_checks]), ", "))`. |
| 98 | + LRP assumes that the model is a "deep rectifier network" that only contains ReLU-like activation functions. |
| 99 | +
|
| 100 | + If you think the missing activation function should be supported by default, please [submit an issue](https://github.com/adrhill/ExplainabilityMethods.jl/issues). |
| 101 | +
|
| 102 | + These model checks can be skipped at your own risk by setting the LRP-analyzer keyword argument `skip_checks=true`. |
| 103 | +
|
| 104 | + ## Using custom activation functions |
| 105 | + If you use custom ReLU-like activation functions, register them via |
| 106 | + ```julia |
| 107 | + LRP_CONFIG.supports_activation(::typeof(myfunction)) = true # for functions |
| 108 | + LRP_CONFIG.supports_activation(::MyActivation) = true # for structs |
| 109 | + ``` |
| 110 | + """, |
| 111 | + ), |
| 112 | + ) |
| 113 | + throw(ArgumentError("Unknown or unsupported activation functions found in model")) |
| 114 | + end |
| 115 | + return false |
| 116 | +end |
0 commit comments