Description
This caused a failure in pymc-devs/pymc#8010
import pytensor
from pytensor.tensor.type import TensorType
from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type
import pytensor.tensor as pt
rng = random_generator_type("rng")
def step(rng):
next_rng, x = pt.random.normal(rng=rng).owner.outputs
return x, next_rng
xs, final_rng = pytensor.scan(
fn=step,
outputs_info=[None, rng],
n_steps=5,
return_updates=False,
)
assert isinstance(xs.type, TensorType) # AssertionError
assert isinstance(final_rng.type, RandomGeneratorType)
The outputs are reversed from what's defined in outputs_info.