Skip to content

Commit d2f58d7

Browse files
committed
Merge branch 'mhauru/vnt-for-fastldf' into mhauru/arraylikeblock
2 parents 420a6b2 + 44be19d commit d2f58d7

21 files changed

+232
-189
lines changed

HISTORY.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@
22

33
## 0.40
44

5+
## 0.39.4
6+
7+
Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed).
8+
9+
## 0.39.3
10+
11+
`DynamicPPL.TestUtils.AD.run_ad` now generates much prettier output.
12+
In particular, when a test fails, it also tells you the tolerances needed to make it pass.
13+
14+
## 0.39.2
15+
16+
`returned(model, parameters...)` now accepts any arguments that can be wrapped in `InitFromParams` (previously it would only accept `NamedTuple`, `AbstractDict{<:VarName}`, or a chain).
17+
518
## 0.39.1
619

720
`LogDensityFunction` now allows you to call `logdensity_and_gradient(ldf, x)` with `AbstractVector`s `x` that are not plain Vectors (they will be converted internally before calculating the gradient).

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ DynamicPPL provides several demo models in the `DynamicPPL.TestUtils` submodule.
267267

268268
```@docs
269269
DynamicPPL.TestUtils.DEMO_MODELS
270+
DynamicPPL.TestUtils.ALL_MODELS
270271
```
271272

272273
For every demo model, one can define the true log prior, log likelihood, and log joint probabilities.

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,23 @@ function DynamicPPL.marginalize(
101101
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
102102
kwargs...,
103103
)
104-
# Determine the indices for the variables to marginalise out.
105-
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames))
106104
# Construct the marginal log-density model.
107-
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
105+
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
106+
# Determine the indices for the variables to marginalise out.
107+
varindices = mapreduce(vcat, marginalized_varnames) do vn
108+
if DynamicPPL.getoptic(vn) === identity
109+
ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range
110+
else
111+
ldf._varname_ranges[vn].range
112+
end
113+
end
108114
mld = MarginalLogDensities.MarginalLogDensity(
109-
LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs...
115+
LogDensityFunctionWrapper(ldf, varinfo),
116+
varinfo[:],
117+
varindices,
118+
(),
119+
method;
120+
kwargs...,
110121
)
111122
return mld
112123
end

src/abstract_varinfo.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,10 @@ end
837837
function link!!(
838838
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
839839
)
840+
# TODO(mhauru) This assumes that the user has defined the bijector using the same
841+
# variable ordering as what `vi[:]` and `unflatten(vi, x)` use. This is a bad user
842+
# interface, and it's also dangerous for any AbstractVarInfo types that may not respect
843+
# a particular ordering, such as SimpleVarInfo{Dict}.
840844
b = inverse(t.bijector)
841845
x = vi[:]
842846
y, logjac = with_logabsdet_jacobian(b, x)
@@ -866,7 +870,7 @@ end
866870
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
867871
return link(default_transformation(model, vi), vi, vns, model)
868872
end
869-
function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
873+
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
870874
return link!!(t, deepcopy(vi), model)
871875
end
872876

@@ -932,7 +936,7 @@ end
932936
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
933937
return invlink(default_transformation(model, vi), vi, vns, model)
934938
end
935-
function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
939+
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
936940
return invlink!!(t, deepcopy(vi), model)
937941
end
938942

src/model.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,10 +1153,16 @@ end
11531153
function predict end
11541154

