diff --git a/docs/src/api.md b/docs/src/api.md index 193a6ce4c..686549e9b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -267,6 +267,7 @@ DynamicPPL provides several demo models in the `DynamicPPL.TestUtils` submodule. ```@docs DynamicPPL.TestUtils.DEMO_MODELS +DynamicPPL.TestUtils.ALL_MODELS ``` For every demo model, one can define the true log prior, log likelihood, and log joint probabilities. diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ec5e1ea10..898b6caf9 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -837,6 +837,10 @@ end function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) + # TODO(mhauru) This assumes that the user has defined the bijector using the same + # variable ordering as what `vi[:]` and `unflatten(vi, x)` use. This is a bad user + # interface, and it's also dangerous for any AbstractVarInfo types that may not respect + # a particular ordering, such as SimpleVarInfo{Dict}. b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) @@ -866,7 +870,7 @@ end function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link(default_transformation(model, vi), vi, vns, model) end -function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) +function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return link!!(t, deepcopy(vi), model) end @@ -932,7 +936,7 @@ end function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end -function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) +function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return invlink!!(t, deepcopy(vi), model) end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index cb949464e..e7fb16fbe 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,7 +92,9 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) + result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict()))))) + # Concretise the element type. + return [x for x in result] end """ diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 8ffb7cbdf..84e1f10d8 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -34,6 +34,11 @@ function logprior_true_with_logabsdet_jacobian( x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x) return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp end +function rand_prior_true(rng::Random.AbstractRNG, ::Model{typeof(demo_dynamic_constraint)}) + m = rand(rng, Normal()) + x = rand(rng, truncated(Normal(); lower=m)) + return (m=m, x=x) +end """ demo_one_variable_multiple_constraints() @@ -109,12 +114,12 @@ x ~ LKJCholesky(d, 1.0) ``` """ @model function demo_lkjchol(d::Int=2) - x ~ LKJCholesky(d, 1.0) + x ~ LKJCholesky(d, 1.5) return (x=x,) end function logprior_true(model::Model{typeof(demo_lkjchol)}, x) - return logpdf(LKJCholesky(model.args.d, 1.0), x) + return logpdf(LKJCholesky(model.args.d, 1.5), x) end function loglikelihood_true(model::Model{typeof(demo_lkjchol)}, x) @@ -163,6 +168,9 @@ end function loglikelihood_true(::Model{typeof(demo_static_transformation)}, s, m) return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0) end +function varnames(::Model{typeof(demo_static_transformation)}) + return [@varname(s), @varname(m)] +end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_static_transformation)}, s, m ) @@ -557,22 +565,6 @@ function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end -const DemoModels = Union{ - Model{typeof(demo_dot_assume_observe)}, - Model{typeof(demo_assume_index_observe)}, - Model{typeof(demo_assume_multivariate_observe)}, - Model{typeof(demo_dot_assume_observe_index)}, - Model{typeof(demo_assume_dot_observe)}, - Model{typeof(demo_assume_dot_observe_literal)}, - Model{typeof(demo_assume_observe_literal)}, - Model{typeof(demo_assume_multivariate_observe_literal)}, - Model{typeof(demo_dot_assume_observe_index_literal)}, - Model{typeof(demo_assume_submodel_observe_index_literal)}, - Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_observe_matrix_index)}, - Model{typeof(demo_assume_matrix_observe_matrix_index)}, -} - const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)}, Model{typeof(demo_assume_dot_observe_literal)}, @@ -758,3 +750,14 @@ const DEMO_MODELS = ( demo_dot_assume_observe_matrix_index(), demo_assume_matrix_observe_matrix_index(), ) + +""" +A tuple of all models defined in DynamicPPL.TestUtils. +""" +const ALL_MODELS = ( + DEMO_MODELS..., + demo_dynamic_constraint(), + demo_one_variable_multiple_constraints(), + demo_lkjchol(), + demo_static_transformation(), +) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 26e2aa7ca..6483b29e8 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -10,7 +10,14 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in """ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...) for vn in vns - @test compare(vi[vn], get(vals, vn); kwargs...) + val = get(vals, vn) + # TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 + # Remove once the fix is all Julia versions we support. + if val isa Cholesky + @test compare(vi[vn].L, val.L; kwargs...) + else + @test compare(vi[vn], val; kwargs...) + end end end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 6d3acce6c..c7ab106a2 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -85,12 +85,24 @@ function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) +function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, model) end -function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) +function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, model) +end + +function link( + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model +) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) +end + +function invlink( + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model +) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. diff --git a/test/chains.jl b/test/chains.jl index 498e2e912..36c274b62 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -67,7 +67,7 @@ using Test end @testset "ParamsWithStats from LogDensityFunction" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS unlinked_vi = VarInfo(m) @testset "$islinked" for islinked in (false, true) vi = if islinked diff --git a/test/contexts.jl b/test/contexts.jl index ae7332a43..9621013ac 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -179,7 +179,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) end - @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS prefix_vn = @varname(my_prefix) context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) new_model = contextualize(model, context) @@ -423,7 +423,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), ), ("SVI+NamedTuple", SimpleVarInfo()), - ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), + ("Svi+Dict", SimpleVarInfo(OrderedDict{VarName,Any}())), ] @model function test_init_model() diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index c74beefdb..e46c25113 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -61,7 +61,21 @@ end @testset "demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS + if model.f === DynamicPPL.TestUtils.demo_lkjchol + # TODO(mhauru) + # The LKJCholesky model fails with JET. The problem is not with Turing but + # with Distributions, and ultimately this in LinearAlgebra: + # julia> v = @view rand(2,2)[:,1]; + # + # julia> JET.@report_call norm(v) + # ═════ 2 possible errors found ═════ + # blahblah + # The below trivial call to @test is just marking that there's something + # broken here. + @test false broken = true + continue + end # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) # Check that the inferred varinfo is indeed suitable for evaluation diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index edfd67d18..b99607aeb 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -1,4 +1,4 @@ -using DynamicPPL.TestUtils: DEMO_MODELS +using DynamicPPL.TestUtils: ALL_MODELS using DynamicPPL.TestUtils.AD: run_ad using ADTypes: AutoEnzyme using Test: @test, @testset @@ -17,7 +17,7 @@ ADTYPES = ( ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES - @testset "$(model.f)" for model in DEMO_MODELS + @testset "$(model.f)" for model in ALL_MODELS @test run_ad(model, ad_type) isa Any end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 383d7593d..011cb22ce 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -16,7 +16,7 @@ using ReverseDiff: ReverseDiff using Mooncake: Mooncake @testset "LogDensityFunction: Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS @testset "$varinfo_func" for varinfo_func in [ DynamicPPL.untyped_varinfo, DynamicPPL.typed_varinfo, @@ -107,7 +107,7 @@ end end @testset "LogDensityFunction: Type stability" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS unlinked_vi = DynamicPPL.VarInfo(m) @testset "$islinked" for islinked in (false, true) vi = if islinked @@ -163,7 +163,7 @@ end ] @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS varinfo = VarInfo(m) linked_varinfo = DynamicPPL.link(varinfo, m) f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) @@ -187,7 +187,7 @@ end end @testset "logdensity_and_gradient with views" begin - # This test ensures that you can call `logdensity_and_gradient` with an array + # This test ensures that you can call `logdensity_and_gradient` with an array # type that isn't the same as the one used in the gradient preparation. @model function f() x ~ Normal() diff --git a/test/model.jl b/test/model.jl index 6da5ea246..c878fd905 100644 --- a/test/model.jl +++ b/test/model.jl @@ -431,7 +431,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end @testset "values_as_in_model" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) diff --git a/test/model_utils.jl b/test/model_utils.jl index af695dbf2..1547ad276 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -1,6 +1,6 @@ @testset "model_utils.jl" begin @testset "value_iterator_from_chain" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$model" for model in DynamicPPL.TestUtils.ALL_MODELS # Check that the values generated by value_iterator_from_chain # match the values in the original chain chain = make_chain_from_prior(model, 10) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 488cb8941..42e377440 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -29,26 +29,26 @@ end @testset "Dict" begin - svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) + svi = SimpleVarInfo(OrderedDict(@varname(m) => 1.0)) @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) - svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) + svi = SimpleVarInfo(OrderedDict(@varname(m) => [1.0])) @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @test svi[@varname(m)][1] == svi[@varname(m[1])] - svi = SimpleVarInfo(Dict(@varname(m) => (a=[1.0],))) + svi = SimpleVarInfo(OrderedDict(@varname(m) => (a=[1.0],))) @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m.a)) @test haskey(svi, @varname(m.a[1])) @test !haskey(svi, @varname(m.a[2])) @test !haskey(svi, @varname(m.a.b)) - svi = SimpleVarInfo(Dict(@varname(m.a) => [1.0])) + svi = SimpleVarInfo(OrderedDict(@varname(m.a) => [1.0])) # Now we only have a variable `m.a` which is subsumed by `m`, # but we can't guarantee that we have the "entire" `m`. @test !haskey(svi, @varname(m)) @@ -87,14 +87,23 @@ end @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS + DynamicPPL.TestUtils.ALL_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$name" for (name, vi) in ( - ("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())), + ("SVI{Dict}", SimpleVarInfo(OrderedDict{VarName,Any}())), ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), ) + if name == "SVI{NamedTuple}" && + model.f === DynamicPPL.TestUtils.demo_one_variable_multiple_constraints + # TODO(mhauru) There's a bug in SimpleVarInfo{<:NamedTuple} for cases where + # a variable set with IndexLenses changes dimension under linking. This + # makes the link!! call crash. The below call to @test just marks the fact + # that there's something broken here. + @test false broken = true + continue + end for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end @@ -134,7 +143,7 @@ end @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS + DynamicPPL.TestUtils.ALL_MODELS # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) @@ -213,7 +222,15 @@ # Values should not have changed. for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_eval[vn] == get(values_eval, vn) + # TODO(mhauru) Workaround for + # https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 + # Remove once the fix is all Julia versions we support. + val = get(values_eval, vn) + if val isa Cholesky + @test svi_eval[vn].L == val.L + else + @test svi_eval[vn] == val + end end # Compare log-probability computations. diff --git a/test/varinfo.jl b/test/varinfo.jl index 0d0ddc15d..a7948cc32 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -101,7 +101,7 @@ end test_base(VarInfo()) test_base(DynamicPPL.typed_varinfo(VarInfo())) test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(Dict{VarName,Any}())) + test_base(SimpleVarInfo(OrderedDict{VarName,Any}())) test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @@ -130,7 +130,7 @@ end test_varinfo_logp!(vi) test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) test_varinfo_logp!(SimpleVarInfo()) - test_varinfo_logp!(SimpleVarInfo(Dict())) + test_varinfo_logp!(SimpleVarInfo(OrderedDict())) test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @@ -318,7 +318,7 @@ end end @testset "returned on MCMCChains.Chains" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS chain = make_chain_from_prior(model, 10) # A simple way of checking that the computation is determinstic: run twice and compare. res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) @@ -451,7 +451,7 @@ end test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(Dict{VarName,Any}()), true) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true) test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}` @@ -460,7 +460,7 @@ end end @testset "values_as" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.ALL_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) vns = DynamicPPL.TestUtils.varnames(model) @@ -730,7 +730,7 @@ end end @testset "merge" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos( model, @@ -825,7 +825,7 @@ end end @testset "issue #842" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] + model = DynamicPPL.TestUtils.demo_dot_assume_observe() varinfo = VarInfo(model) n = length(varinfo[:]) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 3a327c147..9a4ef12c3 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -667,7 +667,7 @@ end end @testset "VarInfo + VarNamedVector" begin - models = DynamicPPL.TestUtils.DEMO_MODELS + models = DynamicPPL.TestUtils.ALL_MODELS @testset "$(model.f)" for model in models # NOTE: Need to set random seed explicitly to avoid using the same seed # for initialization as for sampling in the inner testset below.