Skip to content

Commit 4ec4908

Browse files
authored
Update heatmapping normalizer (#57)
* Update heatmapping normalizer to use ColorSchemes 3.18 * Update and rename heatmap tests * Fix for Julia 1.0 compat
1 parent 8641dfb commit 4ec4908

16 files changed

+22
-38
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1717
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1818

1919
[compat]
20-
ColorSchemes = "3"
20+
ColorSchemes = "3.18"
2121
Distributions = "0.25"
2222
Flux = "0.12, 0.13"
2323
ImageCore = "0.8, 0.9"

docs/literate/example.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ heatmap(input, analyzer)
130130
using ColorSchemes
131131
heatmap(expl; cs=ColorSchemes.jet)
132132
#
133-
heatmap(expl; reduce=:sum, normalize=:extrema, cs=ColorSchemes.inferno)
133+
heatmap(expl; reduce=:sum, rangescale=:extrema, cs=ColorSchemes.inferno)
134134

135135
# This also works with batches
136-
mosaic(heatmap(expl_batch; normalize=:extrema, cs=ColorSchemes.inferno); nrow=10)
136+
mosaic(heatmap(expl_batch; rangescale=:extrema, cs=ColorSchemes.inferno); nrow=10)
137137

138138
# For the full list of keyword arguments, refer to the [`heatmap`](@ref) documentation.

src/heatmap.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
22

33
const HEATMAPPING_PRESETS = Dict{Symbol,Tuple{ColorScheme,Symbol,Symbol}}(
4-
# Analyzer => (colorscheme, reduce, normalize)
4+
# Analyzer => (colorscheme, reduce, rangescale)
55
:LRP => (ColorSchemes.bwr, :sum, :centered), # attribution
66
:InputTimesGradient => (ColorSchemes.bwr, :sum, :centered), # attribution
77
:Gradient => (ColorSchemes.grays, :norm, :extrema), # gradient
@@ -29,7 +29,7 @@ Assumes Flux's WHCN convention (width, height, color channels, batch size).
2929
- `:maxabs`: compute `maximum(abs, x)` over the color channels in
3030
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
3131
When calling `heatmap` with an array, the default is `:sum`.
32-
- `normalize::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
32+
- `rangescale::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
3333
Can be either `:extrema` or `:centered`.
3434
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
3535
When calling `heatmap` with an array, the default for use with the `bwr` colorscheme is `:centered`.
@@ -43,7 +43,7 @@ function heatmap(
4343
attr::AbstractArray{T,N};
4444
cs::ColorScheme=ColorSchemes.bwr,
4545
reduce::Symbol=:sum,
46-
normalize::Symbol=:centered,
46+
rangescale::Symbol=:centered,
4747
permute::Bool=true,
4848
unpack_singleton::Bool=true,
4949
) where {T,N}
@@ -55,18 +55,18 @@ function heatmap(
5555
),
5656
)
5757
if unpack_singleton && size(attr, 4) == 1
58-
return _heatmap(attr[:, :, :, 1], cs, reduce, normalize, permute)
58+
return _heatmap(attr[:, :, :, 1], cs, reduce, rangescale, permute)
5959
end
60-
return map(a -> _heatmap(a, cs, reduce, normalize, permute), eachslice(attr; dims=4))
60+
return map(a -> _heatmap(a, cs, reduce, rangescale, permute), eachslice(attr; dims=4))
6161
end
6262

6363
# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
6464
function heatmap(expl::Explanation; permute::Bool=true, kwargs...)
65-
_cs, _reduce, _normalize = HEATMAPPING_PRESETS[expl.analyzer]
65+
_cs, _reduce, _rangescale = HEATMAPPING_PRESETS[expl.analyzer]
6666
return heatmap(
6767
expl.attribution;
6868
reduce=get(kwargs, :reduce, _reduce),
69-
normalize=get(kwargs, :normalize, _normalize),
69+
rangescale=get(kwargs, :rangescale, _rangescale),
7070
cs=get(kwargs, :cs, _cs),
7171
permute=permute,
7272
)
@@ -81,28 +81,12 @@ function _heatmap(
8181
attr::AbstractArray{T,3},
8282
cs::ColorScheme,
8383
reduce::Symbol,
84-
normalize::Symbol,
84+
rangescale::Symbol,
8585
permute::Bool,
8686
) where {T<:Real}
87-
img = _normalize(dropdims(_reduce(attr, reduce); dims=3), normalize)
87+
img = dropdims(_reduce(attr, reduce); dims=3)
8888
permute && (img = permutedims(img))
89-
return ColorSchemes.get(cs, img)
90-
end
91-
92-
# Normalize activations across pixels
93-
function _normalize(attr, method::Symbol)
94-
if method == :centered
95-
min, max = (-1, 1) .* maximum(abs, attr)
96-
elseif method == :extrema
97-
min, max = extrema(attr)
98-
else
99-
throw(
100-
ArgumentError(
101-
"Color scheme normalizer :$method not supported, `normalize` should be :extrema or :centered",
102-
),
103-
)
104-
end
105-
return (attr .- min) / (max - min)
89+
return ColorSchemes.get(cs, img, rangescale)
10690
end
10791

10892
# Reduce attributions across color channels into a single scalar – assumes WHCN convention

test/references/heatmaps/reduce_maxabs_normalize_centered.txt renamed to test/references/heatmaps/reduce_maxabs_rangescale_centered.txt

File renamed without changes.

test/references/heatmaps/reduce_maxabs_normalize_centered2.txt renamed to test/references/heatmaps/reduce_maxabs_rangescale_centered2.txt

File renamed without changes.

test/references/heatmaps/reduce_maxabs_normalize_extrema.txt renamed to test/references/heatmaps/reduce_maxabs_rangescale_extrema.txt

File renamed without changes.

test/references/heatmaps/reduce_maxabs_normalize_extrema2.txt renamed to test/references/heatmaps/reduce_maxabs_rangescale_extrema2.txt

File renamed without changes.

test/references/heatmaps/reduce_norm_normalize_centered.txt renamed to test/references/heatmaps/reduce_norm_rangescale_centered.txt

File renamed without changes.

test/references/heatmaps/reduce_norm_normalize_centered2.txt renamed to test/references/heatmaps/reduce_norm_rangescale_centered2.txt

File renamed without changes.

test/references/heatmaps/reduce_norm_normalize_extrema.txt renamed to test/references/heatmaps/reduce_norm_rangescale_extrema.txt

File renamed without changes.

0 commit comments

Comments
 (0)