Skip to content

Commit d558b4c

Browse files
committed
Fix Numba tridiagonal casting
Numba doesn't infer the right type based on the static tuple, but does so with separate boolean variables
1 parent 50bd652 commit d558b4c

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,10 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
356356
overwrite_du = op.overwrite_du
357357
out_dtype = node.outputs[1].type.numpy_dtype
358358

359-
must_cast_inputs = tuple(inp.type.numpy_dtype != out_dtype for inp in node.inputs)
360-
if any(must_cast_inputs) and config.compiler_verbose:
359+
cast_inputs = (cast_dl, cast_d, cast_du) = tuple(
360+
inp.type.numpy_dtype != out_dtype for inp in node.inputs
361+
)
362+
if any(cast_inputs) and config.compiler_verbose:
361363
print("LUFactorTridiagonal requires casting at least one input") # noqa: T201
362364

363365
@numba_basic.numba_njit(cache=False)
@@ -371,11 +373,11 @@ def lu_factor_tridiagonal(dl, d, du):
371373
np.zeros(d.shape, dtype="int32"),
372374
)
373375

374-
if must_cast_inputs[0]:
376+
if cast_d:
375377
d = d.astype(out_dtype)
376-
if must_cast_inputs[1]:
378+
if cast_dl:
377379
dl = dl.astype(out_dtype)
378-
if must_cast_inputs[2]:
380+
if cast_du:
379381
du = du.astype(out_dtype)
380382
dl, d, du, du2, ipiv, _ = _gttrf(
381383
dl,
@@ -402,7 +404,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
402404
overwrite_b = op.overwrite_b
403405
transposed = op.transposed
404406

405-
must_cast_inputs = tuple(
407+
must_cast_inputs = (cast_dl, cast_d, cast_du, cast_du2, cast_ipiv, cast_b) = tuple(
406408
inp.type.numpy_dtype != (np.int32 if i == 4 else out_dtype)
407409
for i, inp in enumerate(node.inputs)
408410
)
@@ -417,17 +419,17 @@ def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
417419
else:
418420
return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype)
419421

420-
if must_cast_inputs[0]:
422+
if cast_dl:
421423
dl = dl.astype(out_dtype)
422-
if must_cast_inputs[1]:
424+
if cast_d:
423425
d = d.astype(out_dtype)
424-
if must_cast_inputs[2]:
426+
if cast_du:
425427
du = du.astype(out_dtype)
426-
if must_cast_inputs[3]:
428+
if cast_du2:
427429
du2 = du2.astype(out_dtype)
428-
if must_cast_inputs[4]:
429-
ipiv = ipiv.astype("int32")
430-
if must_cast_inputs[5]:
430+
if cast_ipiv:
431+
ipiv = ipiv.astype(np.int32)
432+
if cast_b:
431433
b = b.astype(out_dtype)
432434
x, _ = _gttrs(
433435
dl,

tests/link/numba/linalg/solve/test_tridiagonal.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,34 @@ def test_tridiagonal_lu_solve(b_ndim, transposed, inplace):
112112
assert (res_non_contig == res).all()
113113
# b must be copied when not contiguous so it can't be inplaced
114114
assert (b_test == b_test_non_contig).all()
115+
116+
117+
def test_cast_needed():
118+
dl = pt.vector("dl", shape=(4,), dtype="int16")
119+
d = pt.vector("d", shape=(5,), dtype="float32")
120+
du = pt.vector("du", shape=(4,), dtype="float64")
121+
b = pt.vector("b", shape=(5,), dtype="float32")
122+
123+
lu_factor_outs = LUFactorTridiagonal()(dl, d, du)
124+
for i, out in enumerate(lu_factor_outs):
125+
if i == 4:
126+
assert out.type.dtype == "int32" # ipiv is int32
127+
else:
128+
assert out.type.dtype == "float64"
129+
130+
lu_solve_out = SolveLUFactorTridiagonal(b_ndim=1, transposed=False)(
131+
*lu_factor_outs, b
132+
)
133+
assert lu_solve_out.type.dtype == "float64"
134+
135+
compare_numba_and_py(
136+
[dl, d, du, b],
137+
lu_solve_out,
138+
test_inputs=[
139+
np.array([1, 2, 3, 4], dtype="int16"),
140+
np.array([1, 2, 3, 4, 5], dtype="float32"),
141+
np.array([1, 2, 3, 4], dtype="float64"),
142+
np.array([1, 2, 3, 4, 5], dtype="float32"),
143+
],
144+
eval_obj_mode=False,
145+
)

0 commit comments

Comments
 (0)