Skip to content

Commit 5962903

Browse files
authored
Improve logging and display methods for run_ad (#1168)
1 parent 06ab810 commit 5962903

File tree

3 files changed

+85
-7
lines changed

3 files changed

+85
-7
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.3
4+
5+
`DynamicPPL.TestUtils.AD.run_ad` now generates much prettier output.
6+
In particular, when a test fails, it also tells you the tolerances needed to make it pass.
7+
38
## 0.39.2
49

510
`returned(model, parameters...)` now accepts any arguments that can be wrapped in `InitFromParams` (previously it would only accept `NamedTuple`, `AbstractDict{<:VarName}`, or a chain).

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.2"
3+
version = "0.39.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/test_utils/ad.jl

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using DynamicPPL:
1212
AbstractVarInfo,
1313
getlogjoint_internal,
1414
link
15+
using LinearAlgebra: norm
1516
using LogDensityProblems: logdensity, logdensity_and_gradient
1617
using Random: AbstractRNG, default_rng
1718
using Statistics: median
@@ -78,6 +79,51 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
7879
value_actual::T
7980
grad_expected::Vector{T}
8081
grad_actual::Vector{T}
82+
atol::T
83+
rtol::T
84+
end
85+
function Base.showerror(io::IO, e::ADIncorrectException)
86+
value_passed = isapprox(e.value_expected, e.value_actual; atol=e.atol, rtol=e.rtol)
87+
grad_passed = isapprox(e.grad_expected, e.grad_actual; atol=e.atol, rtol=e.rtol)
88+
s = if !value_passed && !grad_passed
89+
"value and gradient"
90+
elseif !value_passed
91+
"value"
92+
else
93+
"gradient"
94+
end
95+
println(io, "ADIncorrectException: The AD backend returned an incorrect $s.")
96+
println(io, " Testing was carried out with")
97+
println(io, " atol : $(e.atol)")
98+
println(io, " rtol : $(e.rtol)")
99+
# calculate what tolerances would have been needed to pass for value
100+
if !value_passed
101+
min_atol_needed_to_pass_value = abs(e.value_expected - e.value_actual)
102+
min_rtol_needed_to_pass_value =
103+
min_atol_needed_to_pass_value / max(abs(e.value_expected), abs(e.value_actual))
104+
println(io, " The value check failed because:")
105+
println(io, " expected value : $(e.value_expected)")
106+
println(io, " actual value : $(e.value_actual)")
107+
println(io, " This value correctness check would have passed if either:")
108+
println(io, " atol ≥ $(min_atol_needed_to_pass_value), or")
109+
println(io, " rtol ≥ $(min_rtol_needed_to_pass_value)")
110+
end
111+
if !grad_passed
112+
norm_expected = norm(e.grad_expected)
113+
norm_actual = norm(e.grad_actual)
114+
max_norm = max(norm_expected, norm_actual)
115+
norm_diff = norm(e.grad_expected - e.grad_actual)
116+
min_atol_needed_to_pass_grad = norm_diff
117+
min_rtol_needed_to_pass_grad = norm_diff / max_norm
118+
# min tolerances needed to pass overall
119+
println(io, " The gradient check failed because:")
120+
println(io, " expected grad : $(e.grad_expected)")
121+
println(io, " actual grad : $(e.grad_actual)")
122+
println(io, " The gradient correctness check would have passed if either:")
123+
println(io, " atol ≥ $(min_atol_needed_to_pass_grad), or")
124+
println(io, " rtol ≥ $(min_rtol_needed_to_pass_grad)")
125+
end
126+
return nothing
81127
end
82128

83129
"""
@@ -116,11 +162,26 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa
116162
"The gradient of logp (calculated using `adtype`)"
117163
grad_actual::Vector{Tresult}
118164
"If benchmarking was requested, the time taken by the AD backend to evaluate the gradient
119-
of logp"
165+
of logp (in seconds)"
120166
grad_time::Union{Nothing,Tresult}
121-
"If benchmarking was requested, the time taken by the AD backend to evaluate logp"
167+
"If benchmarking was requested, the time taken by the AD backend to evaluate logp (in
168+
seconds)"
122169
primal_time::Union{Nothing,Tresult}
123170
end
171+
function Base.show(io::IO, ::MIME"text/plain", result::ADResult)
172+
printstyled(io, "ADResult\n"; bold=true)
173+
println(io, " ├ model : $(result.model.f)")
174+
println(io, " ├ adtype : $(result.adtype)")
175+
println(io, " ├ value_actual : $(result.value_actual)")
176+
println(io, " ├ value_expected : $(result.value_expected)")
177+
println(io, " ├ grad_actual : $(result.grad_actual)")
178+
println(io, " ├ grad_expected : $(result.grad_expected)")
179+
if result.grad_time !== nothing && result.primal_time !== nothing
180+
println(io, " ├ grad_time : $(result.grad_time) s")
181+
println(io, " ├ primal_time : $(result.primal_time) s")
182+
end
183+
return println(io, " └ params : $(result.params)")
184+
end
124185

125186
"""
126187
run_ad(
@@ -230,6 +291,14 @@ Everything else is optional, and can be categorised into several groups:
230291
we cannot know the magnitude of logp and its gradient a priori. The `atol`
231292
value is supplied to handle the case where gradients are equal to zero.
232293
294+
1. _Whether to benchmark._
295+
296+
By default, benchmarking is disabled. To enable it, set `benchmark=true`.
297+
When enabled, the time taken to evaluate logp as well as its gradient is
298+
measured using Chairmarks.jl, and the `ADResult` object returned will
299+
contain `grad_time` and `primal_time` fields with the median times (in
300+
seconds).
301+
233302
1. _Whether to output extra logging information._
234303
235304
By default, this function prints messages when it runs. To silence it, set
@@ -297,7 +366,7 @@ function run_ad(
297366
end
298367
# Perform testing
299368
verbose && println(" expected : $((value_true, grad_true))")
300-
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
369+
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true, atol, rtol))
301370
isapprox(value, value_true; atol=atol, rtol=rtol) || exc()
302371
isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc()
303372
end
@@ -306,14 +375,18 @@ function run_ad(
306375
grad_time, primal_time = if benchmark
307376
logdensity(ldf, params) # Warm-up
308377
primal_benchmark = @be logdensity($ldf, $params)
378+
print(" evaluation : ")
379+
show(stdout, MIME("text/plain"), median(primal_benchmark))
380+
println()
309381
logdensity_and_gradient(ldf, params) # Warm-up
310382
grad_benchmark = @be logdensity_and_gradient($ldf, $params)
383+
print(" gradient : ")
384+
show(stdout, MIME("text/plain"), median(grad_benchmark))
385+
println()
311386
median_primal = median(primal_benchmark).time
312387
median_grad = median(grad_benchmark).time
313388
r(f) = round(f; sigdigits=4)
314-
verbose && println(
315-
"grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))",
316-
)
389+
verbose && println(" grad / eval : $(r(median_grad / median_primal))")
317390
(median_grad, median_primal)
318391
else
319392
nothing, nothing

0 commit comments

Comments
 (0)