Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
from numpy.lib.stride_tricks import as_strided

from pytensor import config
from pytensor.graph.op import Op
from pytensor.link.numba.cache import (
compile_numba_function_src,
Expand Down Expand Up @@ -608,45 +609,62 @@ def numba_funcify_Dot(op, node, **kwargs):
x, y = node.inputs
[out] = node.outputs

x_dtype = x.type.dtype
y_dtype = y.type.dtype
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
out_dtype = out.type.dtype
x_dtype = x.type.numpy_dtype
y_dtype = y.type.numpy_dtype

if x_dtype == dot_dtype and y_dtype == dot_dtype:
numba_dot_dtype = out_dtype = out.type.numpy_dtype
if out_dtype.kind not in "fc":
# Numba alawys returns non-integral outputs, we need to cast to float
Copy link

Copilot AI Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "alawys" should be "always".

Suggested change
# Numba alawys returns non-integral outputs, we need to cast to float
# Numba always returns non-integral outputs, we need to cast to float

Copilot uses AI. Check for mistakes.
numba_dot_dtype = np.dtype(
f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
)

if config.compiler_verbose and not (
x_dtype == y_dtype == out_dtype == numba_dot_dtype
):
print( # noqa: T201
"Numba Dot requires a type casting of inputs and/or output: "
f"{x_dtype=}, {y_dtype=}, {out_dtype=}, {numba_dot_dtype=}"
)

if x_dtype == numba_dot_dtype and y_dtype == numba_dot_dtype:

@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y))

elif x_dtype == dot_dtype and y_dtype != dot_dtype:
elif x_dtype == numba_dot_dtype and y_dtype != numba_dot_dtype:

@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y.astype(dot_dtype)))
return np.asarray(np.dot(x, y.astype(numba_dot_dtype)))

elif x_dtype != dot_dtype and y_dtype == dot_dtype:
elif x_dtype != numba_dot_dtype and y_dtype == numba_dot_dtype:

@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y))
return np.asarray(np.dot(x.astype(numba_dot_dtype), y))

else:

@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
return np.asarray(
np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype))
)

cache_version = 1

if out_dtype == dot_dtype:
return dot
if out_dtype == numba_dot_dtype:
return dot, cache_version

else:

@numba_basic.numba_njit
def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype)

return dot_with_cast
return dot_with_cast, cache_version


@register_funcify_default_op_cache_key(BatchedDot)
Expand Down
19 changes: 19 additions & 0 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,25 @@ def test_dimshuffle(self, c_contiguous, benchmark):
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
),
# Viewing the array with 2 last dimensions as complex128 means
# the first entry will be real part and the second entry the imaginary part
(
(
pt.matrix(dtype="complex128"),
rng.random(size=(5, 4, 2)).view("complex128").squeeze(-1),
),
(
pt.matrix(dtype="complex128"),
rng.random(size=(4, 3, 2)).view("complex128").squeeze(-1),
),
),
(
(pt.matrix(dtype="int64"), rng.random(size=(5, 4)).astype("int64")),
(
pt.matrix(dtype="complex128"),
rng.random(size=(4, 3, 2)).view("complex128").squeeze(-1),
),
),
],
)
def test_Dot(x, y):
Expand Down
Loading