Skip to content

Commit 2caab13

Browse files
committed
Add support for TypedList in numba backend
1 parent 3602e49 commit 2caab13

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytensor.sparse import SparseTensorType
2222
from pytensor.tensor.type import TensorType
2323
from pytensor.tensor.utils import hash_from_ndarray
24+
from pytensor.typed_list import TypedListType
2425

2526

2627
def numba_njit(
@@ -99,8 +100,8 @@ def get_numba_type(
99100
return CSRMatrixType(numba_dtype)
100101
if pytensor_type.format == "csc":
101102
return CSCMatrixType(numba_dtype)
102-
103-
raise NotImplementedError()
103+
elif isinstance(pytensor_type, TypedListType):
104+
return numba.types.List(get_numba_type(pytensor_type.ttype))
104105
else:
105106
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
106107

0 commit comments

Comments
 (0)