Skip to content

Commit aebfe93

Browse files
committed
Work-around numba pow failure was not robust enough
1 parent 6416fbb commit aebfe93

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,18 +170,17 @@ def {binary_op_name}({input_signature}):
170170
@register_funcify_and_cache_key(Pow)
171171
def numba_funcify_Pow(op, node, **kwargs):
172172
pow_dtype = node.inputs[1].type.dtype
173-
if pow_dtype.startswith("int"):
174-
# Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True
175-
# https://github.com/numba/numba/issues/9554
176173

177-
def pow(x, y):
178-
return x ** np.asarray(y, dtype=np.int64).item()
179-
else:
174+
def pow(x, y):
175+
return x**y
180176

181-
def pow(x, y):
182-
return x**y
177+
# Numba power fails when exponents are discrete integers and fasthmath=True
178+
# https://github.com/numba/numba/issues/9554
179+
fastmath = False if pow_dtype.startswith("int") else None
183180

184-
return numba_basic.numba_njit(pow), scalar_op_cache_key(op)
181+
return numba_basic.numba_njit(pow, fastmath=fastmath), scalar_op_cache_key(
182+
op, cache_version=1
183+
)
185184

186185

187186
@register_funcify_and_cache_key(Add)

tests/link/numba/test_scalar.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,14 @@ def test_Softplus(dtype):
185185

186186
def test_discrete_power():
187187
# Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554
188-
x = pt.scalar("x", dtype="float64")
188+
x1 = pt.scalar("x1", dtype="float64")
189+
x2 = pt.scalar("x2", dtype="int16")
189190
exponent = pt.scalar("exponent", dtype="int8")
190-
out = pt.power(x, exponent)
191+
outs = [pt.power(x1, exponent), pt.power(x2, exponent)]
191192
compare_numba_and_py(
192-
[x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")]
193+
[x1, x2, exponent],
194+
outs,
195+
[np.array(0.5), np.array(2, dtype="int16"), np.array(2, dtype="int8")],
193196
)
194197

195198

0 commit comments

Comments
 (0)