Skip to content

Commit 50bd652

Browse files
committed
Tridiagonal Solve: Fix dtype inference
Scipy helper doesn't have special handling for ipiv int32 variable, and always assumes it must be cast to a float64
1 parent 54a6158 commit 50bd652

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

pytensor/tensor/_linalg/solve/tridiagonal.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,8 @@ def make_node(self, dl, d, du, du2, ipiv, b):
146146
n = nb
147147

148148
dummy_arrays = [
149-
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
149+
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, b)
150150
]
151-
# Seems to always be float64?
152151
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
153152
if self.b_ndim == 1:
154153
output_shape = (n,)

0 commit comments

Comments
 (0)