11551155
"""
1156-
returned(model::Model, parameters::NamedTuple)
1157-
returned(model::Model, parameters::AbstractDict{<:VarName})
1156+
returned(model::Model, parameters...)
11581157
1159-
Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.
1158+
Initialise a `model` using the given `parameters` and return the model's return value. The
1159+
parameters must be provided in a format that can be wrapped in an `InitFromParams`, i.e.,
1160+
`InitFromParams(parameters..., nothing)` must be a valid `AbstractInitStrategy` (where
1161+
`nothing` is the fallback strategy to use if parameters are not provided).
1162+
1163+
As far as DynamicPPL is concerned, `parameters` can be either a singular `NamedTuple` or an
1164+
`AbstractDict{<:VarName}`; however this method is left flexible to allow for other packages
1165+
that wish to extend `InitFromParams`.
11601166
11611167
# Example
11621168
```jldoctest
@@ -1177,7 +1183,7 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0))
11771183
(mp1 = 3.0,)
11781184
```
11791185
"""
1180-
function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}})
1186+
function returned(model::Model, parameters...)
11811187
# Note: we can't use `fix(model, parameters)` because
11821188
# https://github.com/TuringLang/DynamicPPL.jl/issues/1097
11831189
return first(
@@ -1186,7 +1192,7 @@ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarN
11861192
DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple()),
11871193
# Use `nothing` as the fallback to ensure that any missing parameters cause an
11881194
# error
1189-
InitFromParams(parameters, nothing),
1195+
InitFromParams(parameters..., nothing),
11901196
),
11911197
)
11921198
end

src/test_utils/ad.jl

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using DynamicPPL:
1212
AbstractVarInfo,
1313
getlogjoint_internal,
1414
link
15+
using LinearAlgebra: norm
1516
using LogDensityProblems: logdensity, logdensity_and_gradient
1617
using Random: AbstractRNG, default_rng
1718
using Statistics: median
@@ -78,6 +79,51 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
7879
value_actual::T
7980
grad_expected::Vector{T}
8081
grad_actual::Vector{T}
82+
atol::T
83+
rtol::T
84+
end
85+
function Base.showerror(io::IO, e::ADIncorrectException)
86+
value_passed = isapprox(e.value_expected, e.value_actual; atol=e.atol, rtol=e.rtol)
87+
grad_passed = isapprox(e.grad_expected, e.grad_actual; atol=e.atol, rtol=e.rtol)
88+
s = if !value_passed && !grad_passed
89+
"value and gradient"
90+
elseif !value_passed
91+
"value"
92+
else
93+
"gradient"
94+
end
95+
println(io, "ADIncorrectException: The AD backend returned an incorrect $s.")
96+
println(io, " Testing was carried out with")
97+
println(io, " atol : $(e.atol)")
98+
println(io, " rtol : $(e.rtol)")
99+
# calculate what tolerances would have been needed to pass for value
100+
if !value_passed
101+
min_atol_needed_to_pass_value = abs(e.value_expected - e.value_actual)
102+
min_rtol_needed_to_pass_value =
103+
min_atol_needed_to_pass_value / max(abs(e.value_expected), abs(e.value_actual))
104+
println(io, " The value check failed because:")
105+
println(io, " expected value : $(e.value_expected)")
106+
println(io, " actual value : $(e.value_actual)")
107+
println(io, " This value correctness check would have passed if either:")
108+
println(io, " atol ≥ $(min_atol_needed_to_pass_value), or")
109+
println(io, " rtol ≥ $(min_rtol_needed_to_pass_value)")
110+
end
111+
if !grad_passed
112+
norm_expected = norm(e.grad_expected)
113+
norm_actual = norm(e.grad_actual)
114+
max_norm = max(norm_expected, norm_actual)
115+
norm_diff = norm(e.grad_expected - e.grad_actual)
116+
min_atol_needed_to_pass_grad = norm_diff
117+
min_rtol_needed_to_pass_grad = norm_diff / max_norm
118+
# min tolerances needed to pass overall
119+
println(io, " The gradient check failed because:")
120+
println(io, " expected grad : $(e.grad_expected)")
121+
println(io, " actual grad : $(e.grad_actual)")
122+
println(io, " The gradient correctness check would have passed if either:")
123+
println(io, " atol ≥ $(min_atol_needed_to_pass_grad), or")
124+
println(io, " rtol ≥ $(min_rtol_needed_to_pass_grad)")
125+
end
126+
return nothing
81127
end
82128

