Skip to content

Commit 9b1abf1

Browse files
committed
Try to run full test suite in Numba backend
1 parent 327cb79 commit 9b1abf1

File tree

3 files changed

+30
-31
lines changed

3 files changed

+30
-31
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ jobs:
202202
else
203203
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
204204
fi
205-
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
205+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57";
206206
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
207207
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
208208
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi

pytensor/compile/mode.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def register_linker(name, linker):
6565
# If a string is passed as the optimizer argument in the constructor
6666
# for Mode, it will be used as the key to retrieve the real optimizer
6767
# in this dictionary
68-
exclude = []
69-
if not config.cxx:
70-
exclude = ["cxx_only"]
68+
exclude = ["cxx_only", "BlasOpt"]
7169
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
7270
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
7371
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
@@ -445,37 +443,37 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
445443
return new_mode
446444

447445

448-
# If a string is passed as the mode argument in function or
449-
# FunctionMaker, the Mode will be taken from this dictionary using the
450-
# string as the key
451-
# Use VM_linker to allow lazy evaluation by default.
446+
numba_exclude = [
447+
"cxx_only",
448+
"BlasOpt",
449+
"local_careduce_fusion",
450+
"scan_save_mem_prealloc",
451+
]
452+
452453
FAST_COMPILE = Mode(
453-
VMLinker(use_cloop=False, c_thunks=False),
454-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
454+
NumbaLinker(),
455+
# TODO: Fast_compile should just use python code, CHANGE ME!
456+
RewriteDatabaseQuery(
457+
include=["fast_compile", "numba"],
458+
exclude=numba_exclude,
459+
),
460+
)
461+
FAST_RUN = Mode(
462+
NumbaLinker(),
463+
RewriteDatabaseQuery(include=["fast_run", "numba"], exclude=numba_exclude),
455464
)
456-
if config.cxx:
457-
FAST_RUN = Mode("cvm", "fast_run")
458-
else:
459-
FAST_RUN = Mode(
460-
"vm",
461-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
462-
)
463-
464-
C = Mode("c", "fast_run")
465-
C_VM = Mode("cvm", "fast_run")
466465

467466
NUMBA = Mode(
468467
NumbaLinker(),
469468
RewriteDatabaseQuery(
470469
include=["fast_run", "numba"],
471-
exclude=[
472-
"cxx_only",
473-
"BlasOpt",
474-
"local_careduce_fusion",
475-
"scan_save_mem_prealloc",
476-
],
470+
exclude=numba_exclude,
477471
),
478472
)
473+
del numba_exclude
474+
475+
C = Mode("c", "fast_run")
476+
C_VM = Mode("cvm", "fast_run")
479477

480478
JAX = Mode(
481479
JAXLinker(),
@@ -523,7 +521,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
523521
),
524522
)
525523

526-
524+
# If a string is passed as the mode argument in function or
525+
# FunctionMaker, the Mode will be taken from this dictionary using the
526+
# string as the key
527527
predefined_modes = {
528528
"FAST_COMPILE": FAST_COMPILE,
529529
"FAST_RUN": FAST_RUN,

pytensor/configdefaults.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,25 +371,24 @@ def add_compile_configvars():
371371
)
372372
del param
373373

374-
default_linker = "cvm"
374+
default_linker = "numba"
375375

376376
if rc == 0 and config.cxx != "":
377377
# Keep the default linker the same as the one for the mode FAST_RUN
378378
linker_options = [
379-
"c|py",
379+
"cvmc|py",
380380
"py",
381381
"c",
382382
"c|py_nogc",
383383
"vm",
384384
"vm_nogc",
385385
"cvm_nogc",
386-
"numba",
387386
"jax",
388387
]
389388
else:
390389
# g++ is not present or the user disabled it,
391390
# linker should default to python only.
392-
linker_options = ["py", "vm", "vm_nogc", "numba", "jax"]
391+
linker_options = ["py", "vm", "vm_nogc", "jax"]
393392
if type(config).cxx.is_default:
394393
# If the user provided an empty value for cxx, do not warn.
395394
_logger.warning(

0 commit comments

Comments
 (0)