22
33const HEATMAPPING_PRESETS = Dict {Symbol,Tuple{ColorScheme,Symbol,Symbol}} (
44 # Analyzer => (colorscheme, reduce, normalize)
5- :LRP => (ColorSchemes. bwr, :sum , :centered ),
6- :InputTimesGradient => (ColorSchemes. bwr, :sum , :centered ), # same as LRP
7- :Gradient => (ColorSchemes. grays, :norm , :extrema ),
5+ :LRP => (ColorSchemes. bwr, :sum , :centered ), # attribution
6+ :InputTimesGradient => (ColorSchemes. bwr, :sum , :centered ), # attribution
7+ :Gradient => (ColorSchemes. grays, :norm , :extrema ), # gradient
88)
99
1010"""
@@ -34,35 +34,32 @@ Assumes Flux's WHCN convention (width, height, color channels, batch size).
3434 When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
3535 When calling `heatmap` with an array, the default for use with the `bwr` colorscheme is `:centered`.
3636- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
37+ - `unpack_singleton::Bool`: When heatmapping a batch with a single sample, setting `unpack_singleton=true`
38+ will return an image instead of an Vector containing a single image.
3739
3840**Note:** these keyword arguments can't be used when calling `heatmap` with an analyzer.
3941"""
4042function heatmap (
41- attr:: AbstractArray ;
43+ attr:: AbstractArray{T,N} ;
4244 cs:: ColorScheme = ColorSchemes. bwr,
4345 reduce:: Symbol = :sum ,
4446 normalize:: Symbol = :centered ,
4547 permute:: Bool = true ,
46- )
47- _size = size (attr)
48- length (_size) != 4 && throw (
48+ unpack_singleton :: Bool = true ,
49+ ) where {T,N}
50+ N != 4 && throw (
4951 DomainError (
50- _size ,
52+ N ,
5153 """ heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
5254 Please reshape your attribution to match this format if your model doesn't adhere to this convention.""" ,
5355 ),
5456 )
55- _size[end ] != 1 && throw (
56- DomainError (
57- _size[end ],
58- """ heatmap is only applicable to a single attribution, got a batch dimension of $(_size[end ]) .""" ,
59- ),
60- )
61-
62- img = _normalize (dropdims (_reduce (dropdims (attr; dims= 4 ), reduce); dims= 3 ), normalize)
63- permute && (img = permutedims (img))
64- return ColorSchemes. get (cs, img)
57+ if unpack_singleton && size (attr, 4 ) == 1
58+ return _heatmap (attr[:, :, :, 1 ], cs, reduce, normalize, permute)
59+ end
60+ return map (a -> _heatmap (a, cs, reduce, normalize, permute), eachslice (attr; dims= 4 ))
6561end
62+
6663# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
6764function heatmap (expl:: Explanation ; permute:: Bool = true , kwargs... )
6865 _cs, _reduce, _normalize = HEATMAPPING_PRESETS[expl. analyzer]
@@ -79,6 +76,18 @@ function heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
7976 return heatmap (analyze (input, analyzer, args... ; kwargs... ))
8077end
8178
79+ # Lower level function that is mapped along batch dimension
80+ function _heatmap (
81+ attr:: AbstractArray{T,3} ,
82+ cs:: ColorScheme ,
83+ reduce:: Symbol ,
84+ normalize:: Symbol ,
85+ permute:: Bool ,
86+ ) where {T<: Real }
87+ img = _normalize (dropdims (_reduce (attr, reduce); dims= 3 ), normalize)
88+ permute && (img = permutedims (img))
89+ return ColorSchemes. get (cs, img)
90+ end
8291
8392# Normalize activations across pixels
8493function _normalize (attr, method:: Symbol )
@@ -97,15 +106,15 @@ function _normalize(attr, method::Symbol)
97106end
98107
99108# Reduce attributions across color channels into a single scalar – assumes WHCN convention
100- function _reduce (attr:: T , method:: Symbol ) where {T}
101- if size (attr, 3 ) == 1 # nothing need to reduce
109+ function _reduce (attr:: AbstractArray{T,3} , method:: Symbol ) where {T}
110+ if size (attr, 3 ) == 1 # nothing to reduce
102111 return attr
112+ elseif method == :sum
113+ return reduce (+ , attr; dims= 3 )
103114 elseif method == :maxabs
104- return maximum (abs, attr; dims= 3 )
115+ return reduce ((c ... ) -> maximum (abs .(c)) , attr; dims= 3 , init = zero (T) )
105116 elseif method == :norm
106- return mapslices (norm, attr; dims= 3 ):: T
107- elseif method == :sum
108- return sum (attr; dims= 3 )
117+ return reduce ((c... ) -> sqrt (sum (c .^ 2 )), attr; dims= 3 , init= zero (T))
109118 end
110119 throw (
111120 ArgumentError (
0 commit comments