diff --git a/HISTORY.md b/HISTORY.md index 3db6e21a5..d3650d30c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # DynamicPPL Changelog +## 0.39.3 + +`DynamicPPL.TestUtils.AD.run_ad` now generates much prettier output. +In particular, when a test fails, it also tells you the tolerances needed to make it pass. + ## 0.39.2 `returned(model, parameters...)` now accepts any arguments that can be wrapped in `InitFromParams` (previously it would only accept `NamedTuple`, `AbstractDict{<:VarName}`, or a chain). diff --git a/Project.toml b/Project.toml index 7b995f530..f2ffa2c63 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.39.2" +version = "0.39.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 8ee850877..a030b479e 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -12,6 +12,7 @@ using DynamicPPL: AbstractVarInfo, getlogjoint_internal, link +using LinearAlgebra: norm using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -78,6 +79,51 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception value_actual::T grad_expected::Vector{T} grad_actual::Vector{T} + atol::T + rtol::T +end +function Base.showerror(io::IO, e::ADIncorrectException) + value_passed = isapprox(e.value_expected, e.value_actual; atol=e.atol, rtol=e.rtol) + grad_passed = isapprox(e.grad_expected, e.grad_actual; atol=e.atol, rtol=e.rtol) + s = if !value_passed && !grad_passed + "value and gradient" + elseif !value_passed + "value" + else + "gradient" + end + println(io, "ADIncorrectException: The AD backend returned an incorrect $s.") + println(io, " Testing was carried out with") + println(io, " atol : $(e.atol)") + println(io, " rtol : $(e.rtol)") + # calculate what tolerances would have been needed to pass for value + if !value_passed + min_atol_needed_to_pass_value = abs(e.value_expected - e.value_actual) + min_rtol_needed_to_pass_value = + min_atol_needed_to_pass_value / max(abs(e.value_expected), abs(e.value_actual)) + println(io, " The value check failed because:") + println(io, " expected value : $(e.value_expected)") + println(io, " actual value : $(e.value_actual)") + println(io, " This value correctness check would have passed if either:") + println(io, " atol ≥ $(min_atol_needed_to_pass_value), or") + println(io, " rtol ≥ $(min_rtol_needed_to_pass_value)") + end + if !grad_passed + norm_expected = norm(e.grad_expected) + norm_actual = norm(e.grad_actual) + max_norm = max(norm_expected, norm_actual) + norm_diff = norm(e.grad_expected - e.grad_actual) + min_atol_needed_to_pass_grad = norm_diff + min_rtol_needed_to_pass_grad = norm_diff / max_norm + # min tolerances needed to pass overall + println(io, " The gradient check failed because:") + println(io, " expected grad : $(e.grad_expected)") + println(io, " actual grad : $(e.grad_actual)") + println(io, " The gradient correctness check would have passed if either:") + println(io, " atol ≥ $(min_atol_needed_to_pass_grad), or") + println(io, " rtol ≥ $(min_rtol_needed_to_pass_grad)") + end + return nothing end """ @@ -116,11 +162,26 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa "The gradient of logp (calculated using `adtype`)" grad_actual::Vector{Tresult} "If benchmarking was requested, the time taken by the AD backend to evaluate the gradient - of logp" + of logp (in seconds)" grad_time::Union{Nothing,Tresult} - "If benchmarking was requested, the time taken by the AD backend to evaluate logp" + "If benchmarking was requested, the time taken by the AD backend to evaluate logp (in + seconds)" primal_time::Union{Nothing,Tresult} end +function Base.show(io::IO, ::MIME"text/plain", result::ADResult) + printstyled(io, "ADResult\n"; bold=true) + println(io, " ├ model : $(result.model.f)") + println(io, " ├ adtype : $(result.adtype)") + println(io, " ├ value_actual : $(result.value_actual)") + println(io, " ├ value_expected : $(result.value_expected)") + println(io, " ├ grad_actual : $(result.grad_actual)") + println(io, " ├ grad_expected : $(result.grad_expected)") + if result.grad_time !== nothing && result.primal_time !== nothing + println(io, " ├ grad_time : $(result.grad_time) s") + println(io, " ├ primal_time : $(result.primal_time) s") + end + return println(io, " └ params : $(result.params)") +end """ run_ad( @@ -230,6 +291,14 @@ Everything else is optional, and can be categorised into several groups: we cannot know the magnitude of logp and its gradient a priori. The `atol` value is supplied to handle the case where gradients are equal to zero. +1. _Whether to benchmark._ + + By default, benchmarking is disabled. To enable it, set `benchmark=true`. + When enabled, the time taken to evaluate logp as well as its gradient is + measured using Chairmarks.jl, and the `ADResult` object returned will + contain `grad_time` and `primal_time` fields with the median times (in + seconds). + 1. _Whether to output extra logging information._ By default, this function prints messages when it runs. To silence it, set @@ -297,7 +366,7 @@ function run_ad( end # Perform testing verbose && println(" expected : $((value_true, grad_true))") - exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) + exc() = throw(ADIncorrectException(value, value_true, grad, grad_true, atol, rtol)) isapprox(value, value_true; atol=atol, rtol=rtol) || exc() isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc() end @@ -306,14 +375,18 @@ function run_ad( grad_time, primal_time = if benchmark logdensity(ldf, params) # Warm-up primal_benchmark = @be logdensity($ldf, $params) + print(" evaluation : ") + show(stdout, MIME("text/plain"), median(primal_benchmark)) + println() logdensity_and_gradient(ldf, params) # Warm-up grad_benchmark = @be logdensity_and_gradient($ldf, $params) + print(" gradient : ") + show(stdout, MIME("text/plain"), median(grad_benchmark)) + println() median_primal = median(primal_benchmark).time median_grad = median(grad_benchmark).time r(f) = round(f; sigdigits=4) - verbose && println( - "grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))", - ) + verbose && println(" grad / eval : $(r(median_grad / median_primal))") (median_grad, median_primal) else nothing, nothing