Skip to content

Commit 8073dff

Browse files
jpbrodrick89MatteoSalvadordionhaefner
authored
doc: Fix RBF fitting example (#25)
#### Relevant issue or PR - #3 #### Description of changes - [x] Python API (i.e. `tesseract_core.sdk.tesseract`) now used to launch and interact with Tesseract - [x] Jacobian argument order corrected - [x] Only `'weights'` is provided in `diff_inputs` now (`length_scale` not used) - [x] Offered options for using either jacobian or vjp - [x] Pydantic model validators now return `self` instead of `None` - [x] Removed port argument - [x] Documentation updated - [x] Tidied up imports and namespaces for jax recipe and `vectoradd_jax` unit tesseract #### Testing done - [ ] CI passes - [x] Tesseract builds - [x] `optimization_routine.py` runs successfully with graphs #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Jonathan Brodrick <jonathan.brodrick@simulation.science> --------- Co-authored-by: Matteo Salvador <44835004+MatteoSalvador@users.noreply.github.com> Co-authored-by: Dion Häfner <dion.haefner@simulation.science>
1 parent ae9828b commit 8073dff

File tree

4 files changed

+139
-84
lines changed

4 files changed

+139
-84
lines changed

examples/sciml/rbf_fitting/optimization_routine.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,13 @@
33

44
"""This is an optimisation scripts for the RBF fitting example."""
55

6-
import argparse
7-
86
import jax.numpy as jnp
7+
import matplotlib.pyplot as plt
98
import optax
109
from jax import random
11-
from matplotlib import pyplot as plt
1210
from plotting import plot_loss, plot_rbf
13-
from tesseract_client import Client
14-
15-
parser = argparse.ArgumentParser(description="RBF tesseract optimisation script")
16-
parser.add_argument(
17-
"-p", "--port", help="Port at which RBF tesseract is being served", required=True
18-
)
19-
args = vars(parser.parse_args())
2011

21-
if "port" not in args:
22-
raise ValueError("Port at which RBF tesseract is being served is required.")
23-
24-
port = int(args["port"])
12+
from tesseract_core.sdk.tesseract import Tesseract
2513

2614
# JAX random keys
2715
key = random.PRNGKey(42)
@@ -54,11 +42,8 @@ def ground_truth(x):
5442
"y_target": y_target.tolist(),
5543
}
5644

57-
jac_inputs = ["weights", "length_scale"]
58-
jac_outputs = ["mse"]
59-
60-
# Initialize Tesseract client
61-
client = Client(host="127.0.0.1", port=port)
45+
diff_inputs = ["weights"]
46+
diff_outputs = ["mse"]
6247

6348
# Initialize optimizer
6449
optimizer = optax.adam(learning_rate=0.5)
@@ -68,34 +53,35 @@ def ground_truth(x):
6853
loss_history = []
6954
max_iterations = 100
7055
print(f"Starting the optimization process for {max_iterations} iterations.")
71-
for n_iteration in range(max_iterations):
72-
if n_iteration % 10 == 0:
73-
print(f" ---- iteration {n_iteration} / {max_iterations}")
74-
75-
# Compute loss
76-
apply_response = client.request("apply", method="POST", payload={"inputs": inputs})
77-
loss = apply_response["mse"]["data"]["buffer"]
78-
loss_history.append(loss)
79-
80-
# Compute gradients
81-
jacobian_response = client.request(
82-
"jacobian",
83-
method="POST",
84-
payload={
85-
"inputs": inputs,
86-
"jac_inputs": jac_inputs,
87-
"jac_outputs": jac_outputs,
88-
},
89-
)
9056

91-
buffer = jacobian_response["weights"]["mse"]["data"]["buffer"]
92-
grad_weights = jnp.array(buffer, dtype=jnp.float32)
93-
94-
# Update weights
95-
weights = jnp.array(inputs["weights"])
96-
updates, opt_state = optimizer.update(grad_weights, opt_state, weights)
97-
weights = optax.apply_updates(weights, updates)
98-
inputs["weights"] = weights.tolist()
57+
# Initialize Tesseract client
58+
with Tesseract.from_image(image="rbf_fitting") as tess:
59+
for n_iteration in range(max_iterations):
60+
if n_iteration % 10 == 0:
61+
print(f" ---- iteration {n_iteration} / {max_iterations}")
62+
63+
# Compute loss
64+
apply_response = tess.apply(inputs)
65+
loss = apply_response["mse"]
66+
loss_history.append(loss)
67+
68+
# Compute gradients
69+
# Option 1: Use Jacobian
70+
# grad_weights = tess.jacobian(
71+
# inputs,
72+
# diff_inputs,
73+
# diff_outputs,
74+
# )["mse"]["weights"]
75+
# Option 2: Use VJP
76+
grad_weights = tess.vector_jacobian_product(
77+
inputs, diff_inputs, diff_outputs, {"mse": 1.0}
78+
)["weights"]
79+
80+
# Update weights
81+
weights = jnp.array(inputs["weights"])
82+
updates, opt_state = optimizer.update(grad_weights, opt_state, weights)
83+
weights = optax.apply_updates(weights, updates)
84+
inputs["weights"] = weights.tolist()
9985

