Skip to content

Commit 42eec08

Browse files
committed
Add __repr__ for Numba and JAX linkers
1 parent f8b4b6c commit 42eec08

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

pytensor/link/jax/linker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,6 @@ def create_thunk_inputs(self, storage_map):
124124
thunk_inputs.append(sinput)
125125

126126
return thunk_inputs
127+
128+
def __repr__(self):
129+
return "JAXLinker()"

pytensor/link/numba/linker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ def jit_compile(self, fn_and_cache):
1919

2020
def create_thunk_inputs(self, storage_map):
2121
return [storage_map[n] for n in self.fgraph.inputs]
22+
23+
def __repr__(self):
24+
return "NumbaLinker()"

0 commit comments

Comments
 (0)