From b60d5ad6c0faeb5b3323b6824ffea7298e20f170 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Dec 2025 18:59:54 +0000 Subject: [PATCH 01/11] Use ALL_MODELS rather than DEMO_MODELS --- docs/src/api.md | 1 + src/test_utils/model_interface.jl | 4 +++- src/test_utils/models.jl | 35 +++++++++++++++++-------------- src/test_utils/varinfo.jl | 9 +++++++- test/chains.jl | 2 +- test/contexts.jl | 4 ++-- test/debug_utils.jl | 2 +- test/ext/DynamicPPLJETExt.jl | 2 +- test/integration/enzyme/main.jl | 4 ++-- test/logdensityfunction.jl | 8 +++---- test/model.jl | 6 +++--- test/model_utils.jl | 2 +- test/pointwise_logdensities.jl | 2 +- test/simple_varinfo.jl | 4 ++-- test/varinfo.jl | 10 ++++----- test/varnamedvector.jl | 2 +- 16 files changed, 55 insertions(+), 42 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 193a6ce4c..686549e9b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -267,6 +267,7 @@ DynamicPPL provides several demo models in the `DynamicPPL.TestUtils` submodule. ```@docs DynamicPPL.TestUtils.DEMO_MODELS +DynamicPPL.TestUtils.ALL_MODELS ``` For every demo model, one can define the true log prior, log likelihood, and log joint probabilities. diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index cb949464e..19422bc86 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,7 +92,9 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) + result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) + # Concretise the element type. + return [x for x in result] end """ diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 8ffb7cbdf..95756c46e 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -34,6 +34,11 @@ function logprior_true_with_logabsdet_jacobian( x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x) return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp end +function rand_prior_true(rng::Random.AbstractRNG, ::Model{typeof(demo_dynamic_constraint)}) + m = rand(rng, Normal()) + x = rand(rng, truncated(Normal(); lower=m)) + return (m=m, x=x) +end """ demo_one_variable_multiple_constraints() @@ -163,6 +168,9 @@ end function loglikelihood_true(::Model{typeof(demo_static_transformation)}, s, m) return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0) end +function varnames(::Model{typeof(demo_static_transformation)}) + return [@varname(s), @varname(m)] +end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_static_transformation)}, s, m ) @@ -557,22 +565,6 @@ function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end -const DemoModels = Union{ - Model{typeof(demo_dot_assume_observe)}, - Model{typeof(demo_assume_index_observe)}, - Model{typeof(demo_assume_multivariate_observe)}, - Model{typeof(demo_dot_assume_observe_index)}, - Model{typeof(demo_assume_dot_observe)}, - Model{typeof(demo_assume_dot_observe_literal)}, - Model{typeof(demo_assume_observe_literal)}, - Model{typeof(demo_assume_multivariate_observe_literal)}, - Model{typeof(demo_dot_assume_observe_index_literal)}, - Model{typeof(demo_assume_submodel_observe_index_literal)}, - Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_observe_matrix_index)}, - Model{typeof(demo_assume_matrix_observe_matrix_index)}, -} - const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)}, Model{typeof(demo_assume_dot_observe_literal)}, @@ -758,3 +750,14 @@ const DEMO_MODELS = ( demo_dot_assume_observe_matrix_index(), demo_assume_matrix_observe_matrix_index(), ) + +""" +A tuple of all models defined in DynamicPPL.TestUtils. +""" +const ALL_MODELS = ( + DEMO_MODELS..., + demo_dynamic_constraint(), + demo_one_variable_multiple_constraints(), + demo_lkjchol(), + demo_static_transformation(), +) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 26e2aa7ca..6483b29e8 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -10,7 +10,14 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in """ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...) for vn in vns - @test compare(vi[vn], get(vals, vn); kwargs...) + val = get(vals, vn) + # TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 + # Remove once the fix is all Julia versions we support. + if val isa Cholesky + @test compare(vi[vn].L, val.L; kwargs...) + else + @test compare(vi[vn], val; kwargs...) + end end end diff --git a/test/chains.jl b/test/chains.jl index 498e2e912..36c274b62 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -67,7 +67,7 @@ using Test end @testset "ParamsWithStats from LogDensityFunction" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS unlinked_vi = VarInfo(m) @testset "$islinked" for islinked in (false, true) vi = if islinked diff --git a/test/contexts.jl b/test/contexts.jl index ae7332a43..2aad7f52f 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -59,7 +59,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() ) @testset "$(name)" for (name, context) in contexts - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS DynamicPPL.TestUtils.test_context(context, model) end end @@ -179,7 +179,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) end - @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS prefix_vn = @varname(my_prefix) context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) new_model = contextualize(model, context) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index f950f6b45..fb664fff1 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,5 +1,5 @@ @testset "check_model" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS issuccess, trace = check_model_and_trace(model, VarInfo(model)) # These models should all work. @test issuccess diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index c74beefdb..18ee50ffe 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -61,7 +61,7 @@ end @testset "demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) # Check that the inferred varinfo is indeed suitable for evaluation diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index edfd67d18..b99607aeb 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -1,4 +1,4 @@ -using DynamicPPL.TestUtils: DEMO_MODELS +using DynamicPPL.TestUtils: ALL_MODELS using DynamicPPL.TestUtils.AD: run_ad using ADTypes: AutoEnzyme using Test: @test, @testset @@ -17,7 +17,7 @@ ADTYPES = ( ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES - @testset "$(model.f)" for model in DEMO_MODELS + @testset "$(model.f)" for model in ALL_MODELS @test run_ad(model, ad_type) isa Any end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 383d7593d..011cb22ce 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -16,7 +16,7 @@ using ReverseDiff: ReverseDiff using Mooncake: Mooncake @testset "LogDensityFunction: Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS @testset "$varinfo_func" for varinfo_func in [ DynamicPPL.untyped_varinfo, DynamicPPL.typed_varinfo, @@ -107,7 +107,7 @@ end end @testset "LogDensityFunction: Type stability" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS unlinked_vi = DynamicPPL.VarInfo(m) @testset "$islinked" for islinked in (false, true) vi = if islinked @@ -163,7 +163,7 @@ end ] @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS varinfo = VarInfo(m) linked_varinfo = DynamicPPL.link(varinfo, m) f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) @@ -187,7 +187,7 @@ end end @testset "logdensity_and_gradient with views" begin - # This test ensures that you can call `logdensity_and_gradient` with an array + # This test ensures that you can call `logdensity_and_gradient` with an array # type that isn't the same as the one used in the gradient preparation. @model function f() x ~ Normal() diff --git a/test/model.jl b/test/model.jl index 6da5ea246..8105b84bc 100644 --- a/test/model.jl +++ b/test/model.jl @@ -57,7 +57,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test ljoint ≈ lp #### logprior, logjoint, loglikelihood for MCMC chains #### - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS N = 200 chain = make_chain_from_prior(model, N) logpriors = logprior(model, chain) @@ -290,7 +290,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end @testset "TestUtils" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS x = DynamicPPL.TestUtils.rand_prior_true(model) # `rand_prior_true` should return a `NamedTuple`. @test x isa NamedTuple @@ -431,7 +431,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end @testset "values_as_in_model" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) diff --git a/test/model_utils.jl b/test/model_utils.jl index af695dbf2..1547ad276 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -1,6 +1,6 @@ @testset "model_utils.jl" begin @testset "value_iterator_from_chain" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$model" for model in DynamicPPL.TestUtils.ALL_MODELS # Check that the values generated by value_iterator_from_chain # match the values in the original chain chain = make_chain_from_prior(model, 10) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 780d45b46..821f6a04d 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,5 +1,5 @@ @testset "pointwise_logdensities.jl" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 488cb8941..1a7244c7d 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -87,7 +87,7 @@ end @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS + DynamicPPL.TestUtils.ALL_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$name" for (name, vi) in ( ("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())), @@ -134,7 +134,7 @@ end @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS + DynamicPPL.TestUtils.ALL_MODELS # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) diff --git a/test/varinfo.jl b/test/varinfo.jl index a1a1b370f..97cf724ae 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -318,7 +318,7 @@ end end @testset "returned on MCMCChains.Chains" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS chain = make_chain_from_prior(model, 10) # A simple way of checking that the computation is determinstic: run twice and compare. res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) @@ -460,7 +460,7 @@ end end @testset "values_as" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.ALL_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) vns = DynamicPPL.TestUtils.varnames(model) @@ -730,7 +730,7 @@ end end @testset "merge" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos( model, @@ -827,7 +827,7 @@ end # NOTE: It is not yet clear if this is something we want from all varinfo types. # Hence, we only test the `VarInfo` types here. @testset "vector_getranges for `VarInfo`" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS vns = DynamicPPL.TestUtils.varnames(model) nt = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = DynamicPPL.TestUtils.setup_varinfos( @@ -867,7 +867,7 @@ end end @testset "issue #842" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] + model = DynamicPPL.TestUtils.ALL_MODELS[1] varinfo = VarInfo(model) n = length(varinfo[:]) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 3a327c147..9a4ef12c3 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -667,7 +667,7 @@ end end @testset "VarInfo + VarNamedVector" begin - models = DynamicPPL.TestUtils.DEMO_MODELS + models = DynamicPPL.TestUtils.ALL_MODELS @testset "$(model.f)" for model in models # NOTE: Need to set random seed explicitly to avoid using the same seed # for initialization as for sampling in the inner testset below. From 5f8cfe93a65471303c69d139ac67e9e7255ff4da Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Dec 2025 19:56:42 +0000 Subject: [PATCH 02/11] Switch ALL_MODELS for DEMO_MODELS where necessary --- test/model.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/model.jl b/test/model.jl index 8105b84bc..c878fd905 100644 --- a/test/model.jl +++ b/test/model.jl @@ -57,7 +57,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test ljoint ≈ lp #### logprior, logjoint, loglikelihood for MCMC chains #### - @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS N = 200 chain = make_chain_from_prior(model, N) logpriors = logprior(model, chain) @@ -290,7 +290,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end @testset "TestUtils" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS x = DynamicPPL.TestUtils.rand_prior_true(model) # `rand_prior_true` should return a `NamedTuple`. @test x isa NamedTuple From 6b2d19f28867348e877be085e4237ec40b24e928 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 11:19:40 +0000 Subject: [PATCH 03/11] Curb overzealous use of ALL_MODELS --- test/contexts.jl | 2 +- test/debug_utils.jl | 2 +- test/pointwise_logdensities.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 2aad7f52f..5ff632837 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -59,7 +59,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() ) @testset "$(name)" for (name, context) in contexts - @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index fb664fff1..f950f6b45 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,5 +1,5 @@ @testset "check_model" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS issuccess, trace = check_model_and_trace(model, VarInfo(model)) # These models should all work. @test issuccess diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 821f6a04d..780d45b46 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,5 +1,5 @@ @testset "pointwise_logdensities.jl" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. From eab69d32b7e319af13cc1f170403c630381b1a01 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 11:32:25 +0000 Subject: [PATCH 04/11] Add non-BangBang invlink and link for StaticTransformation --- src/abstract_varinfo.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ec5e1ea10..a1a5f7a95 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -869,6 +869,9 @@ end function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return link!!(t, deepcopy(vi), model) end +function link(t::StaticTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) +end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -935,6 +938,9 @@ end function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return invlink!!(t, deepcopy(vi), model) end +function invlink(t::StaticTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) +end """ maybe_invlink_before_eval!!([t::Transformation,] vi, model) From 77bb29ff1df25e588bbd742411b4c8bc50b33e87 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 11:32:49 +0000 Subject: [PATCH 05/11] Give demo_lkjchol a non-flat prior PDF --- src/test_utils/models.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 95756c46e..84e1f10d8 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -114,12 +114,12 @@ x ~ LKJCholesky(d, 1.0) ``` """ @model function demo_lkjchol(d::Int=2) - x ~ LKJCholesky(d, 1.0) + x ~ LKJCholesky(d, 1.5) return (x=x,) end function logprior_true(model::Model{typeof(demo_lkjchol)}, x) - return logpdf(LKJCholesky(model.args.d, 1.0), x) + return logpdf(LKJCholesky(model.args.d, 1.5), x) end function loglikelihood_true(model::Model{typeof(demo_lkjchol)}, x) From b2193cfe3169a6fefc65d83c860621752668fa76 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 11:34:20 +0000 Subject: [PATCH 06/11] Mark demo_lkjchol JET test as broken --- test/ext/DynamicPPLJETExt.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 18ee50ffe..e46c25113 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -62,6 +62,20 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS + if model.f === DynamicPPL.TestUtils.demo_lkjchol + # TODO(mhauru) + # The LKJCholesky model fails with JET. The problem is not with Turing but + # with Distributions, and ultimately this in LinearAlgebra: + # julia> v = @view rand(2,2)[:,1]; + # + # julia> JET.@report_call norm(v) + # ═════ 2 possible errors found ═════ + # blahblah + # The below trivial call to @test is just marking that there's something + # broken here. + @test false broken = true + continue + end # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) # Check that the inferred varinfo is indeed suitable for evaluation From 93c0ed66a26029294275e7c896cda2bee473f3c0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 11:44:05 +0000 Subject: [PATCH 07/11] Mark an SVI test as broken --- test/simple_varinfo.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 1a7244c7d..142ddb9a0 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -95,6 +95,15 @@ ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), ) + if name == "SVI{NamedTuple}" && + model.f === DynamicPPL.TestUtils.demo_one_variable_multiple_constraints + # TODO(mhauru) There's a bug in SimpleVarInfo{<:NamedTuple} for cases where + # a variable set with IndexLenses changes dimension under linking. This + # makes the link!! call crash. The below call to @test just marks the fact + # that there's something broken here. + @test false broken = true + continue + end for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end From 7b9304fa0fb4e958600d95acf3eab3ceb87af98e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 14:54:50 +0000 Subject: [PATCH 08/11] Work around Cholesky comparison bug --- test/simple_varinfo.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 142ddb9a0..4a2e1e11d 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -222,7 +222,15 @@ # Values should not have changed. for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_eval[vn] == get(values_eval, vn) + # TODO(mhauru) Workaround for + # https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 + # Remove once the fix is all Julia versions we support. + val = get(values_eval, vn) + if val isa Cholesky + @test svi_eval[vn].L == val.L + else + @test svi_eval[vn] == val + end end # Compare log-probability computations. From c8e6a3f78fc7e0a147a83c7ed1d7a43294038079 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 15:40:01 +0000 Subject: [PATCH 09/11] Resolve method ambiguities --- src/abstract_varinfo.jl | 10 ++-------- src/threadsafe.jl | 20 ++++++++++++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index a1a5f7a95..8f995c515 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -866,10 +866,7 @@ end function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link(default_transformation(model, vi), vi, vns, model) end -function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end -function link(t::StaticTransformation, vi::AbstractVarInfo, model::Model) +function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return link!!(t, deepcopy(vi), model) end @@ -935,10 +932,7 @@ end function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end -function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end -function invlink(t::StaticTransformation, vi::AbstractVarInfo, model::Model) +function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return invlink!!(t, deepcopy(vi), model) end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 0e906b6ca..6b59fa43b 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -85,12 +85,24 @@ function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) +function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, model) end -function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) +function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, model) +end + +function link( + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model +) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) +end + +function invlink( + t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model +) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. From d48ab407fe5c95541747359afd43eb11782e8dbd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 12 Dec 2025 15:53:35 +0000 Subject: [PATCH 10/11] Make a test more robust. Co-authored-by: Penelope Yong --- test/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 97cf724ae..b57e56997 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -867,7 +867,7 @@ end end @testset "issue #842" begin - model = DynamicPPL.TestUtils.ALL_MODELS[1] + model = DynamicPPL.TestUtils.demo_dot_assume_observe() varinfo = VarInfo(model) n = length(varinfo[:]) From 2acd9e37c1e531e900b1aeeceef2ca104157506b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 15 Dec 2025 10:04:39 +0000 Subject: [PATCH 11/11] Fix use of SimpleVarInfo{Dict} --- src/abstract_varinfo.jl | 4 ++++ src/test_utils/model_interface.jl | 2 +- test/contexts.jl | 2 +- test/simple_varinfo.jl | 10 +++++----- test/varinfo.jl | 6 +++--- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 8f995c515..898b6caf9 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -837,6 +837,10 @@ end function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) + # TODO(mhauru) This assumes that the user has defined the bijector using the same + # variable ordering as what `vi[:]` and `unflatten(vi, x)` use. This is a bad user + # interface, and it's also dangerous for any AbstractVarInfo types that may not respect + # a particular ordering, such as SimpleVarInfo{Dict}. b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 19422bc86..e7fb16fbe 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,7 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) + result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict()))))) # Concretise the element type. return [x for x in result] end diff --git a/test/contexts.jl b/test/contexts.jl index 5ff632837..9621013ac 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -423,7 +423,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), ), ("SVI+NamedTuple", SimpleVarInfo()), - ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), + ("Svi+Dict", SimpleVarInfo(OrderedDict{VarName,Any}())), ] @model function test_init_model() diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 4a2e1e11d..42e377440 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -29,26 +29,26 @@ end @testset "Dict" begin - svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) + svi = SimpleVarInfo(OrderedDict(@varname(m) => 1.0)) @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) - svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) + svi = SimpleVarInfo(OrderedDict(@varname(m) => [1.0])) @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @test svi[@varname(m)][1] == svi[@varname(m[1])] - svi = SimpleVarInfo(Dict(@varname(m) => (a=[1.0],))) + svi = SimpleVarInfo(OrderedDict(@varname(m) => (a=[1.0],))) @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m.a)) @test haskey(svi, @varname(m.a[1])) @test !haskey(svi, @varname(m.a[2])) @test !haskey(svi, @varname(m.a.b)) - svi = SimpleVarInfo(Dict(@varname(m.a) => [1.0])) + svi = SimpleVarInfo(OrderedDict(@varname(m.a) => [1.0])) # Now we only have a variable `m.a` which is subsumed by `m`, # but we can't guarantee that we have the "entire" `m`. @test !haskey(svi, @varname(m)) @@ -90,7 +90,7 @@ DynamicPPL.TestUtils.ALL_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$name" for (name, vi) in ( - ("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())), + ("SVI{Dict}", SimpleVarInfo(OrderedDict{VarName,Any}())), ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), diff --git a/test/varinfo.jl b/test/varinfo.jl index 2a83e2f05..a7948cc32 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -101,7 +101,7 @@ end test_base(VarInfo()) test_base(DynamicPPL.typed_varinfo(VarInfo())) test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(Dict{VarName,Any}())) + test_base(SimpleVarInfo(OrderedDict{VarName,Any}())) test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @@ -130,7 +130,7 @@ end test_varinfo_logp!(vi) test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) test_varinfo_logp!(SimpleVarInfo()) - test_varinfo_logp!(SimpleVarInfo(Dict())) + test_varinfo_logp!(SimpleVarInfo(OrderedDict())) test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @@ -451,7 +451,7 @@ end test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(Dict{VarName,Any}()), true) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true) test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}`