We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3602e49 commit 2caab13Copy full SHA for 2caab13
pytensor/link/numba/dispatch/basic.py
@@ -21,6 +21,7 @@
21
from pytensor.sparse import SparseTensorType
22
from pytensor.tensor.type import TensorType
23
from pytensor.tensor.utils import hash_from_ndarray
24
+from pytensor.typed_list import TypedListType
25
26
27
def numba_njit(
@@ -99,8 +100,8 @@ def get_numba_type(
99
100
return CSRMatrixType(numba_dtype)
101
if pytensor_type.format == "csc":
102
return CSCMatrixType(numba_dtype)
-
103
- raise NotImplementedError()
+ elif isinstance(pytensor_type, TypedListType):
104
+ return numba.types.List(get_numba_type(pytensor_type.ttype))
105
else:
106
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
107
0 commit comments