Skip to content
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ DynamicPPL provides several demo models in the `DynamicPPL.TestUtils` submodule.

```@docs
DynamicPPL.TestUtils.DEMO_MODELS
DynamicPPL.TestUtils.ALL_MODELS
```

For every demo model, one can define the true log prior, log likelihood, and log joint probabilities.
Expand Down
8 changes: 6 additions & 2 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,10 @@ end
function link!!(
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
)
# TODO(mhauru) This assumes that the user has defined the bijector using the same
# variable ordering as what `vi[:]` and `unflatten(vi, x)` use. This is a bad user
# interface, and it's also dangerous for any AbstractVarInfo types that may not respect
# a particular ordering, such as SimpleVarInfo{Dict}.
b = inverse(t.bijector)
x = vi[:]
y, logjac = with_logabsdet_jacobian(b, x)
Expand Down Expand Up @@ -866,7 +870,7 @@ end
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return link(default_transformation(model, vi), vi, vns, model)
end
function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return link!!(t, deepcopy(vi), model)
end

Expand Down Expand Up @@ -932,7 +936,7 @@ end
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
return invlink(default_transformation(model, vi), vi, vns, model)
end
function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
return invlink!!(t, deepcopy(vi), model)
end

Expand Down
4 changes: 3 additions & 1 deletion src/test_utils/model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ Even though it is recommended to implement this by hand for a particular `Model`
a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided.
"""
function varnames(model::Model)
return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict())))))
result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict())))))
# Concretise the element type.
return [x for x in result]
end

"""
Expand Down
39 changes: 21 additions & 18 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ function logprior_true_with_logabsdet_jacobian(
x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x)
return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp
end
function rand_prior_true(rng::Random.AbstractRNG, ::Model{typeof(demo_dynamic_constraint)})
m = rand(rng, Normal())
x = rand(rng, truncated(Normal(); lower=m))
return (m=m, x=x)
end

"""
demo_one_variable_multiple_constraints()
Expand Down Expand Up @@ -109,12 +114,12 @@ x ~ LKJCholesky(d, 1.0)
```
"""
@model function demo_lkjchol(d::Int=2)
x ~ LKJCholesky(d, 1.0)
x ~ LKJCholesky(d, 1.5)
return (x=x,)
end

function logprior_true(model::Model{typeof(demo_lkjchol)}, x)
return logpdf(LKJCholesky(model.args.d, 1.0), x)
return logpdf(LKJCholesky(model.args.d, 1.5), x)
end

function loglikelihood_true(model::Model{typeof(demo_lkjchol)}, x)
Expand Down Expand Up @@ -163,6 +168,9 @@ end
function loglikelihood_true(::Model{typeof(demo_static_transformation)}, s, m)
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
end
function varnames(::Model{typeof(demo_static_transformation)})
return [@varname(s), @varname(m)]
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_static_transformation)}, s, m
)
Expand Down Expand Up @@ -557,22 +565,6 @@ function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)})
return [@varname(s), @varname(m)]
end

const DemoModels = Union{
Model{typeof(demo_dot_assume_observe)},
Model{typeof(demo_assume_index_observe)},
Model{typeof(demo_assume_multivariate_observe)},
Model{typeof(demo_dot_assume_observe_index)},
Model{typeof(demo_assume_dot_observe)},
Model{typeof(demo_assume_dot_observe_literal)},
Model{typeof(demo_assume_observe_literal)},
Model{typeof(demo_assume_multivariate_observe_literal)},
Model{typeof(demo_dot_assume_observe_index_literal)},
Model{typeof(demo_assume_submodel_observe_index_literal)},
Model{typeof(demo_dot_assume_observe_submodel)},
Model{typeof(demo_dot_assume_observe_matrix_index)},
Model{typeof(demo_assume_matrix_observe_matrix_index)},
}

const UnivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_dot_observe)},
Model{typeof(demo_assume_dot_observe_literal)},
Expand Down Expand Up @@ -758,3 +750,14 @@ const DEMO_MODELS = (
demo_dot_assume_observe_matrix_index(),
demo_assume_matrix_observe_matrix_index(),
)

