From b765a23bc6f637b071ce9cce8a3fd61a66729cb5 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 8 Dec 2025 22:47:21 +0100 Subject: [PATCH 1/5] Work-around numba pow failure was not robust enough --- pytensor/link/numba/dispatch/scalar.py | 17 ++++++++--------- tests/link/numba/test_scalar.py | 20 +++++++++++++++----- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 76fae06fa6..6fe96180e7 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -170,18 +170,17 @@ def {binary_op_name}({input_signature}): @register_funcify_and_cache_key(Pow) def numba_funcify_Pow(op, node, **kwargs): pow_dtype = node.inputs[1].type.dtype - if pow_dtype.startswith("int"): - # Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True - # https://github.com/numba/numba/issues/9554 - def pow(x, y): - return x ** np.asarray(y, dtype=np.int64).item() - else: + def pow(x, y): + return x**y - def pow(x, y): - return x**y + # Numba power fails when exponents are discrete integers and fasthmath=True + # https://github.com/numba/numba/issues/9554 + fastmath = False if np.dtype(pow_dtype).kind in "ibu" else None - return numba_basic.numba_njit(pow), scalar_op_cache_key(op) + return numba_basic.numba_njit(pow, fastmath=fastmath), scalar_op_cache_key( + op, cache_version=1 + ) @register_funcify_and_cache_key(Add) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index fd52a60f4c..040405cb51 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -183,13 +183,23 @@ def test_Softplus(dtype): ) -def test_discrete_power(): +@pytest.mark.parametrize( + "test_base", + [np.bool(True), np.int16(3), np.uint16(3), np.float32(0.5), np.float64(0.5)], +) +@pytest.mark.parametrize( + "test_exponent", + [np.bool(True), np.int16(2), np.uint16(2), np.float32(2.0), np.float64(2.0)], +) +def test_power_fastmath_bug(test_base, test_exponent): # Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554 - x = pt.scalar("x", dtype="float64") - exponent = pt.scalar("exponent", dtype="int8") - out = pt.power(x, exponent) + base = pt.scalar("base", dtype=test_base.dtype) + exponent = pt.scalar("exponent", dtype=test_exponent.dtype) + out = pt.power(base, exponent) compare_numba_and_py( - [x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")] + [base, exponent], + [out], + [test_base, test_exponent], ) From 6a57c9ce94678543b81c5b3ab4951e382b5aa632 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 8 Dec 2025 14:46:00 +0100 Subject: [PATCH 2/5] Numba linalg: Fix obj fallback raise -> return --- pytensor/link/numba/dispatch/slinalg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 4c2647defd..b954706645 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -246,7 +246,7 @@ def numba_funcify_Solve(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype if A_dtype.kind == "c" or b_dtype.kind == "c": - raise generate_fallback_impl(op, node=node, **kwargs) + return generate_fallback_impl(op, node=node, **kwargs) must_cast_A = A_dtype != out_dtype if must_cast_A and config.compiler_verbose: print("Solve requires casting first input `A`") # noqa: T201 @@ -320,7 +320,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype if A_dtype.kind == "c" or b_dtype.kind == "c": - raise generate_fallback_impl(op, node=node, **kwargs) + return generate_fallback_impl(op, node=node, **kwargs) must_cast_A = A_dtype != out_dtype if must_cast_A and config.compiler_verbose: print("SolveTriangular requires casting first input `A`") # noqa: T201 @@ -371,7 +371,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype if c_dtype.kind == "c" or b_dtype.kind == "c": - raise generate_fallback_impl(op, node=node, **kwargs) + return generate_fallback_impl(op, node=node, **kwargs) must_cast_c = c_dtype != out_dtype if must_cast_c and config.compiler_verbose: print("CholeskySolve requires casting first input `c`") # noqa: T201 From f54e022d8d1f00cf2cae2c303020780a946561fd Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 8 Dec 2025 16:15:13 +0100 Subject: [PATCH 3/5] LU Op: Fix P out dtype for complex input --- pytensor/tensor/slinalg.py | 7 ++++++- tests/tensor/test_slinalg.py | 6 ++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 4361b59eb5..5ce5e8da12 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -508,7 +508,12 @@ def make_node(self, x): p_indices = tensor(shape=(x.type.shape[0],), dtype="int32") return Apply(self, inputs=[x], outputs=[p_indices, L, U]) - P = tensor(shape=x.type.shape, dtype=out_dtype) + if out_dtype.startswith("complex"): + P_dtype = "float64" if out_dtype == "complex128" else "float32" + else: + P_dtype = out_dtype + + P = tensor(shape=x.type.shape, dtype=P_dtype) return Apply(self, inputs=[x], outputs=[P, L, U]) def perform(self, node, inputs, outputs): diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 4140331036..a3ccc68ff5 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -648,9 +648,9 @@ def test_lu_decomposition( dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}" A = tensor("A", shape=shape, dtype=dtype) - out = lu(A, permute_l=permute_l, p_indices=p_indices) + pt_out = lu(A, permute_l=permute_l, p_indices=p_indices) - f = function([A], out) + f = function([A], pt_out) rng = np.random.default_rng(utt.fetch_seed()) x = rng.normal(size=shape).astype(config.floatX) @@ -658,6 +658,8 @@ def test_lu_decomposition( x = x + 1j * rng.normal(size=shape).astype(config.floatX) out = f(x) + for numerical_out, symbolic_out in zip(out, pt_out): + assert numerical_out.dtype == symbolic_out.type.dtype if permute_l: PL, U = out From 32837aa4694288f0d2f2ce7e1b47b91559799543 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 9 Dec 2025 01:04:26 +0100 Subject: [PATCH 4/5] Tridiagonal Solve: Fix dtype inference Scipy helper doesn't have special handling for ipiv int32 variable, and always assumes it must be cast to a float64 --- pytensor/tensor/_linalg/solve/tridiagonal.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/tensor/_linalg/solve/tridiagonal.py b/pytensor/tensor/_linalg/solve/tridiagonal.py index 0654d81cc7..a97d8eaf68 100644 --- a/pytensor/tensor/_linalg/solve/tridiagonal.py +++ b/pytensor/tensor/_linalg/solve/tridiagonal.py @@ -146,9 +146,8 @@ def make_node(self, dl, d, du, du2, ipiv, b): n = nb dummy_arrays = [ - np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv) + np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, b) ] - # Seems to always be float64? out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype if self.b_ndim == 1: output_shape = (n,) From 3111d2372173d1d404e0c7082fc87b5989a51531 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 9 Dec 2025 01:05:03 +0100 Subject: [PATCH 5/5] Numba tridiagonal: avoid inference error when casting inputs Numba doesn't infer the right type based on the static tuple, but does so with separate boolean variables --- .../dispatch/linalg/solve/tridiagonal.py | 28 +++++++++-------- .../numba/linalg/solve/test_tridiagonal.py | 31 +++++++++++++++++++ 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py index 73d08c0df8..ded91df90e 100644 --- a/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py +++ b/pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py @@ -356,8 +356,10 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): overwrite_du = op.overwrite_du out_dtype = node.outputs[1].type.numpy_dtype - must_cast_inputs = tuple(inp.type.numpy_dtype != out_dtype for inp in node.inputs) - if any(must_cast_inputs) and config.compiler_verbose: + cast_inputs = (cast_dl, cast_d, cast_du) = tuple( + inp.type.numpy_dtype != out_dtype for inp in node.inputs + ) + if any(cast_inputs) and config.compiler_verbose: print("LUFactorTridiagonal requires casting at least one input") # noqa: T201 @numba_basic.numba_njit(cache=False) @@ -371,11 +373,11 @@ def lu_factor_tridiagonal(dl, d, du): np.zeros(d.shape, dtype="int32"), ) - if must_cast_inputs[0]: + if cast_d: d = d.astype(out_dtype) - if must_cast_inputs[1]: + if cast_dl: dl = dl.astype(out_dtype) - if must_cast_inputs[2]: + if cast_du: du = du.astype(out_dtype) dl, d, du, du2, ipiv, _ = _gttrf( dl, @@ -402,7 +404,7 @@ def numba_funcify_SolveLUFactorTridiagonal( overwrite_b = op.overwrite_b transposed = op.transposed - must_cast_inputs = tuple( + must_cast_inputs = (cast_dl, cast_d, cast_du, cast_du2, cast_ipiv, cast_b) = tuple( inp.type.numpy_dtype != (np.int32 if i == 4 else out_dtype) for i, inp in enumerate(node.inputs) ) @@ -417,17 +419,17 @@ def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b): else: return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype) - if must_cast_inputs[0]: + if cast_dl: dl = dl.astype(out_dtype) - if must_cast_inputs[1]: + if cast_d: d = d.astype(out_dtype) - if must_cast_inputs[2]: + if cast_du: du = du.astype(out_dtype) - if must_cast_inputs[3]: + if cast_du2: du2 = du2.astype(out_dtype) - if must_cast_inputs[4]: - ipiv = ipiv.astype("int32") - if must_cast_inputs[5]: + if cast_ipiv: + ipiv = ipiv.astype(np.int32) + if cast_b: b = b.astype(out_dtype) x, _ = _gttrs( dl, diff --git a/tests/link/numba/linalg/solve/test_tridiagonal.py b/tests/link/numba/linalg/solve/test_tridiagonal.py index 6b4f2babd0..1a4529c6c9 100644 --- a/tests/link/numba/linalg/solve/test_tridiagonal.py +++ b/tests/link/numba/linalg/solve/test_tridiagonal.py @@ -112,3 +112,34 @@ def test_tridiagonal_lu_solve(b_ndim, transposed, inplace): assert (res_non_contig == res).all() # b must be copied when not contiguous so it can't be inplaced assert (b_test == b_test_non_contig).all() + + +def test_cast_needed(): + dl = pt.vector("dl", shape=(4,), dtype="int16") + d = pt.vector("d", shape=(5,), dtype="float32") + du = pt.vector("du", shape=(4,), dtype="float64") + b = pt.vector("b", shape=(5,), dtype="float32") + + lu_factor_outs = LUFactorTridiagonal()(dl, d, du) + for i, out in enumerate(lu_factor_outs): + if i == 4: + assert out.type.dtype == "int32" # ipiv is int32 + else: + assert out.type.dtype == "float64" + + lu_solve_out = SolveLUFactorTridiagonal(b_ndim=1, transposed=False)( + *lu_factor_outs, b + ) + assert lu_solve_out.type.dtype == "float64" + + compare_numba_and_py( + [dl, d, du, b], + lu_solve_out, + test_inputs=[ + np.array([1, 2, 3, 4], dtype="int16"), + np.array([1, 2, 3, 4, 5], dtype="float32"), + np.array([1, 2, 3, 4], dtype="float64"), + np.array([1, 2, 3, 4, 5], dtype="float32"), + ], + eval_obj_mode=False, + )