@@ -6,7 +6,7 @@ on_CI = haskey(ENV, "GITHUB_ACTIONS")
66
77include (" ../test/vgg11.jl" )
88vgg11 = VGG11 (; pretrain= false )
9- model = flatten_model (strip_softmax (vgg19 . layers))
9+ model = flatten_model (strip_softmax (vgg11 . layers))
1010img = rand (MersenneTwister (123 ), Float32, (224 , 224 , 3 , 1 ))
1111
1212# Benchmark custom LRP composite
@@ -23,11 +23,13 @@ algs = Dict(
2323)
2424
2525# Define benchmark
26+ contruct_analyzer (alg, model) = alg (model) # for use with @benchmarkable macro
27+
2628SUITE = BenchmarkGroup ()
2729SUITE[" VGG" ] = BenchmarkGroup ([k for k in keys (algs)])
2830for (name, alg) in algs
2931 SUITE[" VGG" ][name] = BenchmarkGroup ([" construct analyzer" , " analyze" ])
30- SUITE[" VGG" ][name][" construct analyzer" ] = @benchmarkable alg ( $ (model))
32+ SUITE[" VGG" ][name][" construct analyzer" ] = @benchmarkable contruct_analyzer ( $ (alg), $ (model))
3133
3234 analyzer = alg (model)
3335 SUITE[" VGG" ][name][" analyze" ] = @benchmarkable analyze ($ (img), $ (analyzer))
@@ -58,15 +60,15 @@ rules = Dict(
5860)
5961rulenames = [k for k in keys (rules)]
6062
63+ test_rule (rule, layer, aₖ, Rₖ₊₁) = rule (layer, aₖ, Rₖ₊₁) # for use with @benchmarkable macro
64+
6165for (layername, (layer, aₖ)) in layers
6266 SUITE[layername] = BenchmarkGroup (rulenames)
67+ Rₖ₊₁ = layer (aₖ)
6368
64- for (rulename, ruletype) in rules
65- Rₖ₊₁ = layer (aₖ)
69+ for (rulename, rule) in rules
6670 SUITE[layername][rulename] = BenchmarkGroup ([" dispatch" , " AD fallback" ])
67- SUITE[layername][rulename][" dispatch" ] = @benchmarkable rule ($ layer, $ aₖ, $ Rₖ₊₁)
68- SUITE[layername][rulename][" AD fallback" ] = @benchmarkable rule (
69- $ TestWrapper (layer), $ aₖ, $ Rₖ₊₁
70- )
71+ SUITE[layername][rulename][" dispatch" ] = @benchmarkable test_rule ($ (rule), $ (layer), $ (aₖ), $ (Rₖ₊₁))
72+ SUITE[layername][rulename][" AD fallback" ] = @benchmarkable test_rule ($ (rule), $ (TestWrapper (layer)), $ (aₖ), $ (Rₖ₊₁))
7173 end
7274end
0 commit comments