Skip to content

Commit 08b795f

Browse files
xalelaxdionhaefner
andauthored
fix: Issue #74 (#75)
#### Relevant issue or PR Fix #74 #### Description of changes Check that x in is_leaf(x) is a dict before calling methods like .keys() #### Testing done Local checks, added e2e test #### License - [ x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [ x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Alessandro Angioi <alessandro.angioi@simulation.science> --------- Co-authored-by: Dion Häfner <dion.haefner@simulation.science>
1 parent a269a43 commit 08b795f

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

examples/vectoradd_jax/tesseract_api.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,29 @@ def apply(inputs: InputSchema) -> OutputSchema:
8989

9090
def abstract_eval(abstract_inputs):
9191
"""Calculate output shape of apply from the shape of its inputs."""
92+
is_shapedtye_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
93+
is_shapedtye_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)
94+
9295
jaxified_inputs = jax.tree.map(
93-
lambda x: jax.ShapeDtypeStruct(**x),
96+
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtye_dict(x) else x,
9497
abstract_inputs.model_dump(),
95-
is_leaf=lambda x: (x.keys() == {"shape", "dtype"}),
98+
is_leaf=is_shapedtye_dict,
99+
)
100+
dynamic_inputs, static_inputs = eqx.partition(
101+
jaxified_inputs, filter_spec=is_shapedtye_struct
96102
)
97-
jax_shapes = jax.eval_shape(apply_jit, jaxified_inputs)
103+
104+
def wrapped_apply(dynamic_inputs):
105+
inputs = eqx.combine(static_inputs, dynamic_inputs)
106+
return apply_jit(inputs)
107+
108+
jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
98109
return jax.tree.map(
99-
lambda sd: {"shape": sd.shape, "dtype": str(sd.dtype)}, jax_shapes
110+
lambda x: {"shape": x.shape, "dtype": str(x.dtype)}
111+
if is_shapedtye_struct(x)
112+
else x,
113+
jax_shapes,
114+
is_leaf=is_shapedtye_struct,
100115
)
101116

102117

tesseract_core/sdk/templates/jax/tesseract_api.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,29 @@ def vector_jacobian_product(
107107

108108
def abstract_eval(abstract_inputs):
109109
"""Calculate output shape of apply from the shape of its inputs."""
110+
is_shapedtye_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
111+
is_shapedtye_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)
112+
110113
jaxified_inputs = jax.tree.map(
111-
lambda x: jax.ShapeDtypeStruct(**x),
114+
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtye_dict(x) else x,
112115
abstract_inputs.model_dump(),
113-
is_leaf=lambda x: (x.keys() == {"shape", "dtype"}),
116+
is_leaf=is_shapedtye_dict,
117+
)
118+
dynamic_inputs, static_inputs = eqx.partition(
119+
jaxified_inputs, filter_spec=is_shapedtye_struct
114120
)
115-
jax_shapes = jax.eval_shape(apply_jit, jaxified_inputs)
121+
122+
def wrapped_apply(dynamic_inputs):
123+
inputs = eqx.combine(static_inputs, dynamic_inputs)
124+
return apply_jit(inputs)
125+
126+
jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
116127
return jax.tree.map(
117-
lambda sd: {"shape": sd.shape, "dtype": str(sd.dtype)}, jax_shapes
128+
lambda x: (
129+
{"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtye_struct(x) else x
130+
),
131+
jax_shapes,
132+
is_leaf=is_shapedtye_struct,
118133
)
119134

120135

tests/endtoend_tests/test_examples.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,21 @@ class Config:
202202
},
203203
output_contains_array=np.array([7.0, 11.0, 15.0], dtype="float32"),
204204
),
205+
SampleRequest(
206+
endpoint="abstract_eval",
207+
payload={
208+
"inputs": {
209+
"a": {
210+
"v": {"shape": [3], "dtype": "float32"},
211+
"s": {"shape": [], "dtype": "float32"},
212+
},
213+
"b": {
214+
"v": {"shape": [3], "dtype": "float32"},
215+
"s": {"shape": [], "dtype": "float32"},
216+
},
217+
}
218+
},
219+
),
205220
SampleRequest(
206221
endpoint="apply",
207222
payload={"inputs": {}},

0 commit comments

Comments
 (0)