From 73749e2c807538736a68e30ef01b32754e128fc4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 21 Nov 2025 16:44:00 +0100 Subject: [PATCH] Numba Dot: Handle complex inputs --- pytensor/link/numba/dispatch/elemwise.py | 44 +++++++++++++++++------- tests/link/numba/test_elemwise.py | 19 ++++++++++ 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9b2c9f514c..53cacf65c4 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -8,6 +8,7 @@ from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple from numpy.lib.stride_tricks import as_strided +from pytensor import config from pytensor.graph.op import Op from pytensor.link.numba.cache import ( compile_numba_function_src, @@ -608,37 +609,54 @@ def numba_funcify_Dot(op, node, **kwargs): x, y = node.inputs [out] = node.outputs - x_dtype = x.type.dtype - y_dtype = y.type.dtype - dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" - out_dtype = out.type.dtype + x_dtype = x.type.numpy_dtype + y_dtype = y.type.numpy_dtype - if x_dtype == dot_dtype and y_dtype == dot_dtype: + numba_dot_dtype = out_dtype = out.type.numpy_dtype + if out_dtype.kind not in "fc": + # Numba alawys returns non-integral outputs, we need to cast to float + numba_dot_dtype = np.dtype( + f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" + ) + + if config.compiler_verbose and not ( + x_dtype == y_dtype == out_dtype == numba_dot_dtype + ): + print( # noqa: T201 + "Numba Dot requires a type casting of inputs and/or output: " + f"{x_dtype=}, {y_dtype=}, {out_dtype=}, {numba_dot_dtype=}" + ) + + if x_dtype == numba_dot_dtype and y_dtype == numba_dot_dtype: @numba_basic.numba_njit def dot(x, y): return np.asarray(np.dot(x, y)) - elif x_dtype == dot_dtype and y_dtype != dot_dtype: + elif x_dtype == numba_dot_dtype and y_dtype != numba_dot_dtype: @numba_basic.numba_njit def dot(x, y): - return np.asarray(np.dot(x, y.astype(dot_dtype))) + return np.asarray(np.dot(x, y.astype(numba_dot_dtype))) - elif x_dtype != dot_dtype and y_dtype == dot_dtype: + elif x_dtype != numba_dot_dtype and y_dtype == numba_dot_dtype: @numba_basic.numba_njit def dot(x, y): - return np.asarray(np.dot(x.astype(dot_dtype), y)) + return np.asarray(np.dot(x.astype(numba_dot_dtype), y)) else: @numba_basic.numba_njit def dot(x, y): - return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) + return np.asarray( + np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype)) + ) + + cache_version = 1 - if out_dtype == dot_dtype: - return dot + if out_dtype == numba_dot_dtype: + return dot, cache_version else: @@ -646,7 +664,7 @@ def dot(x, y): def dot_with_cast(x, y): return dot(x, y).astype(out_dtype) - return dot_with_cast + return dot_with_cast, cache_version @register_funcify_default_op_cache_key(BatchedDot) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 614d5a092e..ea7892fd0b 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -718,6 +718,25 @@ def test_dimshuffle(self, c_contiguous, benchmark): (pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)), (pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)), ), + # Viewing the array with 2 last dimensions as complex128 means + # the first entry will be real part and the second entry the imaginary part + ( + ( + pt.matrix(dtype="complex128"), + rng.random(size=(5, 4, 2)).view("complex128").squeeze(-1), + ), + ( + pt.matrix(dtype="complex128"), + rng.random(size=(4, 3, 2)).view("complex128").squeeze(-1), + ), + ), + ( + (pt.matrix(dtype="int64"), rng.random(size=(5, 4)).astype("int64")), + ( + pt.matrix(dtype="complex128"), + rng.random(size=(4, 3, 2)).view("complex128").squeeze(-1), + ), + ), ], ) def test_Dot(x, y):