import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
foo = bm.random_variable(lambda: dist.Normal(torch.stack([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
queries=[foo()],
observations={},
num_samples=1,
)
expected Tensor as element 0 in argument 0, but got SampleNode
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-358-e65cd9e99a94> in <module>
4 foo = bm.random_variable(lambda: dist.MultivariateNormal(torch.stack([bar(i) for i in range(2)]), torch.eye(2)))
5 bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
----> 6 BMGInference().infer(
7 queries=[foo()],
8 observations={},
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in infer(self, queries, observations, num_samples, num_chains, inference_type, skip_optimizations)
262 # TODO: Add verbose level
263 # TODO: Add logging
--> 264 samples, _ = self._infer(
265 queries,
266 observations,
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _infer(self, queries, observations, num_samples, num_chains, inference_type, produce_report, skip_optimizations)
182 self._pd = prof.ProfilerData()
183
--> 184 rt = self._accumulate_graph(queries, observations)
185 bmg = rt._bmg
186 report = pr.PerformanceReport()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/inference/bmg_inference.py in _accumulate_graph(self, queries, observations)
71 rt = BMGRuntime()
72 rt._pd = self._pd
---> 73 bmg = rt.accumulate_graph(queries, observations)
74 # TODO: Figure out a better way to pass this flag around
75 bmg._fix_observe_true = self._fix_observe_true
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in accumulate_graph(self, queries, observations)
719 self._bmg.add_observation(node, val)
720 for qrv in queries:
--> 721 node = self._rv_to_node(qrv)
722 q = self._bmg.add_query(node)
723 self._rv_to_query[qrv] = q
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in _rv_to_node(self, rv)
583 # RVID, and if we're in the second situation, we will not.
584
--> 585 value = self._context.call(rewritten_function, rv.arguments)
586 if isinstance(value, RVIdentifier):
587 # We have a rewritten function with a decorator already applied.
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/execution_context.py in call(self, func, args, kwargs)
92 self._stack.push(FunctionCall(func, args, kwargs))
93 try:
---> 94 return func(*args, **kwargs)
95 finally:
96 self._stack.pop()
<BMGJIT> in a1()
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/runtime.py in handle_function(self, function, arguments, kwargs)
510 function, arguments, kwargs
511 ):
--> 512 result = self._special_function_caller.do_special_call_maybe_stochastic(
513 function, arguments, kwargs
514 )
/data/sandcastle/boxes/fbsource/fbcode/beanmachine/beanmachine/ppl/compiler/special_function_caller.py in do_special_call_maybe_stochastic(self, func, args, kwargs)
629 new_args = (_get_ordinary_value(arg) for arg in args)
630 new_kwargs = {key: _get_ordinary_value(arg) for key, arg in kwargs.items()}
--> 631 return func(*new_args, **new_kwargs)
632
633 if _is_in_place_operator(func):
TypeError: expected Tensor as element 0 in argument 0, but got SampleNode
import beanmachine.ppl as bm
from beanmachine.ppl.inference import BMGInference
foo = bm.random_variable(lambda: dist.Normal(torch.tensor([bar(i) for i in range(2)]).sum(), 1.))
bar = bm.random_variable(lambda i: dist.Normal(0., 1.))
BMGInference().infer(
queries=[foo()],
observations={},
num_samples=1,
)
Issue Description
When other RVs are concatenated together using
torch.stack,BMGInferencefails totrace execution because it assumes that all arguments to
stackare of typeTensor.The example runs fine if
stackis replaced bytorch.tensor, buttorch.tensoris not differentiable wrt its arguments which precludes methods such as VI and HMC.Steps to Reproduce
raises
Expected Behavior
Successful execution with identical results to
s/stack/tensori.e.