83129
"""
@@ -116,11 +162,26 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa
116162
"The gradient of logp (calculated using `adtype`)"
117163
grad_actual::Vector{Tresult}
118164
"If benchmarking was requested, the time taken by the AD backend to evaluate the gradient
119-
of logp"
165+
of logp (in seconds)"
120166
grad_time::Union{Nothing,Tresult}
121-
"If benchmarking was requested, the time taken by the AD backend to evaluate logp"
167+
"If benchmarking was requested, the time taken by the AD backend to evaluate logp (in
168+
seconds)"
122169
primal_time::Union{Nothing,Tresult}
123170
end
171+
function Base.show(io::IO, ::MIME"text/plain", result::ADResult)
172+
printstyled(io, "ADResult\n"; bold=true)
173+
println(io, " ├ model : $(result.model.f)")
174+
println(io, " ├ adtype : $(result.adtype)")
175+
println(io, " ├ value_actual : $(result.value_actual)")
176+
println(io, " ├ value_expected : $(result.value_expected)")
177+
println(io, " ├ grad_actual : $(result.grad_actual)")
178+
println(io, " ├ grad_expected : $(result.grad_expected)")
179+
if result.grad_time !== nothing && result.primal_time !== nothing
180+
println(io, " ├ grad_time : $(result.grad_time) s")
181+
println(io, " ├ primal_time : $(result.primal_time) s")
182+
end
183+
return println(io, " └ params : $(result.params)")
184+
end
124185

125186
"""
126187
run_ad(
@@ -230,6 +291,14 @@ Everything else is optional, and can be categorised into several groups:
230291
we cannot know the magnitude of logp and its gradient a priori. The `atol`
231292
value is supplied to handle the case where gradients are equal to zero.
232293
294+
1. _Whether to benchmark._
295+
296+
By default, benchmarking is disabled. To enable it, set `benchmark=true`.
297+
When enabled, the time taken to evaluate logp as well as its gradient is
298+
measured using Chairmarks.jl, and the `ADResult` object returned will
299+
contain `grad_time` and `primal_time` fields with the median times (in
300+
seconds).
301+
233302
1. _Whether to output extra logging information._
234303
235304
By default, this function prints messages when it runs. To silence it, set
@@ -297,7 +366,7 @@ function run_ad(
297366
end
298367
# Perform testing
299368
verbose && println(" expected : $((value_true, grad_true))")
300-
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
369+
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true, atol, rtol))
301370
isapprox(value, value_true; atol=atol, rtol=rtol) || exc()
302371
isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc()
303372
end
@@ -306,14 +375,18 @@ function run_ad(
306375
grad_time, primal_time = if benchmark
307376
logdensity(ldf, params) # Warm-up
308377
primal_benchmark = @be logdensity($ldf, $params)
378+
print(" evaluation : ")
379+
show(stdout, MIME("text/plain"), median(primal_benchmark))
380+
println()
309381
logdensity_and_gradient(ldf, params) # Warm-up
310382
grad_benchmark = @be logdensity_and_gradient($ldf, $params)
383+
print(" gradient : ")
384+
show(stdout, MIME("text/plain"), median(grad_benchmark))
385+
println()
311386
median_primal = median(primal_benchmark).time
312387
median_grad = median(grad_benchmark).time
313388
r(f) = round(f; sigdigits=4)
314-
verbose && println(
315-
"grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))",
316-
)
389+
verbose && println(" grad / eval : $(r(median_grad / median_primal))")
317390
(median_grad, median_primal)
318391
else
319392
nothing, nothing

src/test_utils/model_interface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ Even though it is recommended to implement this by hand for a particular `Model`
9292
a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided.
9393
"""
9494
function varnames(model::Model)
95-
return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict())))))
95+
result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict())))))
96+
# Concretise the element type.
97+
return [x for x in result]
9698
end
9799

98100
"""

src/test_utils/models.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ function logprior_true_with_logabsdet_jacobian(
3434
x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x)
3535
return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp
3636
end
37+
function rand_prior_true(rng::Random.AbstractRNG, ::Model{typeof(demo_dynamic_constraint)})
38+
m = rand(rng, Normal())
39+
x = rand(rng, truncated(Normal(); lower=m))
40+
return (m=m, x=x)
41+
end
3742

3843
"""
3944
demo_one_variable_multiple_constraints()
@@ -109,12 +114,12 @@ x ~ LKJCholesky(d, 1.0)
109114
```
110115
"""
111116
@model function demo_lkjchol(d::Int=2)
112-
x ~ LKJCholesky(d, 1.0)
117+
x ~ LKJCholesky(d, 1.5)
113118
return (x=x,)
114119
end
115120

116121
function logprior_true(model::Model{typeof(demo_lkjchol)}, x)
117-
return logpdf(LKJCholesky(model.args.d, 1.0), x)
122+
return logpdf(LKJCholesky(model.args.d, 1.5), x)
118123
end
119124

120125
function loglikelihood_true(model::Model{typeof(demo_lkjchol)}, x)
@@ -163,6 +168,9 @@ end
163168
function loglikelihood_true(::Model{typeof(demo_static_transformation)}, s, m)
164169
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
165170
end
171+
function varnames(::Model{typeof(demo_static_transformation)})
172+
return [@varname(s), @varname(m)]
173+
end
166174
function logprior_true_with_logabsdet_jacobian(
167175
model::Model{typeof(demo_static_transformation)}, s, m
168176
)
@@ -557,22 +565,6 @@ function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)})
557565
return [@varname(s), @varname(m)]
558566
end
559567

