Skip to content

Commit 11960e5

Browse files
mhaurupenelopeysm
andauthored
Use ALL_MODELS rather than DEMO_MODELS (#1173)
* Use ALL_MODELS rather than DEMO_MODELS * Switch ALL_MODELS for DEMO_MODELS where necessary * Curb overzealous use of ALL_MODELS * Add non-BangBang invlink and link for StaticTransformation * Give demo_lkjchol a non-flat prior PDF * Mark demo_lkjchol JET test as broken * Mark an SVI test as broken * Work around Cholesky comparison bug * Resolve method ambiguities * Make a test more robust. Co-authored-by: Penelope Yong <penelopeysm@gmail.com> * Fix use of SimpleVarInfo{Dict} --------- Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
1 parent 6266f64 commit 11960e5

File tree

16 files changed

+114
-54
lines changed

16 files changed

+114
-54
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ DynamicPPL provides several demo models in the `DynamicPPL.TestUtils` submodule.
267267

268268
```@docs
269269
DynamicPPL.TestUtils.DEMO_MODELS
270+
DynamicPPL.TestUtils.ALL_MODELS
270271
```
271272

272273
For every demo model, one can define the true log prior, log likelihood, and log joint probabilities.

src/abstract_varinfo.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,10 @@ end
837837
function link!!(
838838
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
839839
)
840+
# TODO(mhauru) This assumes that the user has defined the bijector using the same
841+
# variable ordering as what `vi[:]` and `unflatten(vi, x)` use. This is a bad user
842+
# interface, and it's also dangerous for any AbstractVarInfo types that may not respect
843+
# a particular ordering, such as SimpleVarInfo{Dict}.
840844
b = inverse(t.bijector)
841845
x = vi[:]
842846
y, logjac = with_logabsdet_jacobian(b, x)
@@ -866,7 +870,7 @@ end
866870
function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
867871
return link(default_transformation(model, vi), vi, vns, model)
868872
end
869-
function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
873+
function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
870874
return link!!(t, deepcopy(vi), model)
871875
end
872876

@@ -932,7 +936,7 @@ end
932936
function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
933937
return invlink(default_transformation(model, vi), vi, vns, model)
934938
end
935-
function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
939+
function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model)
936940
return invlink!!(t, deepcopy(vi), model)
937941
end
938942

src/test_utils/model_interface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ Even though it is recommended to implement this by hand for a particular `Model`
9292
a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided.
9393
"""
9494
function varnames(model::Model)
95-
return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict())))))
95+
result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict())))))
96+
# Concretise the element type.
97+
return [x for x in result]
9698
end
9799

98100
"""

src/test_utils/models.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ function logprior_true_with_logabsdet_jacobian(
3434
x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x)
3535
return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp
3636
end
37+
function rand_prior_true(rng::Random.AbstractRNG, ::Model{typeof(demo_dynamic_constraint)})
38+
m = rand(rng, Normal())
39+
x = rand(rng, truncated(Normal(); lower=m))
40+
return (m=m, x=x)
41+
end
3742

3843
"""
3944
demo_one_variable_multiple_constraints()
@@ -109,12 +114,12 @@ x ~ LKJCholesky(d, 1.0)
109114
```
110115
"""
111116
@model function demo_lkjchol(d::Int=2)
112-
x ~ LKJCholesky(d, 1.0)
117+
x ~ LKJCholesky(d, 1.5)
113118
return (x=x,)
114119
end
115120

116121
function logprior_true(model::Model{typeof(demo_lkjchol)}, x)
117-
return logpdf(LKJCholesky(model.args.d, 1.0), x)
122+
return logpdf(LKJCholesky(model.args.d, 1.5), x)
118123
end
119124

120125
function loglikelihood_true(model::Model{typeof(demo_lkjchol)}, x)
@@ -163,6 +168,9 @@ end
163168
function loglikelihood_true(::Model{typeof(demo_static_transformation)}, s, m)
164169
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
165170
end
171+
function varnames(::Model{typeof(demo_static_transformation)})
172+
return [@varname(s), @varname(m)]
173+
end
166174
function logprior_true_with_logabsdet_jacobian(
167175
model::Model{typeof(demo_static_transformation)}, s, m
168176
)
@@ -557,22 +565,6 @@ function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)})
557565
return [@varname(s), @varname(m)]
558566
end
559567

