Skip to content

Commit 24ea1bc

Browse files
authored
feat: implement check_gradients runtime command (#72)
#### Relevant issue or PR Fixes #18 #### Description of changes Introduce `tesseract-runtime check-gradients` command line tool to compare user-defined gradient endpoints (`jacobian` + jvp / vjp) against a finite difference approximation (computed by repeatedly calling `apply`). Example session: ```bash # default settings tesseract run univariate check-gradients '{"inputs": {}}' Checking gradients for jacobian... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 ✅ Gradient check for jacobian passed ✅ (0 failures / 2 checks) Checking gradients for jacobian_vector_product... ━━━━━━━━━━━━━━━━━ 100% 0:00:00 ✅ Gradient check for jacobian_vector_product passed ✅ (0 failures / 2 checks) Checking gradients for vector_jacobian_product... ━━━━━━━━━━━━━━━━━ 100% 0:00:00 ✅ Gradient check for vector_jacobian_product passed ✅ (0 failures / 2 checks) # eps too small $ tesseract run univariate check-gradients '{"inputs": {}}' --eps 1e-22 [-] Error running Tesseract. Command 'check-gradients '{"inputs": {}}' --eps 1e-22' in image 'univariate' returned non-zero exit status 1: Checking gradients for jacobian... (failures: 1) ━━━━━━━━━━━━━━━━━━ 100% 0:00:00 ⚠️ Gradient check for jacobian failed ⚠️ (1 failures / 2 checks) First 10 failures: Input path: 'x', Output path: 'result', Index: () jacobian value: -2.0 Finite difference value: 0.0 Checking gradients for jacobian_vector_product... (failures: 1) ━━━━ 100% 0:00:… ⚠️ Gradient check for jacobian_vector_product failed ⚠️ (1 failures / 2 checks) First 10 failures: Input path: 'x', Output path: 'result', Index: () jacobian_vector_product value: -2.0 Finite difference value: 0.0 Checking gradients for vector_jacobian_product... (failures: 1) ━━━━ 100% 0:00:… ⚠️ Gradient check for vector_jacobian_product failed ⚠️ (1 failures / 2 checks) First 10 failures: Input path: 'x', Output path: 'result', Index: () vector_jacobian_product value: -2.0 Finite difference value: 0.0 ❌ Some gradient checks failed ❌ [x] Aborting ``` Command `--help`: ```bash Usage: tesseract run univariate check-gradients [OPTIONS] JSON_PAYLOAD Check gradients of endpoints against a finite difference approximation. This is an automated way to check the correctness of the gradients of the different AD endpoints (jacobian, jacobian_vector_product, vector_jacobian_product) of a ``tesseract_api.py`` module. It will sample random indices and compare the gradients computed by the AD endpoints with the finite difference approximation. Warning: Finite differences are not exact and the comparison is done with a tolerance. This means that the check may fail even if the gradients are correct, and vice versa. Options: --endpoints TEXT Endpoints to check gradients for (default: check all). --input-paths TEXT Paths to differentiable inputs to check gradients for (default: check all). --output-paths TEXT Paths to differentiable outputs to check gradients for (default: check all). --eps FLOAT Step size for finite differences. [default: 0.0001] --rtol FLOAT Relative tolerance when comparing finite differences to gradients. [default: 0.1] --max-evals INTEGER Maximum number of evaluations per input. [default: 1000] --max-failures INTEGER Maximum number of failures to report per endpoint. [default: 10] --seed INTEGER Seed for random number generator. If not set, a random seed is used. --show-progress Show progress bar. -h, --help Show this message and exit. ``` #### Testing done CI, added `check-gradients` to e2e test of all examples that have gradients defined + a unit test on the `check_gradients` function. #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Dion Häfner <dion.haefner@simulation.science>
1 parent 9c69143 commit 24ea1bc

File tree

11 files changed

+1057
-31
lines changed

11 files changed

+1057
-31
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
{
2+
"inputs": {
3+
"mesh": {
4+
"n_points": 5,
5+
"n_cells": 2,
6+
"points": [
7+
[
8+
0.0,
9+
123.0,
10+
0.0
11+
],
12+
[
13+
1.0,
14+
0.0,
15+
0.0
16+
],
17+
[
18+
0.0,
19+
1.0,
20+
0.0
21+
],
22+
[
23+
1.0,
24+
1.0,
25+
0.0
26+
],
27+
[
28+
0.5,
29+
0.5,
30+
1.0
31+
]
32+
],
33+
"num_points_per_cell": [
34+
4,
35+
4
36+
],
37+
"cell_connectivity": [
38+
0,
39+
1,
40+
2,
41+
3,
42+
1,
43+
2,
44+
3,
45+
4
46+
],
47+
"cell_data": {
48+
"temperature": [
49+
[
50+
100.0,
51+
105.0
52+
],
53+
[
54+
110.0,
55+
115.0
56+
]
57+
],
58+
"pressure": [
59+
[
60+
1.0,
61+
1.2
62+
],
63+
[
64+
1.1,
65+
1.3
66+
]
67+
]
68+
},
69+
"point_data": {
70+
"displacement": [
71+
[
72+
0.0,
73+
0.1,
74+
0.2
75+
],
76+
[
77+
0.1,
78+
0.0,
79+
0.2
80+
],
81+
[
82+
0.2,
83+
0.1,
84+
0.0
85+
],
86+
[
87+
0.1,
88+
0.2,
89+
0.1
90+
],
91+
[
92+
0.2,
93+
0.1,
94+
0.1
95+
]
96+
],
97+
"velocity": [
98+
[
99+
0.0,
100+
0.0,
101+
0.0
102+
],
103+
[
104+
0.1,
105+
0.0,
106+
0.0
107+
],
108+
[
109+
0.0,
110+
0.1,
111+
0.0
112+
],
113+
[
114+
0.0,
115+
0.0,
116+
0.1
117+
],
118+
[
119+
0.1,
120+
0.1,
121+
0.1
122+
]
123+
]
124+
}
125+
}
126+
}
127+
}

examples/meshstats/tesseract_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ class VolumetricMeshData(BaseModel):
2020
num_points_per_cell: Array[(None,), Float32] # should have length == n_cells
2121
cell_connectivity: Array[(None,), Int32] # length == sum(num_points_per_cell)
2222

23-
cell_data: dict[str, Differentiable[Array[(None, None), Float32]]]
24-
point_data: dict[str, Differentiable[Array[(None, None), Float32]]]
23+
cell_data: dict[str, Array[(None, None), Float32]]
24+
point_data: dict[str, Array[(None, None), Float32]]
2525

2626
@model_validator(mode="after")
2727
def validate_num_points_per_cell(self):

examples/univariate/tesseract_api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import jax
55
from pydantic import BaseModel, Field
66

7-
from tesseract_core.runtime import Differentiable, Float64, ShapeDType
7+
from tesseract_core.runtime import Differentiable, Float32, ShapeDType
88

99

1010
def rosenbrock(x: float, y: float, a: float = 1.0, b: float = 100.0) -> float:
@@ -17,14 +17,14 @@ def rosenbrock(x: float, y: float, a: float = 1.0, b: float = 100.0) -> float:
1717

1818

1919
class InputSchema(BaseModel):
20-
x: Differentiable[Float64] = Field(description="Scalar value x.", default=0.0)
21-
y: Differentiable[Float64] = Field(description="Scalar value y.", default=0.0)
22-
a: Float64 = Field(description="Scalar parameter a.", default=1.0)
23-
b: Float64 = Field(description="Scalar parameter b.", default=100.0)
20+
x: Differentiable[Float32] = Field(description="Scalar value x.", default=0.0)
21+
y: Differentiable[Float32] = Field(description="Scalar value y.", default=0.0)
22+
a: Float32 = Field(description="Scalar parameter a.", default=1.0)
23+
b: Float32 = Field(description="Scalar parameter b.", default=100.0)
2424

2525

2626
class OutputSchema(BaseModel):
27-
result: Differentiable[Float64] = Field(
27+
result: Differentiable[Float32] = Field(
2828
description="Result of Rosenbrock function evaluation."
2929
)
3030

@@ -91,4 +91,4 @@ def vector_jacobian_product(
9191

9292
def abstract_eval(abstract_inputs):
9393
"""Calculate output shape of apply from the shape of its inputs."""
94-
return {"result": ShapeDType(shape=(), dtype="float64")}
94+
return {"result": ShapeDType(shape=(), dtype="Float32")}

examples/vectoradd/tesseract_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class InputSchema(BaseModel):
1111
a: Differentiable[Array[(None,), Float32]] = Field(
12-
description="An arbitrary vector normalized according to [...]"
12+
description="An arbitrary vector."
1313
)
1414
b: Differentiable[Array[(None,), Float32]] = Field(
1515
description="An arbitrary vector. Needs to have the same dimensions as a."

examples/vectoradd_jax/tesseract_api.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ class Vector_and_Scalar(BaseModel):
1919
)
2020
s: Differentiable[Float32] = Field(description="A scalar", default=1.0)
2121

22-
# we lose the ability to use methods such as this when using model_dump
23-
# unless we reconstruct nested models
2422
def scale(self) -> Differentiable[Array[(None,), Float32]]:
2523
return self.s * self.v
2624

ruff.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ ignore = [
4646

4747
[lint.extend-per-file-ignores]
4848
# Ignore missing docstrings and type annotations for selected directories
49-
"./*.py" = ["D101", "D102", "D103", "ANN"]
50-
"tests/*" = ["D101", "D102", "D103", "ANN"]
51-
"examples/**/*" = ["D101", "D102", "D103", "ANN"]
52-
"tesseract_core/sdk/templates/*" = ["D101", "D102", "D103", "ANN"]
49+
"./*.py" = ["D101", "D102", "D103", "D106", "ANN"]
50+
"tests/*" = ["D101", "D102", "D103", "D106", "ANN"]
51+
"examples/**/*" = ["D101", "D102", "D103", "D106", "ANN"]
52+
"tesseract_core/sdk/templates/*" = ["D101", "D102", "D103", "D106", "ANN"]
5353

5454
[lint.pydocstyle]
5555
convention = "google"

tesseract_core/runtime/cli.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
read_from_path,
2727
write_to_path,
2828
)
29+
from tesseract_core.runtime.finite_differences import (
30+
check_gradients as check_gradients_,
31+
)
2932
from tesseract_core.runtime.serve import create_rest_api
3033
from tesseract_core.runtime.serve import serve as serve_
3134

@@ -152,6 +155,154 @@ def check() -> None:
152155
typer.echo("✅ Tesseract API check successful ✅")
153156

154157

158+
@tesseract_runtime.command()
159+
@click.argument(
160+
"payload",
161+
type=click.STRING,
162+
required=True,
163+
metavar="JSON_PAYLOAD",
164+
callback=_parse_arg_callback,
165+
)
166+
@click.option(
167+
"--endpoints",
168+
type=click.STRING,
169+
required=False,
170+
multiple=True,
171+
help="Endpoints to check gradients for (default: check all).",
172+
)
173+
@click.option(
174+
"--input-paths",
175+
type=click.STRING,
176+
required=False,
177+
multiple=True,
178+
help="Paths to differentiable inputs to check gradients for (default: check all).",
179+
)
180+
@click.option(
181+
"--output-paths",
182+
type=click.STRING,
183+
required=False,
184+
multiple=True,
185+
help="Paths to differentiable outputs to check gradients for (default: check all).",
186+
)
187+
@click.option(
188+
"--eps",
189+
type=click.FLOAT,
190+
required=False,
191+
help="Step size for finite differences.",
192+
default=1e-4,
193+
show_default=True,
194+
)
195+
@click.option(
196+
"--rtol",
197+
type=click.FLOAT,
198+
required=False,
199+
help="Relative tolerance when comparing finite differences to gradients.",
200+
default=0.1,
201+
show_default=True,
202+
)
203+
@click.option(
204+
"--max-evals",
205+
type=click.INT,
206+
required=False,
207+
help="Maximum number of evaluations per input.",
208+
default=1000,
209+
show_default=True,
210+
)
211+
@click.option(
212+
"--max-failures",
213+
type=click.INT,
214+
required=False,
215+
help="Maximum number of failures to report per endpoint.",
216+
default=10,
217+
show_default=True,
218+
)
219+
@click.option(
220+
"--seed",
221+
type=click.INT,
222+
required=False,
223+
help="Seed for random number generator. If not set, a random seed is used.",
224+
default=None,
225+
)
226+
@click.option(
227+
"--show-progress",
228+
is_flag=True,
229+
default=True,
230+
help="Show progress bar.",
231+
)
232+
def check_gradients(
233+
payload,
234+
input_paths,
235+
output_paths,
236+
endpoints,
237+
eps,
238+
rtol,
239+
max_evals,
240+
max_failures,
241+
seed,
242+
show_progress,
243+
) -> None:
244+
"""Check gradients of endpoints against a finite difference approximation.
245+
246+
This is an automated way to check the correctness of the gradients of the different AD endpoints
247+
(jacobian, jacobian_vector_product, vector_jacobian_product) of a ``tesseract_api.py`` module.
248+
It will sample random indices and compare the gradients computed by the AD endpoints with the
249+
finite difference approximation.
250+
251+
Warning:
252+
Finite differences are not exact and the comparison is done with a tolerance. This means
253+
that the check may fail even if the gradients are correct, and vice versa.
254+
255+
Finite difference approximations are sensitive to numerical precision. When finite differences
256+
are reported incorrectly as 0.0, it is likely that the chosen `eps` is too small, especially for
257+
inputs that do not use float64 precision.
258+
"""
259+
api_module = get_tesseract_api()
260+
inputs, base_dir = payload
261+
262+
result_iter = check_gradients_(
263+
api_module,
264+
inputs,
265+
base_dir=base_dir,
266+
input_paths=input_paths,
267+
output_paths=output_paths,
268+
endpoints=endpoints,
269+
max_evals=max_evals,
270+
eps=eps,
271+
rtol=rtol,
272+
seed=seed,
273+
show_progress=show_progress,
274+
)
275+
276+
failed = False
277+
for endpoint, failures, num_evals in result_iter:
278+
if not failures:
279+
typer.echo(
280+
f"✅ Gradient check for {endpoint} passed ✅ ({len(failures)} failures / {num_evals} checks)"
281+
)
282+
else:
283+
failed = True
284+
typer.echo()
285+
typer.echo(
286+
f"⚠️ Gradient check for {endpoint} failed ⚠️ ({len(failures)} failures / {num_evals} checks)"
287+
)
288+
printed_failures = min(len(failures), max_failures)
289+
typer.echo(f"First {printed_failures} failures:")
290+
for failure in failures[:printed_failures]:
291+
typer.echo(
292+
f" Input path: '{failure.in_path}', Output path: '{failure.out_path}', Index: {failure.idx}"
293+
)
294+
if failure.exception:
295+
typer.echo(f" Encountered exception: {failure.exception}")
296+
else:
297+
typer.echo(f" {endpoint} value: {failure.grad_val}")
298+
typer.echo(f" Finite difference value: {failure.ref_val}")
299+
typer.echo()
300+
301+
if failed:
302+
typer.echo("❌ Some gradient checks failed ❌")
303+
sys.exit(1)
304+
305+
155306
@tesseract_runtime.command()
156307
@click.option("-p", "--port", default=8000, help="Port number")
157308
@click.option("-h", "--host", default="0.0.0.0", help="Host IP address")

0 commit comments

Comments
 (0)