Skip to content

Commit 608b699

Browse files
committed
Suppress noisy numba warnings
1 parent 9a9e9d8 commit 608b699

File tree

2 files changed

+71
-8
lines changed

2 files changed

+71
-8
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numba
77
import numpy as np
8+
from numba import NumbaPerformanceWarning, NumbaWarning
89
from numba import njit as _njit
910
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
1011

@@ -23,6 +24,35 @@
2324
from pytensor.tensor.utils import hash_from_ndarray
2425

2526

27+
def _filter_numba_warnings():
28+
# Suppress large global arrays cache warning for internal functions
29+
# We have to add an ansi escape code for optional bold text by numba
30+
# TODO: We could avoid inlining large constants and pass them at runtime
31+
warnings.filterwarnings(
32+
"ignore",
33+
message=(
34+
"(\x1b\\[1m)*" # ansi escape code for bold text
35+
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
36+
),
37+
category=NumbaWarning,
38+
)
39+
40+
# Disable loud / incorrect warnings from Numba
41+
# https://github.com/numba/numba/issues/10086
42+
# TODO: Would be much better if we could disable only for our functions
43+
warnings.filterwarnings(
44+
"ignore",
45+
message=(
46+
"(\x1b\\[1m)*" # ansi escape code for bold text
47+
r"np\.dot\(\) is faster on contiguous arrays"
48+
),
49+
category=NumbaPerformanceWarning,
50+
)
51+
52+
53+
_filter_numba_warnings()
54+
55+
2656
def numba_njit(
2757
*args, fastmath=None, final_function: bool = False, **kwargs
2858
) -> Callable:

tests/link/numba/test_basic.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytensor.graph.type import Type
2626
from pytensor.link.numba.dispatch import basic as numba_basic
2727
from pytensor.link.numba.dispatch.basic import (
28+
_filter_numba_warnings,
2829
cache_key_for_constant,
2930
numba_funcify_and_cache_key,
3031
)
@@ -455,14 +456,46 @@ def test_scalar_return_value_conversion():
455456
assert isinstance(x_fn(1.0), np.ndarray)
456457

457458

458-
@pytest.mark.filterwarnings("error")
459-
def test_cache_warning_suppressed():
460-
x = pt.vector("x", shape=(5,), dtype="float64")
461-
out = pt.psi(x) * 2
462-
fn = function([x], out, mode="NUMBA")
463-
464-
x_test = np.random.uniform(size=5)
465-
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)
459+
class TestNumbaWarnings:
460+
def setup_method(self, method):
461+
# Pytest messes up with the package filters, reenable here for testing
462+
_filter_numba_warnings()
463+
464+
@pytest.mark.filterwarnings("error")
465+
def test_cache_pointer_func_warning_suppressed(self):
466+
x = pt.vector("x", shape=(5,), dtype="float64")
467+
out = pt.psi(x) * 2
468+
fn = function([x], out, mode="NUMBA")
469+
470+
x_test = np.random.uniform(size=5)
471+
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)
472+
473+
@pytest.mark.filterwarnings("error")
474+
def test_cache_large_global_array_warning_suppressed(self):
475+
rng = np.random.default_rng(458)
476+
large_constant = rng.normal(size=(100000, 5))
477+
478+
x = pt.vector("x", shape=(5,), dtype="float64")
479+
out = x * large_constant
480+
fn = function([x], out, mode="NUMBA")
481+
482+
x_test = rng.uniform(size=5)
483+
np.testing.assert_allclose(fn(x_test), x_test * large_constant)
484+
485+
@pytest.mark.filterwarnings("error")
486+
def test_contiguous_array_dot_warning_suppressed(self):
487+
A = pt.matrix("A")
488+
b = pt.vector("b")
489+
out = pt.dot(A, b[:, None])
490+
# Cached functions won't reemit the warning, so we have to disable it
491+
with config.change_flags(numba__cache=False):
492+
fn = function([A, b], out, mode="NUMBA")
493+
494+
A_test = np.ones((5, 5))
495+
# Numba actually warns even on contiguous arrays: https://github.com/numba/numba/issues/10086
496+
# But either way we don't want this warning for users as they have little control over strides
497+
b_test = np.ones((10,))[::2]
498+
np.testing.assert_allclose(fn(A_test, b_test), np.dot(A_test, b_test[:, None]))
466499

467500

468501
@pytest.mark.parametrize("mode", ("default", "trust_input", "direct"))

0 commit comments

Comments
 (0)