Skip to content

Commit 870faab

Browse files
committed
Fix compiler output for sub-variables of immutable data
1 parent 6266f64 commit 870faab

File tree

5 files changed

+59
-11
lines changed

5 files changed

+59
-11
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.5
4+
5+
Fixed a bug which prevented passing immutable data (such as NamedTuples or ordinary structs) as arguments to DynamicPPL models, or fixing the model on such data.
6+
37
## 0.39.4
48

59
Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed).

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.4"
3+
version = "0.39.5"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/compiler.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,11 @@ function generate_tilde_literal(left, right)
463463
end
464464
end
465465

466+
assign_or_set!!(lhs::Symbol, rhs) = AbstractPPL.drop_escape(:($lhs = $rhs))
467+
function assign_or_set!!(lhs::Expr, rhs)
468+
return AbstractPPL.drop_escape(:($BangBang.@set!! $lhs = $rhs))
469+
end
470+
466471
"""
467472
generate_tilde(left, right)
468473
@@ -474,27 +479,36 @@ function generate_tilde(left, right)
474479

475480
# Otherwise it is determined by the model or its value,
476481
# if the LHS represents an observation
477-
@gensym vn isassumption value dist
482+
@gensym vn isassumption value dist supplied_val
478483

479484
return quote
480485
$dist = $right
481486
$vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
482487
$isassumption = $(DynamicPPL.isassumption(left, vn))
483488
if $(DynamicPPL.isfixed(left, vn))
484-
$left = $(DynamicPPL.getfixed_nested)(
485-
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
486-
)
489+
# $left may not be a simple varname, it might be x.a or x[1], hence we need to
490+
# use Accessors.@set to safely set it.
491+
# We need overwrite=true to make sure that the parent value `x` is overwritten
492+
# after this statement.
493+
$(assign_or_set!!(
494+
left,
495+
:($(DynamicPPL.getfixed_nested)(
496+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
497+
)),
498+
))
487499
elseif $isassumption
488500
$(generate_tilde_assume(left, dist, vn))
489501
else
490502
# If `vn` is not in `argnames`, then it's definitely been conditioned on (if
491503
# it's not in `argnames` and wasn't conditioned on, then `isassumption` would
492504
# be true).
493-
$left = if $(DynamicPPL.inargnames)($vn, __model__)
494-
# This is a no-op and looks redundant, but defining the compiler output this
495-
# way ensures that the variable `$left` is always defined. See
496-
# https://github.com/TuringLang/DynamicPPL.jl/pull/1110.
497-
$left
505+
# Note that it's important to always make sure that the variable `$supplied_val`
506+
# is defined (by putting $supplied_val outside the if/else block), otherwise
507+
# Libtask can trip up with variables that are only defined in one branch. See
508+
# eg. https://github.com/TuringLang/DynamicPPL.jl/pull/1110 for a discussion of
509+
# this.
510+
$supplied_val = if $(DynamicPPL.inargnames)($vn, __model__)
511+
$(maybe_view(left))
498512
else
499513
$(DynamicPPL.getconditioned_nested)(
500514
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
@@ -504,10 +518,11 @@ function generate_tilde(left, right)
504518
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
505519
__model__.context,
506520
$(DynamicPPL.check_tilde_rhs)($dist),
507-
$(maybe_view(left)),
521+
$supplied_val,
508522
$vn,
509523
__varinfo__,
510524
)
525+
$(assign_or_set!!(left, value))
511526
$value
512527
end
513528
end

test/compiler.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,4 +823,17 @@ module Issue537 end
823823
end
824824
@test_logs (:warn, r"threadsafe evaluation") eval(e3)
825825
end
826+
827+
@testset "Immutable data as model arguments" begin
828+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1176
829+
@model function nt(data)
830+
m ~ Normal()
831+
return data.x ~ Normal(m, 1.0)
832+
end
833+
data = (; x=5.0)
834+
retval, vi = DynamicPPL.init!!(nt(data), VarInfo())
835+
@test retval == 5.0
836+
@test vi isa VarInfo
837+
@test vi[@varname(m)] isa Real
838+
end
826839
end

test/contexts.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,22 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
312312
@test model_fixed().m != m
313313
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
314314
end
315+
316+
@testset "Can fix immutable data safely" begin
317+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1176#issuecomment-3648871018
318+
@model function ntfix()
319+
m ~ Normal()
320+
data = (; x=undef)
321+
data.x ~ Normal(m, 1.0)
322+
return data.x
323+
end
324+
fixm = DynamicPPL.fix(ntfix(), (; data=(; x=5.0)))
325+
retval, vi = DynamicPPL.init!!(fixm, VarInfo())
326+
# The fixed data should overwrite the NamedTuple that came before it
327+
@test retval == 5.0
328+
@test vi isa VarInfo
329+
@test vi[@varname(m)] isa Real
330+
end
315331
end
316332

317333
@testset "PrefixContext + Condition/FixedContext interactions" begin

0 commit comments

Comments
 (0)