Skip to content

Commit 5a976fd

Browse files
committed
Generalise eltype inference
1 parent c188b7e commit 5a976fd

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/fasteval.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ using DynamicPPL:
6060
AbstractContext,
6161
AbstractVarInfo,
6262
AccumulatorTuple,
63+
DynamicPPL,
6364
InitContext,
6465
InitFromParams,
6566
InitFromPrior,
@@ -125,17 +126,20 @@ function DynamicPPL.get_param_eltype(
125126
leaf_ctx = DynamicPPL.leafcontext(model.context)
126127
if leaf_ctx isa FastEvalVectorContext
127128
return eltype(leaf_ctx.params)
128-
elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams}
129-
return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params))
130-
elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}}
131-
# No need to enforce any particular eltype here, since new parameters are sampled
132-
return Any
129+
elseif leaf_ctx isa InitContext
130+
return _get_strategy_eltype(leaf_ctx.strategy)
133131
else
134132
error(
135133
"OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))",
136134
)
137135
end
138136
end
137+
_get_strategy_eltype(s::InitFromParams) = DynamicPPL.infer_nested_eltype(typeof(s.params))
138+
# No need to enforce any particular eltype here, since new parameters are sampled
139+
_get_strategy_eltype(::InitFromPrior) = Any
140+
_get_strategy_eltype(::InitFromUniform) = Any
141+
# Default fallback
142+
_get_strategy_eltype(::DynamicPPL.AbstractInitStrategy) = Any
139143

140144
"""
141145
RangeAndLinked

0 commit comments

Comments
 (0)