560-
const DemoModels = Union{
561-
Model{typeof(demo_dot_assume_observe)},
562-
Model{typeof(demo_assume_index_observe)},
563-
Model{typeof(demo_assume_multivariate_observe)},
564-
Model{typeof(demo_dot_assume_observe_index)},
565-
Model{typeof(demo_assume_dot_observe)},
566-
Model{typeof(demo_assume_dot_observe_literal)},
567-
Model{typeof(demo_assume_observe_literal)},
568-
Model{typeof(demo_assume_multivariate_observe_literal)},
569-
Model{typeof(demo_dot_assume_observe_index_literal)},
570-
Model{typeof(demo_assume_submodel_observe_index_literal)},
571-
Model{typeof(demo_dot_assume_observe_submodel)},
572-
Model{typeof(demo_dot_assume_observe_matrix_index)},
573-
Model{typeof(demo_assume_matrix_observe_matrix_index)},
574-
}
575-
576568
const UnivariateAssumeDemoModels = Union{
577569
Model{typeof(demo_assume_dot_observe)},
578570
Model{typeof(demo_assume_dot_observe_literal)},
@@ -758,3 +750,14 @@ const DEMO_MODELS = (
758750
demo_dot_assume_observe_matrix_index(),
759751
demo_assume_matrix_observe_matrix_index(),
760752
)
753+
754+
"""
755+
A tuple of all models defined in DynamicPPL.TestUtils.
756+
"""
757+
const ALL_MODELS = (
758+
DEMO_MODELS...,
759+
demo_dynamic_constraint(),
760+
demo_one_variable_multiple_constraints(),
761+
demo_lkjchol(),
762+
demo_static_transformation(),
763+
)

src/test_utils/varinfo.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in
1010
"""
1111
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...)
1212
for vn in vns
13-
@test compare(vi[vn], get(vals, vn); kwargs...)
13+
val = get(vals, vn)
14+
# TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404
15+
# Remove once the fix is all Julia versions we support.
16+
if val isa Cholesky
17+
@test compare(vi[vn].L, val.L; kwargs...)
18+
else
19+
@test compare(vi[vn], val; kwargs...)
20+
end
1421
end
1522
end
1623

src/threadsafe.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,24 @@ function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
8585
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...)
8686
end
8787

88-
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
89-
return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...)
88+
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model)
89+
return Accessors.@set vi.varinfo = link(t, vi.varinfo, model)
9090
end
9191

92-
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
93-
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...)
92+
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model)
93+
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, model)
94+
end
95+
96+
function link(
97+
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model
98+
)
99+
return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model)
100+
end
101+
102+
function invlink(
103+
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model
104+
)
105+
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model)
94106
end
95107

96108
# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
@@ -155,10 +167,6 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:
155167
end
156168

157169
vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo)
158-
vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn)
159-
function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
160-
return vector_getranges(vi.varinfo, vns)
161-
end
162170

163171
isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
164172
function BangBang.empty!!(vi::ThreadSafeVarInfo)

0 commit comments

Comments
 (0)