Skip to content

Commit f3f8e8f

Browse files
committed
Revert "Change default linker back to CVM"
This reverts commit 7cb92a9.
1 parent c9dd687 commit f3f8e8f

File tree

6 files changed

+114
-92
lines changed

6 files changed

+114
-92
lines changed

.github/workflows/test.yml

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
6666

6767
test:
68-
name: "tests : default-mode ${{ matrix.default-mode }} : py${{ matrix.python-version }} : ${{ matrix.os }} ${{ matrix.part }}"
68+
name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
6969
needs:
7070
- changes
7171
- style
@@ -74,9 +74,11 @@ jobs:
7474
strategy:
7575
fail-fast: false
7676
matrix:
77-
default-mode: [ "C", "NUMBA", "FAST_COMPILE" ]
77+
os: ["ubuntu-latest"]
7878
python-version: ["3.11", "3.14"]
79-
os: [ "ubuntu-latest" ]
79+
fast-compile: [0, 1]
80+
float32: [0, 1]
81+
install-numba: [0]
8082
install-jax: [0]
8183
install-torch: [0]
8284
install-mlx: [0]
@@ -92,38 +94,81 @@ jobs:
9294
- "tests/tensor/linalg tests/tensor/test_nlinalg.py tests/tensor/test_slinalg.py"
9395
exclude:
9496
- python-version: "3.11"
95-
default-mode: "FAST_COMPILE"
97+
fast-compile: 1
98+
- python-version: "3.11"
99+
float32: 1
100+
- fast-compile: 1
101+
float32: 1
96102
include:
97-
- part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
98-
python-version: "3.12"
99-
default-mode: "C"
100-
- part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
103+
- os: "ubuntu-latest"
104+
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
101105
python-version: "3.12"
102-
default-mode: "C"
103-
- part: "tests/link/numba/test_slinalg.py"
104-
python-version: "3.13"
105-
default-mode: "C"
106-
- part: "tests/link/jax"
107-
install-jax: 1
106+
fast-compile: 0
107+
float32: 0
108+
install-numba: 0
109+
install-jax: 0
110+
install-torch: 0
111+
install-mlx: 0
112+
install-xarray: 0
113+
- install-numba: 1
114+
os: "ubuntu-latest"
115+
python-version: "3.11"
116+
fast-compile: 0
117+
float32: 0
118+
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
119+
- install-numba: 1
120+
os: "ubuntu-latest"
108121
python-version: "3.14"
109-
default-mode: "C"
110-
- part: "tests/link/pytorch"
111-
install-torch: 1
122+
fast-compile: 0
123+
float32: 0
124+
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
125+
- install-numba: 1
126+
os: "ubuntu-latest"
127+
python-version: "3.14"
128+
fast-compile: 0
129+
float32: 0
130+
part: "tests/link/numba/test_slinalg.py"
131+
- install-jax: 1
132+
os: "ubuntu-latest"
112133
python-version: "3.11"
113-
default-mode: "C"
114-
- part: "tests/xtensor"
115-
install-xarray: 1
134+
fast-compile: 0
135+
float32: 0
136+
part: "tests/link/jax"
137+
- install-jax: 1
138+
os: "ubuntu-latest"
116139
python-version: "3.14"
117-
default-mode: "C"
118-
- part: "tests/link/mlx"
119-
install-mlx: 1
120-
os: "macos-15"
140+
fast-compile: 0
141+
float32: 0
142+
part: "tests/link/jax"
143+
- install-torch: 1
144+
os: "ubuntu-latest"
121145
python-version: "3.11"
122-
default-mode: "C"
123-
- part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"
124-
os: "macos-15"
146+
fast-compile: 0
147+
float32: 0
148+
part: "tests/link/pytorch"
149+
- install-xarray: 1
150+
os: "ubuntu-latest"
125151
python-version: "3.14"
126-
default-mode: "C"
152+
fast-compile: 0
153+
float32: 0
154+
part: "tests/xtensor"
155+
- os: "macos-15"
156+
python-version: "3.11"
157+
fast-compile: 0
158+
float32: 0
159+
install-mlx: 1
160+
install-numba: 0
161+
install-jax: 0
162+
install-torch: 0
163+
part: "tests/link/mlx"
164+
- os: "macos-15"
165+
python-version: "3.14"
166+
fast-compile: 0
167+
float32: 0
168+
install-numba: 0
169+
install-jax: 0
170+
install-torch: 0
171+
part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"
127172

