Skip to content

Commit 590155c

Browse files
authored
feat: Validate endpoint argument names before building (#95)
BREAKING CHANGE: Custom argument names for any endpoint are no longer supported. All endpoint argument names and order must exactly match standard. #### Relevant issue or PR - #84 #### Description of changes We already validate whether endpoint signature have the correct number of arguments in `sdk.api_parse`. In addition to suppressing unnecessary traceback and improving the error message when this is violated, this PR provides validation of the endpoint signature argument _names_ and order before building. This guards against the possibility that arguments may have been misordered leading to potential silent fails and guarantees that natural dev assumptions on the endpoint signatures hold. Note it is not possible to provide this with `tesseract-runtime check` (or other endpoints) as the runtime does not have access to the sdk. I'm comfortable with this as we can add runtime errors in a different way by reviving #80. Alternatively, we move `api_parse` to `runtime` allowing it to be accessed by `tesseract-runtime check`. #### Missing argument error message ``` [-] Error validating tesseract_api.py: jacobian_vector_product must have 4 arguments: inputs, jvp_inputs, jvp_outputs, tangent_vector. However, tesseract_api.py specifies 3 arguments: inputs, jvp_outputs, tangent_vector. [x] Aborting ``` #### Wrong argument name error message ``` [-] Error validating tesseract_api.py: The second argument of jacobian_vector_product must be named jvp_inputs, but tesseract_api.py has named it jvp_input. [x] Aborting ``` #### Testing done - [x] Removed an input, error message as above - [x] Misspelled an input, error message as above - [x] CI passes #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Jonathan Brodrick <jonathan.brodrick@simulation.science>
1 parent 2ebafd0 commit 590155c

File tree

6 files changed

+65
-20
lines changed

6 files changed

+65
-20
lines changed

examples/vectoradd_torch/tesseract_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,14 @@ def jacobian_vector_product(
150150
inputs: InputSchema,
151151
jvp_inputs: set[str],
152152
jvp_outputs: set[str],
153-
tangent: dict[str, Any],
153+
tangent_vector: dict[str, Any],
154154
):
155-
# Make ordering of tangent identical to jvp_inputs
156-
tangent = {key: tangent[key] for key in jvp_inputs}
155+
# Make ordering of tangent_vector identical to jvp_inputs
156+
tangent_vector = {key: tangent_vector[key] for key in jvp_inputs}
157157

158158
# convert all numbers and arrays to torch tensors
159159
tensor_inputs = tree_map(to_tensor, inputs.model_dump())
160-
pos_tangent = tree_map(to_tensor, tangent).values()
160+
pos_tangent = tree_map(to_tensor, tangent_vector).values()
161161

162162
# flatten the dictionaries such that they can be accessed by paths
163163
pos_inputs = flatten_with_paths(tensor_inputs, jvp_inputs).values()

tesseract_core/sdk/api_parse.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,44 @@ class _ApiObject(NamedTuple):
2121
name: str
2222
expected_type: type
2323
num_args: int | None = None
24+
arg_names: tuple[str, ...] | None = None
2425
optional: bool = False
2526

2627

28+
ORDINALS = ["first", "second", "third", "fourth", "fifth", "sixth", "seventh", "eighth"]
29+
2730
EXPECTED_OBJECTS = (
28-
_ApiObject("apply", ast.FunctionDef, 1),
31+
_ApiObject("apply", ast.FunctionDef, 1, arg_names=("inputs",)),
2932
_ApiObject("InputSchema", ast.ClassDef),
3033
_ApiObject("OutputSchema", ast.ClassDef),
31-
_ApiObject("jacobian", ast.FunctionDef, 3, optional=True),
32-
_ApiObject("jacobian_vector_product", ast.FunctionDef, 4, optional=True),
33-
_ApiObject("vector_jacobian_product", ast.FunctionDef, 4, optional=True),
34-
_ApiObject("abstract_eval", ast.FunctionDef, 1, optional=True),
34+
_ApiObject(
35+
"jacobian",
36+
ast.FunctionDef,
37+
3,
38+
arg_names=("inputs", "jac_inputs", "jac_outputs"),
39+
optional=True,
40+
),
41+
_ApiObject(
42+
"jacobian_vector_product",
43+
ast.FunctionDef,
44+
4,
45+
arg_names=("inputs", "jvp_inputs", "jvp_outputs", "tangent_vector"),
46+
optional=True,
47+
),
48+
_ApiObject(
49+
"vector_jacobian_product",
50+
ast.FunctionDef,
51+
4,
52+
arg_names=("inputs", "vjp_inputs", "vjp_outputs", "cotangent_vector"),
53+
optional=True,
54+
),
55+
_ApiObject(
56+
"abstract_eval",
57+
ast.FunctionDef,
58+
1,
59+
arg_names=("abstract_inputs",),
60+
optional=True,
61+
),
3562
)
3663

3764

@@ -112,8 +139,8 @@ class ValidationError(Exception):
112139
pass
113140

114141

115-
def _get_func_argnums(func: ast.FunctionDef) -> int:
116-
"""Get the number of the arguments of a function node.
142+
def _get_func_argnames(func: ast.FunctionDef) -> tuple[str, ...]:
143+
"""Get the names of the arguments of a function node.
117144
118145
See:
119146
https://docs.python.org/3/library/ast.html#ast.FunctionDef
@@ -128,7 +155,7 @@ def _get_func_argnums(func: ast.FunctionDef) -> int:
128155
raise ValidationError(
129156
f"Function {func.name} must not have positional-only arguments"
130157
)
131-
return len(func_args.args)
158+
return tuple(arg.arg for arg in func_args.args)
132159

133160

134161
def validate_tesseract_api(src_dir: Path) -> None:
@@ -190,8 +217,23 @@ def validate_tesseract_api(src_dir: Path) -> None:
190217
)
191218