"""
A tuple of all models defined in DynamicPPL.TestUtils.
"""
const ALL_MODELS = (
DEMO_MODELS...,
demo_dynamic_constraint(),
demo_one_variable_multiple_constraints(),
demo_lkjchol(),
demo_static_transformation(),
)
9 changes: 8 additions & 1 deletion src/test_utils/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in
"""
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...)
for vn in vns
@test compare(vi[vn], get(vals, vn); kwargs...)
val = get(vals, vn)
# TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404
# Remove once the fix is all Julia versions we support.
if val isa Cholesky
@test compare(vi[vn].L, val.L; kwargs...)
else
@test compare(vi[vn], val; kwargs...)
end
end
end

Expand Down
20 changes: 16 additions & 4 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,24 @@ function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...)
end

function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...)
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, model)
end

function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...)
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, model)
end

function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model
)
return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model)
end

function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model
)
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model)
end

# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
Expand Down
2 changes: 1 addition & 1 deletion test/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ using Test
end

@testset "ParamsWithStats from LogDensityFunction" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS
unlinked_vi = VarInfo(m)
@testset "$islinked" for islinked in (false, true)
vi = if islinked
Expand Down
4 changes: 2 additions & 2 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
@test new_ctx == FixedContext((b=4,), ConditionContext((a=1,)))
end

@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS
prefix_vn = @varname(my_prefix)
context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext())
new_model = contextualize(model, context)
Expand Down Expand Up @@ -423,7 +423,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())),
),
("SVI+NamedTuple", SimpleVarInfo()),
("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())),
("Svi+Dict", SimpleVarInfo(OrderedDict{VarName,Any}())),
]

@model function test_init_model()
Expand Down
16 changes: 15 additions & 1 deletion test/ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,21 @@
end

@testset "demo models" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS
if model.f === DynamicPPL.TestUtils.demo_lkjchol
# TODO(mhauru)
# The LKJCholesky model fails with JET. The problem is not with Turing but
# with Distributions, and ultimately this in LinearAlgebra:
# julia> v = @view rand(2,2)[:,1];
#
# julia> JET.@report_call norm(v)
# ═════ 2 possible errors found ═════
# blahblah
# The below trivial call to @test is just marking that there's something
# broken here.
@test false broken = true
continue
end
# Use debug logging below.
varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model)
# Check that the inferred varinfo is indeed suitable for evaluation
Expand Down
4 changes: 2 additions & 2 deletions test/integration/enzyme/main.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DynamicPPL.TestUtils: DEMO_MODELS
using DynamicPPL.TestUtils: ALL_MODELS
using DynamicPPL.TestUtils.AD: run_ad
using ADTypes: AutoEnzyme
using Test: @test, @testset
Expand All @@ -17,7 +17,7 @@ ADTYPES = (
)

@testset "$ad_key" for (ad_key, ad_type) in ADTYPES
@testset "$(model.f)" for model in DEMO_MODELS
@testset "$(model.f)" for model in ALL_MODELS
@test run_ad(model, ad_type) isa Any
end
end
8 changes: 4 additions & 4 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using ReverseDiff: ReverseDiff
using Mooncake: Mooncake

@testset "LogDensityFunction: Correctness" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS
@testset "$varinfo_func" for varinfo_func in [
DynamicPPL.untyped_varinfo,
DynamicPPL.typed_varinfo,
Expand Down Expand Up @@ -107,7 +107,7 @@ end
end

@testset "LogDensityFunction: Type stability" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS
unlinked_vi = DynamicPPL.VarInfo(m)
@testset "$islinked" for islinked in (false, true)
vi = if islinked
Expand Down Expand Up @@ -163,7 +163,7 @@ end
]

@testset "Correctness" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS
varinfo = VarInfo(m)
linked_varinfo = DynamicPPL.link(varinfo, m)
f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo)
Expand All @@ -187,7 +187,7 @@ end
end

@testset "logdensity_and_gradient with views" begin
# This test ensures that you can call `logdensity_and_gradient` with an array
# This test ensures that you can call `logdensity_and_gradient` with an array
# type that isn't the same as the one used in the gradient preparation.
@model function f()
x ~ Normal()
Expand Down
2 changes: 1 addition & 1 deletion test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end

@testset "values_as_in_model" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
Expand Down
2 changes: 1 addition & 1 deletion test/model_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testset "model_utils.jl" begin
@testset "value_iterator_from_chain" begin
@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$model" for model in DynamicPPL.TestUtils.ALL_MODELS
# Check that the values generated by value_iterator_from_chain
# match the values in the original chain
chain = make_chain_from_prior(model, 10)
Expand Down
33 changes: 25 additions & 8 deletions test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@
end

@testset "Dict" begin
svi = SimpleVarInfo(Dict(@varname(m) => 1.0))
svi = SimpleVarInfo(OrderedDict(@varname(m) => 1.0))
@test getlogjoint(svi) == 0.0
@test haskey(svi, @varname(m))
@test !haskey(svi, @varname(m[1]))

svi = SimpleVarInfo(Dict(@varname(m) => [1.0]))
svi = SimpleVarInfo(OrderedDict(@varname(m) => [1.0]))
@test getlogjoint(svi) == 0.0
@test haskey(svi, @varname(m))
@test haskey(svi, @varname(m[1]))
@test !haskey(svi, @varname(m[2]))
@test svi[@varname(m)][1] == svi[@varname(m[1])]

svi = SimpleVarInfo(Dict(@varname(m) => (a=[1.0],)))
svi = SimpleVarInfo(OrderedDict(@varname(m) => (a=[1.0],)))
@test haskey(svi, @varname(m))
@test haskey(svi, @varname(m.a))
@test haskey(svi, @varname(m.a[1]))
@test !haskey(svi, @varname(m.a[2]))
@test !haskey(svi, @varname(m.a.b))

svi = SimpleVarInfo(Dict(@varname(m.a) => [1.0]))
svi = SimpleVarInfo(OrderedDict(@varname(m.a) => [1.0]))
# Now we only have a variable `m.a` which is subsumed by `m`,
# but we can't guarantee that we have the "entire" `m`.
@test !haskey(svi, @varname(m))
Expand Down Expand Up @@ -87,14 +87,23 @@
end

@testset "link!! & invlink!! on $(nameof(model))" for model in
DynamicPPL.TestUtils.DEMO_MODELS
DynamicPPL.TestUtils.ALL_MODELS
values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
@testset "$name" for (name, vi) in (
("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())),
("SVI{Dict}", SimpleVarInfo(OrderedDict{VarName,Any}())),
("SVI{NamedTuple}", SimpleVarInfo(values_constrained)),
("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())),
("TypedVarInfo", DynamicPPL.typed_varinfo(model)),
)
if name == "SVI{NamedTuple}" &&
model.f === DynamicPPL.TestUtils.demo_one_variable_multiple_constraints
# TODO(mhauru) There's a bug in SimpleVarInfo{<:NamedTuple} for cases where
# a variable set with IndexLenses changes dimension under linking. This
# makes the link!! call crash. The below call to @test just marks the fact
# that there's something broken here.
@test false broken = true
continue
end
for vn in DynamicPPL.TestUtils.varnames(model)
vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn)
end
Expand Down Expand Up @@ -134,7 +143,7 @@
end

@testset "SimpleVarInfo on $(nameof(model))" for model in
DynamicPPL.TestUtils.DEMO_MODELS
DynamicPPL.TestUtils.ALL_MODELS
# We might need to pre-allocate for the variable `m`, so we need
# to see whether this is the case.
svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model))
Expand Down Expand Up @@ -213,7 +222,15 @@

# Values should not have changed.
for vn in DynamicPPL.TestUtils.varnames(model)
@test svi_eval[vn] == get(values_eval, vn)
# TODO(mhauru) Workaround for
# https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404
# Remove once the fix is all Julia versions we support.
val = get(values_eval, vn)
if val isa Cholesky
@test svi_eval[vn].L == val.L
else
@test svi_eval[vn] == val
end
end

# Compare log-probability computations.
Expand Down
Loading