We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 42eec08 commit e449e1eCopy full SHA for e449e1e
pytensor/link/numba/dispatch/basic.py
@@ -22,6 +22,7 @@
22
from pytensor.sparse import SparseTensorType
23
from pytensor.tensor.type import TensorType
24
from pytensor.tensor.utils import hash_from_ndarray
25
+from pytensor.typed_list import TypedListType
26
27
28
# Disable loud / incorrect warnings from Numba
@@ -122,8 +123,8 @@ def get_numba_type(
122
123
return CSRMatrixType(numba_dtype)
124
if pytensor_type.format == "csc":
125
return CSCMatrixType(numba_dtype)
-
126
- raise NotImplementedError()
+ elif isinstance(pytensor_type, TypedListType):
127
+ return numba.types.List(get_numba_type(pytensor_type.ttype))
128
else:
129
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
130
0 commit comments