Skip to content

Commit 9de97c6

Browse files
committed
Simplify Numba implementation of Alloc
1 parent 6a80058 commit 9de97c6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def numba_funcify_Alloc(op, node, **kwargs):
6767
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
6868
shapes_to_items_src = indent(
6969
"\n".join(
70-
f"{item_name} = to_scalar({shape_name})"
70+
f"{item_name} = {shape_name}.item()"
7171
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
7272
),
7373
" " * 4,
@@ -83,7 +83,7 @@ def numba_funcify_Alloc(op, node, **kwargs):
8383

8484
alloc_def_src = f"""
8585
def alloc(val, {", ".join(shape_var_names)}):
86-
val_np = np.asarray(val)
86+
val_np = val
8787
{shapes_to_items_src}
8888
scalar_shape = {create_tuple_string(shape_var_item_names)}
8989
{check_runtime_broadcast_src}

0 commit comments

Comments
 (0)