1- function gradient_wrt_input (model, input, output_indices )
2- return only ( gradient ((in) -> model (in)[output_indices] , input) )
3- end
1+ function gradient_wrt_input (model, input, ns :: AbstractNeuronSelector )
2+ output, back = Zygote . pullback (model , input)
3+ output_indices = ns (output)
44
5- function gradients_wrt_batch (model, input:: AbstractArray{T,N} , output_indices) where {T,N}
6- # To avoid computing a sparse jacobian, we compute individual gradients
7- # by calling `gradient_wrt_input` on slices of the input along the batch dimension.
8- out = similar (input)
9- inds_before_N = ntuple (Returns (:), N - 1 )
10- for (i, ax) in enumerate (axes (input, N))
11- view (out, inds_before_N... , ax, :) .= gradient_wrt_input (
12- model, view (input, inds_before_N... , ax, :), drop_batch_index (output_indices[i])
13- )
14- end
15- return out
5+ # Compute VJP w.r.t. full model output, selecting vector s.t. it masks output neurons
6+ v = zero (output)
7+ v[output_indices] .= 1
8+ grad = only (back (v))
9+ return grad, output, output_indices
1610end
1711
1812"""
@@ -25,9 +19,7 @@ struct Gradient{C<:Chain} <: AbstractXAIMethod
2519 Gradient (model:: Chain ) = new {typeof(model)} (Flux. testmode! (check_output_softmax (model)))
2620end
2721function (analyzer:: Gradient )(input, ns:: AbstractNeuronSelector )
28- output = analyzer. model (input)
29- output_indices = ns (output)
30- grad = gradients_wrt_batch (analyzer. model, input, output_indices)
22+ grad, output, output_indices = gradient_wrt_input (analyzer. model, input, ns)
3123 return Explanation (grad, output, output_indices, :Gradient , nothing )
3224end
3325
@@ -44,9 +36,8 @@ struct InputTimesGradient{C<:Chain} <: AbstractXAIMethod
4436 end
4537end
4638function (analyzer:: InputTimesGradient )(input, ns:: AbstractNeuronSelector )
47- output = analyzer. model (input)
48- output_indices = ns (output)
49- attr = input .* gradients_wrt_batch (analyzer. model, input, output_indices)
39+ grad, output, output_indices = gradient_wrt_input (analyzer. model, input, ns)
40+ attr = input .* grad
5041 return Explanation (attr, output, output_indices, :InputTimesGradient , nothing )
5142end
5243
0 commit comments