10086
print("Optimisation completed!")
10187

examples/sciml/rbf_fitting/tesseract_api.py

Lines changed: 103 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
# Copyright 2025 Pasteur Labs. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import jax
4+
from functools import partial
5+
from typing import Any
6+
57
import jax.numpy as jnp
8+
import jax.tree
9+
from jax import ShapeDtypeStruct, eval_shape, jacrev, jit, jvp, vjp
610
from pydantic import BaseModel, Field, model_validator
11+
from typing_extensions import Self
712

813
from tesseract_core.runtime import Array, Differentiable, Float32
14+
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths
915

1016

1117
def gaussian_rbf(x: float, c: float, length_scale: float) -> float:
@@ -25,7 +31,13 @@ def mse_error(
2531
for coeff, center in zip(weights, x_centers, strict=False):
2632
y_hat += coeff * gaussian_rbf(x_target, center, length_scale)
2733

28-
return jnp.mean((y_target - y_hat) ** 2)
34+
return {"mse": jnp.mean((y_target - y_hat) ** 2)}
35+
36+
37+
@jit
38+
def apply_jit(inputs: dict) -> dict:
39+
ordered_keys = ["x_centers", "weights", "length_scale", "x_target", "y_target"]
40+
return mse_error(*(inputs[key] for key in ordered_keys))
2941

3042

3143
#
@@ -49,20 +61,22 @@ class InputSchema(BaseModel):
4961
)
5062

5163
@model_validator(mode="after")
52-
def validate_shape_targets(self) -> None:
64+
def validate_shape_targets(self) -> Self:
5365
if self.x_target.shape != self.y_target.shape:
5466
raise ValueError(
5567
f"x_target and y_target must have the same shape. "
5668
f"Got {self.x_target.shape} and {self.y_target.shape} instead."
5769
)
70+
return self
5871

5972
@model_validator(mode="after")
60-
def validate_shape_weights(self) -> None:
73+
def validate_shape_weights(self) -> Self:
6174
if self.x_centers.shape != self.weights.shape:
6275
raise ValueError(
6376
f"x_centers and weights must have the same shape. "
6477
f"Got {self.x_centers.shape} and {self.weights.shape} instead."
6578
)
79+
return self
6680

6781

6882
class OutputSchema(BaseModel):
@@ -77,18 +91,11 @@ class OutputSchema(BaseModel):
7791

7892

7993
def apply(inputs: InputSchema) -> OutputSchema:
80-
mse = mse_error(
81-
inputs.x_centers,
82-
inputs.weights,
83-
inputs.length_scale,
84-
inputs.x_target,
85-
inputs.y_target,
86-
)
87-
return OutputSchema(mse=mse)
94+
return apply_jit(inputs.model_dump())
8895

8996

9097
#
91-
# Optional endpoints
98+
# Jax-handled AD endpoints (no need to modify)
9299
#
93100

94101

