Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.39.5

Fixed a bug which prevented passing immutable data (such as NamedTuples or ordinary structs) as arguments to DynamicPPL models, or fixing the model on such data.

## 0.39.4

Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed).
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.39.4"
version = "0.39.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
33 changes: 23 additions & 10 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,11 @@ function generate_tilde_literal(left, right)
end
end

assign_or_set!!(lhs::Symbol, rhs) = AbstractPPL.drop_escape(:($lhs = $rhs))
function assign_or_set!!(lhs::Expr, rhs)
return AbstractPPL.drop_escape(:($BangBang.@set!! $lhs = $rhs))
end

"""
generate_tilde(left, right)

Expand All @@ -474,27 +479,34 @@ function generate_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn isassumption value dist
@gensym vn isassumption value dist supplied_val

return quote
$dist = $right
$vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
# $left may not be a simple varname, it might be x.a or x[1], in which case we
# need to use BangBang.@set!! to safely set it.
$(assign_or_set!!(
left,
:($(DynamicPPL.getfixed_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)),
))
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
# If `vn` is not in `argnames`, then it's definitely been conditioned on (if
# it's not in `argnames` and wasn't conditioned on, then `isassumption` would
# be true).
$left = if $(DynamicPPL.inargnames)($vn, __model__)
# This is a no-op and looks redundant, but defining the compiler output this
# way ensures that the variable `$left` is always defined. See
# https://github.com/TuringLang/DynamicPPL.jl/pull/1110.
$left
# Note that it's important to always make sure that the variable `$supplied_val`
# is defined (by putting $supplied_val outside the if/else block), otherwise
# Libtask can trip up with variables that are only defined in one branch. See
# eg. https://github.com/TuringLang/DynamicPPL.jl/pull/1110 for a discussion of
# this.
$supplied_val = if $(DynamicPPL.inargnames)($vn, __model__)
$(maybe_view(left))
else
$(DynamicPPL.getconditioned_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
Expand All @@ -504,10 +516,11 @@ function generate_tilde(left, right)
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
__model__.context,
$(DynamicPPL.check_tilde_rhs)($dist),
$(maybe_view(left)),
$supplied_val,
$vn,
__varinfo__,
)
$(assign_or_set!!(left, value))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might the fact that left is now assigned to the value that's gone through tilde_observe!! cause any changes? Maybe something about copies or allocations? I'm guessing not, just thinking out loud.

$value
end
end
Expand Down
13 changes: 13 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -823,4 +823,17 @@ module Issue537 end
end
@test_logs (:warn, r"threadsafe evaluation") eval(e3)
end

@testset "Immutable data as model arguments" begin
# https://github.com/TuringLang/DynamicPPL.jl/issues/1176
@model function nt(data)
m ~ Normal()
return data.x ~ Normal(m, 1.0)
end
data = (; x=5.0)
retval, vi = DynamicPPL.init!!(nt(data), VarInfo())
@test retval == 5.0
@test vi isa VarInfo
@test vi[@varname(m)] isa Real
end
end
16 changes: 16 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,22 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
@test model_fixed().m != m
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
end

@testset "Can fix immutable data safely" begin
# https://github.com/TuringLang/DynamicPPL.jl/issues/1176#issuecomment-3648871018
@model function ntfix()
m ~ Normal()
data = (; x=undef)
data.x ~ Normal(m, 1.0)
return data.x
end
fixm = DynamicPPL.fix(ntfix(), (; data=(; x=5.0)))
retval, vi = DynamicPPL.init!!(fixm, VarInfo())
# The fixed data should overwrite the NamedTuple that came before it
@test retval == 5.0
@test vi isa VarInfo
@test vi[@varname(m)] isa Real
end
end

@testset "PrefixContext + Condition/FixedContext interactions" begin
Expand Down