Skip to content

Commit 6f911dc

Browse files
committed
Eig impls: Keep complex outputs in Blockwise and Numba
1 parent eb13bda commit 6f911dc

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ def numba_funcify_Eig(op, node, **kwargs):
8181

8282
@numba_basic.numba_njit
8383
def eig(x):
84-
return np.linalg.eig(inputs_cast(x))
84+
w, v = np.linalg.eig(inputs_cast(x))
85+
return w.astype(w_dtype), v.astype(w_dtype)
8586

86-
return eig
87+
cache_version = 1
88+
return eig, cache_version
8789

8890

8991
@register_funcify_default_op_cache_key(Eigh)

pytensor/tensor/nlinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ class Eig(Op):
325325
"""
326326

327327
__props__: tuple[str, ...] = ()
328-
gufunc_spec = ("numpy.linalg.eig", 1, 2)
328+
# Can't use numpy directly in Blockwise, because of the dynamic dtype
329+
# gufunc_spec = ("numpy.linalg.eig", 1, 2)
329330
gufunc_signature = "(m,m)->(m),(m,m)"
330331

331332
def make_node(self, x):

tests/tensor/test_blockwise.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
vector,
2525
)
2626
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
27-
from pytensor.tensor.nlinalg import MatrixInverse
27+
from pytensor.tensor.nlinalg import MatrixInverse, eig
2828
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
2929
from pytensor.tensor.signal import convolve1d
3030
from pytensor.tensor.slinalg import (
@@ -763,3 +763,22 @@ def perform(self, node, inputs, outputs):
763763
add_supervisor_to_fgraph(fgraph, [In(inp, mutable=True) for inp in fgraph.inputs])
764764
rewrite_graph(fgraph, include=("inplace",))
765765
assert fgraph.outputs[0].owner.op.destroy_map == {1: [1]}
766+
767+
768+
def test_eig_blockwise():
769+
x = tensor("x", shape=(2, 3, 3), dtype="float64")
770+
eigen_values, eigen_vectors = eig(x)
771+
assert eigen_values.dtype == "complex128"
772+
assert eigen_vectors.dtype == "complex128"
773+
fn = function([x], [eigen_values, eigen_vectors])
774+
eigen_values_res, eigen_vectors_res = fn(np.full((2, 3, 3), np.eye(3)))
775+
np.testing.assert_allclose(
776+
eigen_values_res,
777+
np.ones((2, 3), dtype="complex128"),
778+
strict=True,
779+
)
780+
np.testing.assert_allclose(
781+
eigen_vectors_res,
782+
np.full((2, 3, 3), np.eye(3), dtype="complex128"),
783+
strict=True,
784+
)

0 commit comments

Comments
 (0)