Skip to content

Scan new API does not respect user-defined ordering of mapped (None) outputs and non-traceable SIT-SOT #1796

@ricardoV94

Description

@ricardoV94

Description

This caused a failure in pymc-devs/pymc#8010

import pytensor
from pytensor.tensor.type import TensorType
from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type
import pytensor.tensor as pt

rng = random_generator_type("rng")

def step(rng):
    next_rng, x = pt.random.normal(rng=rng).owner.outputs
    return x, next_rng

xs, final_rng = pytensor.scan(
    fn=step,
    outputs_info=[None, rng],
    n_steps=5,
    return_updates=False,
)
assert isinstance(xs.type, TensorType)  # AssertionError
assert isinstance(final_rng.type, RandomGeneratorType)

The outputs are reversed from what's defined in outputs_info.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingscan

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions