Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cb1a1e6
Numba CAReduce: respect acc_dtype
ricardoV94 Nov 21, 2025
a85af82
Fix AdvancedSubtensor static shape with newaxis
ricardoV94 Dec 9, 2025
d97df0a
Numba AdvancedIndexing: Complete support for integer (and mixed basic…
ricardoV94 Dec 7, 2025
65dc20a
construct_pfunc_ins_and_outs doesn't use mode
ricardoV94 Dec 9, 2025
3e436ab
More informative NotImplementedError
ricardoV94 Nov 21, 2025
6da237d
Try to run full test suite in Numba backend
ricardoV94 Nov 19, 2025
a8dfcb9
Rebalance test splits
ricardoV94 Nov 30, 2025
e300e00
Revert mode for tests that are C-specific
ricardoV94 Nov 24, 2025
cb7bc09
Ignore numba object mode warning
ricardoV94 Jun 13, 2024
5ed6978
Align numba reciprocal with C backend
ricardoV94 Nov 21, 2025
79f0254
Tweak test errors and requirements
ricardoV94 Nov 24, 2025
c8245eb
Tweak test tolerances
ricardoV94 Nov 19, 2025
1ac9f89
Tweak Blockwise/RandomVariable tests
ricardoV94 Nov 24, 2025
bdfe327
Tweak RandomGenerator tests
ricardoV94 Nov 30, 2025
2808a9e
Tweak Boolean AdvancedSubtensor tests
ricardoV94 Dec 9, 2025
9c19f11
Test SVD and Eig(h): allow benign sign change
ricardoV94 Dec 1, 2025
9c313c3
Fix passing M=None to function in Eye test
ricardoV94 Nov 19, 2025
e13e88e
Test wasn't actually covering local_0_dot_x rewrite
ricardoV94 Nov 24, 2025
b3bb833
Allow burn-in in memory leak test
ricardoV94 Nov 28, 2025
0bb06df
XFAIL conv tests of Ops without Python implementation
ricardoV94 Jun 13, 2024
19a2397
XFAIL/SKIP float16 tests
ricardoV94 Jun 7, 2024
aecaf62
XFAIL TypedList global constant
ricardoV94 Nov 25, 2025
7d4ac51
XFAIL/SKIP Sparse tests
ricardoV94 Nov 24, 2025
c3b70b3
Numba does not output numpy scalars
ricardoV94 Nov 19, 2025
e928b3e
Change default linker back to CVM
ricardoV94 Dec 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 43 additions & 84 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1

test:
name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
name: "mode ${{ matrix.default-mode }} : py${{ matrix.python-version }} : ${{ matrix.os }} : ${{ matrix.part[0] }}"
needs:
- changes
- style
Expand All @@ -74,100 +74,62 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
default-mode: ["C", "NUMBA", "FAST_COMPILE"]
python-version: ["3.11", "3.14"]
fast-compile: [0, 1]
float32: [0, 1]
install-numba: [0]
os: ["ubuntu-latest"]
install-jax: [0]
install-torch: [0]
install-mlx: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor"
- "tests/scan"
- "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting"
- "tests/tensor/test_basic.py tests/tensor/test_elemwise.py"
- "tests/tensor/test_math.py"
- "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv"
- "tests/tensor/rewriting"
- [ "*rest", "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor --ignore=tests/link/numba" ]
- [ "scan", "tests/scan" ]
- [ "tensor *rest", "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/linalg --ignore=tests/tensor/test_nlinalg.py --ignore=tests/tensor/test_slinalg.py --ignore=tests/tensor/test_pad.py" ]
- [ "tensor basic+elemwise", "tests/tensor/test_basic.py tests/tensor/test_elemwise.py" ]
- [ "tensor math", "tests/tensor/test_math.py" ]
- [ "tensor scipy+blas+conv+pad", "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv tests/tensor/test_pad.py" ]
- [ "tensor rewriting", "tests/tensor/rewriting" ]
- [ "tensor linalg", "tests/tensor/linalg tests/tensor/test_nlinalg.py tests/tensor/test_slinalg.py" ]
exclude:
- python-version: "3.11"
fast-compile: 1
- python-version: "3.11"
float32: 1
- fast-compile: 1
float32: 1
default-mode: "FAST_COMPILE"
include:
- os: "ubuntu-latest"
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
- part: ["doctests", "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"]
default-mode: "C"
python-version: "3.12"
fast-compile: 0
float32: 0
install-numba: 0
install-jax: 0
install-torch: 0
install-mlx: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
- install-numba: 1
- part: ["numba link", "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"]
default-mode: "C"
python-version: "3.12"
os: "ubuntu-latest"
python-version: "3.14"
fast-compile: 0
float32: 0
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
- install-numba: 1
- part: ["numba link slinalg", "tests/link/numba/test_slinalg.py"]
default-mode: "C"
python-version: "3.13"
os: "ubuntu-latest"
- part: ["jax link", "tests/link/jax"]
install-jax: 1
default-mode: "C"
python-version: "3.14"
fast-compile: 0
float32: 0
part: "tests/link/numba/test_slinalg.py"
- install-jax: 1
os: "ubuntu-latest"
- part: ["pytorch link", "tests/link/pytorch"]
install-torch: 1
default-mode: "C"
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-jax: 1
os: "ubuntu-latest"
- part: ["xtensor", "tests/xtensor"]
install-xarray: 1
default-mode: "C"
python-version: "3.14"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-torch: 1
os: "ubuntu-latest"
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.14"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: "macos-15"
python-version: "3.11"
fast-compile: 0
float32: 0
- part: ["mlx link", "tests/link/mlx"]
install-mlx: 1
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/link/mlx"
- os: "macos-15"
default-mode: "C"
python-version: "3.11"
os: "macos-15"
- part: ["macos smoke test", "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"]
default-mode: "C"
python-version: "3.14"
fast-compile: 0
float32: 0
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"
os: "macos-15"

steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
Expand Down Expand Up @@ -198,11 +160,10 @@ jobs:
run: |

if [[ $OS == "macos-15" ]]; then
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
else
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx mkl mkl-service;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then pip install "numba>=0.63"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi
Expand All @@ -218,28 +179,26 @@ jobs:
fi
env:
PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_TORCH: ${{ matrix.install-torch }}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
INSTALL_MLX: ${{ matrix.install-mlx }}
OS: ${{ matrix.os}}

- name: Run tests
shell: micromamba-shell {0}
run: |
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
if [[ $DEFAULT_MODE == "FAST_COMPILE" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
if [[ $DEFAULT_MODE == "NUMBA" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,linker=numba; fi
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest -r A --verbose --runslow --durations=50 --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
env:
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
MKL_THREADING_LAYER: GNU
MKL_NUM_THREADS: 1
OMP_NUM_THREADS: 1
PART: ${{ matrix.part }}
FAST_COMPILE: ${{ matrix.fast-compile }}
FLOAT32: ${{ matrix.float32 }}
PART: ${{ matrix.part[1] }}
DEFAULT_MODE: ${{ matrix.default-mode }}

- name: Upload coverage file
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ tag_prefix = "rel-"
addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/ipython.py"
testpaths = ["pytensor/", "tests/"]
xfail_strict = true
filterwarnings =[
'ignore:^Numba will use object mode to run.*perform method\.:UserWarning',
]

[tool.ruff]
line-length = 88
Expand Down
3 changes: 2 additions & 1 deletion pytensor/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
)
from pytensor.compile.io import In, Out, SymbolicInput, SymbolicOutput
from pytensor.compile.mode import (
CVM,
FAST_COMPILE,
FAST_RUN,
JAX,
NUMBA,
OPT_FAST_COMPILE,
Expand All @@ -33,6 +33,7 @@
PYTORCH,
AddDestroyHandler,
AddFeatureOptimizer,
C,
Mode,
PrintCurrentFunctionGraph,
get_default_mode,
Expand Down
2 changes: 0 additions & 2 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ def pfunc(
inputs, cloned_outputs = construct_pfunc_ins_and_outs(
params,
outputs,
mode,
updates,
givens,
no_default_updates,
Expand All @@ -479,7 +478,6 @@ def pfunc(
def construct_pfunc_ins_and_outs(
params,
outputs=None,
mode=None,
updates=None,
givens=None,
no_default_updates=False,
Expand Down
72 changes: 36 additions & 36 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import logging
import warnings
from typing import Literal
from typing import Any, Literal

from pytensor.compile.function.types import Supervisor
from pytensor.configdefaults import config
Expand Down Expand Up @@ -62,23 +62,17 @@ def register_linker(name, linker):
predefined_linkers[name] = linker


# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
exclude = ["cxx_only"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
OPT_NONE = RewriteDatabaseQuery(include=[])
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"])
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_MERGE = RewriteDatabaseQuery(include=["merge"])
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"])
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")

OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude)
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"])
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"])
OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE"
OPT_MINIMUM.name = "OPT_MINIMUM"
Expand Down Expand Up @@ -316,6 +310,8 @@ def __init__(
):
if linker is None:
linker = config.linker
if isinstance(linker, str) and linker == "auto":
linker = "cvm" if config.cxx else "vm"
if isinstance(optimizer, str) and optimizer == "default":
optimizer = config.optimizer

Expand Down Expand Up @@ -451,24 +447,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
return new_mode


# If a string is passed as the mode argument in function or
# FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
)
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
else:
FAST_RUN = Mode(
"vm",
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)

C = Mode("c", "fast_run")
C_VM = Mode("cvm", "fast_run")
CVM = Mode("cvm", "fast_run")
VM = (Mode("vm", "fast_run"),)

NUMBA = Mode(
NumbaLinker(),
Expand All @@ -489,19 +470,28 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
RewriteDatabaseQuery(include=["fast_run"]),
)

FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
)

fast_run_linkers_to_mode = {
"cvm": CVM,
"vm": VM,
"numba": NUMBA,
}

predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"C": C,
"C_VM": C_VM,
"CVM": CVM,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
"MLX": MLX,
}

_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
_CACHED_RUNTIME_MODES: dict[Any, Mode] = {}


def get_mode(orig_string):
Expand All @@ -519,10 +509,20 @@ def get_mode(orig_string):
if upper_string in predefined_modes:
return predefined_modes[upper_string]

if upper_string == "FAST_RUN":
linker = config.linker
if linker == "auto":
return CVM if config.cxx else VM
return fast_run_linkers_to_mode[linker]

global _CACHED_RUNTIME_MODES

if upper_string in _CACHED_RUNTIME_MODES:
return _CACHED_RUNTIME_MODES[upper_string]
cache_key = ("MODE", config.linker) if upper_string == "MODE" else upper_string

try:
return _CACHED_RUNTIME_MODES[cache_key]
except KeyError:
pass

# Need to define the mode for the first time
if upper_string == "MODE":
Expand All @@ -548,7 +548,7 @@ def get_mode(orig_string):
if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":"))
# Cache the mode for next time
_CACHED_RUNTIME_MODES[upper_string] = ret
_CACHED_RUNTIME_MODES[cache_key] = ret

return ret

Expand Down
Loading
Loading