Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,10 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
overwrite_du = op.overwrite_du
out_dtype = node.outputs[1].type.numpy_dtype

must_cast_inputs = tuple(inp.type.numpy_dtype != out_dtype for inp in node.inputs)
if any(must_cast_inputs) and config.compiler_verbose:
cast_inputs = (cast_dl, cast_d, cast_du) = tuple(
inp.type.numpy_dtype != out_dtype for inp in node.inputs
)
if any(cast_inputs) and config.compiler_verbose:
print("LUFactorTridiagonal requires casting at least one input") # noqa: T201

@numba_basic.numba_njit(cache=False)
Expand All @@ -371,11 +373,11 @@ def lu_factor_tridiagonal(dl, d, du):
np.zeros(d.shape, dtype="int32"),
)

if must_cast_inputs[0]:
if cast_d:
d = d.astype(out_dtype)
if must_cast_inputs[1]:
if cast_dl:
dl = dl.astype(out_dtype)
if must_cast_inputs[2]:
if cast_du:
du = du.astype(out_dtype)
dl, d, du, du2, ipiv, _ = _gttrf(
dl,
Expand All @@ -402,7 +404,7 @@ def numba_funcify_SolveLUFactorTridiagonal(
overwrite_b = op.overwrite_b
transposed = op.transposed

must_cast_inputs = tuple(
must_cast_inputs = (cast_dl, cast_d, cast_du, cast_du2, cast_ipiv, cast_b) = tuple(
inp.type.numpy_dtype != (np.int32 if i == 4 else out_dtype)
for i, inp in enumerate(node.inputs)
)
Expand All @@ -417,17 +419,17 @@ def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
else:
return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype)

if must_cast_inputs[0]:
if cast_dl:
dl = dl.astype(out_dtype)
if must_cast_inputs[1]:
if cast_d:
d = d.astype(out_dtype)
if must_cast_inputs[2]:
if cast_du:
du = du.astype(out_dtype)
if must_cast_inputs[3]:
if cast_du2:
du2 = du2.astype(out_dtype)
if must_cast_inputs[4]:
ipiv = ipiv.astype("int32")
if must_cast_inputs[5]:
if cast_ipiv:
ipiv = ipiv.astype(np.int32)
if cast_b:
b = b.astype(out_dtype)
x, _ = _gttrs(
dl,
Expand Down
17 changes: 8 additions & 9 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,17 @@ def {binary_op_name}({input_signature}):
@register_funcify_and_cache_key(Pow)
def numba_funcify_Pow(op, node, **kwargs):
pow_dtype = node.inputs[1].type.dtype
if pow_dtype.startswith("int"):
# Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True
# https://github.com/numba/numba/issues/9554

def pow(x, y):
return x ** np.asarray(y, dtype=np.int64).item()
else:
def pow(x, y):
return x**y

def pow(x, y):
return x**y
# Numba power fails when exponents are discrete integers and fasthmath=True
# https://github.com/numba/numba/issues/9554
fastmath = False if np.dtype(pow_dtype).kind in "ibu" else None

return numba_basic.numba_njit(pow), scalar_op_cache_key(op)
return numba_basic.numba_njit(pow, fastmath=fastmath), scalar_op_cache_key(
op, cache_version=1
)


@register_funcify_and_cache_key(Add)
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def numba_funcify_Solve(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype

if A_dtype.kind == "c" or b_dtype.kind == "c":
raise generate_fallback_impl(op, node=node, **kwargs)
return generate_fallback_impl(op, node=node, **kwargs)
must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose:
print("Solve requires casting first input `A`") # noqa: T201
Expand Down Expand Up @@ -320,7 +320,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype

if A_dtype.kind == "c" or b_dtype.kind == "c":
raise generate_fallback_impl(op, node=node, **kwargs)
return generate_fallback_impl(op, node=node, **kwargs)
must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose:
print("SolveTriangular requires casting first input `A`") # noqa: T201
Expand Down Expand Up @@ -371,7 +371,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype

if c_dtype.kind == "c" or b_dtype.kind == "c":
raise generate_fallback_impl(op, node=node, **kwargs)
return generate_fallback_impl(op, node=node, **kwargs)
must_cast_c = c_dtype != out_dtype
if must_cast_c and config.compiler_verbose:
print("CholeskySolve requires casting first input `c`") # noqa: T201
Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/_linalg/solve/tridiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,8 @@ def make_node(self, dl, d, du, du2, ipiv, b):
n = nb

dummy_arrays = [
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, b)
]
# Seems to always be float64?
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
if self.b_ndim == 1:
output_shape = (n,)
Expand Down
7 changes: 6 additions & 1 deletion pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,12 @@ def make_node(self, x):
p_indices = tensor(shape=(x.type.shape[0],), dtype="int32")
return Apply(self, inputs=[x], outputs=[p_indices, L, U])

P = tensor(shape=x.type.shape, dtype=out_dtype)
if out_dtype.startswith("complex"):
P_dtype = "float64" if out_dtype == "complex128" else "float32"
else:
P_dtype = out_dtype

P = tensor(shape=x.type.shape, dtype=P_dtype)
return Apply(self, inputs=[x], outputs=[P, L, U])

def perform(self, node, inputs, outputs):
Expand Down
31 changes: 31 additions & 0 deletions tests/link/numba/linalg/solve/test_tridiagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,34 @@ def test_tridiagonal_lu_solve(b_ndim, transposed, inplace):
assert (res_non_contig == res).all()
# b must be copied when not contiguous so it can't be inplaced
assert (b_test == b_test_non_contig).all()


def test_cast_needed():
dl = pt.vector("dl", shape=(4,), dtype="int16")
d = pt.vector("d", shape=(5,), dtype="float32")
du = pt.vector("du", shape=(4,), dtype="float64")
b = pt.vector("b", shape=(5,), dtype="float32")

lu_factor_outs = LUFactorTridiagonal()(dl, d, du)
for i, out in enumerate(lu_factor_outs):
if i == 4:
assert out.type.dtype == "int32" # ipiv is int32
else:
assert out.type.dtype == "float64"

lu_solve_out = SolveLUFactorTridiagonal(b_ndim=1, transposed=False)(
*lu_factor_outs, b
)
assert lu_solve_out.type.dtype == "float64"

compare_numba_and_py(
[dl, d, du, b],
lu_solve_out,
test_inputs=[
np.array([1, 2, 3, 4], dtype="int16"),
np.array([1, 2, 3, 4, 5], dtype="float32"),
np.array([1, 2, 3, 4], dtype="float64"),
np.array([1, 2, 3, 4, 5], dtype="float32"),
],
eval_obj_mode=False,
)
20 changes: 15 additions & 5 deletions tests/link/numba/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,23 @@ def test_Softplus(dtype):
)


def test_discrete_power():
@pytest.mark.parametrize(
"test_base",
[np.bool(True), np.int16(3), np.uint16(3), np.float32(0.5), np.float64(0.5)],
)
@pytest.mark.parametrize(
"test_exponent",
[np.bool(True), np.int16(2), np.uint16(2), np.float32(2.0), np.float64(2.0)],
)
def test_power_fastmath_bug(test_base, test_exponent):
# Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554
x = pt.scalar("x", dtype="float64")
exponent = pt.scalar("exponent", dtype="int8")
out = pt.power(x, exponent)
base = pt.scalar("base", dtype=test_base.dtype)
exponent = pt.scalar("exponent", dtype=test_exponent.dtype)
out = pt.power(base, exponent)
compare_numba_and_py(
[x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")]
[base, exponent],
[out],
[test_base, test_exponent],
)


Expand Down
6 changes: 4 additions & 2 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,16 +648,18 @@ def test_lu_decomposition(
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"

A = tensor("A", shape=shape, dtype=dtype)
out = lu(A, permute_l=permute_l, p_indices=p_indices)
pt_out = lu(A, permute_l=permute_l, p_indices=p_indices)

f = function([A], out)
f = function([A], pt_out)

rng = np.random.default_rng(utt.fetch_seed())
x = rng.normal(size=shape).astype(config.floatX)
if complex:
x = x + 1j * rng.normal(size=shape).astype(config.floatX)

out = f(x)
for numerical_out, symbolic_out in zip(out, pt_out):
assert numerical_out.dtype == symbolic_out.type.dtype

if permute_l:
PL, U = out
Expand Down
Loading