@@ -97,20 +104,86 @@ def jacobian(
97104
jac_inputs: set[str],
98105
jac_outputs: set[str],
99106
):
100-
mse_signature = ["x_centers", "weights", "length_scale", "x_target", "y_target"]
101-
102-
jac_result = {}
103-
for dx in jac_inputs:
104-
grad_func = jax.jacrev(mse_error, argnums=mse_signature.index(dx))
105-
for dy in jac_outputs:
106-
jac_result[dx] = {
107-
dy: grad_func(
108-
inputs.x_centers,
109-
inputs.weights,
110-
inputs.length_scale,
111-
inputs.x_target,
112-
inputs.y_target,
113-
)
114-
}
115-
116-
return jac_result
107+
return jac_jit(inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs))
108+
109+
110+
def jacobian_vector_product(
111+
inputs: InputSchema,
112+
jvp_inputs: set[str],
113+
jvp_outputs: set[str],
114+
tangent_vector: dict[str, Any],
115+
):
116+
return jvp_jit(
117+
inputs.model_dump(),
118+
tuple(jvp_inputs),
119+
tuple(jvp_outputs),
120+
tangent_vector,
121+
)
122+
123+
124+
def vector_jacobian_product(
125+
inputs: InputSchema,
126+
vjp_inputs: set[str],
127+
vjp_outputs: set[str],
128+
cotangent_vector: dict[str, Any],
129+
):
130+
return vjp_jit(
131+
inputs.model_dump(),
132+
tuple(vjp_inputs),
133+
tuple(vjp_outputs),
134+
cotangent_vector,
135+
)
136+
137+
138+
def abstract_eval(abstract_inputs):
139+
"""Calculate output shape of apply from the shape of its inputs."""
140+
jaxified_inputs = jax.tree.map(
141+
lambda x: ShapeDtypeStruct(**x),
142+
abstract_inputs.model_dump(),
143+
is_leaf=lambda x: (x.keys() == {"shape", "dtype"}),
144+
)
145+
jax_shapes = eval_shape(apply_jit, jaxified_inputs)
146+
return jax.tree.map(
147+
lambda sd: {"shape": sd.shape, "dtype": str(sd.dtype)}, jax_shapes
148+
)
149+
150+
151+
#
152+
# Helper functions
153+
#
154+
155+
156+
@partial(jit, static_argnames=["jac_inputs", "jac_outputs"])
157+
def jac_jit(
158+
inputs: dict,
159+
jac_inputs: tuple[str],
160+
jac_outputs: tuple[str],
161+
):
162+
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
163+
return jacrev(filtered_apply)(flatten_with_paths(inputs, include_paths=jac_inputs))
164+
165+
166+
@partial(jit, static_argnames=["jvp_inputs", "jvp_outputs"])
167+
def jvp_jit(
168+
inputs: dict, jvp_inputs: tuple[str], jvp_outputs: tuple[str], tangent_vector: dict
169+
):
170+
filtered_apply = filter_func(apply_jit, inputs, jvp_outputs)
171+
return jvp(
172+
filtered_apply,
173+
[flatten_with_paths(inputs, include_paths=jvp_inputs)],
174+
[tangent_vector],
175+
)[1]
176+
177+
178+
@partial(jit, static_argnames=["vjp_inputs", "vjp_outputs"])
179+
def vjp_jit(
180+
inputs: dict,
181+
vjp_inputs: tuple[str],
182+
vjp_outputs: tuple[str],
183+
cotangent_vector: dict,
184+
):
185+
filtered_apply = filter_func(apply_jit, inputs, vjp_outputs)
186+
_, vjp_func = vjp(
187+
filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs)
188+
)
189+
return vjp_func(cotangent_vector)[0]

examples/unit_tesseracts/vectoradd_jax/tesseract_api.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import jax.numpy as jnp
88
import jax.tree
9-
from jax import ShapeDtypeStruct, eval_shape, jit, jvp, vjp
9+
from jax import ShapeDtypeStruct, eval_shape, jacrev, jit, jvp, vjp
1010
from pydantic import BaseModel, Field, model_validator
1111
from typing_extensions import Self
1212

@@ -164,6 +164,4 @@ def jac_jit(
164164
jac_outputs: tuple[str],
165165
):
166166
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
167-
return jax.jacrev(filtered_apply)(
168-
flatten_with_paths(inputs, include_paths=jac_inputs)
169-
)
167+
return jacrev(filtered_apply)(flatten_with_paths(inputs, include_paths=jac_inputs))

tesseract_core/sdk/templates/jax/tesseract_api.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any
99

1010
import jax.tree
11-
from jax import ShapeDtypeStruct, eval_shape, jit, jvp, vjp
11+
from jax import ShapeDtypeStruct, eval_shape, jacrev, jit, jvp, vjp
1212
from pydantic import BaseModel
1313

1414
from tesseract_core.runtime import Differentiable, Float32
@@ -123,9 +123,7 @@ def jac_jit(
123123
jac_outputs: tuple[str],
124124
):
125125
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
126-
return jax.jacrev(filtered_apply)(
127-
flatten_with_paths(inputs, include_paths=jac_inputs)
128-
)
126+
return jacrev(filtered_apply)(flatten_with_paths(inputs, include_paths=jac_inputs))
129127

130128

131129
@partial(jit, static_argnames=["jvp_inputs", "jvp_outputs"])

0 commit comments

Comments
 (0)