Skip to content

Commit 312c060

Browse files
authored
Support direct calls to the analyzer (#45)
* Support direct calls to the analyzer * Creating an LRP analyzer without rules uses LRP-0
1 parent eac651c commit 312c060

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

src/analyze_api.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ end
2222
function analyze(input::AbstractArray{<:Real}, method::AbstractXAIMethod; kwargs...)
2323
return method(input, MaxActivationNS(); kwargs...)
2424
end
25+
function (method::AbstractXAIMethod)(input::AbstractArray{<:Real}; kwargs...)
26+
return method(input, MaxActivationNS(); kwargs...)
27+
end
2528

2629
# Explanations and outputs are returned in a wrapper.
2730
# Metadata such as the analyzer allows dispatching on functions like `heatmap`.

src/lrp.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ function LRP(model::Chain, r::AbstractLRPRule; kwargs...)
4242
return LRP(model, rules; kwargs...)
4343
end
4444
# Additional constructors for convenience:
45+
LRP(model::Chain; kwargs...) = LRP(model, ZeroRule(); kwargs...)
4546
LRPZero(model::Chain; kwargs...) = LRP(model, ZeroRule(); kwargs...)
4647
LRPEpsilon(model::Chain; kwargs...) = LRP(model, EpsilonRule(); kwargs...)
4748
LRPGamma(model::Chain; kwargs...) = LRP(model, GammaRule(); kwargs...)

test/test_vgg11.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,19 @@ end
2828
function test_vgg11(name, method; kwargs...)
2929
analyzer = method(model)
3030
@testset "$name" begin
31+
# Reference test attribution
3132
print("Timing $name...\t")
3233
@time expl = analyze(img, analyzer; kwargs...)
3334
attr = expl.attribution
34-
3535
@test size(attr) == size(img)
3636
@test_reference "references/vgg11/$(name).jld2" Dict("expl" => attr) by =
3737
(r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05)
3838

39+
# Test direct call of analyzer
40+
expl2 = analyzer(img; kwargs...)
41+
@test expl.attribution expl2.attribution
42+
43+
# Test direct call of heatmap
3944
h1 = heatmap(expl)
4045
h2 = heatmap(img, analyzer; kwargs...)
4146
@test h1 h2
@@ -75,3 +80,9 @@ end
7580
@testset "Layerwise relevances" begin
7681
test_vgg11("LRPZero", LRPZero; layerwise_relevances=true)
7782
end
83+
84+
# Test LRP constructor with no rules
85+
a1 = LRP(model)
86+
a2 = LRPZero(model)
87+
@test a1.model == a2.model
88+
@test a1.rules == a2.rules

0 commit comments

Comments
 (0)