Skip to content

Commit e449e1e

Browse files
committed
Add support for TypedList in numba backend
1 parent 42eec08 commit e449e1e

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytensor.sparse import SparseTensorType
2323
from pytensor.tensor.type import TensorType
2424
from pytensor.tensor.utils import hash_from_ndarray
25+
from pytensor.typed_list import TypedListType
2526

2627

2728
# Disable loud / incorrect warnings from Numba
@@ -122,8 +123,8 @@ def get_numba_type(
122123
return CSRMatrixType(numba_dtype)
123124
if pytensor_type.format == "csc":
124125
return CSCMatrixType(numba_dtype)
125-
126-
raise NotImplementedError()
126+
elif isinstance(pytensor_type, TypedListType):
127+
return numba.types.List(get_numba_type(pytensor_type.ttype))
127128
else:
128129
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
129130

0 commit comments

Comments
 (0)