128173
steps:
129174
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
@@ -154,10 +199,11 @@ jobs:
154199
run: |
155200
156201
if [[ $OS == "macos-15" ]]; then
157-
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
202+
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
158203
else
159-
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx mkl mkl-service;
204+
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
160205
fi
206+
pip install "numba>=0.63"
161207
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
162208
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
163209
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi
@@ -173,6 +219,7 @@ jobs:
173219
fi
174220
env:
175221
PYTHON_VERSION: ${{ matrix.python-version }}
222+
INSTALL_NUMBA: ${{ matrix.install-numba }}
176223
INSTALL_JAX: ${{ matrix.install-jax }}
177224
INSTALL_TORCH: ${{ matrix.install-torch}}
178225
INSTALL_XARRAY: ${{ matrix.install-xarray }}
@@ -182,8 +229,8 @@ jobs:
182229
- name: Run tests
183230
shell: micromamba-shell {0}
184231
run: |
185-
if [[ $default-mode == "FAST_COMPILE" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
186-
if [[ $default-mode == "NUMBA" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,fast_run_backend=NUMBA; fi
232+
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
233+
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
187234
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
188235
python -m pytest -r A --verbose --runslow --durations=50 --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
189236
env:
@@ -192,8 +239,8 @@ jobs:
192239
MKL_NUM_THREADS: 1
193240
OMP_NUM_THREADS: 1
194241
PART: ${{ matrix.part }}
195-
default-mode: ${{ matrix.default-mode }}
196-
242+
FAST_COMPILE: ${{ matrix.fast-compile }}
243+
FLOAT32: ${{ matrix.float32 }}
197244

198245
- name: Upload coverage file
199246
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2

pytensor/compile/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
)
1818
from pytensor.compile.io import In, Out, SymbolicInput, SymbolicOutput
1919
from pytensor.compile.mode import (
20-
CVM,
2120
FAST_COMPILE,
21+
FAST_RUN,
2222
JAX,
2323
NUMBA,
2424
OPT_FAST_COMPILE,
@@ -33,7 +33,6 @@
3333
PYTORCH,
3434
AddDestroyHandler,
3535
AddFeatureOptimizer,
36-
C,
3736
Mode,
3837
PrintCurrentFunctionGraph,
3938
get_default_mode,

pytensor/compile/mode.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import logging
77
import warnings
8-
from typing import Any, Literal
8+
from typing import Literal
99

1010
from pytensor.compile.function.types import Supervisor
1111
from pytensor.configdefaults import config
@@ -62,17 +62,20 @@ def register_linker(name, linker):
6262
predefined_linkers[name] = linker
6363

6464

65-
OPT_NONE = RewriteDatabaseQuery(include=[])
65+
exclude = []
66+
if not config.cxx:
67+
exclude = ["cxx_only"]
68+
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
6669
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
67-
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"])
70+
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
6871
# Even if multiple merge optimizer call will be there, this shouldn't
6972
# impact performance.
70-
OPT_MERGE = RewriteDatabaseQuery(include=["merge"])
71-
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"])
73+
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
74+
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
7275
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
7376

74-
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"])
75-
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"])
77+
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude)
78+
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
7679
OPT_STABILIZE.position_cutoff = 1.5000001
7780
OPT_NONE.name = "OPT_NONE"
7881
OPT_MINIMUM.name = "OPT_MINIMUM"
@@ -310,8 +313,6 @@ def __init__(
310313
):
311314
if linker is None:
312315
linker = config.linker
313-
if isinstance(linker, str) and linker == "auto":
314-
linker = "cvm" if config.cxx else "vm"
315316
if isinstance(optimizer, str) and optimizer == "default":
316317
optimizer = config.optimizer
317318

@@ -447,15 +448,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
447448
return new_mode
448449

449450

450-
C = Mode("c", "fast_run")
451-
CVM = Mode("cvm", "fast_run")
452-
VM = (Mode("vm", "fast_run"),)
453-
454451
NUMBA = Mode(
455452
NumbaLinker(),
456453
RewriteDatabaseQuery(include=["fast_run", "numba"]),
457454
)
458455

456+
FAST_COMPILE = Mode(
457+
NumbaLinker(),
458+
RewriteDatabaseQuery(include=["fast_compile"]),
459+
)
460+
FAST_RUN = NUMBA
461+
462+
C = Mode("c", "fast_run")
463+
CVM = Mode("cvm", "fast_run")
464+
459465
JAX = Mode(
460466
JAXLinker(),
461467
RewriteDatabaseQuery(include=["fast_run", "jax"]),
@@ -470,19 +476,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
470476
RewriteDatabaseQuery(include=["fast_run"]),
471477
)
472478

473-
FAST_COMPILE = Mode(
474-
VMLinker(use_cloop=False, c_thunks=False),
475-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
476-
)
477-
478-
fast_run_linkers_to_mode = {
479-
"cvm": CVM,
480-
"vm": VM,
481-
"numba": NUMBA,
482-
}
483479

484480
predefined_modes = {
485481
"FAST_COMPILE": FAST_COMPILE,
482+
"FAST_RUN": FAST_RUN,
486483
"C": C,
487484
"CVM": CVM,
488485
"JAX": JAX,
@@ -491,7 +488,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
491488
"MLX": MLX,
492489
}
493490

494-
_CACHED_RUNTIME_MODES: dict[Any, Mode] = {}
491+
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
495492

496493

497494
def get_mode(orig_string):
@@ -509,20 +506,10 @@ def get_mode(orig_string):
509506
if upper_string in predefined_modes:
510507
return predefined_modes[upper_string]
511508

512-
if upper_string == "FAST_RUN":
513-
linker = config.linker
514-
if linker == "auto":
515-
return CVM if config.cxx else VM
516-
return fast_run_linkers_to_mode[linker]
517-
518509
global _CACHED_RUNTIME_MODES
519510

520-
cache_key = ("MODE", config.linker) if upper_string == "MODE" else upper_string
521-
522-
try:
523-
return _CACHED_RUNTIME_MODES[cache_key]
524-
except KeyError:
525-
pass
511+
if upper_string in _CACHED_RUNTIME_MODES:
512+
return _CACHED_RUNTIME_MODES[upper_string]
526513

527514
# Need to define the mode for the first time
528515
if upper_string == "MODE":
@@ -548,7 +535,7 @@ def get_mode(orig_string):
548535
if config.optimizer_requiring:
549536
ret = ret.requiring(*config.optimizer_requiring.split(":"))
550537
# Cache the mode for next time
551-
_CACHED_RUNTIME_MODES[cache_key] = ret
538+
_CACHED_RUNTIME_MODES[upper_string] = ret
552539

553540
return ret
554541

pytensor/configdefaults.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,26 +371,24 @@ def add_compile_configvars():
371371
)
372372
del param
373373

374-
default_linker = "auto"
374+
default_linker = "numba"
375375

376376
if rc == 0 and config.cxx != "":
377377
# Keep the default linker the same as the one for the mode FAST_RUN
378378
linker_options = [
379-
"cvm",
380-
"c|py",
379+
"cvmc|py",
381380
"py",
382381
"c",
383382
"c|py_nogc",
384383
"vm",
385384
"vm_nogc",
386385
"cvm_nogc",
387-
"numba",
388386
"jax",
389387
]
390388
else:
391389
# g++ is not present or the user disabled it,
392390
# linker should default to python only.
393-
linker_options = ["py", "vm", "vm_nogc", "numba", "jax"]
391+
linker_options = ["py", "vm", "vm_nogc", "jax"]
394392
if type(config).cxx.is_default:
395393
# If the user provided an empty value for cxx, do not warn.
396394
_logger.warning(
@@ -402,8 +400,9 @@ def add_compile_configvars():
402400

403401
config.add(
404402
"linker",
405-
"Default linker used if the pytensor flags mode is Mode or FAST_RUN",
406-
EnumStr(default_linker, linker_options, mutable=True),
403+
"Default linker used if the pytensor flags mode is Mode",
404+
# Not mutable because the default mode is cached after the first use.
405+
EnumStr(default_linker, linker_options, mutable=False),
407406
in_c_key=False,
408407
)
409408

pytensor/configparser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class PyTensorConfigParser:
7676
unpickle_function: bool
7777
# add_compile_configvars
7878
mode: str
79-
fast_run_backend: str
8079
cxx: str
8180
linker: str
8281
allow_gc: bool

tests/compile/test_mode.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,9 @@ def test_modes(self):
7777

7878
# Linkers to use with regular Mode
7979
if config.cxx:
80-
linkers = [
81-
"py",
82-
"c|py",
83-
"c|py_nogc",
84-
"vm",
85-
"vm_nogc",
86-
"cvm",
87-
"cvm_nogc",
88-
"numba",
89-
]
80+
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "cvm", "cvm_nogc"]
9081
else:
91-
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "numba"]
82+
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc"]
9283
modes = predef_modes + [Mode(linker, "fast_run") for linker in linkers]
9384

9485
for mode in modes:

0 commit comments

Comments
 (0)