Skip to content

Commit 64eeee1

Browse files
committed
Try to run full test suite in Numba backend
1 parent 2caab13 commit 64eeee1

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ jobs:
202202
else
203203
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
204204
fi
205-
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
205+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57";
206206
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
207207
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
208208
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi

pytensor/compile/mode.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -445,37 +445,37 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
445445
return new_mode
446446

447447

448-
# If a string is passed as the mode argument in function or
449-
# FunctionMaker, the Mode will be taken from this dictionary using the
450-
# string as the key
451-
# Use VM_linker to allow lazy evaluation by default.
448+
numba_exclude = [
449+
"cxx_only",
450+
"BlasOpt",
451+
"local_careduce_fusion",
452+
"scan_save_mem_prealloc",
453+
]
454+
452455
FAST_COMPILE = Mode(
453-
VMLinker(use_cloop=False, c_thunks=False),
454-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
456+
NumbaLinker(),
457+
# TODO: Fast_compile should just use python code, CHANGE ME!
458+
RewriteDatabaseQuery(
459+
include=["fast_compile", "numba"],
460+
exclude=numba_exclude,
461+
),
462+
)
463+
FAST_RUN = Mode(
464+
NumbaLinker(),
465+
RewriteDatabaseQuery(include=["fast_run", "numba"], exclude=numba_exclude),
455466
)
456-
if config.cxx:
457-
FAST_RUN = Mode("cvm", "fast_run")
458-
else:
459-
FAST_RUN = Mode(
460-
"vm",
461-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
462-
)
463-
464-
C = Mode("c", "fast_run")
465-
C_VM = Mode("cvm", "fast_run")
466467

467468
NUMBA = Mode(
468469
NumbaLinker(),
469470
RewriteDatabaseQuery(
470471
include=["fast_run", "numba"],
471-
exclude=[
472-
"cxx_only",
473-
"BlasOpt",
474-
"local_careduce_fusion",
475-
"scan_save_mem_prealloc",
476-
],
472+
exclude=numba_exclude,
477473
),
478474
)
475+
del numba_exclude
476+
477+
C = Mode("c", "fast_run")
478+
C_VM = Mode("cvm", "fast_run")
479479

480480
JAX = Mode(
481481
JAXLinker(),
@@ -523,7 +523,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
523523
),
524524
)
525525

526-
526+
# If a string is passed as the mode argument in function or
527+
# FunctionMaker, the Mode will be taken from this dictionary using the
528+
# string as the key
527529
predefined_modes = {
528530
"FAST_COMPILE": FAST_COMPILE,
529531
"FAST_RUN": FAST_RUN,

pytensor/configdefaults.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,25 +371,24 @@ def add_compile_configvars():
371371
)
372372
del param
373373

374-
default_linker = "cvm"
374+
default_linker = "numba"
375375

376376
if rc == 0 and config.cxx != "":
377377
# Keep the default linker the same as the one for the mode FAST_RUN
378378
linker_options = [
379-
"c|py",
379+
"cvmc|py",
380380
"py",
381381
"c",
382382
"c|py_nogc",
383383
"vm",
384384
"vm_nogc",
385385
"cvm_nogc",
386-
"numba",
387386
"jax",
388387
]
389388
else:
390389
# g++ is not present or the user disabled it,
391390
# linker should default to python only.
392-
linker_options = ["py", "vm", "vm_nogc", "numba", "jax"]
391+
linker_options = ["py", "vm", "vm_nogc", "jax"]
393392
if type(config).cxx.is_default:
394393
# If the user provided an empty value for cxx, do not warn.
395394
_logger.warning(

0 commit comments

Comments
 (0)