From 870faab7812138b2e15106fdb2451bb0549b0f5a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 13 Dec 2025 13:56:59 +0000 Subject: [PATCH 1/2] Fix compiler output for sub-variables of immutable data --- HISTORY.md | 4 ++++ Project.toml | 2 +- src/compiler.jl | 35 +++++++++++++++++++++++++---------- test/compiler.jl | 13 +++++++++++++ test/contexts.jl | 16 ++++++++++++++++ 5 files changed, 59 insertions(+), 11 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 5d7892d8f..db4fcbb16 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.39.5 + +Fixed a bug which prevented passing immutable data (such as NamedTuples or ordinary structs) as arguments to DynamicPPL models, or fixing the model on such data. + ## 0.39.4 Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed). diff --git a/Project.toml b/Project.toml index 240e154bb..e02ee621e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.39.4" +version = "0.39.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/compiler.jl b/src/compiler.jl index 1b4260121..19f5c9003 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -463,6 +463,11 @@ function generate_tilde_literal(left, right) end end +assign_or_set!!(lhs::Symbol, rhs) = AbstractPPL.drop_escape(:($lhs = $rhs)) +function assign_or_set!!(lhs::Expr, rhs) + return AbstractPPL.drop_escape(:($BangBang.@set!! $lhs = $rhs)) +end + """ generate_tilde(left, right) @@ -474,27 +479,36 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption value dist + @gensym vn isassumption value dist supplied_val return quote $dist = $right $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + # $left may not be a simple varname, it might be x.a or x[1], hence we need to + # use Accessors.@set to safely set it. + # We need overwrite=true to make sure that the parent value `x` is overwritten + # after this statement. + $(assign_or_set!!( + left, + :($(DynamicPPL.getfixed_nested)( + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) + )), + )) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else # If `vn` is not in `argnames`, then it's definitely been conditioned on (if # it's not in `argnames` and wasn't conditioned on, then `isassumption` would # be true). - $left = if $(DynamicPPL.inargnames)($vn, __model__) - # This is a no-op and looks redundant, but defining the compiler output this - # way ensures that the variable `$left` is always defined. See - # https://github.com/TuringLang/DynamicPPL.jl/pull/1110. - $left + # Note that it's important to always make sure that the variable `$supplied_val` + # is defined (by putting $supplied_val outside the if/else block), otherwise + # Libtask can trip up with variables that are only defined in one branch. See + # eg. https://github.com/TuringLang/DynamicPPL.jl/pull/1110 for a discussion of + # this. + $supplied_val = if $(DynamicPPL.inargnames)($vn, __model__) + $(maybe_view(left)) else $(DynamicPPL.getconditioned_nested)( __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) @@ -504,10 +518,11 @@ function generate_tilde(left, right) $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __model__.context, $(DynamicPPL.check_tilde_rhs)($dist), - $(maybe_view(left)), + $supplied_val, $vn, __varinfo__, ) + $(assign_or_set!!(left, value)) $value end end diff --git a/test/compiler.jl b/test/compiler.jl index 9056f666a..c701bce29 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -823,4 +823,17 @@ module Issue537 end end @test_logs (:warn, r"threadsafe evaluation") eval(e3) end + + @testset "Immutable data as model arguments" begin + # https://github.com/TuringLang/DynamicPPL.jl/issues/1176 + @model function nt(data) + m ~ Normal() + return data.x ~ Normal(m, 1.0) + end + data = (; x=5.0) + retval, vi = DynamicPPL.init!!(nt(data), VarInfo()) + @test retval == 5.0 + @test vi isa VarInfo + @test vi[@varname(m)] isa Real + end end diff --git a/test/contexts.jl b/test/contexts.jl index ae7332a43..9e377dd46 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -312,6 +312,22 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test model_fixed().m != m @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) end + + @testset "Can fix immutable data safely" begin + # https://github.com/TuringLang/DynamicPPL.jl/issues/1176#issuecomment-3648871018 + @model function ntfix() + m ~ Normal() + data = (; x=undef) + data.x ~ Normal(m, 1.0) + return data.x + end + fixm = DynamicPPL.fix(ntfix(), (; data=(; x=5.0))) + retval, vi = DynamicPPL.init!!(fixm, VarInfo()) + # The fixed data should overwrite the NamedTuple that came before it + @test retval == 5.0 + @test vi isa VarInfo + @test vi[@varname(m)] isa Real + end end @testset "PrefixContext + Condition/FixedContext interactions" begin From 485f8d502e5d1600a850c68bd8dc494c772047a5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 15 Dec 2025 01:13:20 +0000 Subject: [PATCH 2/2] Update src/compiler.jl --- src/compiler.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 19f5c9003..f1e92e369 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -486,10 +486,8 @@ function generate_tilde(left, right) $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) - # $left may not be a simple varname, it might be x.a or x[1], hence we need to - # use Accessors.@set to safely set it. - # We need overwrite=true to make sure that the parent value `x` is overwritten - # after this statement. + # $left may not be a simple varname, it might be x.a or x[1], in which case we + # need to use BangBang.@set!! to safely set it. $(assign_or_set!!( left, :($(DynamicPPL.getfixed_nested)(