Skip to content

Commit 02bf0a0

Browse files
committed
Add support for TypedList in numba backend
1 parent 643209c commit 02bf0a0

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pytensor.tensor.slinalg import Solve
4040
from pytensor.tensor.type import TensorType
4141
from pytensor.tensor.type_other import MakeSlice, NoneConst
42+
from pytensor.typed_list import TypedListType
4243

4344

4445
def global_numba_func(func):
@@ -121,6 +122,8 @@ def get_numba_type(
121122
return CSCMatrixType(numba_dtype)
122123

123124
raise NotImplementedError()
125+
elif isinstance(pytensor_type, TypedListType):
126+
return numba.types.List(get_numba_type(pytensor_type.ttype))
124127
else:
125128
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
126129

0 commit comments

Comments
 (0)