Skip to content

Commit d468e9e

Browse files
authored
Add layers and layer tests for LRP (#12)
* Introduce Union types for Flux layers * Update tests
1 parent 9c16fc5 commit d468e9e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+123
-35
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ jobs:
1010
test:
1111
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
1212
runs-on: ${{ matrix.os }}
13+
continue-on-error: ${{ matrix.version == 'nightly' }}
1314
strategy:
1415
fail-fast: false
1516
matrix:

src/flux.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
## Group layers by type:
2+
const ConvLayers = Union{Conv,DepthwiseConv,ConvTranspose,CrossCor}
3+
const DropoutLayers = Union{Dropout,typeof(Flux.dropout),AlphaDropout}
4+
const ReshapingLayers = Union{typeof(Flux.flatten)}
5+
# Pooling layers
6+
const MaxPoolLayers = Union{MaxPool,AdaptiveMaxPool,GlobalMaxPool}
7+
const MeanPoolLayers = Union{MeanPool,AdaptiveMeanPool,GlobalMeanPool}
8+
const PoolingLayers = Union{MaxPoolLayers,MeanPoolLayers}
9+
110
_flatten_chain(x) = x
211
_flatten_chain(c::Chain) = [c.layers...]
312
"""

src/lrp_rules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ abstract type AbstractLRPRule end
2020
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
2121
# It can be extended for new rules via `modify_denominator` and `modify_layer`,
2222
# which in turn uses `modify_params`.
23-
function (rule::AbstractLRPRule)(layer::Union{Dense,Conv,MaxPool,MeanPool}, aₖ, Rₖ₊₁)
23+
function (rule::AbstractLRPRule)(layer::Union{Dense,ConvLayers,PoolingLayers}, aₖ, Rₖ₊₁)
2424
layerᵨ = modify_layer(rule, layer)
2525
function fwpass(a)
2626
z = layerᵨ(a)
@@ -33,8 +33,8 @@ function (rule::AbstractLRPRule)(layer::Union{Dense,Conv,MaxPool,MeanPool}, aₖ
3333
end
3434

3535
# Special cases are dispatched on layer type:
36-
(rule::AbstractLRPRule)(::Dropout, aₖ, Rₖ₊₁) = Rₖ₊₁
37-
(rule::AbstractLRPRule)(::typeof(Flux.flatten), aₖ, Rₖ₊₁) = reshape(Rₖ₊₁, size(aₖ))
36+
(rule::AbstractLRPRule)(::DropoutLayers, aₖ, Rₖ₊₁) = Rₖ₊₁
37+
(rule::AbstractLRPRule)(::ReshapingLayers, aₖ, Rₖ₊₁) = reshape(Rₖ₊₁, size(aₖ))
3838

3939
"""
4040
modify_layer(rule, layer)
1.06 KB
Binary file not shown.
1.06 KB
Binary file not shown.
1.06 KB
Binary file not shown.
1.06 KB
Binary file not shown.
1.06 KB
Binary file not shown.
1.06 KB
Binary file not shown.
File renamed without changes.

0 commit comments

Comments
 (0)