192219
if obj.num_args is not None:
193-
if _get_func_argnums(toplevel_objects[obj.name]) != obj.num_args:
194-
raise ValidationError(f"{obj.name} must have {obj.num_args} arguments")
220+
func_argnames = _get_func_argnames(toplevel_objects[obj.name])
221+
func_argnums = len(func_argnames)
222+
if func_argnums != obj.num_args:
223+
raise ValidationError(
224+
f"{obj.name} must have {obj.num_args} arguments: {', '.join(obj.arg_names)}.\n"
225+
f"However, {tesseract_api_location} specifies {func_argnums} "
226+
f"arguments: {', '.join(func_argnames)}."
227+
)
228+
msgs = []
229+
for i in range(obj.num_args):
230+
if func_argnames[i] != obj.arg_names[i]:
231+
msgs.append(
232+
f"The {ORDINALS[i]} argument (argument {i}) of {obj.name} must be named {obj.arg_names[i]}, "
233+
f"but {tesseract_api_location} has named it {func_argnames[i]}."
234+
)
235+
if msgs:
236+
raise ValidationError("\n".join(msgs))
195237

196238
# Check InputSchema and OutputSchema are pydantic BaseModels
197239
for schema in ("InputSchema", "OutputSchema"):

tesseract_core/sdk/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
EXPECTED_OBJECTS,
3434
TesseractBuildConfig,
3535
TesseractConfig,
36+
ValidationError,
3637
get_non_base_fields_in_tesseract_config,
3738
)
3839
from .exceptions import UserError
@@ -307,6 +308,8 @@ def build_image(
307308
raise UserError(f"Input error building Tesseract: {e}") from e
308309
except PermissionError as e:
309310
raise UserError(f"Permission denied: {e}") from e
311+
except ValidationError as e:
312+
raise UserError(f"Error validating tesseract_api.py: {e}") from e
310313

311314
if generate_only:
312315
# output is the path to the build context

tesseract_core/sdk/templates/pytorch/tesseract_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,14 @@ def jacobian_vector_product(
109109
inputs: InputSchema,
110110
jvp_inputs: set[str],
111111
jvp_outputs: set[str],
112-
tangent: dict[str, Any],
112+
tangent_vector: dict[str, Any],
113113
):
114-
# Make ordering of tangent identical to jvp_inputs
115-
tangent = {key: tangent[key] for key in jvp_inputs}
114+
# Make ordering of tangent_vector identical to jvp_inputs
115+
tangent_vector = {key: tangent_vector[key] for key in jvp_inputs}
116116

117117
# convert all numbers and arrays to torch tensors
118118
tensor_inputs = tree_map(to_tensor, inputs.model_dump())
119-
pos_tangent = tree_map(to_tensor, tangent).values()
119+
pos_tangent = tree_map(to_tensor, tangent_vector).values()
120120

121121
# flatten the dictionaries such that they can be accessed by paths
122122
pos_inputs = flatten_with_paths(tensor_inputs, jvp_inputs).values()

tests/runtime_tests/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def vector_jacobian_product(
105105

106106
return out
107107

108-
def abstract_eval(self, inputs):
108+
def abstract_eval(self, abstract_inputs):
109109
return {
110110
"result_seq": [{"shape": (10,), "dtype": "float32"}] * 3,
111111
"result_arr": {"shape": (5,), "dtype": "float32"},

tests/runtime_tests/test_serve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def apply(input: InputSchema) -> OutputSchema:
179179
assert threading.current_thread() == threading.main_thread()
180180
return OutputSchema()
181181
182-
def abstract_eval(inputs: dict) -> dict:
182+
def abstract_eval(abstract_inputs: dict) -> dict:
183183
pass
184184
"""
185185
)

0 commit comments

Comments
 (0)