Skip to content

Commit 6416fbb

Browse files
committed
Numba linalg: Fix obj fallback raise -> return
1 parent 695574b commit 6416fbb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def numba_funcify_Solve(op, node, **kwargs):
246246
out_dtype = node.outputs[0].type.numpy_dtype
247247

248248
if A_dtype.kind == "c" or b_dtype.kind == "c":
249-
raise generate_fallback_impl(op, node=node, **kwargs)
249+
return generate_fallback_impl(op, node=node, **kwargs)
250250
must_cast_A = A_dtype != out_dtype
251251
if must_cast_A and config.compiler_verbose:
252252
print("Solve requires casting first input `A`") # noqa: T201
@@ -320,7 +320,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
320320
out_dtype = node.outputs[0].type.numpy_dtype
321321

322322
if A_dtype.kind == "c" or b_dtype.kind == "c":
323-
raise generate_fallback_impl(op, node=node, **kwargs)
323+
return generate_fallback_impl(op, node=node, **kwargs)
324324
must_cast_A = A_dtype != out_dtype
325325
if must_cast_A and config.compiler_verbose:
326326
print("SolveTriangular requires casting first input `A`") # noqa: T201
@@ -371,7 +371,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
371371
out_dtype = node.outputs[0].type.numpy_dtype
372372

373373
if c_dtype.kind == "c" or b_dtype.kind == "c":
374-
raise generate_fallback_impl(op, node=node, **kwargs)
374+
return generate_fallback_impl(op, node=node, **kwargs)
375375
must_cast_c = c_dtype != out_dtype
376376
if must_cast_c and config.compiler_verbose:
377377
print("CholeskySolve requires casting first input `c`") # noqa: T201

0 commit comments

Comments
 (0)