@@ -12,6 +12,7 @@ using DynamicPPL:
1212 AbstractVarInfo,
1313 getlogjoint_internal,
1414 link
15+ using LinearAlgebra: norm
1516using LogDensityProblems: logdensity, logdensity_and_gradient
1617using Random: AbstractRNG, default_rng
1718using Statistics: median
@@ -78,6 +79,51 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
7879 value_actual:: T
7980 grad_expected:: Vector{T}
8081 grad_actual:: Vector{T}
82+ atol:: T
83+ rtol:: T
84+ end
85+ function Base. showerror (io:: IO , e:: ADIncorrectException )
86+ value_passed = isapprox (e. value_expected, e. value_actual; atol= e. atol, rtol= e. rtol)
87+ grad_passed = isapprox (e. grad_expected, e. grad_actual; atol= e. atol, rtol= e. rtol)
88+ s = if ! value_passed && ! grad_passed
89+ " value and gradient"
90+ elseif ! value_passed
91+ " value"
92+ else
93+ " gradient"
94+ end
95+ println (io, " The AD backend returned an incorrect $s ." )
96+ println (io, " Testing was carried out with" )
97+ println (io, " atol : $(e. atol) " )
98+ println (io, " rtol : $(e. rtol) " )
99+ # calculate what tolerances would have been needed to pass for value
100+ if ! value_passed
101+ min_atol_needed_to_pass_value = abs (e. value_expected - e. value_actual)
102+ min_rtol_needed_to_pass_value =
103+ min_atol_needed_to_pass_value / max (abs (e. value_expected), abs (e. value_actual))
104+ println (io, " The value check failed because:" )
105+ println (io, " expected value : $(e. value_expected) " )
106+ println (io, " actual value : $(e. value_actual) " )
107+ println (io, " This value correctness check would have passed if either:" )
108+ println (io, " atol ≥ $(min_atol_needed_to_pass_value) , or" )
109+ println (io, " rtol ≥ $(min_rtol_needed_to_pass_value) " )
110+ end
111+ if ! grad_passed
112+ norm_expected = norm (e. grad_expected)
113+ norm_actual = norm (e. grad_actual)
114+ max_norm = max (norm_expected, norm_actual)
115+ norm_diff = norm (e. grad_expected - e. grad_actual)
116+ min_atol_needed_to_pass_grad = norm_diff
117+ min_rtol_needed_to_pass_grad = norm_diff / max_norm
118+ # min tolerances needed to pass overall
119+ println (io, " The gradient check failed because:" )
120+ println (io, " expected grad : $(e. value_expected) " )
121+ println (io, " actual grad : $(e. value_actual) " )
122+ println (io, " The gradient correctness check would have passed if either:" )
123+ println (io, " atol ≥ $(min_atol_needed_to_pass_grad) , or" )
124+ println (io, " rtol ≥ $(min_rtol_needed_to_pass_grad) " )
125+ end
126+ return nothing
81127end
82128
83129"""
@@ -306,7 +352,7 @@ function run_ad(
306352 end
307353 # Perform testing
308354 verbose && println (" expected : $((value_true, grad_true)) " )
309- exc () = throw (ADIncorrectException (value, value_true, grad, grad_true))
355+ exc () = throw (ADIncorrectException (value, value_true, grad, grad_true, atol, rtol ))
310356 isapprox (value, value_true; atol= atol, rtol= rtol) || exc ()
311357 isapprox (grad, grad_true; atol= atol, rtol= rtol) || exc ()
312358 end
@@ -315,16 +361,18 @@ function run_ad(
315361 grad_time, primal_time = if benchmark
316362 logdensity (ldf, params) # Warm-up
317363 primal_benchmark = @be logdensity ($ ldf, $ params)
318- @info " Evaluation median benchmark: $(median (primal_benchmark)) "
364+ print (" evaluation : " )
365+ show (stdout , MIME (" text/plain" ), median (primal_benchmark))
366+ println ()
319367 logdensity_and_gradient (ldf, params) # Warm-up
320368 grad_benchmark = @be logdensity_and_gradient ($ ldf, $ params)
321- @info " Gradient median benchmark: $(median (grad_benchmark)) "
369+ print (" gradient : " )
370+ show (stdout , MIME (" text/plain" ), median (grad_benchmark))
371+ println ()
322372 median_primal = median (primal_benchmark). time
323373 median_grad = median (grad_benchmark). time
324374 r (f) = round (f; sigdigits= 4 )
325- verbose && println (
326- " grad / primal : $(r (median_grad)) /$(r (median_primal)) = $(r (median_grad / median_primal)) " ,
327- )
375+ verbose && println (" grad / eval : $(r (median_grad / median_primal)) " )
328376 (median_grad, median_primal)
329377 else
330378 nothing , nothing
0 commit comments