Skip to content

Commit 27aea8f

Browse files
committed
Fix LU P out dtype for complex input
1 parent 6416fbb commit 27aea8f

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

pytensor/tensor/slinalg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,12 @@ def make_node(self, x):
508508
p_indices = tensor(shape=(x.type.shape[0],), dtype="int32")
509509
return Apply(self, inputs=[x], outputs=[p_indices, L, U])
510510

511-
P = tensor(shape=x.type.shape, dtype=out_dtype)
511+
if out_dtype.startswith("complex"):
512+
P_dtype = "float64" if out_dtype == "complex128" else "float32"
513+
else:
514+
P_dtype = out_dtype
515+
516+
P = tensor(shape=x.type.shape, dtype=P_dtype)
512517
return Apply(self, inputs=[x], outputs=[P, L, U])
513518

514519
def perform(self, node, inputs, outputs):

tests/tensor/test_slinalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,16 +648,18 @@ def test_lu_decomposition(
648648
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
649649

650650
A = tensor("A", shape=shape, dtype=dtype)
651-
out = lu(A, permute_l=permute_l, p_indices=p_indices)
651+
pt_out = lu(A, permute_l=permute_l, p_indices=p_indices)
652652

653-
f = function([A], out)
653+
f = function([A], pt_out)
654654

655655
rng = np.random.default_rng(utt.fetch_seed())
656656
x = rng.normal(size=shape).astype(config.floatX)
657657
if complex:
658658
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
659659

660660
out = f(x)
661+
for numerical_out, symbolic_out in zip(out, pt_out):
662+
assert numerical_out.dtype == symbolic_out.type.dtype
661663

662664
if permute_l:
663665
PL, U = out

0 commit comments

Comments
 (0)