Skip to content

Commit 5704ec0

Browse files
authored
Faster stabilize_denom (#47)
1 parent c74f50d commit 5704ec0

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

src/lrp_rules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ modify_params(::AbstractLRPRule, W, b) = (W, b) # general fallback
5454
5555
Function that modifies zₖ on the forward pass, e.g. for numerical stability.
5656
"""
57-
modify_denominator(::AbstractLRPRule, d) = stabilize_denom(d; eps=1.0f-9) # general fallback
57+
modify_denominator(::AbstractLRPRule, d) = stabilize_denom(d, 1.0f-9) # general fallback
5858

5959
"""
6060
modify_layer(rule, layer)
@@ -104,7 +104,7 @@ struct EpsilonRule{T} <: AbstractLRPRule
104104
ϵ::T
105105
EpsilonRule(; ϵ=1.0f-6) = new{Float32}(ϵ)
106106
end
107-
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d; eps=r.ϵ)
107+
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d, r.ϵ)
108108

109109
"""
110110
ZBoxRule()
@@ -131,7 +131,7 @@ function lrp_zbox!(layer::L, Rₖ::T1, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
131131
f⁻::T2 = layer⁻(h)
132132

133133
z = f - f⁺ - f⁻
134-
s = Zygote.@ignore safedivide(Rₖ₊₁, z; eps=1e-9)
134+
s = Zygote.@ignore safedivide(Rₖ₊₁, z, 1e-9)
135135
z s
136136
end
137137
Rₖ .= aₖ .* c + l .* cₗ + h .* cₕ

src/utils.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
"""
2-
stabilize_denom(d; eps = 1f-6)
2+
stabilize_denom(d, [eps = 1f-9])
33
44
Replace zero terms of a matrix `d` with `eps`.
55
"""
6-
stabilize_denom(d::Real; eps=1.0f-9) = ifelse(d 0, d + sign(d) * eps, d)
7-
function stabilize_denom(d::AbstractArray; eps=1.0f-9)
8-
return d + ((d .≈ 0) + sign.(d)) * eps
6+
function stabilize_denom(d::T, eps=T(1.0f-9)) where {T}
7+
iszero(d) && (return T(eps))
8+
return d + sign(d) * T(eps)
99
end
10+
stabilize_denom(D::AbstractArray{T}, eps=T(1.0f-9)) where {T} = stabilize_denom.(D, eps)
1011

1112
"""
12-
safedivide(a, b; eps = 1f-6)
13+
safedivide(a, b, [eps = 1f-6])
1314
1415
Elementwise division of two matrices avoiding near zero terms
1516
in the denominator by replacing them with `± eps`.
1617
"""
17-
safedivide(a, b; eps=1.0f-9) = a ./ stabilize_denom(b; eps=eps)
18+
function safedivide(a::AbstractArray{T}, b::AbstractArray{T}, eps=T(1.0f-9)) where {T}
19+
return a ./ stabilize_denom(b, T(eps))
20+
end
1821

1922
"""
2023
batch_dim_view(A)

test/test_utils.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Flux
2-
using ExplainableAI: flatten_model, has_output_softmax, check_output_softmax
3-
using ExplainableAI: stabilize_denom
2+
using ExplainableAI:
3+
flatten_model, has_output_softmax, check_output_softmax, stabilize_denom
44

55
# flatten_model
66
@test flatten_model(Chain(Chain(Chain(abs)), sqrt, Chain(relu))) == Chain(abs, sqrt, relu)
@@ -22,5 +22,8 @@ using ExplainableAI: stabilize_denom
2222
Chain(Chain(abs), Chain(Chain(softmax)), sqrt) # don't do anything if there is no softmax at the end
2323

2424
# stabilize_denom
25-
A = [1.0 0; -0 -1.0e-25]
26-
@test stabilize_denom(A; eps=1e-3) [1.001 1e-3; 1e-3 -1e-3]
25+
A = [1.0 0.0 1.0e-25; -1.0 -0.0 -1.0e-25]
26+
S = @inferred stabilize_denom(A, 1e-3)
27+
@test S [1.001 1e-3 1e-3; -1.001 1e-3 -1e-3]
28+
S = @inferred stabilize_denom(Float32.(A), 1e-2)
29+
@test S [1.01 1f-2 1f-2; -1.01 1f-2 -1f-2]

0 commit comments

Comments
 (0)