@@ -12,6 +12,7 @@ using DynamicPPL:
1212 AbstractVarInfo,
1313 getlogjoint_internal,
1414 link
15+ using LinearAlgebra: norm
1516using LogDensityProblems: logdensity, logdensity_and_gradient
1617using Random: AbstractRNG, default_rng
1718using Statistics: median
@@ -78,6 +79,51 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
7879 value_actual:: T
7980 grad_expected:: Vector{T}
8081 grad_actual:: Vector{T}
82+ atol:: T
83+ rtol:: T
84+ end
85+ function Base. showerror (io:: IO , e:: ADIncorrectException )
86+ value_passed = isapprox (e. value_expected, e. value_actual; atol= e. atol, rtol= e. rtol)
87+ grad_passed = isapprox (e. grad_expected, e. grad_actual; atol= e. atol, rtol= e. rtol)
88+ s = if ! value_passed && ! grad_passed
89+ " value and gradient"
90+ elseif ! value_passed
91+ " value"
92+ else
93+ " gradient"
94+ end
95+ println (io, " ADIncorrectException: The AD backend returned an incorrect $s ." )
96+ println (io, " Testing was carried out with" )
97+ println (io, " atol : $(e. atol) " )
98+ println (io, " rtol : $(e. rtol) " )
99+ # calculate what tolerances would have been needed to pass for value
100+ if ! value_passed
101+ min_atol_needed_to_pass_value = abs (e. value_expected - e. value_actual)
102+ min_rtol_needed_to_pass_value =
103+ min_atol_needed_to_pass_value / max (abs (e. value_expected), abs (e. value_actual))
104+ println (io, " The value check failed because:" )
105+ println (io, " expected value : $(e. value_expected) " )
106+ println (io, " actual value : $(e. value_actual) " )
107+ println (io, " This value correctness check would have passed if either:" )
108+ println (io, " atol ≥ $(min_atol_needed_to_pass_value) , or" )
109+ println (io, " rtol ≥ $(min_rtol_needed_to_pass_value) " )
110+ end
111+ if ! grad_passed
112+ norm_expected = norm (e. grad_expected)
113+ norm_actual = norm (e. grad_actual)
114+ max_norm = max (norm_expected, norm_actual)
115+ norm_diff = norm (e. grad_expected - e. grad_actual)
116+ min_atol_needed_to_pass_grad = norm_diff
117+ min_rtol_needed_to_pass_grad = norm_diff / max_norm
118+ # min tolerances needed to pass overall
119+ println (io, " The gradient check failed because:" )
120+ println (io, " expected grad : $(e. grad_expected) " )
121+ println (io, " actual grad : $(e. grad_actual) " )
122+ println (io, " The gradient correctness check would have passed if either:" )
123+ println (io, " atol ≥ $(min_atol_needed_to_pass_grad) , or" )
124+ println (io, " rtol ≥ $(min_rtol_needed_to_pass_grad) " )
125+ end
126+ return nothing
81127end
82128
83129"""
@@ -116,11 +162,26 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa
116162 " The gradient of logp (calculated using `adtype`)"
117163 grad_actual:: Vector{Tresult}
118164 " If benchmarking was requested, the time taken by the AD backend to evaluate the gradient
119- of logp"
165+ of logp (in seconds) "
120166 grad_time:: Union{Nothing,Tresult}
121- " If benchmarking was requested, the time taken by the AD backend to evaluate logp"
167+ " If benchmarking was requested, the time taken by the AD backend to evaluate logp (in
168+ seconds)"
122169 primal_time:: Union{Nothing,Tresult}
123170end
171+ function Base. show (io:: IO , :: MIME"text/plain" , result:: ADResult )
172+ printstyled (io, " ADResult\n " ; bold= true )
173+ println (io, " ├ model : $(result. model. f) " )
174+ println (io, " ├ adtype : $(result. adtype) " )
175+ println (io, " ├ value_actual : $(result. value_actual) " )
176+ println (io, " ├ value_expected : $(result. value_expected) " )
177+ println (io, " ├ grad_actual : $(result. grad_actual) " )
178+ println (io, " ├ grad_expected : $(result. grad_expected) " )
179+ if result. grad_time != = nothing && result. primal_time != = nothing
180+ println (io, " ├ grad_time : $(result. grad_time) s" )
181+ println (io, " ├ primal_time : $(result. primal_time) s" )
182+ end
183+ return println (io, " └ params : $(result. params) " )
184+ end
124185
125186"""
126187 run_ad(
@@ -230,6 +291,14 @@ Everything else is optional, and can be categorised into several groups:
230291 we cannot know the magnitude of logp and its gradient a priori. The `atol`
231292 value is supplied to handle the case where gradients are equal to zero.
232293
294+ 1. _Whether to benchmark._
295+
296+ By default, benchmarking is disabled. To enable it, set `benchmark=true`.
297+ When enabled, the time taken to evaluate logp as well as its gradient is
298+ measured using Chairmarks.jl, and the `ADResult` object returned will
299+ contain `grad_time` and `primal_time` fields with the median times (in
300+ seconds).
301+
2333021. _Whether to output extra logging information._
234303
235304 By default, this function prints messages when it runs. To silence it, set
@@ -297,7 +366,7 @@ function run_ad(
297366 end
298367 # Perform testing
299368 verbose && println (" expected : $((value_true, grad_true)) " )
300- exc () = throw (ADIncorrectException (value, value_true, grad, grad_true))
369+ exc () = throw (ADIncorrectException (value, value_true, grad, grad_true, atol, rtol ))
301370 isapprox (value, value_true; atol= atol, rtol= rtol) || exc ()
302371 isapprox (grad, grad_true; atol= atol, rtol= rtol) || exc ()
303372 end
@@ -306,14 +375,18 @@ function run_ad(
306375 grad_time, primal_time = if benchmark
307376 logdensity (ldf, params) # Warm-up
308377 primal_benchmark = @be logdensity ($ ldf, $ params)
378+ print (" evaluation : " )
379+ show (stdout , MIME (" text/plain" ), median (primal_benchmark))
380+ println ()
309381 logdensity_and_gradient (ldf, params) # Warm-up
310382 grad_benchmark = @be logdensity_and_gradient ($ ldf, $ params)
383+ print (" gradient : " )
384+ show (stdout , MIME (" text/plain" ), median (grad_benchmark))
385+ println ()
311386 median_primal = median (primal_benchmark). time
312387 median_grad = median (grad_benchmark). time
313388 r (f) = round (f; sigdigits= 4 )
314- verbose && println (
315- " grad / primal : $(r (median_grad)) /$(r (median_primal)) = $(r (median_grad / median_primal)) " ,
316- )
389+ verbose && println (" grad / eval : $(r (median_grad / median_primal)) " )
317390 (median_grad, median_primal)
318391 else
319392 nothing , nothing
0 commit comments