Skip to content

Commit 0768858

Browse files
committed
Work-around numba pow failure was not robust enough
1 parent 6416fbb commit 0768858

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,18 +170,17 @@ def {binary_op_name}({input_signature}):
170170
@register_funcify_and_cache_key(Pow)
171171
def numba_funcify_Pow(op, node, **kwargs):
172172
pow_dtype = node.inputs[1].type.dtype
173-
if pow_dtype.startswith("int"):
174-
# Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True
175-
# https://github.com/numba/numba/issues/9554
176173

177-
def pow(x, y):
178-
return x ** np.asarray(y, dtype=np.int64).item()
179-
else:
174+
def pow(x, y):
175+
return x**y
180176

181-
def pow(x, y):
182-
return x**y
177+
# Numba power fails when exponents are discrete integers and fasthmath=True
178+
# https://github.com/numba/numba/issues/9554
179+
fastmath = False if np.dtype(pow_dtype).kind in "ibu" else None
183180

184-
return numba_basic.numba_njit(pow), scalar_op_cache_key(op)
181+
return numba_basic.numba_njit(pow, fastmath=fastmath), scalar_op_cache_key(
182+
op, cache_version=1
183+
)
185184

186185

187186
@register_funcify_and_cache_key(Add)

tests/link/numba/test_scalar.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,23 @@ def test_Softplus(dtype):
183183
)
184184

185185

186-
def test_discrete_power():
186+
@pytest.mark.parametrize(
187+
"test_base",
188+
[np.bool(True), np.int16(3), np.uint16(3), np.float32(0.5), np.float64(0.5)],
189+
)
190+
@pytest.mark.parametrize(
191+
"test_exponent",
192+
[np.bool(True), np.int16(2), np.uint16(2), np.float32(2.0), np.float64(2.0)],
193+
)
194+
def test_power_fastmath_bug(test_base, test_exponent):
187195
# Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554
188-
x = pt.scalar("x", dtype="float64")
189-
exponent = pt.scalar("exponent", dtype="int8")
190-
out = pt.power(x, exponent)
196+
base = pt.scalar("base", dtype=test_base.dtype)
197+
exponent = pt.scalar("exponent", dtype=test_exponent.dtype)
198+
out = pt.power(base, exponent)
191199
compare_numba_and_py(
192-
[x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")]
200+
[base, exponent],
201+
[out],
202+
[test_base, test_exponent],
193203
)
194204

195205

0 commit comments

Comments
 (0)