Skip to content

Commit 0bfdcf7

Browse files
committed
Try to run full test suite in Numba backend
1 parent 02bf0a0 commit 0bfdcf7

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ jobs:
153153
- name: Install dependencies
154154
shell: micromamba-shell {0}
155155
run: |
156-
157156
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
158-
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
157+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"
159158
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
160159
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
161160
pip install pytest-sphinx

pytensor/compile/mode.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ def register_linker(name, linker):
6363
# If a string is passed as the optimizer argument in the constructor
6464
# for Mode, it will be used as the key to retrieve the real optimizer
6565
# in this dictionary
66-
exclude = []
67-
if not config.cxx:
68-
exclude = ["cxx_only"]
66+
67+
exclude = ["cxx_only", "BlasOpt"]
6968
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
7069
# Even if multiple merge optimizer call will be there, this shouldn't
7170
# impact performance.
@@ -342,6 +341,11 @@ def __setstate__(self, state):
342341
optimizer = predefined_optimizers[optimizer]
343342
if isinstance(optimizer, RewriteDatabaseQuery):
344343
self.provided_optimizer = optimizer
344+
345+
# Force numba-required rewrites if using NumbaLinker
346+
if isinstance(linker, NumbaLinker):
347+
optimizer = optimizer.including("numba")
348+
345349
self._optimizer = optimizer
346350
self.call_time = 0
347351
self.fn_time = 0
@@ -439,16 +443,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
439443
# string as the key
440444
# Use VM_linker to allow lazy evaluation by default.
441445
FAST_COMPILE = Mode(
442-
VMLinker(use_cloop=False, c_thunks=False),
443-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
446+
NumbaLinker(),
447+
# TODO: Fast_compile should just use python code, CHANGE ME!
448+
RewriteDatabaseQuery(
449+
include=["fast_compile", "numba"],
450+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
451+
),
452+
)
453+
FAST_RUN = Mode(
454+
NumbaLinker(),
455+
RewriteDatabaseQuery(
456+
include=["fast_run", "numba"],
457+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
458+
),
444459
)
445-
if config.cxx:
446-
FAST_RUN = Mode("cvm", "fast_run")
447-
else:
448-
FAST_RUN = Mode(
449-
"vm",
450-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
451-
)
452460

453461
JAX = Mode(
454462
JAXLinker(),
@@ -528,7 +536,7 @@ def get_mode(orig_string):
528536
# NanGuardMode use its own linker.
529537
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
530538
else:
531-
# TODO: Can't we look up the name and invoke it rather than using eval here?
539+
# TODO: Get rid of this? Or refactor?
532540
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
533541
elif string in predefined_modes:
534542
ret = predefined_modes[string]
@@ -557,6 +565,7 @@ def register_mode(name, mode):
557565
Add a `Mode` which can be referred to by `name` in `function`.
558566
559567
"""
568+
# TODO: Remove me
560569
if name in predefined_modes:
561570
raise ValueError(f"Mode name already taken: {name}")
562571
predefined_modes[name] = mode

pytensor/configdefaults.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,11 @@ def add_compile_configvars():
371371

372372
if rc == 0 and config.cxx != "":
373373
# Keep the default linker the same as the one for the mode FAST_RUN
374-
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
374+
linker_options = ["cvm", "c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc", "jax"]
375375
else:
376376
# g++ is not present or the user disabled it,
377377
# linker should default to python only.
378-
linker_options = ["py", "vm_nogc"]
378+
linker_options = ["py", "vm", "vm_nogc", "jax"]
379379
if type(config).cxx.is_default:
380380
# If the user provided an empty value for cxx, do not warn.
381381
_logger.warning(
@@ -388,7 +388,7 @@ def add_compile_configvars():
388388
config.add(
389389
"linker",
390390
"Default linker used if the pytensor flags mode is Mode",
391-
EnumStr("cvm", linker_options),
391+
EnumStr("numba", linker_options),
392392
in_c_key=False,
393393
)
394394

0 commit comments

Comments
 (0)