11# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
22
33const HEATMAPPING_PRESETS = Dict {Symbol,Tuple{ColorScheme,Symbol,Symbol}} (
4- # Analyzer => (colorscheme, reduce, normalize )
4+ # Analyzer => (colorscheme, reduce, rangescale )
55 :LRP => (ColorSchemes. bwr, :sum , :centered ), # attribution
66 :InputTimesGradient => (ColorSchemes. bwr, :sum , :centered ), # attribution
77 :Gradient => (ColorSchemes. grays, :norm , :extrema ), # gradient
@@ -29,7 +29,7 @@ Assumes Flux's WHCN convention (width, height, color channels, batch size).
2929 - `:maxabs`: compute `maximum(abs, x)` over the color channels in
3030 When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
3131 When calling `heatmap` with an array, the default is `:sum`.
32- - `normalize ::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
32+ - `rangescale ::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
3333 Can be either `:extrema` or `:centered`.
3434 When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
3535 When calling `heatmap` with an array, the default for use with the `bwr` colorscheme is `:centered`.
@@ -43,7 +43,7 @@ function heatmap(
4343 attr:: AbstractArray{T,N} ;
4444 cs:: ColorScheme = ColorSchemes. bwr,
4545 reduce:: Symbol = :sum ,
46- normalize :: Symbol = :centered ,
46+ rangescale :: Symbol = :centered ,
4747 permute:: Bool = true ,
4848 unpack_singleton:: Bool = true ,
4949) where {T,N}
@@ -55,18 +55,18 @@ function heatmap(
5555 ),
5656 )
5757 if unpack_singleton && size (attr, 4 ) == 1
58- return _heatmap (attr[:, :, :, 1 ], cs, reduce, normalize , permute)
58+ return _heatmap (attr[:, :, :, 1 ], cs, reduce, rangescale , permute)
5959 end
60- return map (a -> _heatmap (a, cs, reduce, normalize , permute), eachslice (attr; dims= 4 ))
60+ return map (a -> _heatmap (a, cs, reduce, rangescale , permute), eachslice (attr; dims= 4 ))
6161end
6262
6363# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
6464function heatmap (expl:: Explanation ; permute:: Bool = true , kwargs... )
65- _cs, _reduce, _normalize = HEATMAPPING_PRESETS[expl. analyzer]
65+ _cs, _reduce, _rangescale = HEATMAPPING_PRESETS[expl. analyzer]
6666 return heatmap (
6767 expl. attribution;
6868 reduce= get (kwargs, :reduce , _reduce),
69- normalize = get (kwargs, :normalize , _normalize ),
69+ rangescale = get (kwargs, :rangescale , _rangescale ),
7070 cs= get (kwargs, :cs , _cs),
7171 permute= permute,
7272 )
@@ -81,28 +81,12 @@ function _heatmap(
8181 attr:: AbstractArray{T,3} ,
8282 cs:: ColorScheme ,
8383 reduce:: Symbol ,
84- normalize :: Symbol ,
84+ rangescale :: Symbol ,
8585 permute:: Bool ,
8686) where {T<: Real }
87- img = _normalize ( dropdims (_reduce (attr, reduce); dims= 3 ), normalize )
87+ img = dropdims (_reduce (attr, reduce); dims= 3 )
8888 permute && (img = permutedims (img))
89- return ColorSchemes. get (cs, img)
90- end
91-
92- # Normalize activations across pixels
93- function _normalize (attr, method:: Symbol )
94- if method == :centered
95- min, max = (- 1 , 1 ) .* maximum (abs, attr)
96- elseif method == :extrema
97- min, max = extrema (attr)
98- else
99- throw (
100- ArgumentError (
101- " Color scheme normalizer :$method not supported, `normalize` should be :extrema or :centered" ,
102- ),
103- )
104- end
105- return (attr .- min) / (max - min)
89+ return ColorSchemes. get (cs, img, rangescale)
10690end
10791
10892# Reduce attributions across color channels into a single scalar – assumes WHCN convention
0 commit comments