560-
const DemoModels = Union{
561-
Model{typeof(demo_dot_assume_observe)},
562-
Model{typeof(demo_assume_index_observe)},
563-
Model{typeof(demo_assume_multivariate_observe)},
564-
Model{typeof(demo_dot_assume_observe_index)},
565-
Model{typeof(demo_assume_dot_observe)},
566-
Model{typeof(demo_assume_dot_observe_literal)},
567-
Model{typeof(demo_assume_observe_literal)},
568-
Model{typeof(demo_assume_multivariate_observe_literal)},
569-
Model{typeof(demo_dot_assume_observe_index_literal)},
570-
Model{typeof(demo_assume_submodel_observe_index_literal)},
571-
Model{typeof(demo_dot_assume_observe_submodel)},
572-
Model{typeof(demo_dot_assume_observe_matrix_index)},
573-
Model{typeof(demo_assume_matrix_observe_matrix_index)},
574-
}
575-
576568
const UnivariateAssumeDemoModels = Union{
577569
Model{typeof(demo_assume_dot_observe)},
578570
Model{typeof(demo_assume_dot_observe_literal)},
@@ -758,3 +750,14 @@ const DEMO_MODELS = (
758750
demo_dot_assume_observe_matrix_index(),
759751
demo_assume_matrix_observe_matrix_index(),
760752
)
753+
754+
"""
755+
A tuple of all models defined in DynamicPPL.TestUtils.
756+
"""
757+
const ALL_MODELS = (
758+
DEMO_MODELS...,
759+
demo_dynamic_constraint(),
760+
demo_one_variable_multiple_constraints(),
761+
demo_lkjchol(),
762+
demo_static_transformation(),
763+
)

src/test_utils/varinfo.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in
1010
"""
1111
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...)
1212
for vn in vns
13-
@test compare(vi[vn], get(vals, vn); kwargs...)
13+
val = get(vals, vn)
14+
# TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404
15+
# Remove once the fix is all Julia versions we support.
16+
if val isa Cholesky
17+
@test compare(vi[vn].L, val.L; kwargs...)
18+
else
19+
@test compare(vi[vn], val; kwargs...)
20+
end
1421
end
1522
end
1623

src/threadsafe.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,24 @@ function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
8585
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...)
8686
end
8787

88-
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
89-
return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...)
88+
function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model)
89+
return Accessors.@set vi.varinfo = link(t, vi.varinfo, model)
9090
end
9191

92-
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...)
93-
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...)
92+
function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model)
93+
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, model)
94+
end
95+
96+
function link(
97+
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model
98+
)
99+
return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model)
100+
end
101+
102+
function invlink(
103+
t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model
104+
)
105+
return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model)
94106
end
95107

96108
# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.

test/chains.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ using Test
6767
end
6868

6969
@testset "ParamsWithStats from LogDensityFunction" begin
70-
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
70+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS
7171
unlinked_vi = VarInfo(m)
7272
@testset "$islinked" for islinked in (false, true)
7373
vi = if islinked

test/contexts.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
179179
@test new_ctx == FixedContext((b=4,), ConditionContext((a=1,)))
180180
end
181181

182-
@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
182+
@testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS
183183
prefix_vn = @varname(my_prefix)
184184
context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext())
185185
new_model = contextualize(model, context)
@@ -423,7 +423,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
423423
DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())),
424424
),
425425
("SVI+NamedTuple", SimpleVarInfo()),
426-
("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())),
426+
("Svi+Dict", SimpleVarInfo(OrderedDict{VarName,Any}())),
427427
]
428428

429429
@model function test_init_model()

test/ext/DynamicPPLJETExt.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,21 @@
6161
end
6262

6363
@testset "demo models" begin
64-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
64+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS
65+
if model.f === DynamicPPL.TestUtils.demo_lkjchol
66+
# TODO(mhauru)
67+
# The LKJCholesky model fails with JET. The problem is not with Turing but
68+
# with Distributions, and ultimately this in LinearAlgebra:
69+
# julia> v = @view rand(2,2)[:,1];
70+
#
71+
# julia> JET.@report_call norm(v)
72+
# ═════ 2 possible errors found ═════
73+
# blahblah
74+
# The below trivial call to @test is just marking that there's something
75+
# broken here.
76+
@test false broken = true
77+
continue
78+
end
6579
# Use debug logging below.
6680
varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model)
6781
# Check that the inferred varinfo is indeed suitable for evaluation

test/integration/enzyme/main.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DynamicPPL.TestUtils: DEMO_MODELS
1+
using DynamicPPL.TestUtils: ALL_MODELS
22
using DynamicPPL.TestUtils.AD: run_ad
33
using ADTypes: AutoEnzyme
44
using Test: @test, @testset
@@ -17,7 +17,7 @@ ADTYPES = (
1717
)
1818

1919
@testset "$ad_key" for (ad_key, ad_type) in ADTYPES
20-
@testset "$(model.f)" for model in DEMO_MODELS
20+
@testset "$(model.f)" for model in ALL_MODELS
2121
@test run_ad(model, ad_type) isa Any
2222
end
2323
end

0 commit comments

Comments
 (0)