From c1b20113ca67b8149f524fdc77cd852753f1b9e3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Nov 2025 22:08:21 +0100 Subject: [PATCH 1/6] Fix invalid py_file in numba cache locator --- pytensor/link/numba/cache.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/cache.py b/pytensor/link/numba/cache.py index a93ad09cd2..09d3c05830 100644 --- a/pytensor/link/numba/cache.py +++ b/pytensor/link/numba/cache.py @@ -1,6 +1,5 @@ from collections.abc import Callable from hashlib import sha256 -from pathlib import Path from pickle import dump from tempfile import NamedTemporaryFile from typing import Any @@ -64,8 +63,8 @@ def get_disambiguator(self): @classmethod def from_function(cls, py_func, py_file): """Create a locator instance for functions stored in CACHED_SRC_FUNCTIONS.""" - if config.numba__cache and py_func in CACHED_SRC_FUNCTIONS: - return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func]) + if py_func in CACHED_SRC_FUNCTIONS and config.numba__cache: + return cls(py_func, py_file, CACHED_SRC_FUNCTIONS[py_func]) # Register our locator at the front of Numba's locator list From 1979244748d3951ca9984240008cbaedc705bcff Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Nov 2025 19:41:27 +0100 Subject: [PATCH 2/6] Fix numba FunctionGraph cache key It's necessary to encode the edge information, not only the nodes and their ordering --- pytensor/link/numba/dispatch/basic.py | 51 ++++++++++------ pytensor/link/utils.py | 15 +++-- tests/link/numba/test_basic.py | 87 +++++++++++++++++++++++++-- 3 files changed, 122 insertions(+), 31 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 35f0af32d5..744abb0a17 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -9,7 +9,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from pytensor import config -from pytensor.graph.basic import Apply, Constant +from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump @@ -498,36 +498,46 @@ def numba_funcify_FunctionGraph( ): # Collect cache keys of every Op/Constant in the FunctionGraph # so we can create a global cache key for the whole FunctionGraph + fgraph_can_be_cached = [True] cache_keys = [] toposort = fgraph.toposort() - clients = fgraph.clients - toposort_indices = {node: i for i, node in enumerate(toposort)} - # Add dummy output clients which are not included of the toposort - toposort_indices |= { - clients[out][0][0]: i - for i, out in enumerate(fgraph.outputs, start=len(toposort)) + toposort_coords: dict[Variable, tuple[int, int | str]] = { + inp: (0, i) for i, inp in enumerate(fgraph.inputs) + } + toposort_coords |= { + out: (i, j) + for i, node in enumerate(toposort, start=1) + for j, out in enumerate(node.outputs) } - def op_conversion_and_key_collection(*args, **kwargs): + def op_conversion_and_key_collection(op, *args, node, **kwargs): # Convert an Op to a funcified function and store the cache_key # We also Cache each Op so Numba can do less work next time it sees it - func, key = numba_funcify_ensure_cache(*args, **kwargs) - cache_keys.append(key) + func, key = numba_funcify_ensure_cache(op, node=node, *args, **kwargs) + if key is None: + fgraph_can_be_cached[0] = False + else: + # Add graph coordinate information (input edges and node location) + cache_keys.append( + ( + tuple(toposort_coords[inp] for inp in node.inputs), + key, + ) + ) return func def type_conversion_and_key_collection(value, variable, **kwargs): # Convert a constant type to a numba compatible one and compute a cache key for it - # We need to know where in the graph the constants are used - # Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same # FIXME: It doesn't make sense to call type_conversion on non-constants, - # but that's what fgraph_to_python currently does. We appease it, but don't consider for caching + # but that's what fgraph_to_python currently does. + # We appease it, but don't consider for caching if isinstance(variable, Constant): - client_indices = tuple( - (toposort_indices[node], inp_idx) for node, inp_idx in clients[variable] - ) - cache_keys.append((client_indices, cache_key_for_constant(value))) + # Store unique key in toposort_coords. It will be included by whichever nodes make use of the constant + constant_cache_key = cache_key_for_constant(value) + assert constant_cache_key is not None + toposort_coords[variable] = (-1, constant_cache_key) return numba_typify(value, variable=variable, **kwargs) py_func = fgraph_to_python( @@ -537,12 +547,15 @@ def type_conversion_and_key_collection(value, variable, **kwargs): fgraph_name=fgraph_name, **kwargs, ) - if any(key is None for key in cache_keys): + if not fgraph_can_be_cached[0]: # If a single element couldn't be cached, we can't cache the whole FunctionGraph either fgraph_key = None else: + # Add graph coordinate information for fgraph outputs + fgraph_output_ancestors = tuple(toposort_coords[out] for out in fgraph.outputs) + # Compose individual cache_keys into a global key for the FunctionGraph fgraph_key = sha256( - f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode() + f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {fgraph_output_ancestors})".encode() ).hexdigest() return numba_njit(py_func), fgraph_key diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 6910601e3c..d13719a067 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -735,14 +735,6 @@ def fgraph_to_python( body_assigns = [] for node in order: - compiled_func = op_conversion_fn( - node.op, node=node, storage_map=storage_map, **kwargs - ) - - # Create a local alias with a unique name - local_compiled_func_name = unique_name(compiled_func) - global_env[local_compiled_func_name] = compiled_func - node_input_names = [] for inp in node.inputs: local_input_name = unique_name(inp) @@ -772,6 +764,13 @@ def fgraph_to_python( node_output_names = [unique_name(v) for v in node.outputs] + compiled_func = op_conversion_fn( + node.op, node=node, storage_map=storage_map, **kwargs + ) + # Create a local alias with a unique name + local_compiled_func_name = unique_name(compiled_func) + global_env[local_compiled_func_name] = compiled_func + assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})" assign_comment_str = f"{indent(str(node), '# ')}" assign_block_str = f"{assign_comment_str}\n{assign_str}" diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 43e4885840..0054fda15a 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -7,8 +7,7 @@ import pytest import scipy -from pytensor.compile import SymbolicInput -from pytensor.tensor.utils import hash_from_ndarray +from pytensor.tensor import scalar_from_tensor numba = pytest.importorskip("numba") @@ -16,17 +15,23 @@ import pytensor.scalar as ps import pytensor.tensor as pt from pytensor import config, shared +from pytensor.compile import SymbolicInput from pytensor.compile.function import function from pytensor.compile.mode import Mode from pytensor.graph.basic import Apply, Variable +from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.type import Type from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import cache_key_for_constant +from pytensor.link.numba.dispatch.basic import ( + cache_key_for_constant, + numba_funcify_and_cache_key, +) from pytensor.link.numba.linker import NumbaLinker -from pytensor.scalar.basic import ScalarOp, as_scalar +from pytensor.scalar.basic import Composite, ScalarOp, as_scalar from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.utils import hash_from_ndarray if TYPE_CHECKING: @@ -652,3 +657,77 @@ def impl(x): outs[2].owner.op, outs[2].owner ) assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2 + + +class TestFgraphCacheKey: + @staticmethod + def generate_and_validate_key(fg): + _, key = numba_funcify_and_cache_key(fg) + assert key is not None + _, key_again = numba_funcify_and_cache_key(fg) + assert key == key_again # Check its stable + return key + + def test_node_order(self): + x = pt.scalar("x") + log_x = pt.log(x) + graphs = [ + pt.exp(x) / log_x, + log_x / pt.exp(x), + pt.exp(log_x) / x, + x / pt.exp(log_x), + pt.exp(log_x) / log_x, + log_x / pt.exp(log_x), + ] + + keys = [] + for graph in graphs: + fg = FunctionGraph([x], [graph], clone=False) + keys.append(self.generate_and_validate_key(fg)) + # Check keys are unique + assert len(set(keys)) == len(graphs) + + # Extra unused input should alter the key, because it changes the function signature + y = pt.scalar("y") + for inputs in [[x, y], [y, x]]: + fg = FunctionGraph(inputs, [graphs[0]], clone=False) + keys.append(self.generate_and_validate_key(fg)) + assert len(set(keys)) == len(graphs) + 2 + + # Adding an input as an output should also change the key + for outputs in [ + [graphs[0], x], + [x, graphs[0]], + [x, x, graphs[0]], + [x, graphs[0], x], + [graphs[0], x, x], + ]: + fg = FunctionGraph([x], outputs, clone=False) + keys.append(self.generate_and_validate_key(fg)) + assert len(set(keys)) == len(graphs) + 2 + 5 + + def test_multi_output(self): + x = pt.scalar("x") + + xs = scalar_from_tensor(x) + out0, out1 = Elemwise(Composite([xs], [xs * 2, xs - 2]))(x) + + test_outs = [ + [out0], + [out1], + [out0, out1], + [out1, out0], + ] + keys = [] + for test_out in test_outs: + fg = FunctionGraph([x], test_out, clone=False) + keys.append(self.generate_and_validate_key(fg)) + assert len(set(keys)) == len(test_outs) + + def test_constant_output(self): + fg_pi = FunctionGraph([], [pt.constant(np.pi)]) + fg_e = FunctionGraph([], [pt.constant(np.e)]) + + assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key( + fg_e + ) From 8ffe065d74f1b9f85cfd6a33aa8a2b8d6ed82e83 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Nov 2025 23:06:06 +0100 Subject: [PATCH 3/6] Fix cache key of `numba_funcify_multiple_integer_vector_indexing` --- pytensor/link/numba/dispatch/subtensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 1e8078e477..9b65eddee5 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -409,6 +409,8 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs): op, func="multiple_integer_vector_indexing", y_is_broadcasted=y_is_broadcasted, + first_axis=first_axis, + last_axis=last_axis, ) return ret_func, cache_key From 6cd29cc51fdee10bc9e08c634bd7a74dff51b33a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Nov 2025 23:08:46 +0100 Subject: [PATCH 4/6] Fix cache of default subtensor Implementation was specializing on node repeated inputs an `unique_names` would return the same name for repeated inputs. The cache key didn't account for this. We also don't want to compile different functions for different patterns of repeated inputs as it doesn't translate to an obvious handle for the compiler to specialize upon. We we wanted to inline constants that may make more sense. --- pytensor/link/numba/dispatch/subtensor.py | 30 +++++++++++------------ 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 9b65eddee5..c7cc4cfd8e 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -18,7 +18,6 @@ register_funcify_and_cache_key, register_funcify_default_op_cache_key, ) -from pytensor.link.utils import unique_name_generator from pytensor.tensor import TensorType from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( @@ -143,19 +142,14 @@ def subtensor_op_cache_key(op, **extra_fields): def numba_funcify_default_subtensor(op, node, **kwargs): """Create a Python function that assembles and uses an index on an array.""" - unique_names = unique_name_generator( - ["subtensor", "incsubtensor", "z"], suffix_sep="_" - ) - - def convert_indices(indices, entry): - if indices and isinstance(entry, Type): - rval = indices.pop(0) - return unique_names(rval) + def convert_indices(indice_names, entry): + if indice_names and isinstance(entry, Type): + return next(indice_names) elif isinstance(entry, slice): return ( - f"slice({convert_indices(indices, entry.start)}, " - f"{convert_indices(indices, entry.stop)}, " - f"{convert_indices(indices, entry.step)})" + f"slice({convert_indices(indice_names, entry.start)}, " + f"{convert_indices(indice_names, entry.stop)}, " + f"{convert_indices(indice_names, entry.step)})" ) elif isinstance(entry, type(None)): return "None" @@ -166,13 +160,15 @@ def convert_indices(indices, entry): op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor ) index_start_idx = 1 + int(set_or_inc) - - input_names = [unique_names(v, force_unique=True) for v in node.inputs] op_indices = list(node.inputs[index_start_idx:]) idx_list = getattr(op, "idx_list", None) + idx_names = [f"idx_{i}" for i in range(len(op_indices))] + input_names = ["x", "y", *idx_names] if set_or_inc else ["x", *idx_names] + + idx_names_iterator = iter(idx_names) indices_creation_src = ( - tuple(convert_indices(op_indices, idx) for idx in idx_list) + tuple(convert_indices(idx_names_iterator, idx) for idx in idx_list) if idx_list else tuple(input_names[index_start_idx:]) ) @@ -220,7 +216,9 @@ def {function_name}({", ".join(input_names)}): function_name=function_name, global_env=globals() | {"np": np}, ) - cache_key = subtensor_op_cache_key(op, func="numba_funcify_default_subtensor") + cache_key = subtensor_op_cache_key( + op, func="numba_funcify_default_subtensor", version=1 + ) return numba_basic.numba_njit(func, boundscheck=True), cache_key From 9a9e9d80169c3ca3b5b04e54c310f6705369a663 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 20 Nov 2025 00:18:00 +0100 Subject: [PATCH 5/6] Remove uses of unique_name_generator in numba dispatch It's more readable and avoids potential bugs when force_unique is not set to True --- pytensor/link/numba/dispatch/scalar.py | 34 +++++++------------- pytensor/link/numba/dispatch/tensor_basic.py | 34 ++++---------------- 2 files changed, 18 insertions(+), 50 deletions(-) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 4a4d9b319d..50af695a2e 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -14,7 +14,6 @@ from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.utils import ( get_name_for_object, - unique_name_generator, ) from pytensor.scalar.basic import ( Add, @@ -81,23 +80,21 @@ def numba_funcify_ScalarOp(op, node, **kwargs): scalar_func_numba = generate_fallback_impl(op, node, **kwargs) scalar_op_fn_name = get_name_for_object(scalar_func_numba) - + prefix = "x" if scalar_func_name != "x" else "y" + input_names = [f"{prefix}{i}" for i in range(len(node.inputs))] + input_signature = ", ".join(input_names) global_env = {"scalar_func_numba": scalar_func_numba} if input_inner_dtypes is None and output_inner_dtype is None: - unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" - ) - input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs) if not has_pyx_skip_dispatch: scalar_op_src = f""" -def {scalar_op_fn_name}({input_names}): - return scalar_func_numba({input_names}) +def {scalar_op_fn_name}({input_signature}): + return scalar_func_numba({input_signature}) """ else: scalar_op_src = f""" -def {scalar_op_fn_name}({input_names}): - return scalar_func_numba({input_names}, np.intc(1)) +def {scalar_op_fn_name}({input_signature}): + return scalar_func_numba({input_signature}, np.intc(1)) """ else: @@ -108,13 +105,6 @@ def {scalar_op_fn_name}({input_names}): for i, i_dtype in enumerate(input_inner_dtypes) } global_env.update(input_tmp_dtype_names) - - unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func_numba", *global_env.keys()], - suffix_sep="_", - ) - - input_names = [unique_names(v, force_unique=True) for v in node.inputs] converted_call_args = ", ".join( f"direct_cast({i_name}, {i_tmp_dtype_name})" for i_name, i_tmp_dtype_name in zip( @@ -123,19 +113,19 @@ def {scalar_op_fn_name}({input_names}): ) if not has_pyx_skip_dispatch: scalar_op_src = f""" -def {scalar_op_fn_name}({", ".join(input_names)}): +def {scalar_op_fn_name}({input_signature}): return direct_cast(scalar_func_numba({converted_call_args}), output_dtype) """ else: scalar_op_src = f""" -def {scalar_op_fn_name}({", ".join(input_names)}): +def {scalar_op_fn_name}({input_signature}): return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) """ scalar_op_fn = compile_numba_function_src( scalar_op_src, scalar_op_fn_name, - {**globals(), **global_env}, + globals() | global_env, ) # Functions that call a function pointer can't be cached @@ -157,8 +147,8 @@ def switch(condition, x, y): def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): """Create a Numba-compatible N-ary function from a binary function.""" - unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_") - input_names = [unique_names(v, force_unique=True) for v in inputs] + var_prefix = "x" if binary_op_name != "x" else "y" + input_names = [f"{var_prefix}{i}" for i in range(len(inputs))] input_signature = ", ".join(input_names) output_expr = binary_op.join(input_names) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 9553273f99..3c7bda9a15 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -10,7 +10,6 @@ register_funcify_and_cache_key, register_funcify_default_op_cache_key, ) -from pytensor.link.utils import unique_name_generator from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -28,15 +27,7 @@ @register_funcify_default_op_cache_key(AllocEmpty) def numba_funcify_AllocEmpty(op, node, **kwargs): - global_env = { - "np": np, - "dtype": np.dtype(op.dtype), - } - - unique_names = unique_name_generator( - ["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_" - ) - shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs] + shape_var_names = [f"sh{i}" for i in range(len(node.inputs))] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( @@ -56,7 +47,7 @@ def allocempty({", ".join(shape_var_names)}): """ alloc_fn = compile_numba_function_src( - alloc_def_src, "allocempty", {**globals(), **global_env} + alloc_def_src, "allocempty", globals() | {"np": np, "dtype": np.dtype(op.dtype)} ) return numba_basic.numba_njit(alloc_fn) @@ -64,13 +55,7 @@ def allocempty({", ".join(shape_var_names)}): @register_funcify_and_cache_key(Alloc) def numba_funcify_Alloc(op, node, **kwargs): - global_env = {"np": np} - - unique_names = unique_name_generator( - ["np", "alloc", "val_np", "val", "scalar_shape", "res"], - suffix_sep="_", - ) - shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]] + shape_var_names = [f"sh{i}" for i in range(len(node.inputs) - 1)] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( @@ -102,7 +87,7 @@ def alloc(val, {", ".join(shape_var_names)}): alloc_fn = compile_numba_function_src( alloc_def_src, "alloc", - {**globals(), **global_env}, + globals() | {"np": np}, ) cache_key = sha256( @@ -207,14 +192,7 @@ def eye(N, M, k): @register_funcify_default_op_cache_key(MakeVector) def numba_funcify_MakeVector(op, node, **kwargs): dtype = np.dtype(op.dtype) - - global_env = {"np": np, "dtype": dtype} - - unique_names = unique_name_generator( - ["np"], - suffix_sep="_", - ) - input_names = [unique_names(v, force_unique=True) for v in node.inputs] + input_names = [f"x{i}" for i in range(len(node.inputs))] def create_list_string(x): args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else [])) @@ -228,7 +206,7 @@ def makevector({", ".join(input_names)}): makevector_fn = compile_numba_function_src( makevector_def_src, "makevector", - {**globals(), **global_env}, + globals() | {"np": np, "dtype": dtype}, ) return numba_basic.numba_njit(makevector_fn) From 608b699ba84322254ad3b713f839bb649f4c2a37 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Nov 2025 15:55:32 +0100 Subject: [PATCH 6/6] Suppress noisy numba warnings --- pytensor/link/numba/dispatch/basic.py | 30 ++++++++++++++++ tests/link/numba/test_basic.py | 49 ++++++++++++++++++++++----- 2 files changed, 71 insertions(+), 8 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 744abb0a17..299479af07 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -5,6 +5,7 @@ import numba import numpy as np +from numba import NumbaPerformanceWarning, NumbaWarning from numba import njit as _njit from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 @@ -23,6 +24,35 @@ from pytensor.tensor.utils import hash_from_ndarray +def _filter_numba_warnings(): + # Suppress large global arrays cache warning for internal functions + # We have to add an ansi escape code for optional bold text by numba + # TODO: We could avoid inlining large constants and pass them at runtime + warnings.filterwarnings( + "ignore", + message=( + "(\x1b\\[1m)*" # ansi escape code for bold text + 'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals' + ), + category=NumbaWarning, + ) + + # Disable loud / incorrect warnings from Numba + # https://github.com/numba/numba/issues/10086 + # TODO: Would be much better if we could disable only for our functions + warnings.filterwarnings( + "ignore", + message=( + "(\x1b\\[1m)*" # ansi escape code for bold text + r"np\.dot\(\) is faster on contiguous arrays" + ), + category=NumbaPerformanceWarning, + ) + + +_filter_numba_warnings() + + def numba_njit( *args, fastmath=None, final_function: bool = False, **kwargs ) -> Callable: diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 0054fda15a..e6d84ea4d0 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -25,6 +25,7 @@ from pytensor.graph.type import Type from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( + _filter_numba_warnings, cache_key_for_constant, numba_funcify_and_cache_key, ) @@ -455,14 +456,46 @@ def test_scalar_return_value_conversion(): assert isinstance(x_fn(1.0), np.ndarray) -@pytest.mark.filterwarnings("error") -def test_cache_warning_suppressed(): - x = pt.vector("x", shape=(5,), dtype="float64") - out = pt.psi(x) * 2 - fn = function([x], out, mode="NUMBA") - - x_test = np.random.uniform(size=5) - np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) +class TestNumbaWarnings: + def setup_method(self, method): + # Pytest messes up with the package filters, reenable here for testing + _filter_numba_warnings() + + @pytest.mark.filterwarnings("error") + def test_cache_pointer_func_warning_suppressed(self): + x = pt.vector("x", shape=(5,), dtype="float64") + out = pt.psi(x) * 2 + fn = function([x], out, mode="NUMBA") + + x_test = np.random.uniform(size=5) + np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) + + @pytest.mark.filterwarnings("error") + def test_cache_large_global_array_warning_suppressed(self): + rng = np.random.default_rng(458) + large_constant = rng.normal(size=(100000, 5)) + + x = pt.vector("x", shape=(5,), dtype="float64") + out = x * large_constant + fn = function([x], out, mode="NUMBA") + + x_test = rng.uniform(size=5) + np.testing.assert_allclose(fn(x_test), x_test * large_constant) + + @pytest.mark.filterwarnings("error") + def test_contiguous_array_dot_warning_suppressed(self): + A = pt.matrix("A") + b = pt.vector("b") + out = pt.dot(A, b[:, None]) + # Cached functions won't reemit the warning, so we have to disable it + with config.change_flags(numba__cache=False): + fn = function([A, b], out, mode="NUMBA") + + A_test = np.ones((5, 5)) + # Numba actually warns even on contiguous arrays: https://github.com/numba/numba/issues/10086 + # But either way we don't want this warning for users as they have little control over strides + b_test = np.ones((10,))[::2] + np.testing.assert_allclose(fn(A_test, b_test), np.dot(A_test, b_test[:, None])) @pytest.mark.parametrize("mode", ("default", "trust_input", "direct"))