Skip to content

Commit b54a554

Browse files
committed
Improve AD printing
1 parent 554aba8 commit b54a554

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

src/test_utils/ad.jl

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using DynamicPPL:
1212
AbstractVarInfo,
1313
getlogjoint_internal,
1414
link
15+
using LinearAlgebra: norm
1516
using LogDensityProblems: logdensity, logdensity_and_gradient
1617
using Random: AbstractRNG, default_rng
1718
using Statistics: median
@@ -78,6 +79,51 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
7879
value_actual::T
7980
grad_expected::Vector{T}
8081
grad_actual::Vector{T}
82+
atol::T
83+
rtol::T
84+
end
85+
function Base.showerror(io::IO, e::ADIncorrectException)
86+
value_passed = isapprox(e.value_expected, e.value_actual; atol=e.atol, rtol=e.rtol)
87+
grad_passed = isapprox(e.grad_expected, e.grad_actual; atol=e.atol, rtol=e.rtol)
88+
s = if !value_passed && !grad_passed
89+
"value and gradient"
90+
elseif !value_passed
91+
"value"
92+
else
93+
"gradient"
94+
end
95+
println(io, "The AD backend returned an incorrect $s.")
96+
println(io, " Testing was carried out with")
97+
println(io, " atol : $(e.atol)")
98+
println(io, " rtol : $(e.rtol)")
99+
# calculate what tolerances would have been needed to pass for value
100+
if !value_passed
101+
min_atol_needed_to_pass_value = abs(e.value_expected - e.value_actual)
102+
min_rtol_needed_to_pass_value =
103+
min_atol_needed_to_pass_value / max(abs(e.value_expected), abs(e.value_actual))
104+
println(io, " The value check failed because:")
105+
println(io, " expected value : $(e.value_expected)")
106+
println(io, " actual value : $(e.value_actual)")
107+
println(io, " This value correctness check would have passed if either:")
108+
println(io, " atol ≥ $(min_atol_needed_to_pass_value), or")
109+
println(io, " rtol ≥ $(min_rtol_needed_to_pass_value)")
110+
end
111+
if !grad_passed
112+
norm_expected = norm(e.grad_expected)
113+
norm_actual = norm(e.grad_actual)
114+
max_norm = max(norm_expected, norm_actual)
115+
norm_diff = norm(e.grad_expected - e.grad_actual)
116+
min_atol_needed_to_pass_grad = norm_diff
117+
min_rtol_needed_to_pass_grad = norm_diff / max_norm
118+
# min tolerances needed to pass overall
119+
println(io, " The gradient check failed because:")
120+
println(io, " expected grad : $(e.value_expected)")
121+
println(io, " actual grad : $(e.value_actual)")
122+
println(io, " The gradient correctness check would have passed if either:")
123+
println(io, " atol ≥ $(min_atol_needed_to_pass_grad), or")
124+
println(io, " rtol ≥ $(min_rtol_needed_to_pass_grad)")
125+
end
126+
return nothing
81127
end
82128

83129
"""
@@ -306,7 +352,7 @@ function run_ad(
306352
end
307353
# Perform testing
308354
verbose && println(" expected : $((value_true, grad_true))")
309-
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
355+
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true, atol, rtol))
310356
isapprox(value, value_true; atol=atol, rtol=rtol) || exc()
311357
isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc()
312358
end
@@ -315,16 +361,18 @@ function run_ad(
315361
grad_time, primal_time = if benchmark
316362
logdensity(ldf, params) # Warm-up
317363
primal_benchmark = @be logdensity($ldf, $params)
318-
@info "Evaluation median benchmark: $(median(primal_benchmark))"
364+
print(" evaluation : ")
365+
show(stdout, MIME("text/plain"), median(primal_benchmark))
366+
println()
319367
logdensity_and_gradient(ldf, params) # Warm-up
320368
grad_benchmark = @be logdensity_and_gradient($ldf, $params)
321-
@info "Gradient median benchmark: $(median(grad_benchmark))"
369+
print(" gradient : ")
370+
show(stdout, MIME("text/plain"), median(grad_benchmark))
371+
println()
322372
median_primal = median(primal_benchmark).time
323373
median_grad = median(grad_benchmark).time
324374
r(f) = round(f; sigdigits=4)
325-
verbose && println(
326-
"grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))",
327-
)
375+
verbose && println(" grad / eval : $(r(median_grad / median_primal))")
328376
(median_grad, median_primal)
329377
else
330378
nothing, nothing

0 commit comments

Comments
 (0)